In [76]:
import torch as th
import torch.nn as nn
from active_critic.model_src.state_model import StateModel, StateModelArgs

In [77]:
def build_sequence(rnn_predictor:th.nn.GRU, action_model:StateModel, embedding:th.Tensor, seq_len:int, actions=None):
    embeddings = embedding.unsqueeze(0)
    if actions is None:
        action = action_model.forward(inpt=embedding.transpose(0,1).reshape(embedding.shape[1], -1)).unsqueeze(0)
        actions = action.unsqueeze(0)
    for i in range(seq_len):
        _, embedding = rnn_predictor(actions[i], embedding)
        embeddings = th.cat((embeddings, embedding.unsqueeze(0)), dim=0)
        if len(actions) == i + 1:
            action = action_model.forward(inpt=embedding.transpose(0,1).reshape(embedding.shape[1], -1)).unsqueeze(0)
            actions = th.cat((actions, action.unsqueeze(0)), dim=0)
    return embeddings, actions

In [78]:
batch_size = 3
embedding_size = 20
action_dim = 10
layers = 2
seq_len = 1
ac_seq_len = 4


rnn = nn.GRU(action_dim, embedding_size, layers)

sma = StateModelArgs()
sma.arch = [10,action_dim]
sma.device = 'cpu'
sma.lr = 1e-3
action_model = StateModel(args=sma)
zero_actions = th.zeros([ac_seq_len + 1, 1, batch_size, action_dim], dtype=th.float32)

init_embeddng = th.ones([layers, batch_size, embedding_size], requires_grad=True, dtype=th.float32)
konstant_embedding = th.ones([ac_seq_len+1, layers, batch_size, embedding_size])
goal_embeddng = th.zeros([layers, batch_size, embedding_size], requires_grad=True, dtype=th.float32)

embeddings, actions = build_sequence(rnn_predictor=rnn, action_model=action_model, embedding=init_embeddng, seq_len=ac_seq_len, actions=zero_actions)

In [79]:
class SuperModel(nn.Module):
    def __init__(self, action_model, prediction_model) -> None:
        super().__init__()
        self.actor = action_model
        self.predictor = prediction_model
sm = SuperModel(action_model=action_model, prediction_model=rnn)

In [81]:
optimizer = th.optim.Adam(sm.parameters(), lr=1e-3)
for i in range(100):
    embeddings, actions = build_sequence(rnn_predictor=sm.predictor, action_model=sm.actor, embedding=init_embeddng, seq_len=ac_seq_len)
    optimizer.zero_grad()
    emb_loss = ((embeddings[-1] - goal_embeddng)**2).mean()
    emb_loss.backward()
    optimizer.step()

    embeddings, actions = build_sequence(rnn_predictor=rnn, action_model=action_model, embedding=init_embeddng, seq_len=ac_seq_len, actions=zero_actions)
    optimizer.zero_grad()
    konst_loss = ((embeddings - konstant_embedding)**2).mean()
    konst_loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print('______________________')
        print(f'kl: {konst_loss}')
        print(f'el: {emb_loss}')


______________________
kl: 0.057582270354032516
el: 0.018223732709884644
______________________
kl: 0.023938482627272606
el: 0.009393999353051186
______________________
kl: 0.01228262111544609
el: 0.006575061473995447
______________________
kl: 0.008100227452814579
el: 0.0036845074500888586
______________________
kl: 0.00581239303573966
el: 0.0020863250829279423
______________________
kl: 0.004309098236262798
el: 0.0012225396931171417
______________________
kl: 0.003308013780042529
el: 0.0007025115774013102
______________________
kl: 0.0026115956716239452
el: 0.0003985613875556737
______________________
kl: 0.002111358568072319
el: 0.0002234279818367213
______________________
kl: 0.0017417585477232933
el: 0.00012453195813577622


In [68]:
embeddings, actions = build_sequence(rnn_predictor=sm.predictor, action_model=sm.actor, embedding=init_embeddng, seq_len=ac_seq_len)


In [70]:
embeddings[-1]

tensor([[[0.0031, 0.0030, 0.0022, 0.0030, 0.0017, 0.0039, 0.0044, 0.0009,
          0.0061, 0.0050, 0.0030, 0.0046, 0.0042, 0.0028, 0.0036, 0.0030,
          0.0060, 0.0021, 0.0034, 0.0041],
         [0.0031, 0.0030, 0.0022, 0.0030, 0.0017, 0.0039, 0.0044, 0.0009,
          0.0061, 0.0050, 0.0030, 0.0046, 0.0042, 0.0028, 0.0036, 0.0030,
          0.0060, 0.0021, 0.0034, 0.0041],
         [0.0031, 0.0030, 0.0022, 0.0030, 0.0017, 0.0039, 0.0044, 0.0009,
          0.0061, 0.0050, 0.0030, 0.0046, 0.0042, 0.0028, 0.0036, 0.0030,
          0.0060, 0.0021, 0.0034, 0.0041]],

        [[0.0028, 0.0075, 0.0073, 0.0049, 0.0088, 0.0068, 0.0065, 0.0082,
          0.0075, 0.0073, 0.0051, 0.0032, 0.0037, 0.0052, 0.0057, 0.0067,
          0.0082, 0.0034, 0.0088, 0.0086],
         [0.0028, 0.0075, 0.0073, 0.0049, 0.0088, 0.0068, 0.0065, 0.0082,
          0.0075, 0.0073, 0.0051, 0.0032, 0.0037, 0.0052, 0.0057, 0.0067,
          0.0082, 0.0034, 0.0088, 0.0086],
         [0.0028, 0.0075, 0.0073, 0.0049, 0

In [75]:
actions[1] - actions[2]

tensor([[[-0.1383, -0.5198,  0.4141, -0.0764, -0.4537, -0.3518, -0.6914,
          -0.0708, -0.5698, -0.4783],
         [-0.1383, -0.5198,  0.4141, -0.0764, -0.4537, -0.3518, -0.6914,
          -0.0708, -0.5698, -0.4783],
         [-0.1383, -0.5198,  0.4141, -0.0764, -0.4537, -0.3518, -0.6914,
          -0.0708, -0.5698, -0.4783]]], grad_fn=<SubBackward0>)