In [16]:
import utilities
import torch
t = utilities.Tokenizer()
import numpy as np
import random

import torch.nn as nn

In [22]:
class Attention(nn.Module):
    
    def __init__(self, Nx, c1, Ny, c2, causal_mask = False, N_heads = 16, w = 4):
        super().__init__()
        self.ln1 = nn.LayerNorm([Nx, c1])
        self.ln2 = nn.LayerNorm([Ny, c2])
        if not causal_mask:
            self.MAH = nn.MultiheadAttention(embed_dim=c1, kdim=c2, vdim=c2, num_heads=N_heads)
        else:
            mask = torch.triu(torch.ones((Nx, Ny)))
            self.MAH = nn.MultiheadAttention(embed_dim=c1, kdim=c2, vdim=c2, num_heads=N_heads, attn_mask=mask)
        self.ln3 = nn.LayerNorm([Nx, c1])
        self.l1 = nn.Linear(c1, c1*w)
        self.gelu = nn.GELU()
        self.l2 = nn.Linear(c1*w, c1)

    def forward(self, x, y):
        xn = self.ln1(x)
        yn = self.ln2(y)
        attn, _ = self.MAH(xn, yn, yn)
        x = x + attn
        x = x + self.l2(self.gelu(self.l1(self.ln3(x))))
        return x

In [25]:
attentive_layer = Attention(10, 64, 16, 16, causal_mask = False, N_heads = 16, w = 4)
attentive_layer.forward(torch.randn(10, 64), torch.randn(16, 16)) 

tensor([[ 0.6321, -0.4318, -0.1447,  0.2664,  0.8855, -0.9413,  0.2048,  1.8881,
         -1.8704, -0.0671,  1.6914,  0.0241,  0.9246, -1.4592, -2.0389, -0.2827,
          1.2817, -0.3346, -0.0486,  0.6836, -1.4108,  0.4785, -0.6820,  0.5239,
          0.7440, -0.0505,  0.1972,  1.2665, -0.2735, -1.5424, -0.3817, -0.0351,
          0.2799,  0.6614,  0.0917, -0.6840, -0.3617,  0.5782,  0.1629,  0.3017,
         -0.3865, -0.7851, -1.1775, -1.0810,  0.8854, -0.2663, -0.2306, -1.3840,
          1.5014,  0.3405, -0.5500, -0.1443, -2.5243, -0.3444, -0.6236,  1.6203,
         -0.1134,  0.5366,  1.4482, -0.5182,  0.6706, -0.7755, -0.4147,  0.5994],
        [ 0.4875,  1.9492,  0.3515, -0.0473,  2.8478,  0.7668, -0.7632, -0.7184,
         -0.6317,  0.3455,  0.8557, -0.2615,  2.4716, -0.8772, -1.2903,  1.8257,
          1.1174, -1.8998,  1.3568,  0.9678, -1.1556, -1.4056,  1.4699, -0.9453,
         -0.1490, -1.4407, -0.0308, -0.2206,  0.1583, -0.6473, -0.9560,  1.0041,
          1.8237, -0.9852, 

In [62]:
class AttentiveModes(nn.Module):
    def __init__(self, s, c):
        super().__init__()
        self.attention = Attention(2 * s, c, 2 * s, c, N_heads = 8)
        self.s = s
        self.c = c

    def forward(self, x1, x2, x3):
        g = [x1, x2, x3]
        for m1, m2 in [(0, 1), (2, 0), (1, 2)]:
            a = torch.concatenate((g[m1], torch.transpose(g[m2], 0, 1)), axis=1)
            for i in range(self.s):
                c = self.attention(a[i, :, :], a[i, :, :])
                g[m1][i, :, :] = c[:self.s, :]
                g[m2][:, i, :] = c[self.s:, :]
        return g
    
class Torso(nn.Module):
    def __init__(self, s, c, i):
        super().__init__()
        self.l1 = nn.Linear(s, c)
        self.attentive_modes = nn.ModuleList([AttentiveModes(s, c) for _ in range(i)])
        self.s = s
        self.c = c
        self.i = i

    def forward(self, x):
        x1 = torch.permute(x, (0, 1, 2))
        x2 = torch.permute(x, (1, 2, 0))
        x3 = torch.permute(x, (2, 0, 1))

        x1 = self.l1(x1)
        x2 = self.l1(x2)
        x3 = self.l1(x3)
        
        for am in self.attentive_modes:
            x1, x2, x3 = am(x1, x2, x3)

        e = torch.reshape(torch.stack([x1, x2, x3], axis=1), (3 * self.s ** 2, self.c))    
        return e
    
class ValueHead(nn.Module):
    def __init__(self, c, d):
        super().__init__()
        self.c = c
        self.d = d
        
        self.l1 = nn.Linear(c, d)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(d, d)
        self.l3 = nn.Linear(d, d)
        self.lf = nn.Linear(d, 1)

    def forward(self, x):
        x = torch.mean(x, axis=0)
        x = self.relu(self.l1(x))
        x = self.relu(self.l2(x))
        x = self.relu(self.l3(x))
        x = self.lf(x)
        return x


In [40]:
x = torch.randn(10, 10, 64)
x.shape
torch.transpose(x, 0, 1).shape

torch.Size([10, 10, 64])

In [65]:
vh = ValueHead(64, 64)

vh.forward(torch.randn(27, 64))

tensor([0.0708], grad_fn=<AddBackward0>)

In [2]:
a = torch.tensor([[1, 0, 0],[0,0,0], [-2, -2, -2]])
b = torch.tensor([[-2,1,0],[0,1,1], [-1, 1, -1]])

In [3]:
x = t.tokenize(a)
y = t.tokenize(b)
print(a, b)
print(x, y)
w = t.detokenize(x)
x = t.detokenize(y)

print(w, x)

tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]]) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]])
tensor([63, 62,  0], dtype=torch.int32) tensor([65, 92, 41], dtype=torch.int32)
tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]], dtype=torch.int32) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]], dtype=torch.int32)


