In [None]:
import sys

from active_critic.model_src.transformer import ModelSetup, TransformerModel, generate_square_subsequent_mask
from active_critic.model_src.base_transformer import DebugTEL
import torch as th
from active_critic.utils.pytorch_utils import calcMSE

In [None]:
ms = ModelSetup()
seq_len = 3
ntoken = 3
batch_size = 2
d_output = 4

ms.d_output = d_output
ms.nhead = 2
ms.d_hid = 10
ms.d_model = 10
ms.nlayers = 1
ms.seq_len = seq_len
ms.dropout = 0
ms.ntoken = 1
ms.lr = None
ms.device = 'cpu'
device = 'cpu'


In [None]:
def make_seq_encoding_data(batch_size, seq_len, ntoken, d_out, device = 'cuda'):
    inpt_seq = th.ones([batch_size,seq_len,ntoken], dtype=th.float, device=device)
    outpt_seq = th.ones([batch_size,seq_len,d_out], dtype=th.float, device=device)
    outpt_seq[:,::2] = 0
    return inpt_seq, outpt_seq

In [None]:
tm = TransformerModel(model_setup=ms)
mask = generate_square_subsequent_mask(3)
mask = mask.unsqueeze(0).repeat([2, 1, 1])
mask[0] = 0
mask

In [14]:
def pull_tens_to_front(src, i):
    if i > 0:
        src[i, :, :-i] = src[i,:,i:]
        src[i, :, -i:] = -1

    return src

def pull_tens_to_front_sparse(src, i):
    pulled = src[i,:,i:]
    max_pulled = pulled.max(dim=-2).values == 1
    max_pulled = max_pulled.unsqueeze(-2).repeat([1, 1, src.shape[2] - i, 1])
    if i > 0:
        src[i, :, :-i] = max_pulled
        src[i, :, -i:] = -1
    else:
        src[i, :] = max_pulled

    return src

def repeat_along_seq_td(src):
    src = src.repeat([src.shape[1], 1, 1]).reshape([src.shape[1], src.shape[0], src.shape[1], src.shape[2]])
    return src

def repeat_along_seq(src, seq_len):
    src = src.repeat([1, seq_len, 1])
    return src

def generate_square_subsequent_mask_dense(seq_len, rewards):
    obsv = th.rand([2,3,4], requires_grad=True)


def make_dense_seq_encoding_data(actions, obsv, rewards):
    actions = repeat_along_seq_td(actions)
    obsv = repeat_along_seq_td(obsv)
    rewards = repeat_along_seq_td(rewards)
    for i in range(len(obsv)):
        obsv[i] = obsv[i,:,i].unsqueeze(1)

        actions = pull_tens_to_front(actions, i)
        rewards = pull_tens_to_front_sparse(rewards, i)
    return actions.reshape([-1, actions.shape[-2], actions.shape[-1]]), obsv.reshape([-1, obsv.shape[-2], obsv.shape[-1]]), rewards.reshape([-1, rewards.shape[-2], rewards.shape[-1]])

def generate_partial_observed_mask(reward, nheads):
    device = reward.device
    inv_result_mask = reward.squeeze() == -1
    result_mask = ~inv_result_mask
    args = th.argwhere(inv_result_mask)
    restructured_args = args.repeat([1, reward.shape[1]]).reshape([-1, 2])
    exp_ind = th.arange(reward.shape[1], device=device).repeat(args.shape[0])
    full_ind = th.cat((restructured_args[:, :1], exp_ind.unsqueeze(1), restructured_args[:, 1:]), dim=-1)
    attention_mask = th.zeros([reward.shape[0], reward.shape[1], reward.shape[1]], device=device)
    attention_mask[tuple(full_ind.T)] = -float('inf')
    attention_mask = attention_mask.repeat([1,1,nheads]).reshape([2*attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2]])
    return attention_mask, result_mask

obsv = th.rand([2,3,4], requires_grad=True)
act = th.rand([2,3,2])
rewards = th.rand([2,3,1])

rewards[0,0,0] = 1
a, o, r = make_dense_seq_encoding_data(actions=act, obsv=obsv, rewards=rewards)
attention_mask, result_mask = generate_partial_observed_mask(r, ms.nhead)
o_req = o.detach()
o_req.requires_grad = True
print(f'o_req: {o_req.shape}')
print(f'attention_mask: {attention_mask.shape}')
result = tm.forward(src=o_req, mask=attention_mask)
result[result_mask].mean().backward()
o_req.grad

o_req: torch.Size([6, 3, 4])
attention_mask: torch.Size([12, 3, 3])
src: torch.Size([6, 3, 10])
mask: torch.Size([12, 3, 3])


