In [1]:
import sys
sys.path.append('C:/Users/Hendrik/PycharmProjects') # go to parent dir

from ActiveCritic.model_src.transformer import ModelSetup, TransformerModel, CriticTransformer
from ActiveCritic.model_src.base_transformer import DebugTEL
import torch as th
from ActiveCritic.utils.pytorch_utils import calcMSE

In [2]:
ms = ModelSetup()
seq_len = 3
ntoken = 3
batch_size = 2
d_output = 4

ms.d_output = d_output
ms.nhead = 1
ms.d_hid = 10
ms.d_model = 10
ms.nlayers = 1
ms.seq_len = seq_len
ms.dropout = 0
ms.ntoken = 1
ms.lr = None
ms.device = 'cpu'
ms.optimizer_class = th.optim.AdamW
ms.optimizer_kwargs = {}
ms.model_class:TransformerModel = TransformerModel
device = 'cpu'


In [3]:
def make_seq_encoding_data(batch_size, seq_len, ntoken, d_out, device = 'cuda'):
    inpt_seq = th.ones([batch_size,seq_len,ntoken], dtype=th.float, device=device)
    outpt_seq = th.ones([batch_size,seq_len,d_out], dtype=th.float, device=device)
    outpt_seq[:,::2] = 0
    return inpt_seq, outpt_seq

In [4]:
def make_mask_data(batch_size, seq_len, ntoken, device = 'cuda'):
    mask = generate_square_subsequent_mask(seq_len).to(device)
    inpt_seq = th.ones([batch_size,seq_len,ntoken], dtype=th.float, device=device)
    inpt_seq[0,-1,0] = 0
    outpt_seq = th.ones_like(inpt_seq)
    outpt_seq[0] = 0
    return inpt_seq, outpt_seq, mask

In [5]:
inpt_seq, outpt_seq = make_seq_encoding_data(batch_size, seq_len, ntoken, d_out = d_output, device=device)


In [6]:
def make_no_conflict_part_data(batch_size, seq_len, obs_dim, d_out, device='cpu'):
    inpt = th.ones([batch_size, seq_len, obs_dim])
    inpt[0, 1:] = 2
    actions = th.arange(seq_len)
    actions = actions.reshape([1,-1,1]).repeat([batch_size, 1, d_out])
    actions[0] += 1

    res, actions = make_part_observed(inpt, actions)
    return res, actions
    

In [7]:
def make_part_observed(inpt, actions):
    #batch, seq, dim
    inpt = inpt.unsqueeze(1).repeat([1,inpt.shape[1], 1, 1])
    inpt = inpt.permute([0,3,1,2])
    res = th.triu(inpt)
    res = res.permute([0,2,3,1]).reshape([-1,seq_len, obs_dim])
    
    rep_actions = actions.repeat([1,seq_len,1]).reshape([-1,seq_len, d_output])
    return res, rep_actions

In [41]:
def make_conflict_part_data(batch_size, seq_len, obs_dim, d_out, device='cpu'):
    inpt = th.ones([batch_size, seq_len, obs_dim])
    inpt[0, :1] = 2
    actions = th.arange(seq_len)
    actions = actions.reshape([1,-1,1]).repeat([batch_size, 1, d_out])
    actions[0, :1] += 1

    res, actions = make_part_observed(inpt, actions)
    rew = th.ones([batch_size * seq_len, seq_len], device = device)
    
    return res, actions, rew

In [165]:
def make_conflict_part_data_neg(batch_size, seq_len, obs_dim, d_out, device='cpu'):
    inpt = th.ones([batch_size, seq_len, obs_dim])
    inpt[0, -1:] = 2
    actions = th.arange(seq_len)
    actions = actions.reshape([1,-1,1]).repeat([batch_size, 1, d_out])
    actions[0, -1:] += 1

    res, actions = make_part_observed(inpt, actions)
    rew = th.ones([batch_size * seq_len, seq_len], device = device)
    scale = th.arange(seq_len).reshape([1,-1]).repeat([rew.shape[0], 1])
    rew = rew * scale / (seq_len - 1)
    rew[:,-1:] = 0
    
    res_copy = res.clone()
    res[:seq_len] = res_copy[seq_len:]
    res[seq_len:] = res_copy[:seq_len]
    
    return res, actions, rew

In [166]:
def make_sequence(obs, acts):
    result = th.cat((obs, acts), dim=-1)
    return result

In [167]:
obs_dim = 3


In [168]:
obs_nc, act_nc = make_no_conflict_part_data(batch_size, seq_len, obs_dim, d_output, device='cpu')
obs_c, act_c, rew_c_pos = make_conflict_part_data(batch_size, seq_len, obs_dim, d_output, device='cpu')
obs_c_n, act_c_n, rew_c_neg = make_conflict_part_data_neg(batch_size, seq_len, obs_dim, d_output, device='cpu')

In [169]:
rew_c_neg

tensor([[0.0000, 0.5000, 0.0000],
        [0.0000, 0.5000, 0.0000],
        [0.0000, 0.5000, 0.0000],
        [0.0000, 0.5000, 0.0000],
        [0.0000, 0.5000, 0.0000],
        [0.0000, 0.5000, 0.0000]])