In [4]:
def generate_sample_r1(S: int, vals: list[int], factor_dist: list[float]):
    nonzero = False
    while not nonzero:
        t = np.random.choice(vals, size=(3, S), p=factor_dist)
        m = np.tensordot(np.tensordot(t[0, :], t[1, :], axes=0), t[2, :], axes=0)
        assert m.shape == (S, S, S)
        nonzero = np.any(m)
    return t, m

In [5]:
S = 3
vals = [-2, -1, 0, 1, 2]
factor_dist = [0.1, 0.2, 0.4, 0.2, 0.1]

t, m = generate_sample_r1(S, vals, factor_dist)

print(f"tensors {t}")
print(f"result {m}")

tensors [[-1  0  0]
 [ 1  1  0]
 [ 1  0  2]]
result [[[-1  0 -2]
  [-1  0 -2]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]]


In [12]:
def main(S: int, r_limit: int, factor_dist: dict, N: int, seed: int = None):
    if seed is not None:
        random.seed(seed)

    low, high = min(factor_dist.keys()), max(factor_dist.keys())
    vals = list(factor_dist.keys())
    dist = [factor_dist[i] for i in factor_dist.keys()]

    
    tokenizer = utilities.Tokenizer(range=(low, high))

    SAR_pairs = []
    
    for i in range(N):
        R = random.randint(1, r_limit)
        T = torch.zeros((S, S, S), dtype=torch.int)
        reward = 0
        for j in range(R):
            sample, m = generate_sample_r1(S, vals, dist)
            T += torch.from_numpy(m)
            tokens = tokenizer.tokenize(torch.from_numpy(sample.T))
            reward += -1
            SAR_pairs.append((T, tokens, reward))

    return SAR_pairs
            

In [15]:
torch.load("data/Sar_pairs_3_1000_123456.pt")

[(tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([61, 62, 82], dtype=torch.int32),
  -1),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([82, 63, 37], dtype=torch.int32),
  -2),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([82, 62, 63], dtype=torch.int32),
  -3),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
      

In [14]:
print(main(3, 10, {-2: .1, -1 : .2, 0: 0.4, 1: 0.2, 2: 0.1}, 10, seed=0))

[(tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([109,  62,  28], dtype=torch.int32), -1), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([38, 58, 56], dtype=torch.int32), -2), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([93, 37, 65], dtype=torch.int32), -3), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   