In [11]:
# append parent directory to path
import sys
sys.path.append("../..")

from continous import ContinousPrediction, ContinousActionEncoder, tokenizer
import json
actions = json.loads(open("actions.json").read())

In [12]:

formatted_actions = []
prompt = lambda action_typ, action: f"{action_typ}: {action}"
for action_typ, actions_ in actions.items():
    for action in actions_:
        formatted_actions.append(prompt(action_typ, action))

In [13]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenized_actions = tokenizer(formatted_actions, padding=True, return_tensors="pt")

In [16]:
encoder = ContinousActionEncoder()

In [20]:
embedded_actions = encoder(**tokenized_actions)

In [21]:
embedded_actions

tensor([[ 0.0037, -0.0440, -0.0106,  ..., -0.0105, -0.0125,  0.0240],
        [-0.0176, -0.0251, -0.0212,  ...,  0.0112, -0.0173, -0.0417],
        [-0.0210, -0.0173, -0.0187,  ..., -0.0305, -0.0193,  0.0408],
        ...,
        [-0.0218, -0.0424, -0.0061,  ...,  0.0394, -0.0247, -0.0070],
        [-0.0086, -0.0513,  0.0189,  ...,  0.0417,  0.0113, -0.0583],
        [-0.0064, -0.0148,  0.0164,  ...,  0.0137,  0.0106,  0.0022]])

In [31]:
import torch
def loss(pred, target):
    return torch.nn.functional.cosine_embedding_loss(pred, target, torch.ones(pred.shape[0]))

In [23]:
predictor = ContinousPrediction(512, 64, 61, 512)

In [24]:
def random_input(B: int) -> torch.Tensor:
    return torch.randn(B, 512)

In [28]:
pi_probs, values = predictor(random_input(4))

In [29]:
pi_probs.shape

torch.Size([4, 512])

In [32]:
loss(pi_probs, embedded_actions[0:4])

tensor(0.9964, grad_fn=<MeanBackward0>)

In [120]:
t1 = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.5, 0.5]])
t1.shape

torch.Size([4, 3])

In [130]:
t1

tensor([[1.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 0.0000, 1.0000],
        [0.0000, 0.5000, 0.5000]])

In [121]:
t2 = torch.tensor([[[0.0, 1.0, 0.0], [0.0, 0.1, 0.5], [.8, 0.2, 0.0]], [[0.0, .5, 0.5], [0.0, 0.3, .6], [1.0, 0.0, 0.0]]])

In [122]:
t2

tensor([[[0.0000, 1.0000, 0.0000],
         [0.0000, 0.1000, 0.5000],
         [0.8000, 0.2000, 0.0000]],

        [[0.0000, 0.5000, 0.5000],
         [0.0000, 0.3000, 0.6000],
         [1.0000, 0.0000, 0.0000]]])

In [158]:
t2.shape

torch.Size([2, 3, 3])

In [123]:
cos_sim = torch.nn.CosineSimilarity(dim=-1)

In [133]:
sims = torch.stack([cos_sim(t2[:,i,:], t1.unsqueeze(1)) for i in range(t2.shape[1])], dim=-1)

In [134]:
sims

tensor([[[0.0000, 0.0000, 0.9701],
         [0.0000, 0.0000, 1.0000]],

        [[1.0000, 0.1961, 0.2425],
         [0.7071, 0.4472, 0.0000]],

        [[0.0000, 0.9806, 0.0000],
         [0.7071, 0.8944, 0.0000]],

        [[0.7071, 0.8321, 0.1715],
         [1.0000, 0.9487, 0.0000]]])

In [135]:
sims.shape

torch.Size([4, 2, 3])

In [146]:
sims.argmax(dim=0)

tensor([[1, 2, 0],
        [3, 3, 0]])

In [147]:
selected = t1[sims.argmax(dim=0)]

In [148]:
selected

tensor([[[0.0000, 1.0000, 0.0000],
         [0.0000, 0.0000, 1.0000],
         [1.0000, 0.0000, 0.0000]],

        [[0.0000, 0.5000, 0.5000],
         [0.0000, 0.5000, 0.5000],
         [1.0000, 0.0000, 0.0000]]])

In [178]:
no_select = []
# print("not selected", no_select)
for B in range(t2.shape[0]):
    no_select_batch = []
    for i in range(t2.shape[1]):
        print(f"Best match for {t2[B,i]} is {selected[B][i]}")
        not_selected = t1[~(t1 == selected[B][i]).all(dim=-1)]
        print("not selected", not_selected.shape)
        no_select_batch.append(not_selected)
    no_select_batch = torch.cat(no_select_batch, dim=0)
    print("no select batch", no_select_batch.shape)
    no_select.append(no_select_batch)

Best match for tensor([0., 1., 0.]) is tensor([0., 1., 0.])
not selected torch.Size([3, 3])
Best match for tensor([0.0000, 0.1000, 0.5000]) is tensor([0., 0., 1.])
not selected torch.Size([3, 3])
Best match for tensor([0.8000, 0.2000, 0.0000]) is tensor([1., 0., 0.])
not selected torch.Size([3, 3])
no select batch torch.Size([9, 3])
Best match for tensor([0.0000, 0.5000, 0.5000]) is tensor([0.0000, 0.5000, 0.5000])
not selected torch.Size([3, 3])
Best match for tensor([0.0000, 0.3000, 0.6000]) is tensor([0.0000, 0.5000, 0.5000])
not selected torch.Size([3, 3])
Best match for tensor([1., 0., 0.]) is tensor([1., 0., 0.])
not selected torch.Size([3, 3])
no select batch torch.Size([9, 3])


In [179]:
non_targets = torch.stack(no_select, dim=0)

In [180]:
non_targets.shape

torch.Size([2, 9, 3])

In [181]:
a = torch.rand(10, 128)
q = torch.rand(5, 128)

In [186]:
sims = torch.einsum('ik, jk -> ij', a, q)

In [187]:
sims.shape

torch.Size([10, 5])

In [193]:
comp = torch.stack([torch.stack([ai.dot(qj) for ai in a]) for qj in q], dim=-1) 

In [195]:
comp - sims

tensor([[-7.6294e-06,  5.7220e-06,  0.0000e+00, -1.9073e-06, -3.8147e-06],
        [ 0.0000e+00, -3.8147e-06, -3.8147e-06,  3.8147e-06, -1.9073e-06],
        [-3.8147e-06,  0.0000e+00,  0.0000e+00,  3.8147e-06,  3.8147e-06],
        [ 0.0000e+00,  0.0000e+00,  3.8147e-06,  3.8147e-06,  0.0000e+00],
        [-3.8147e-06,  0.0000e+00,  3.8147e-06, -5.7220e-06,  5.7220e-06],
        [ 1.5259e-05, -1.9073e-06,  0.0000e+00, -5.7220e-06,  1.9073e-06],
        [-1.1444e-05, -3.8147e-06, -3.8147e-06,  1.5259e-05, -3.8147e-06],
        [ 0.0000e+00, -1.9073e-06,  0.0000e+00,  0.0000e+00, -3.8147e-06],
        [-3.8147e-06,  0.0000e+00,  0.0000e+00,  3.8147e-06,  0.0000e+00],
        [ 0.0000e+00, -1.9073e-06,  0.0000e+00,  0.0000e+00,  0.0000e+00]])