tensor([[[ 0.0289, -0.0063,  0.0144, -0.0289],
         [ 0.0194,  0.0019,  0.0216, -0.0212],
         [ 0.0169,  0.0015,  0.0225, -0.0183]],

        [[ 0.0124, -0.0052,  0.0108, -0.0006],
         [ 0.0115, -0.0010,  0.0121, -0.0034],
         [ 0.0112, -0.0003,  0.0115, -0.0053]],

        [[ 0.0133,  0.0013,  0.0081, -0.0046],
         [ 0.0131,  0.0015,  0.0087, -0.0050],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0174,  0.0020,  0.0126,  0.0090],
         [ 0.0190,  0.0018,  0.0105,  0.0027],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0123, -0.0005,  0.0078, -0.0144],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0080, -0.0123,  0.0179,  0.0036],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]])

In [None]:
obsv = th.rand([2,1,4], requires_grad=True)
rep_obsv = repeat_along_seq(src=obsv, seq_len=3)

In [None]:
rewards[0,0,0] = 1
a, o, r = make_dense_seq_encoding_data(actions=act, obsv=obsv, rewards=rewards)

In [None]:
attention_mask, result_mask = generate_partial_observed_mask(r)

In [None]:
o_req = o.detach()
o_req.requires_grad = True

In [None]:
result = tm.forward(src=o_req, mask=attention_mask)
result[result_mask].mean().backward()

In [None]:
o_req.grad

In [None]:
args = th.argwhere(r.squeeze() == -1)

In [None]:
args

In [None]:
restructured_args = args.repeat([1, 3]).reshape([18, 2])

In [None]:
exp_ind = th.arange(3).repeat(6)

In [None]:
full_ind = th.cat((restructured_args[:, :1], exp_ind.unsqueeze(1), restructured_args[:, 1:]), dim=-1)

In [None]:
full_ind.shape

In [None]:
mask = th.zeros([r.shape[0], r.shape[1], r.shape[1]])
mask[tuple(full_ind.T)] = -float('inf')

In [None]:
mask

In [None]:
exp_args = th.cat((args[:, :1], exp_ind))

In [None]:
mask.shape

In [None]:
mask

In [None]:
r

In [None]:
obsv

In [None]:
obsv

In [None]:
obsv

In [None]:
def make_partial_sequence(seq):
    

In [None]:
result = tm.forward(src=obsv, mask=mask)
loss = ((result[:, -2])**2).mean()
loss.backward()


In [None]:
obsv.grad

In [None]:
result

In [None]:
def make_mask_data(batch_size, seq_len, ntoken, device = 'cuda'):
    mask = generate_square_subsequent_mask(seq_len).to(device)
    inpt_seq = th.ones([batch_size,seq_len,ntoken], dtype=th.float, device=device)
    inpt_seq[0,-1,0] = 0
    outpt_seq = th.ones_like(inpt_seq)
    outpt_seq[0] = 0
    return inpt_seq, outpt_seq, mask

In [None]:
inpt_seq, outpt_seq = make_seq_encoding_data(batch_size, seq_len, ntoken, d_out = d_output, device=device)

In [None]:
def make_no_conflict_part_data(batch_size, seq_len, obs_dim, d_out, device='cpu'):
    inpt = th.ones([batch_size, seq_len, obs_dim])
    inpt[0, 1:] = 2
    actions = th.arange(seq_len)
    actions = actions.reshape([1,-1,1]).repeat([batch_size, 1, d_out])
    actions[0] += 1

    res, actions = make_part_observed(inpt, actions)
    return res, actions
    

In [None]:
def make_part_observed(inpt, actions):
    #batch, seq, dim
    inpt = inpt.unsqueeze(1).repeat([1,inpt.shape[1], 1, 1])
    inpt = inpt.permute([0,3,1,2])
    res = th.triu(inpt)
    res = res.permute([0,2,3,1]).reshape([-1,seq_len, obs_dim])
    
    rep_actions = actions.repeat([1,seq_len,1]).reshape([-1,seq_len, d_output])
    return res, rep_actions

In [None]:
def make_conflict_part_data(batch_size, seq_len, obs_dim, d_out, device='cpu'):
    inpt = th.ones([batch_size, seq_len, obs_dim])
    inpt[0, :1] = 2
    actions = th.arange(seq_len)
    actions = actions.reshape([1,-1,1]).repeat([batch_size, 1, d_out])
    actions[0, :1] += 1

    res, actions = make_part_observed(inpt, actions)
    rew = th.ones([batch_size * seq_len, seq_len], device = device)
    
    return res, actions, rew

In [None]:
def make_conflict_part_data_neg(batch_size, seq_len, obs_dim, d_out, device='cpu'):
    inpt = th.ones([batch_size, seq_len, obs_dim])
    inpt[0, -1:] = 2
    actions = th.arange(seq_len)
    actions = actions.reshape([1,-1,1]).repeat([batch_size, 1, d_out])
    actions[0, -1:] += 1

    res, actions = make_part_observed(inpt, actions)
    rew = th.ones([batch_size * seq_len, seq_len], device = device)
    scale = th.arange(seq_len).reshape([1,-1]).repeat([rew.shape[0], 1])
    rew = rew * scale / (seq_len - 1)
    rew[:,-1:] = 0
    
    res_copy = res.clone()
    res[:seq_len] = res_copy[seq_len:]
    res[seq_len:] = res_copy[:seq_len]
    
    return res, actions, rew

