In [None]:
import torch as th
import torch.nn as nn
from active_critic.model_src.state_model import *
from active_critic.utils.pytorch_utils import calcMSE

In [None]:
rnn = nn.GRU(10, 20, 2)
input = th.randn(5, 3, 10)
h0 = th.randn(2, 3, 20)
output, hn = rnn(input, h0)

In [None]:
'''input = th.randn(seq_len, batch_size, gru_inp_dim, requires_grad=True)
h0 = th.randn(n_layers, batch_size, hidden_dim)
output, hn = rnn(input, h0)'''

In [None]:
rnn = nn.LSTM(10, 20, 2)
input = th.randn(5, 3, 10)
h0 = th.randn(2, 3, 20)
c0 = th.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

In [None]:
output.shape

In [None]:
def get_pred_input(actions, observations, goals):
    return th.cat((actions, observations, goals), dim=-1)

def get_actor_input(observations, goals):
    return th.cat((observations, goals), dim=-1)

def split_pred_output(pred_output, action_dim, observation_dim, goal_dim):
    return pred_output[:, :, :action_dim], pred_output[:, :, action_dim:action_dim+observation_dim], pred_output[:, :, -goal_dim:]

def append_goal(states, goals):
    return th.cat((states, goals), dim=-1)

def build_seqence(observations:th.Tensor, actions:th.Tensor, goals:th.Tensor, actor:StateModel, predictor:nn.GRU, projector:StateModel, seq_len:int):
    #obsv: [0:step, batch, dim]
    #acts: [0:step-1, batch, dim]
    actor_input = get_actor_input(observations=observations[-1:], goals=goals[-1:])
    action = actor.forward(actor_input)
    actions = th.cat((actions, action), dim=0)

    seq_input = get_pred_input(actions=actions, observations=observations, goals=goals)
    outputs, hn = predictor.forward(seq_input)

    outputs = projector.forward(outputs)
    outputs = append_goal(states=outputs, goals=goals)

    outputs = th.cat((seq_input[:1], outputs), dim=0) #first embedding is not predicted, but action is

    for step in range(seq_len - observations.shape[0] - 1):
        next_state, hn = predictor.forward(outputs[-1:], hn)
        next_state = projector.forward(next_state)
        next_state = append_goal(next_state, goals=goals[-1:])
        outputs = th.cat((outputs, next_state), dim=0)
    return outputs

In [239]:
th.manual_seed(0)
obs_dim = 1
act_dim = 2
batch_size = 3
n_layers = 4
seq_len = 10
goal_dim = obs_dim
gru_inp_dim = obs_dim + act_dim + goal_dim
hidden_dim = 200

device = 'cuda'

observations = th.zeros(seq_len, batch_size, obs_dim, requires_grad=True, device=device)
actions = th.ones(seq_len, batch_size, act_dim, requires_grad=True, device=device)
goals = th.ones(seq_len, batch_size, goal_dim, device=device)
achieved_goals = th.zeros(seq_len, batch_size, goal_dim, device=device)

sml = StateModelArgs()
sml.arch = [10, act_dim]
sml.device = device
sml.lr = 1e-2
actor = StateModel(args=sml)

sml_proj = StateModelArgs()
sml_proj.arch = [10, gru_inp_dim - goal_dim]
sml_proj.device = device
sml_proj.lr = 1e-2
projector = StateModel(args=sml_proj)


predictor = nn.GRU(gru_inp_dim, hidden_dim, n_layers, device=device)
optimizer = th.optim.Adam(predictor.parameters(), lr=1e-2)

In [None]:
outputs = build_seqence(observations=observations, actions=actions, goals=goals, actor=actor, predictor=predictor, projector=projector, seq_len=6)

In [None]:
out_actions, out_observations, out_goals = split_pred_output(outputs, action_dim=act_dim, observation_dim=obs_dim, goal_dim=goal_dim)

In [None]:
loss = calcMSE(outputs[4], th.ones_like(outputs[0]))

In [None]:
loss.backward()

In [240]:
optimizer = th.optim.Adam(predictor.parameters(), lr=1e-2)

goal_observations = th.ones_like(observations[-1:])


def step_goal_prediction(observations:th.Tensor, actions:th.Tensor, goals:th.Tensor, seq_len:int, actor:StateModel, projector:StateModel, predictor:nn.GRU):
    outputs = build_seqence(observations=observations, actions=actions, goals=goals, actor=actor, predictor=predictor, projector=projector, seq_len=seq_len)
    out_actions, out_observations, out_goals = split_pred_output(outputs, action_dim=act_dim, observation_dim=obs_dim, goal_dim=goal_dim)
    
    loss_observations = calcMSE(out_observations[-1:], goal_observations)
    optimizer.zero_grad()
    loss_observations.backward()
    optimizer.step()
    return loss_observations

In [241]:
def step_prediction(observations:th.Tensor, actions:th.Tensor, goals:th.Tensor, seq_len:int, actor:StateModel, projector:StateModel, predictor:nn.GRU, step:int):
    outputs = build_seqence(observations=observations[:step+1], actions=actions[:step], goals=goals[:step+1], actor=actor, predictor=predictor, projector=projector, seq_len=seq_len)
    out_actions, out_observations, out_goals = split_pred_output(outputs, action_dim=act_dim, observation_dim=obs_dim, goal_dim=goal_dim)

    loss_actions = calcMSE(out_actions, actions)
    loss_observations = calcMSE(out_observations, observations)

    loss = loss_actions + loss_observations

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss_actions, loss_observations

