In [16]:
import utilities
import torch
t = utilities.Tokenizer()
import numpy as np
import random

import torch.nn as nn

In [68]:
class Attention(nn.Module):
    
    def __init__(self, Nx, c1, Ny, c2, causal_mask = False, N_heads = 16, w = 4):
        super().__init__()
        self.ln1 = nn.LayerNorm([Nx, c1])
        self.ln2 = nn.LayerNorm([Ny, c2])
        self.causal_mask = causal_mask
        if not causal_mask:
            self.MAH = nn.MultiheadAttention(embed_dim=c1, kdim=c2, vdim=c2, num_heads=N_heads)
        else:
            self.mask = torch.triu(torch.ones((Nx, Ny)))
            self.MAH = nn.MultiheadAttention(embed_dim=c1, kdim=c2, vdim=c2, num_heads=N_heads)
        self.ln3 = nn.LayerNorm([Nx, c1])
        self.l1 = nn.Linear(c1, c1*w)
        self.gelu = nn.GELU()
        self.l2 = nn.Linear(c1*w, c1)

    def forward(self, x, y):
        xn = self.ln1(x)
        yn = self.ln2(y)
        if self.causal_mask:
            attn, _ = self.MAH(xn, yn, yn, attn_mask=self.mask)
        else:
            attn, _ = self.MAH(xn, yn, yn)
        x = x + attn
        x = x + self.l2(self.gelu(self.l1(self.ln3(x))))
        return x

In [69]:
attentive_layer = Attention(10, 64, 16, 16, causal_mask = False, N_heads = 16, w = 4)
attentive_layer.forward(torch.randn(10, 64), torch.randn(16, 16)) 