In [164]:
seq_n = make_sequence(obs_c_n,act_c_n)
seq_p = make_sequence(obs_c,act_c)

In [83]:
seq = th.cat((seq_n, seq_p), dim=0)
rew = th.cat((rew_c_pos, rew_c_neg), dim=0)

In [86]:
ms = ModelSetup()
seq_len = 3
ntoken = 3 + 4
batch_size = 2
d_output = 1

ms.d_output = d_output
ms.nhead = 1
ms.d_hid = 10
ms.d_model = 10
ms.nlayers = 1
ms.seq_len = seq_len
ms.dropout = 0
ms.ntoken = 1
ms.lr = None
ms.device = 'cpu'
ms.optimizer_class = th.optim.AdamW
ms.optimizer_kwargs = {}
ms.model_class:TransformerModel = TransformerModel
device = 'cpu'


In [88]:
inpt = seq

model = TransformerModel(model_setup=ms).to(device)
with th.no_grad():
    answer = model.forward(inpt)
optimizer = th.optim.Adam(params=model.parameters(), lr=1e-3)
loss = 0
for i in range(3000):
    result = model.forward(inpt)
    loss = calcMSE(result, rew)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [89]:
loss

tensor(0.0556, grad_fn=<MeanBackward0>)

In [94]:
inpt[0]

tensor([[1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 2., 2., 2., 2.]])

In [108]:
answer = model.forward(inpt[5].unsqueeze(0))

In [109]:
answer

tensor([[[0.4987],
         [1.0001],
         [0.9994]]], grad_fn=<AddBackward0>)

In [107]:
rew[5]

tensor([1., 1., 1.])

In [111]:
inpt[9]

tensor([[1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 2., 2., 2., 2.]])

In [24]:
result, attention = model.forward(res_c[0].unsqueeze(0), return_attention=True)

In [25]:
result

tensor([[[1.0152, 0.9970, 1.0143, 0.9915],
         [1.9825, 1.9594, 1.9742, 1.9546],
         [2.9630, 2.9960, 2.9902, 3.0279],
         [3.9759, 3.9846, 3.9950, 3.9762],
         [4.9925, 4.9972, 4.9863, 4.9976]]], grad_fn=<AddBackward0>)

In [26]:
attention

tensor([[[1.8680e-05, 4.4999e-03, 9.9548e-01, 1.9782e-08, 2.7214e-10],
         [9.6357e-06, 3.2713e-03, 9.9672e-01, 6.2094e-09, 6.4015e-11],
         [1.2000e-05, 3.4134e-03, 9.9657e-01, 8.6596e-09, 1.0958e-10],
         [7.7636e-04, 2.9140e-02, 9.7008e-01, 7.3897e-06, 4.3202e-07],
         [1.4408e-03, 3.8925e-02, 9.5961e-01, 2.1836e-05, 1.7041e-06]]])

In [None]:
result, attention = model.forward(res[4].unsqueeze(0), return_attention=True)

In [None]:
attention

In [None]:
res[4]

In [None]:
from ActiveCritic.model_src.transformer import generate_square_subsequent_mask

In [None]:
mask = generate_square_subsequent_mask(seq_len).to('cuda')

In [None]:
inpt_seq = th.ones([2,seq_len,1], dtype=th.float, device='cuda')
inpt_seq[0,-1,0] = 0

outpt_seq = th.ones_like(inpt_seq)
outpt_seq[0] = 0

In [None]:
inpt_seq, outpt_seq, mask = make_mask_data(batch_size=batch_size, seq_len=seq_len, ntoken=ntoken)


In [None]:
model = TransformerModel(model_setup=ms).to('cuda')
with th.no_grad():
    model.forward(inpt_seq)
optimizer = th.optim.Adam(params=model.parameters(), lr=1e-3)
loss = 0
for i in range(1000):
    result = model.forward(inpt_seq, mask=None)
    loss = calcMSE(result, outpt_seq)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
loss

In [None]:
from ActiveCritic.model_src.transformer import CriticTransformer

In [None]:
ms = ModelSetup()
seq_len = 6
d_output = 1
ms.d_output = d_output
ms.nhead = 1
ms.d_hid = 10
ms.d_model = 10
ms.nlayers = 2
ms.seq_len = seq_len
ms.dropout = 0
ms.ntoken = 1
ms.lr = None
ms.device = 'cuda'
ms.optimizer_class = th.optim.AdamW
ms.optimizer_kwargs = {}
ms.d_result = 1
ms.model_class:TransformerModel = CriticTransformer

In [None]:
inpt_seq = th.ones([2,seq_len,4], dtype=th.float, device='cuda')
inpt_seq[0,-1,0] = 0

outpt_seq = th.ones([2,1], dtype=th.float, device='cuda')
outpt_seq[0] = 0

In [None]:
model = CriticTransformer(model_setup=ms).to('cuda')
with th.no_grad():
    a = model.forward(inpt_seq)

In [None]:
a.shape

In [None]:
model = CriticTransformer(model_setup=ms).to('cuda')
with th.no_grad():
    model.forward(inpt_seq)
optimizer = th.optim.Adam(params=model.parameters(), lr=1e-3)
loss = 0
for i in range(2000):
    result = model.forward(inpt_seq)
    loss = calcMSE(result, outpt_seq)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
loss