In [242]:
step = 2
for i in range(100):
    loss_goal_pred = step_goal_prediction(observations=observations[:1], actions=actions[:0], goals=goals[:1], seq_len=seq_len, actor=actor, projector=projector, predictor=predictor)
    loss_actions, loss_observations = step_prediction(observations=observations, actions=actions, goals=achieved_goals, seq_len=seq_len, actor=actor, projector=projector, predictor=predictor, step=step)
    if i % 10 == 0:
        print('_______________________________')
        print(f'loss_goal_pred: {loss_goal_pred}')
        print(f'loss_actions: {loss_actions}')
        print(f'loss_observations: {loss_observations}')
    

_______________________________
loss_goal_pred: 1.1433781385421753
loss_actions: 1.01786208152771
loss_observations: 0.0013858922757208347
_______________________________
loss_goal_pred: 0.010416515171527863
loss_actions: 0.2416079342365265
loss_observations: 0.03544744849205017
_______________________________
loss_goal_pred: 0.026089003309607506
loss_actions: 0.06860804557800293
loss_observations: 0.0030568591319024563
_______________________________
loss_goal_pred: 4.359491867944598e-05
loss_actions: 0.00461193872615695
loss_observations: 0.0004593618505168706
_______________________________
loss_goal_pred: 9.138700988842174e-06
loss_actions: 0.002148514613509178
loss_observations: 6.573241989826784e-05
_______________________________
loss_goal_pred: 1.8145840385841439e-10
loss_actions: 0.001031401101499796
loss_observations: 6.757899973308668e-05
_______________________________
loss_goal_pred: 1.7180255440507608e-07
loss_actions: 0.0006839908310212195
loss_observations: 4.1074159526

In [249]:
outputs = build_seqence(observations=observations[:step+1], actions=actions[:step], goals=achieved_goals[:step+1], actor=actor, predictor=predictor, projector=projector, seq_len=seq_len)
out_actions_copy, out_observations, out_goals = split_pred_output(outputs, action_dim=act_dim, observation_dim=obs_dim, goal_dim=goal_dim)

In [250]:
out_actions_copy

tensor([[[1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000]],

        [[0.9649, 0.9631],
         [0.9649, 0.9631],
         [0.9649, 0.9631]],

        [[1.0258, 1.0228],
         [1.0258, 1.0228],
         [1.0258, 1.0228]],

        [[1.0056, 1.0041],
         [1.0056, 1.0041],
         [1.0056, 1.0041]],

        [[0.9929, 0.9951],
         [0.9929, 0.9951],
         [0.9929, 0.9951]],

        [[0.9985, 0.9963],
         [0.9985, 0.9963],
         [0.9985, 0.9963]],

        [[1.0040, 1.0098],
         [1.0040, 1.0098],
         [1.0040, 1.0098]],

        [[0.9971, 1.0006],
         [0.9971, 1.0006],
         [0.9971, 1.0006]],

        [[1.0066, 0.9994],
         [1.0066, 0.9994],
         [1.0066, 0.9994]],

        [[0.9993, 1.0016],
         [0.9993, 1.0016],
         [0.9993, 1.0016]]], device='cuda:0', grad_fn=<SliceBackward0>)

In [246]:
calcMSE(out_observations, out_goals)

tensor(2.0903e-05, device='cuda:0', grad_fn=<MeanBackward0>)

In [247]:
outputs = build_seqence(observations=observations[:1], actions=actions[:0], goals=goals[:1], actor=actor, predictor=predictor, projector=projector, seq_len=seq_len)
out_actions, out_observations, out_goals = split_pred_output(outputs, action_dim=act_dim, observation_dim=obs_dim, goal_dim=goal_dim)

In [251]:
out_actions

tensor([[[-0.0735,  0.0123],
         [-0.0735,  0.0123],
         [-0.0735,  0.0123]],

        [[ 0.0575, -0.2760],
         [ 0.0575, -0.2760],
         [ 0.0575, -0.2760]],

        [[ 0.1737, -0.2472],
         [ 0.1737, -0.2472],
         [ 0.1737, -0.2472]],

        [[ 0.2110, -0.1459],
         [ 0.2110, -0.1459],
         [ 0.2110, -0.1459]],

        [[ 0.2674, -0.0408],
         [ 0.2674, -0.0408],
         [ 0.2674, -0.0408]],

        [[ 0.3151,  0.0707],
         [ 0.3151,  0.0707],
         [ 0.3151,  0.0707]],

        [[ 0.3262,  0.1571],
         [ 0.3262,  0.1571],
         [ 0.3262,  0.1571]],

        [[ 0.3051,  0.1798],
         [ 0.3051,  0.1798],
         [ 0.3051,  0.1798]],

        [[ 0.2326,  0.0655],
         [ 0.2326,  0.0655],
         [ 0.2326,  0.0655]],

        [[ 0.1769, -0.0340],
         [ 0.1769, -0.0340],
         [ 0.1769, -0.0340]]], device='cuda:0', grad_fn=<SliceBackward0>)

In [248]:
out_observations[-1]

tensor([[1.0000],
        [1.0000],
        [1.0000]], device='cuda:0', grad_fn=<SelectBackward0>)

In [233]:
out_observations[-1]

tensor([[1.0000],
        [1.0000],
        [1.0000]], device='cuda:0', grad_fn=<SelectBackward0>)

In [225]:
calcMSE(out_observations[-1], out_goals)

tensor(0.0104, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
out_observations[-1]

In [None]:
outputs = build_seqence(observations=observations[:1], actions=actions[:0], goals=achieved_goals[:1], actor=actor, predictor=predictor, projector=projector, seq_len=seq_len)
out_actions, out_observations, out_goals = split_pred_output(outputs, action_dim=act_dim, observation_dim=obs_dim, goal_dim=goal_dim)


In [None]:
out_observations[-1]

In [None]:
out_goals

In [None]:
out_observations