tensor([[ 8.6339e-01,  2.1609e+00,  9.1776e-01, -2.6134e+00, -7.0656e-02,
          4.5290e-01,  4.0118e-01, -2.8243e-01, -2.4631e-01, -4.3023e-01,
          1.6653e+00,  1.3829e+00,  3.2515e-02,  4.5046e-01,  1.4140e-01,
          1.8600e-01,  1.1458e+00,  9.6733e-02,  1.4028e+00, -1.5345e+00,
          9.9876e-01, -1.0317e-02,  2.9677e-01,  7.9878e-01,  6.3228e-01,
         -2.6228e-01,  2.7505e-01, -7.6220e-02, -8.1162e-02, -1.0798e+00,
         -2.6195e-01, -1.2498e+00, -1.2078e+00,  1.0001e+00, -3.0643e-01,
         -1.1021e+00, -1.8784e-01, -6.9430e-01, -6.2286e-01,  4.2635e-01,
         -6.1999e-01, -5.3332e-01,  1.3667e+00, -1.0432e+00, -1.9445e-01,
          8.6308e-01,  6.3100e-01,  2.5688e-02, -2.5792e-01, -3.1098e-01,
         -1.3929e+00, -2.1365e+00,  7.7157e-02,  6.8993e-01,  1.4393e+00,
          9.0598e-01, -8.1687e-02,  4.9912e-01,  3.0330e+00, -1.0907e+00,
          2.8486e-01,  1.5091e+00,  8.2846e-03,  2.2945e-01],
        [ 2.4506e-01,  3.8950e-01, -5.4670e-01, -9

In [112]:
class AttentiveModes(nn.Module):
    def __init__(self, s, c):
        super().__init__()
        self.attention = Attention(2 * s, c, 2 * s, c, N_heads = 8)
        self.s = s
        self.c = c

    def forward(self, x1, x2, x3):
        g = [x1, x2, x3]
        for m1, m2 in [(0, 1), (2, 0), (1, 2)]:
            a = torch.concatenate((g[m1], torch.transpose(g[m2], 0, 1)), axis=1)
            for i in range(self.s):
                c = self.attention(a[i, :, :], a[i, :, :])
                g[m1][i, :, :] = c[:self.s, :]
                g[m2][:, i, :] = c[self.s:, :]
        return g
    
class Torso(nn.Module):
    def __init__(self, s, c, i):
        super().__init__()
        self.l1 = nn.Linear(s, c)
        self.attentive_modes = nn.ModuleList([AttentiveModes(s, c) for _ in range(i)])
        self.s = s
        self.c = c
        self.i = i

    def forward(self, x):
        x1 = torch.permute(x, (0, 1, 2))
        x2 = torch.permute(x, (1, 2, 0))
        x3 = torch.permute(x, (2, 0, 1))

        x1 = self.l1(x1)
        x2 = self.l1(x2)
        x3 = self.l1(x3)
        
        for am in self.attentive_modes:
            x1, x2, x3 = am(x1, x2, x3)

        e = torch.reshape(torch.stack([x1, x2, x3], axis=1), (3 * self.s ** 2, self.c))    
        return e
    
class ValueHead(nn.Module):
    def __init__(self, c, d):
        super().__init__()
        self.c = c
        self.d = d
        
        self.l1 = nn.Linear(c, d)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(d, d)
        self.l3 = nn.Linear(d, d)
        self.lf = nn.Linear(d, 1)

    def forward(self, x):
        x = torch.mean(x, axis=0)
        x = self.relu(self.l1(x))
        x = self.relu(self.l2(x))
        x = self.relu(self.l3(x))
        x = self.lf(x)
        return x
    
class Torso(nn.Module):
    def __init__(self, s, c, i):
        super().__init__()
        self.l1 = nn.Linear(s, c)
        self.attentive_modes = nn.ModuleList([AttentiveModes(s, c) for _ in range(i)])
        self.s = s
        self.c = c
        self.i = i

    def forward(self, x):
        x1 = torch.permute(x, (0, 1, 2))
        x2 = torch.permute(x, (1, 2, 0))
        x3 = torch.permute(x, (2, 0, 1))

        x1 = self.l1(x1)
        x2 = self.l1(x2)
        x3 = self.l1(x3)
        
        for am in self.attentive_modes:
            x1, x2, x3 = am(x1, x2, x3)

        e = torch.reshape(torch.stack([x1, x2, x3], axis=1), (3 * self.s ** 2, self.c))

class PolicyHead(nn.Module):
    def __init__(self, Nsteps, Nlogits, s, c, Nfeatures = 64, Nheads = 16, Nlayers = 2):
        super().__init__()
        self.Nlayers = Nlayers
        self.Nlogits = Nlogits
        self.Nsteps = Nsteps
        self.Nfeatures = Nfeatures
        self.Nheads = Nheads


        self.l1 = nn.Linear(Nlogits, Nfeatures * Nheads)
        self.ln = nn.LayerNorm([Nsteps, Nfeatures * Nheads])
        self.lookup = nn.Parameter(torch.empty((Nsteps, Nfeatures * Nheads)))
        nn.init.normal_(self.lookup, mean=0, std=1)

        self.ln = nn.LayerNorm([Nsteps, Nfeatures * Nheads])
        self.dropout = nn.Dropout(p=0.1)
        self.self_attention = nn.ModuleList([Attention(Nsteps, Nfeatures * Nheads, Nsteps, Nfeatures * Nheads, causal_mask=True, N_heads=Nheads) for _ in range(Nlayers)])
        self.cross_attention = nn.ModuleList([Attention(Nsteps, Nfeatures * Nheads, 3 * s ** 2, c, N_heads=Nheads) for _ in range(Nlayers)])
        
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(Nfeatures * Nheads, Nlogits)

    def predict_logits(self, a, e):
        x = self.l1(a)
        # x  = x + Learnable Position Encoding

        for i in range(self.Nlayers):
            x = self.ln(x)
            c = self.self_attention[i](x, x)
            c = self.dropout(c)
            x = x + c
            x = self.ln(x)
            c = self.cross_attention[i](x, e)
            c = self.dropout(c)
            x = x + c
        o = self.l2(self.relu(x))
        return o, x
    
    def forward(self, e, **kwargs):
        if self.training:
            g = kwargs['g']
            #I'm not entirely sure this is right -- need to think on tokens and what the null character is
            #g = torch.cat((torch.tensor([0]), g))
            #Not working at the moment, going to stick with this
            a = nn.functional.one_hot(g, self.Nlogits).float()
            o, z = self.predict_logits(a, e)
            return o, z
        
        else:
            Nsamples = kwargs['Nsamples']
            a = torch.zeros((Nsamples, self.Nsteps)).long()
            p = torch.ones(Nsamples)
            #z = torch.zeros((Nsamples, self.Nsteps, self.Nfeatures * self.Nheads))
            #Don't care about exporting Z anymore
            for j in range(Nsamples):
                for i in range(self.Nsteps):
                    encoded = nn.functional.one_hot(a[j, :], self.Nlogits)
                    o, _ = self.predict_logits(encoded.float(), e)
                    probs = torch.softmax(o[i, :], 0)
                    a[j, i] = torch.multinomial(probs, num_samples=1)
                    p = p * probs[a[j, i]]

            return a, p
        
class AlphaTensor184(nn.Module):
    def __init__(self, s, c, d, Nlogits, Nsteps, Nsamples, torso_iterations = 8):
        super().__init__()
        self.s = s
        self.c = c
        self.Nlogits = Nlogits
        self.Nsteps = Nsteps
        self.Nsamples = Nsamples
        
        self.torso = Torso(s, c, torso_iterations)
        self.value_head = ValueHead(c, d) 
        self.policy_head = PolicyHead(Nsteps, Nlogits, s, c)
    
    def forward(self, x):
        if self.training:
            pass
        else:
            pass


In [115]:
alphaTensor184 = AlphaTensor184(3, 64, 64, 125, 3, 20)

pytorch_total_params = sum(p.numel() for p in alphaTensor184.parameters())

print(pytorch_total_params)

47200126


In [111]:
vh = PolicyHead(10, 64, 3, 64)

vh.train()
vh.eval()

vh.forward(torch.randn(3 * 3 * 3, 64), Nsamples = 32)

tensor(0.0117, grad_fn=<SelectBackward0>)
tensor(0.0025, grad_fn=<SelectBackward0>)
tensor(0.0162, grad_fn=<SelectBackward0>)
tensor(0.0144, grad_fn=<SelectBackward0>)
tensor(0.0142, grad_fn=<SelectBackward0>)
tensor(0.0059, grad_fn=<SelectBackward0>)
tensor(0.0518, grad_fn=<SelectBackward0>)
tensor(0.0255, grad_fn=<SelectBackward0>)
tensor(0.0087, grad_fn=<SelectBackward0>)
tensor(0.0165, grad_fn=<SelectBackward0>)
tensor(0.0085, grad_fn=<SelectBackward0>)
tensor(0.0104, grad_fn=<SelectBackward0>)
tensor(0.0245, grad_fn=<SelectBackward0>)
tensor(0.0692, grad_fn=<SelectBackward0>)
tensor(0.0134, grad_fn=<SelectBackward0>)
tensor(0.0196, grad_fn=<SelectBackward0>)
tensor(0.0106, grad_fn=<SelectBackward0>)
tensor(0.0098, grad_fn=<SelectBackward0>)
tensor(0.0346, grad_fn=<SelectBackward0>)
tensor(0.0364, grad_fn=<SelectBackward0>)
tensor(0.0501, grad_fn=<SelectBackward0>)
tensor(0.0221, grad_fn=<SelectBackward0>)
tensor(0.0286, grad_fn=<SelectBackward0>)
tensor(0.0161, grad_fn=<SelectBack

(tensor([[ 8, 14, 29, 11, 11, 16, 62, 52, 43, 42],
         [26, 27,  1,  4, 33, 30, 63,  7, 48, 48],
         [10, 60, 19, 42, 31, 10, 39, 28,  1, 11],
         [59, 39, 39, 62, 25, 62,  3, 60, 21,  0],
         [40, 22, 45, 39, 13, 48,  1,  4, 48, 30],
         [ 4, 47, 29, 21,  6, 16, 59, 61, 63, 31],
         [ 3, 62,  3,  1, 41, 12, 22,  6, 52, 52],
         [10, 19, 16, 62, 58, 63, 12,  4, 40, 33],
         [47, 39, 57, 30, 10,  0, 31, 52, 44, 61],
         [ 4, 60,  4,  4, 25,  0, 51, 57, 15, 33],
         [63, 48, 10, 51, 18,  1,  4, 21, 15, 47],
         [54, 10, 39, 50,  3, 54, 48, 30, 61, 54],
         [ 6, 38, 37,  4, 36,  3, 48, 47,  4,  3],
         [19, 38,  7, 38,  9, 31, 48, 31, 47, 29],
         [39, 12, 29, 50, 63, 10,  4, 48, 44, 35],
         [45, 16, 48, 57, 45, 57, 41, 38, 19, 30],
         [55, 44, 51, 58,  4,  3,  4, 33, 21, 62],
         [62,  6,  7, 40, 44, 19,  8, 18, 46, 24],
         [47, 40, 42, 22, 14, 10,  3, 21,  5, 10],
         [10, 45, 27,  4, 30, 6

In [2]:
a = torch.tensor([[1, 0, 0],[0,0,0], [-2, -2, -2]])
b = torch.tensor([[-2,1,0],[0,1,1], [-1, 1, -1]])

In [3]:
x = t.tokenize(a)
y = t.tokenize(b)
print(a, b)
print(x, y)
w = t.detokenize(x)
x = t.detokenize(y)

print(w, x)

tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]]) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]])
tensor([63, 62,  0], dtype=torch.int32) tensor([65, 92, 41], dtype=torch.int32)
tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]], dtype=torch.int32) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]], dtype=torch.int32)


