In [9]:
import torch as th
import torch.nn as nn
from active_critic.model_src.state_model import StateModel, StateModelArgs
from active_critic.model_src.whole_sequence_model import WholeSequenceModel, WholeSequenceModelArgs
from active_critic.model_src.transformer import TransformerModel, ModelSetup
from active_critic.utils.pytorch_utils import generate_square_subsequent_mask

In [73]:
def build_seq(actor:StateModel, predictor:WholeSequenceModel, seq_len:int, embeddings:th.Tensor, tf_mask:th.Tensor, actions:th.Tensor=None):
    init_embedding = embeddings.detach().clone()
    for i in range(seq_len):
        embeddings = embeddings.detach()
        if actions is None or actions.shape[1] == i:
            actions = actor.forward(embeddings)
        act_emb = th.cat((actions[:,:i+1], embeddings), dim=-1)
        next_embedings = predictor.forward(inputs=act_emb, tf_mask=tf_mask[:i+1, :i+1])
        embeddings = th.cat((init_embedding.clone(), next_embedings), dim=1)
    return embeddings, actions


In [74]:
batch_size = 3
embedding_size = 20
action_dim = 10
layers = 2
seq_len = 1
ac_seq_len = 4

ms = ModelSetup()
ms.d_hid=20
ms.d_model = 20
ms.d_output=embedding_size
ms.device='cpu'
ms.dropout=0
ms.nhead=1
ms.nlayers=2
ms.seq_len = ac_seq_len+1

wsma = WholeSequenceModelArgs()
wsma.model_setup = ms
wsma.optimizer_class = th.optim.Adam
wsma.lr = 1e-3
wsma.name = 'test'

predictor = WholeSequenceModel(args=wsma)

sma = StateModelArgs()
sma.arch = [10,action_dim]
sma.device = 'cpu'
sma.lr = 1e-3
action_model = StateModel(args=sma)

zero_actions = th.zeros([batch_size, ac_seq_len+1, action_dim], dtype=th.float32, requires_grad=True)

init_embeddng = th.ones([batch_size, 1, embedding_size], requires_grad=True, dtype=th.float32)
konstant_embedding = th.ones([batch_size, ac_seq_len+1, embedding_size])
goal_embeddng = th.zeros([batch_size, embedding_size], requires_grad=True, dtype=th.float32)


In [77]:
tf_mask = generate_square_subsequent_mask(ac_seq_len+1)
embeddings, actions = build_seq(actor=action_model, predictor=predictor, seq_len=ac_seq_len, embeddings=init_embeddng, tf_mask=tf_mask, actions=zero_actions)


In [78]:
((embeddings[0,-1] - th.ones_like(embeddings[0,-1]))**2).mean().backward()

In [79]:
params_list = []
params_list += list(action_model.parameters())
params_list+= list(predictor.parameters())

optimizer = th.optim.Adam(params_list, lr=1e-3)

In [81]:
tf_mask = generate_square_subsequent_mask(ac_seq_len+1)
for i in range(100):
    embeddings, actions = build_seq(actor=action_model, predictor=predictor, seq_len=ac_seq_len, embeddings=init_embeddng, tf_mask=tf_mask)
    optimizer.zero_grad()
    emb_loss = ((embeddings[:, -1] - goal_embeddng)**2).mean()
    emb_loss.backward()
    optimizer.step()

    embeddings, actions = build_seq(actor=action_model, predictor=predictor, seq_len=ac_seq_len, embeddings=init_embeddng, tf_mask=tf_mask, actions=zero_actions)
    optimizer.zero_grad()
    konst_loss = ((embeddings - konstant_embedding)**2).mean()
    konst_loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print('______________________')
        print(f'kl: {konst_loss}')
        print(f'el: {emb_loss}')


______________________
kl: 0.0006082796608097851
el: 9.607563697500154e-05
______________________
kl: 0.0005140891298651695
el: 6.922691682120785e-05
______________________
kl: 0.00047120917588472366
el: 4.0036913560470566e-05
______________________
kl: 0.00044673739466816187
el: 6.001058500260115e-05
______________________
kl: 0.00041722905007191
el: 3.6524284951156005e-05
______________________
kl: 0.0003933057887479663
el: 5.6502649385947734e-05
______________________
kl: 0.00037595638423226774
el: 3.781697523663752e-05
______________________
kl: 0.0003434489481151104
el: 7.229750917758793e-05
______________________
kl: 0.0003429613425396383
el: 7.22010518074967e-05
______________________
kl: 0.00032454082975164056
el: 3.0326085834531114e-05