In [None]:
def make_sequence(obs, acts):
    result = th.cat((obs, acts), dim=-1)
    return result

In [None]:
obs_dim = 3


In [None]:
obs_nc, act_nc = make_no_conflict_part_data(batch_size, seq_len, obs_dim, d_output, device='cpu')
obs_c, act_c, rew_c_pos = make_conflict_part_data(batch_size, seq_len, obs_dim, d_output, device='cpu')
obs_c_n, act_c_n, rew_c_neg = make_conflict_part_data_neg(batch_size, seq_len, obs_dim, d_output, device='cpu')

In [None]:
rew_c_neg

In [None]:
seq_n = make_sequence(obs_c_n,act_c_n)
seq_p = make_sequence(obs_c,act_c)

In [None]:
seq = th.cat((seq_n, seq_p), dim=0)
rew = th.cat((rew_c_pos, rew_c_neg), dim=0)

In [None]:
ms = ModelSetup()
seq_len = 3
ntoken = 3 + 4
batch_size = 2
d_output = 1

ms.d_output = d_output
ms.nhead = 1
ms.d_hid = 10
ms.d_model = 10
ms.nlayers = 1
ms.seq_len = seq_len
ms.dropout = 0
ms.ntoken = 1
ms.lr = None
ms.device = 'cpu'
ms.optimizer_class = th.optim.AdamW
ms.optimizer_kwargs = {}
ms.model_class:TransformerModel = TransformerModel
device = 'cpu'


In [None]:
inpt = seq

model = TransformerModel(model_setup=ms).to(device)
with th.no_grad():
    answer = model.forward(inpt)
optimizer = th.optim.Adam(params=model.parameters(), lr=1e-3)
loss = 0
for i in range(3000):
    result = model.forward(inpt)
    loss = calcMSE(result, rew)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
loss

In [None]:
inpt[0]

In [None]:
answer = model.forward(inpt[5].unsqueeze(0))

In [None]:
answer

In [None]:
rew[5]

In [None]:
inpt[9]

In [None]:
result, attention = model.forward(res_c[0].unsqueeze(0), return_attention=True)

In [None]:
result

In [None]:
attention

In [None]:
result, attention = model.forward(res[4].unsqueeze(0), return_attention=True)

In [None]:
attention

In [None]:
res[4]

In [None]:
mask = generate_square_subsequent_mask(seq_len).to('cuda')

In [None]:
inpt_seq = th.ones([2,seq_len,1], dtype=th.float, device='cuda')
inpt_seq[0,-1,0] = 0

outpt_seq = th.ones_like(inpt_seq)
outpt_seq[0] = 0

In [None]:
inpt_seq, outpt_seq, mask = make_mask_data(batch_size=batch_size, seq_len=seq_len, ntoken=ntoken)


In [None]:
model = TransformerModel(model_setup=ms).to('cuda')
with th.no_grad():
    model.forward(inpt_seq)
optimizer = th.optim.Adam(params=model.parameters(), lr=1e-3)
loss = 0
for i in range(1000):
    result = model.forward(inpt_seq, mask=None)
    loss = calcMSE(result, outpt_seq)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
loss

In [None]:
from ActiveCritic.model_src.transformer import CriticTransformer

In [None]:
ms = ModelSetup()
seq_len = 6
d_output = 1
ms.d_output = d_output
ms.nhead = 1
ms.d_hid = 10
ms.d_model = 10
ms.nlayers = 2
ms.seq_len = seq_len
ms.dropout = 0
ms.ntoken = 1
ms.lr = None
ms.device = 'cuda'
ms.optimizer_class = th.optim.AdamW
ms.optimizer_kwargs = {}
ms.d_result = 1
ms.model_class:TransformerModel = CriticTransformer

In [None]:
inpt_seq = th.ones([2,seq_len,4], dtype=th.float, device='cuda')
inpt_seq[0,-1,0] = 0

outpt_seq = th.ones([2,1], dtype=th.float, device='cuda')
outpt_seq[0] = 0

In [None]:
model = CriticTransformer(model_setup=ms).to('cuda')
with th.no_grad():
    a = model.forward(inpt_seq)

In [None]:
a.shape

In [None]:
model = CriticTransformer(model_setup=ms).to('cuda')
with th.no_grad():
    model.forward(inpt_seq)
optimizer = th.optim.Adam(params=model.parameters(), lr=1e-3)
loss = 0
for i in range(2000):
    result = model.forward(inpt_seq)
    loss = calcMSE(result, outpt_seq)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
loss