In [4]:
def generate_sample_r1(S: int, vals: list[int], factor_dist: list[float]):
    nonzero = False
    while not nonzero:
        t = np.random.choice(vals, size=(3, S), p=factor_dist)
        m = np.tensordot(np.tensordot(t[0, :], t[1, :], axes=0), t[2, :], axes=0)
        assert m.shape == (S, S, S)
        nonzero = np.any(m)
    return t, m

In [5]:
S = 3
vals = [-2, -1, 0, 1, 2]
factor_dist = [0.1, 0.2, 0.4, 0.2, 0.1]

t, m = generate_sample_r1(S, vals, factor_dist)

print(f"tensors {t}")
print(f"result {m}")

tensors [[-1  0  0]
 [ 1  1  0]
 [ 1  0  2]]
result [[[-1  0 -2]
  [-1  0 -2]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]]


In [12]:
def main(S: int, r_limit: int, factor_dist: dict, N: int, seed: int = None):
    if seed is not None:
        random.seed(seed)

    low, high = min(factor_dist.keys()), max(factor_dist.keys())
    vals = list(factor_dist.keys())
    dist = [factor_dist[i] for i in factor_dist.keys()]

    
    tokenizer = utilities.Tokenizer(range=(low, high))

    SAR_pairs = []
    
    for i in range(N):
        R = random.randint(1, r_limit)
        T = torch.zeros((S, S, S), dtype=torch.int)
        reward = 0
        for j in range(R):
            sample, m = generate_sample_r1(S, vals, dist)
            T += torch.from_numpy(m)
            tokens = tokenizer.tokenize(torch.from_numpy(sample.T))
            reward += -1
            SAR_pairs.append((T, tokens, reward))

    return SAR_pairs
            

In [15]:
torch.load("data/Sar_pairs_3_1000_123456.pt")

[(tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([61, 62, 82], dtype=torch.int32),
  -1),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([82, 63, 37], dtype=torch.int32),
  -2),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([82, 62, 63], dtype=torch.int32),
  -3),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
      

In [14]:
print(main(3, 10, {-2: .1, -1 : .2, 0: 0.4, 1: 0.2, 2: 0.1}, 10, seed=0))

[(tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([109,  62,  28], dtype=torch.int32), -1), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([38, 58, 56], dtype=torch.int32), -2), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([93, 37, 65], dtype=torch.int32), -3), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   