In [1]:
import torch as th
import torch.nn as nn
from active_critic.utils.pytorch_utils import calcMSE

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MLPNetwork(th.nn.Module):
    def __init__(self, arch:list[int]) -> None:
        super().__init__()
        self.layers = th.nn.Sequential()
        for i in range(len(arch)-2):
            self.layers.append(th.nn.Linear(arch[i], arch[i+1]))
            self.layers.append(th.nn.ReLU())
        self.layers.append(th.nn.Linear(arch[-2], arch[-1]))

    def forward(self, inpt:th.Tensor) -> th.Tensor:

        return self.layers.forward(inpt)


In [3]:
test_nw = MLPNetwork([3, 3, 2])

In [4]:
inpt = th.rand(2, 3)
outpt = th.rand(2, 2)

result = test_nw.forward(inpt)
loss = calcMSE(result, outpt)
loss

tensor(0.5408, grad_fn=<MeanBackward0>)

In [5]:
from turtle import forward


class StateModel(nn.Module):
    def __init__(self, arch, lr) -> None:
        super().__init__()
        self.arch = arch
        self.lr = lr
        self.reset()

    def reset(self):
        self.model = MLPNetwork(arch=self.arch)
        self.optimizer = th.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=0)
                

    def forward(self, inpt):
        return self.model.forward(inpt)

    def optimizer_step(self, inpt, label):
        result = self.model.forward(inpt=inpt)
        loss = calcMSE(result, label)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return {'Loss ':loss.detach()}

In [6]:
class ActiveCritic(th.nn.Module):
    def __init__(self, embedding_model:StateModel, reward_model:StateModel, action_model:StateModel, prediction_model:StateModel) -> None:
        super().__init__()
        self.embedding_model = embedding_model
        self.reward_model = reward_model
        self.action_model = action_model
        self.prediction_model = prediction_model
        self.current_step = 0

    def predict(self, observation:th.Tensor, horizon:int):
        embedding = self.embedding_model.forward(observation)

    def step_predict(self, embedding:th.Tensor, action:th.Tensor):
        state_action = th.cat((action, embedding), dim=-1)
        next_embedding = self.prediction_model.forward(state_action)
        next_reward = self.reward_model.forward(next_embedding)
        return next_embedding, next_reward

    def build_sequence(self, observation:th.Tensor, horizon:int, actions:th.Tensor = None):
        embeddings_seq = []
        action_dependend_embeddings_seq = []
        actions_seq = []
        rewards_seq = []

        embedding = self.embedding_model.forward(observation).detach()
        embedding.requires_grad = True
        reward = self.reward_model.forward(embedding)
        action = self.action_model.forward(embedding).detach()
        action.requires_grad = True
        embeddings_seq.append(embedding)
        actions_seq.append(action)
        rewards_seq.append(reward)

        for i in range(horizon - self.current_step - 1):
            embedding, reward = self.step_predict(embedding, action)
            if actions is None:
                action = self.action_model.forward(embedding).detach()
            else:
                action = actions[i]
            action_dependend_embeddings_seq.append(embedding)
            embedding = embedding.detach()
            embedding.requires_grad = True
            action.requires_grad = True
            embeddings_seq.append(embedding)
            actions_seq.append(action)
            rewards_seq.append(reward)
        

        return embeddings_seq, action_dependend_embeddings_seq, actions_seq, rewards_seq
            
    def optimize_sequence_step(self, sequence:list[th.Tensor], lr:float):
        e_n0, e_n1, actions_n1, rewards_n0, rewards_n1 = sequence
        goal_label = th.ones_like(rewards_n1[-1])
        
        r_last_optimizer = th.optim.Adam([e_n0[-1], actions_n1[-1]], lr=lr)
        reward_loss = calcMSE(goal_label, rewards_n1[-1])
        reward_loss.backward()
        r_last_optimizer.step()

        for i in range(len(sequence[0])-1, -1, -1):
            print(i)
            print(e_n0[i])
            print(e_n1[i])
            print(actions_n1[i])
            r_optimizer = th.optim.Adam([e_n0[i], actions_n1[i]], lr=lr)

            loss_reward = calcMSE(rewards_n0[i], goal_label)
            loss_embedding = calcMSE(e_n1[i-1], e_n0[i])
            loss = loss_embedding + loss_reward
            loss.backward()
            r_optimizer.step()
            
        return sequence

In [7]:
action_dim = 2
obsv_dim = 3
emb_dim = 3
rew_dim = 1
seq_len = 3
batch_size = 2

embedding_model = StateModel(arch=[obsv_dim, 10, emb_dim], lr=1e-2)
reward_model = StateModel(arch=[emb_dim, 10, rew_dim], lr=1e-2)
action_model = StateModel(arch=[emb_dim, 10, action_dim], lr=1e-2)
prediction_model = StateModel(arch=[emb_dim+action_dim, 10, emb_dim], lr=1e-2)
ac = ActiveCritic(embedding_model=embedding_model, reward_model=reward_model, action_model=action_model, prediction_model=prediction_model)

In [8]:
obsv = th.ones([batch_size, 1, obsv_dim])
embedding = embedding_model.forward(obsv)
action = th.ones([batch_size, 1, action_dim])
embedding, reward = ac.step_predict(embedding=embedding, action=action)

In [9]:
sequence = ac.build_sequence(obsv, horizon=seq_len)
#embeddings_seq, action_dependend_embeddings_seq, actions_seq, rewards_seq

In [10]:
sequence[0]

[tensor([[[-0.1659, -0.1835, -0.0809]],
 
         [[-0.1659, -0.1835, -0.0809]]], requires_grad=True),
 tensor([[[ 0.1325, -0.1063, -0.0314]],
 
         [[ 0.1325, -0.1063, -0.0314]]], requires_grad=True),
 tensor([[[ 0.1625, -0.1167,  0.0453]],
 
         [[ 0.1625, -0.1167,  0.0453]]], requires_grad=True)]

In [11]:
sequence[1]

[tensor([[[ 0.1325, -0.1063, -0.0314]],
 
         [[ 0.1325, -0.1063, -0.0314]]], grad_fn=<ViewBackward0>),
 tensor([[[ 0.1625, -0.1167,  0.0453]],
 
         [[ 0.1625, -0.1167,  0.0453]]], grad_fn=<ViewBackward0>)]

In [12]:
rewards_r1 = sequence[-1][1]
actions_2 = sequence[2][-1]

In [13]:
goal_rewards = th.ones_like(rewards_r1)
loss_rewards = calcMSE(goal_rewards, rewards_r1)
loss_rewards.backward(retain_graph=True)

In [20]:
sequence[0][1].grad

In [15]:
actions_0 = sequence[2][0]

In [16]:
embedding_2 = sequence[0][2]

In [17]:
embedding_2

tensor([[[ 0.1625, -0.1167,  0.0453]],

        [[ 0.1625, -0.1167,  0.0453]]], requires_grad=True)