In [1]:
import torch as th
import torch.nn as nn
from active_critic.model_src.state_model import StateModel, StateModelArgs
from active_critic.model_src.whole_sequence_model import WholeSequenceModel, WholeSequenceModelArgs
from active_critic.model_src.transformer import TransformerModel, ModelSetup
from active_critic.utils.pytorch_utils import generate_square_subsequent_mask, build_tf_horizon_mask
from typing import Union

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
a = th.tensor([1,2,3.], requires_grad=True)
a = a.clamp(1.5,2.5)

In [3]:
a

tensor([1.5000, 2.0000, 2.5000], grad_fn=<ClampBackward1>)

In [4]:
def build_seq(
    embeddings:th.Tensor,
    actions:th.Tensor,
    seq_len:int,
    goal_state:th.Tensor, 
    goal_emb_acts:th.Tensor, 
    tf_mask:th.Tensor,
    actor:StateModel, 
    predictor: Union[WholeSequenceModel, StateModel] ):
    init_embedding = embeddings.detach().clone()
    for i in range(seq_len - 1):
        embeddings = embeddings.detach()
        if actions is None or actions.shape[1] == i:
            actions = actor.forward(embeddings)
        act_emb = th.cat((actions[:,:i+1], embeddings), dim=-1)
        next_embedings = predictor.forward(inputs=act_emb, tf_mask=tf_mask[:i+1, :i+1])
        embeddings = th.cat((init_embedding.clone(), next_embedings), dim=1)
    if actions.shape[1] == seq_len - 1:
        actions = th.cat((actions, actor.forward(embeddings[:,-1:])), dim=1)
    return embeddings, actions


In [5]:
batch_size = 3
embedding_size = 20
action_dim = 10
layers = 2
seq_len = 1
ac_seq_len = 4

ms = ModelSetup()
ms.d_hid=20
ms.d_model = 20
ms.d_output=embedding_size
ms.device='cpu'
ms.dropout=0
ms.nhead=1
ms.nlayers=2
ms.seq_len = ac_seq_len+1

wsma = WholeSequenceModelArgs()
wsma.model_setup = ms
wsma.optimizer_class = th.optim.Adam
wsma.lr = 1e-3
wsma.name = 'test'

predictor = WholeSequenceModel(args=wsma)

sma = StateModelArgs()
sma.arch = [10,action_dim]
sma.device = 'cpu'
sma.lr = 1e-3
action_model = StateModel(args=sma)

zero_actions = th.zeros([batch_size, ac_seq_len+1, action_dim], dtype=th.float32, requires_grad=True)

init_embeddng = th.ones([batch_size, 1, embedding_size], requires_grad=True, dtype=th.float32)
konstant_embedding = th.ones([batch_size, ac_seq_len+1, embedding_size])
goal_embeddng = th.zeros([batch_size, embedding_size], requires_grad=True, dtype=th.float32)
goal_emb_act = th.cat((th.zeros_like(konstant_embedding[:,-1]), th.zeros_like(zero_actions[:,-1])), dim=1)
tf_mask = generate_square_subsequent_mask(ac_seq_len+1)


In [6]:
build_tf_horizon_mask(4, 4, 'cpu')

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [98]:
new_embeddings, new_actions = build_seq(actor=action_model, predictor=predictor, seq_len=ac_seq_len, embeddings=init_embeddng, tf_mask=tf_mask, actions=None, goal_state=None, goal_emb_acts=None)

In [100]:
new_embeddings.shape

torch.Size([3, 4, 20])

In [None]:
params_list = []
params_list += list(action_model.parameters())
params_list+= list(predictor.parameters())

optimizer = th.optim.Adam(params_list, lr=1e-3)

In [None]:
tf_mask = generate_square_subsequent_mask(ac_seq_len+1)
for i in range(200):
    embeddings, const_actions = build_seq(actor=action_model, predictor=predictor, seq_len=ac_seq_len, embeddings=init_embeddng, tf_mask=tf_mask)
    optimizer.zero_grad()
    emb_loss = ((th.cat((embeddings[:,-1], const_actions[:,-1]), dim=1) - goal_emb_act)**2).mean()
    emb_loss.backward()
    optimizer.step()

    embeddings, actions = build_seq(actor=action_model, predictor=predictor, seq_len=ac_seq_len, embeddings=init_embeddng, tf_mask=tf_mask, actions=zero_actions)
    optimizer.zero_grad()
    konst_loss = ((embeddings - konstant_embedding)**2).mean()
    konst_loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print('______________________')
        print(f'kl: {konst_loss}')
        print(f'el: {emb_loss}')


In [None]:
embeddings, actions = build_seq(actor=action_model, predictor=predictor, seq_len=ac_seq_len, embeddings=init_embeddng, tf_mask=tf_mask, actions=zero_actions)


In [None]:
zero_actions = th.zeros([batch_size, ac_seq_len+1, action_dim], dtype=th.float32, requires_grad=True)
act_opt = th.optim.Adam([zero_actions], lr=1e-2)

for i in range(1000):
    embeddings, actions = build_seq(actor=action_model, predictor=predictor, seq_len=ac_seq_len, embeddings=init_embeddng, tf_mask=tf_mask, actions=zero_actions)
    act_opt.zero_grad()
    act_loss = ((th.cat((embeddings[:,-1], actions[:,-1]), dim=1) - goal_emb_act)**2).mean()
    act_loss.backward()
    act_opt.step()
    if i % 10 == 0:
        print(act_loss)

In [84]:
zero_actions - const_actions

tensor([[[ 0.5164,  0.6334,  0.1692,  0.4370,  0.2631, -0.8581,  0.0290,
          -0.0294,  0.1332,  0.0104],
         [ 0.4611,  1.1811, -0.9004,  0.3003,  0.1788, -1.0660,  0.1180,
           0.0105, -0.7085, -0.1959],
         [ 0.4386,  1.2712, -0.9857,  0.1112,  0.4578, -1.5414,  0.4657,
          -0.1728, -1.0671, -0.3941],
         [-0.1232,  0.2503,  0.8597,  0.1416,  0.7080, -0.5630, -0.1443,
          -0.4430,  0.4682, -0.4535],
         [ 0.0407, -0.0156,  0.0124, -0.0049,  0.0053, -0.0305, -0.0894,
          -0.0148, -0.0085,  0.0436]],

        [[ 0.5164,  0.6334,  0.1692,  0.4370,  0.2631, -0.8581,  0.0290,
          -0.0294,  0.1332,  0.0104],
         [ 0.4611,  1.1811, -0.9004,  0.3003,  0.1788, -1.0660,  0.1180,
           0.0105, -0.7085, -0.1959],
         [ 0.4386,  1.2712, -0.9857,  0.1112,  0.4578, -1.5414,  0.4657,
          -0.1728, -1.0671, -0.3941],
         [-0.1232,  0.2503,  0.8597,  0.1416,  0.7080, -0.5630, -0.1443,
          -0.4430,  0.4682, -0.4535],