In [1]:
import utilities
import torch
t = utilities.Tokenizer()
import numpy as np
import random

import torch.nn as nn

from collections import namedtuple

In [3]:
class Attention(nn.Module):
    
    def __init__(self, c1, c2, causal_mask = False, N_heads = 16, w = 4, device = torch.device('cuda')):
        super().__init__()
        self.ln1 = nn.LayerNorm([c1])  # [Nx, c1]
        self.ln2 = nn.LayerNorm([c2])  # [Ny, c2]
        self.causal_mask = causal_mask
        self.MAH = nn.MultiheadAttention(embed_dim=c1, kdim=c2, vdim=c2, num_heads=N_heads, batch_first=True)
        self.ln3 = nn.LayerNorm([c1])  # [Nx, c1]
        self.l1 = nn.Linear(c1, c1*w)
        self.gelu = nn.GELU()
        self.l2 = nn.Linear(c1*w, c1)
        self.device = device

    def forward(self, x, y):
        xn = self.ln1(x)
        yn = self.ln2(y)
        if self.causal_mask:
            mask = torch.triu(torch.ones(x.shape[1], y.shape[1], dtype=bool), diagonal=1).to(self.device)
            attn = self.MAH(xn, yn, yn, attn_mask = mask)[0]
        else:
            attn = self.MAH(xn, yn, yn)[0]
        x = x + attn
        x = x + self.l2(self.gelu(self.l1(self.ln3(x))))
        return x

class AttentiveModes(nn.Module):
    def __init__(self, s, c):
        super().__init__()
        self.attention = Attention(c, c, N_heads = 8)
        self.s = s
        self.c = c

    def forward(self, x1, x2, x3):
        g = [x1, x2, x3]
        for m1, m2 in [(0, 1), (2, 0), (1, 2)]:
            a = torch.concatenate((g[m1], torch.transpose(g[m2], 1, 2)), axis=2)
            aflat = a.flatten(0,1)
            c = self.attention(aflat, aflat).reshape_as(a)
            g[m1] = c[:, :, :self.s, :]
            g[m2] = c[:, :, self.s:, :].transpose(1,2)
        return g


class Torso(nn.Module):
    def __init__(self, s, c, i):
        super().__init__()
        self.l1 = nn.Linear(s, c)
        self.attentive_modes = nn.ModuleList([AttentiveModes(s, c) for _ in range(i)])
        self.s = s
        self.c = c
        self.i = i

    def forward(self, x):
        x1 = torch.permute(x, (0, 1, 2, 3))
        x2 = torch.permute(x, (0, 2, 3, 1))
        x3 = torch.permute(x, (0, 3, 1, 2))

        x1 = self.l1(x1)
        x2 = self.l1(x2)
        x3 = self.l1(x3)
        
        for am in self.attentive_modes:
            x1, x2, x3 = am(x1, x2, x3)

        e = torch.reshape(torch.stack([x1, x2, x3], axis=2), (-1, 3 * self.s ** 2, self.c))
        return e

class PolicyHead(nn.Module):
    # Currently assumes our implemented tokenization scheme
    # That is, Nstesp = s and Nlogits = range^3
    def __init__(self, Nsteps, elmnt_range, s, c, Nfeatures = 64, Nheads = 16, Nlayers = 2, device = torch.device('cuda')):
        super().__init__()
        self.Nlayers = Nlayers
        self.Nlogits = (elmnt_range[1]-elmnt_range[0]+1)**3
        self.tokenizer = utilities.Tokenizer(elmnt_range)
        self.Nsteps = Nsteps
        self.Nfeatures = Nfeatures
        self.Nheads = Nheads
        self.device = device

        self.tok_embedding = nn.Embedding(self.Nlogits+1, Nfeatures * Nheads)  #In principle more efficient than forming one-hot vectors and matrix multplying
        self.START_TOK = self.Nlogits
        self.pos_embedding = nn.Embedding(Nsteps, Nfeatures * Nheads)

        # I figure if we are keeping the weights in the LayerNorm, we might as well have
        #   a different one for each layer, but idk really
        self.ln1 = nn.ModuleList([nn.LayerNorm([Nfeatures * Nheads]) for _ in range(Nlayers)])  # [Nsteps, Nfeatures * Nheads]
        self.dropout = nn.Dropout(p=0.1)
        self.self_attention = nn.ModuleList([Attention(Nfeatures * Nheads, Nfeatures * Nheads, causal_mask=True, N_heads=Nheads) for _ in range(Nlayers)])
        self.ln2 = nn.ModuleList([nn.LayerNorm([Nfeatures * Nheads]) for _ in range(Nlayers)])
        self.cross_attention = nn.ModuleList([Attention(Nfeatures * Nheads, c, N_heads=Nheads) for _ in range(Nlayers)])
        
        self.relu = nn.ReLU()
        self.lfinal = nn.Linear(Nfeatures * Nheads, self.Nlogits)

    def predict_logits(self, a, e):   # Assumes a is in tokenized, not one-hot form
        x = self.tok_embedding(a)
        positions = torch.arange(a.shape[1]).repeat((a.shape[0], 1)).to(self.device)
        x = x + self.pos_embedding(positions)
        for i in range(self.Nlayers):
            x = self.ln1[i](x)
            c = self.self_attention[i](x, x)
            c = self.dropout(c)  # Does not run if in evaluation mode
            x = x + c
            x = self.ln2[i](x)
            c = self.cross_attention[i](x, e)
            c = self.dropout(c)
            x = x + c
        o = self.lfinal(self.relu(x))
        return o    # Don't need x bc we are not feeding it to the value head
    
    def forward(self, e, **kwargs):
        if self.training:
            g = kwargs['g']
            #I'm not entirely sure this is right -- need to think on tokens and what the null character is
            #g = torch.cat((torch.tensor([0]), g))
            #Not working at the moment, going to stick with this and not shifting, but maybe there's a shift or something needed?
            # a = nn.functional.one_hot(g, self.Nlogits).float()
            # o, z = self.predict_logits(a, e)
            # return o, z
            a = torch.concatenate((torch.tensor(self.START_TOK).repeat(g.shape[0], 1).to(self.device), g[:, :-1].to(self.device)), axis=1).to(self.device)
            return self.predict_logits(a, e)
        
        else:
            Nsamples = kwargs['Nsamples']
            #a = torch.zeros((Nsamples, self.Nsteps)).long()
            a = [[self.START_TOK] for _ in range(Nsamples)]
            p = torch.ones(Nsamples)
            #z = torch.zeros((Nsamples, self.Nsteps, self.Nfeatures * self.Nheads))
            #Don't care about exporting Z anymore
            for j in range(Nsamples):
                for i in range(self.Nsteps):
                    # encoded = nn.functional.one_hot(a[j, :], self.Nlogits)
                    # o, _ = self.predict_logits(encoded.float(), e)
                    o = self.predict_logits(torch.tensor([a[j]]).to(self.device), e)
                    probs = torch.softmax(o[0, i, :], -1).to('cpu')
                    tok = torch.multinomial(probs, num_samples=1).item()
                    a[j].append(tok)
                    p[j] *= probs[tok]
            
            actions = self.tokenizer.batch_detokenize(torch.tensor(a)[:,1:])
            probs = p/p.sum()

            return namedtuple('Policy', 'actions probs')(actions, probs)
            

class ValueHead(nn.Module):
    def __init__(self, c, d):
        super().__init__()
        self.c = c
        self.d = d
        
        self.l1 = nn.Linear(c, d)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(d, d)
        self.l3 = nn.Linear(d, d)
        self.lf = nn.Linear(d, 1)

    def forward(self, x):
        x = torch.mean(x, axis=1)
        x = self.relu(self.l1(x))
        x = self.relu(self.l2(x))
        x = self.relu(self.l3(x))
        x = self.lf(x)
        return x


## Before implementing heads, read up on Torch head/transformer modules and how they work further. Unclear to me if their transformers do what we want. 
## Also, need to be careful with setting up training vs acting

class AlphaTensor184(nn.Module):
    def __init__(self, s, c, d, elmnt_range, Nsteps, Nsamples, N_policy_features = 48, N_policy_heads = 12, torso_iterations = 8):
        super().__init__()
        self.s = s
        self.c = c
        self.Nlogits = elmnt_range[1]-elmnt_range[0]+1
        self.Nsteps = Nsteps
        self.Nsamples = Nsamples
        
        self.torso = Torso(s, c, torso_iterations)
        self.value_head = ValueHead(c, d) 
        self.policy_head = PolicyHead(Nsteps, elmnt_range, s, c, Nfeatures=N_policy_features, Nheads=N_policy_heads)
    
    def forward(self, x, g=None):
        e = self.torso(x)
        q = self.value_head(e)
        if g == None:  # Inference
            assert(not(self.training))
            policy = self.policy_head(e, Nsamples=self.Nsamples)
            return (q, policy)
        else: # Training
            assert(self.training)
            logits = self.policy_head(e, g=g)
            return (q, logits)

In [34]:
attn_modes = AttentiveModes(4, 64)
print(attn_modes(torch.ones(64,4,4,64), torch.ones(64,4,4,64), torch.ones(64,4,4,64))[0].shape)
torso = Torso(4,64,2)
print(torso(torch.ones(64,4,4,4)).shape)

torch.Size([64, 4, 4, 64])
torch.Size([64, 48, 64])


In [4]:
torch.cuda.empty_cache()

t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
print(f"Total memory: {t}")
print(f"Reserved memory: {r}")
print(f"Allocated memory: {a}")

alphaTensor184 = AlphaTensor184(s = 4, c = 64, d = 48, elmnt_range=(-2, 2), Nsteps=4, Nsamples=32, torso_iterations=5)

t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
print(f"Total memory: {t}")
print(f"Reserved memory: {r}")
print(f"Allocated memory: {a}")


pytorch_total_params = sum(p.numel() for p in alphaTensor184.parameters())

print(pytorch_total_params)



Total memory: 8585216000
Reserved memory: 0
Allocated memory: 0
Total memory: 8585216000
Reserved memory: 0
Allocated memory: 0
15188478


In [5]:
alphaTensor184.to("cuda")

t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
print(f"Total memory: {t}")
print(f"Reserved memory: {r}")
print(f"Allocated memory: {a}")


alphaTensor184.train()

Total memory: 8585216000
Reserved memory: 67108864
Allocated memory: 60779008


AlphaTensor184(
  (torso): Torso(
    (l1): Linear(in_features=4, out_features=64, bias=True)
    (attentive_modes): ModuleList(
      (0-4): 5 x AttentiveModes(
        (attention): Attention(
          (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (MAH): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (l1): Linear(in_features=64, out_features=256, bias=True)
          (gelu): GELU(approximate='none')
          (l2): Linear(in_features=256, out_features=64, bias=True)
        )
      )
    )
  )
  (value_head): ValueHead(
    (l1): Linear(in_features=64, out_features=48, bias=True)
    (relu): ReLU()
    (l2): Linear(in_features=48, out_features=48, bias=True)
    (l3): Linear(in_features=48, out_features=48, bias=True)
    (lf): L

In [6]:
print(alphaTensor184)

from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(alphaTensor184)



AlphaTensor184(
  (torso): Torso(
    (l1): Linear(in_features=3, out_features=64, bias=True)
    (attentive_modes): ModuleList(
      (0-5): 6 x AttentiveModes(
        (attention): Attention(
          (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (MAH): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (l1): Linear(in_features=64, out_features=256, bias=True)
          (gelu): GELU(approximate='none')
          (l2): Linear(in_features=256, out_features=64, bias=True)
        )
      )
    )
  )
  (value_head): ValueHead(
    (l1): Linear(in_features=64, out_features=64, bias=True)
    (relu): ReLU()
    (l2): Linear(in_features=64, out_features=64, bias=True)
    (l3): Linear(in_features=64, out_features=64, bias=True)
    (lf): L

47038974

In [111]:
vh = PolicyHead(10, 64, 3, 64)

vh.train()
vh.eval()

vh.forward(torch.randn(3 * 3 * 3, 64), Nsamples = 32)

tensor(0.0117, grad_fn=<SelectBackward0>)
tensor(0.0025, grad_fn=<SelectBackward0>)
tensor(0.0162, grad_fn=<SelectBackward0>)
tensor(0.0144, grad_fn=<SelectBackward0>)
tensor(0.0142, grad_fn=<SelectBackward0>)
tensor(0.0059, grad_fn=<SelectBackward0>)
tensor(0.0518, grad_fn=<SelectBackward0>)
tensor(0.0255, grad_fn=<SelectBackward0>)
tensor(0.0087, grad_fn=<SelectBackward0>)
tensor(0.0165, grad_fn=<SelectBackward0>)
tensor(0.0085, grad_fn=<SelectBackward0>)
tensor(0.0104, grad_fn=<SelectBackward0>)
tensor(0.0245, grad_fn=<SelectBackward0>)
tensor(0.0692, grad_fn=<SelectBackward0>)
tensor(0.0134, grad_fn=<SelectBackward0>)
tensor(0.0196, grad_fn=<SelectBackward0>)
tensor(0.0106, grad_fn=<SelectBackward0>)
tensor(0.0098, grad_fn=<SelectBackward0>)
tensor(0.0346, grad_fn=<SelectBackward0>)
tensor(0.0364, grad_fn=<SelectBackward0>)
tensor(0.0501, grad_fn=<SelectBackward0>)
tensor(0.0221, grad_fn=<SelectBackward0>)
tensor(0.0286, grad_fn=<SelectBackward0>)
tensor(0.0161, grad_fn=<SelectBack

(tensor([[ 8, 14, 29, 11, 11, 16, 62, 52, 43, 42],
         [26, 27,  1,  4, 33, 30, 63,  7, 48, 48],
         [10, 60, 19, 42, 31, 10, 39, 28,  1, 11],
         [59, 39, 39, 62, 25, 62,  3, 60, 21,  0],
         [40, 22, 45, 39, 13, 48,  1,  4, 48, 30],
         [ 4, 47, 29, 21,  6, 16, 59, 61, 63, 31],
         [ 3, 62,  3,  1, 41, 12, 22,  6, 52, 52],
         [10, 19, 16, 62, 58, 63, 12,  4, 40, 33],
         [47, 39, 57, 30, 10,  0, 31, 52, 44, 61],
         [ 4, 60,  4,  4, 25,  0, 51, 57, 15, 33],
         [63, 48, 10, 51, 18,  1,  4, 21, 15, 47],
         [54, 10, 39, 50,  3, 54, 48, 30, 61, 54],
         [ 6, 38, 37,  4, 36,  3, 48, 47,  4,  3],
         [19, 38,  7, 38,  9, 31, 48, 31, 47, 29],
         [39, 12, 29, 50, 63, 10,  4, 48, 44, 35],
         [45, 16, 48, 57, 45, 57, 41, 38, 19, 30],
         [55, 44, 51, 58,  4,  3,  4, 33, 21, 62],
         [62,  6,  7, 40, 44, 19,  8, 18, 46, 24],
         [47, 40, 42, 22, 14, 10,  3, 21,  5, 10],
         [10, 45, 27,  4, 30, 6

In [2]:
a = torch.tensor([[1, 0, 0],[0,0,0], [-2, -2, -2]])
b = torch.tensor([[-2,1,0],[0,1,1], [-1, 1, -1]])

In [3]:
x = t.tokenize(a)
y = t.tokenize(b)
print(a, b)
print(x, y)
w = t.detokenize(x)
x = t.detokenize(y)

print(w, x)

tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]]) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]])
tensor([63, 62,  0], dtype=torch.int32) tensor([65, 92, 41], dtype=torch.int32)
tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]], dtype=torch.int32) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]], dtype=torch.int32)


In [4]:
def generate_sample_r1(S: int, vals: list[int], factor_dist: list[float]):
    nonzero = False
    while not nonzero:
        t = np.random.choice(vals, size=(3, S), p=factor_dist)
        m = np.tensordot(np.tensordot(t[0, :], t[1, :], axes=0), t[2, :], axes=0)
        assert m.shape == (S, S, S)
        nonzero = np.any(m)
    return t, m

In [5]:
S = 3
vals = [-2, -1, 0, 1, 2]
factor_dist = [0.1, 0.2, 0.4, 0.2, 0.1]

t, m = generate_sample_r1(S, vals, factor_dist)

print(f"tensors {t}")
print(f"result {m}")

tensors [[-1  0  0]
 [ 1  1  0]
 [ 1  0  2]]
result [[[-1  0 -2]
  [-1  0 -2]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]]


In [12]:
def main(S: int, r_limit: int, factor_dist: dict, N: int, seed: int = None):
    if seed is not None:
        random.seed(seed)

    low, high = min(factor_dist.keys()), max(factor_dist.keys())
    vals = list(factor_dist.keys())
    dist = [factor_dist[i] for i in factor_dist.keys()]

    
    tokenizer = utilities.Tokenizer(range=(low, high))

    SAR_pairs = []
    
    for i in range(N):
        R = random.randint(1, r_limit)
        T = torch.zeros((S, S, S), dtype=torch.int)
        reward = 0
        for j in range(R):
            sample, m = generate_sample_r1(S, vals, dist)
            T += torch.from_numpy(m)
            tokens = tokenizer.tokenize(torch.from_numpy(sample.T))
            reward += -1
            SAR_pairs.append((T, tokens, reward))

    return SAR_pairs
            

In [15]:
torch.load("data/Sar_pairs_3_1000_123456.pt")

[(tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([61, 62, 82], dtype=torch.int32),
  -1),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([82, 63, 37], dtype=torch.int32),
  -2),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([82, 62, 63], dtype=torch.int32),
  -3),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
      

In [14]:
print(main(3, 10, {-2: .1, -1 : .2, 0: 0.4, 1: 0.2, 2: 0.1}, 10, seed=0))

[(tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([109,  62,  28], dtype=torch.int32), -1), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([38, 58, 56], dtype=torch.int32), -2), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([93, 37, 65], dtype=torch.int32), -3), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   

In [2]:
t.batch_detokenize(torch.tensor([[0,2,1]]))

tensor([[[-2, -2, -2],
         [ 0, -2, -2],
         [-1, -2, -2]]], dtype=torch.int32)

In [6]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import itertools

class ActionDataset(Dataset):
    def __init__(self, pregen_files, max_pregen, max_selfplay, selfplay_files = None):
        self.max_pregen = max_pregen
        self.max_selfplay = max_selfplay
        l = [torch.load(file) for file in pregen_files]
        self.pregen_actions = list(itertools.chain.from_iterable(l))[:self.max_pregen]
        self.selfplay_actions = []
        if selfplay_files != None:
            l = [torch.load(file) for file in selfplay_files]
            self.selfplay_actions = list(itertools.chain.from_iterable(l))[:self.max_selfplay]

    def __len__(self):
        return len(self.pregen_actions) + len(self.selfplay_actions)

    def __getitem__(self, idx):
        if idx < self.max_pregen:
            return self.pregen_actions[idx]
        else:
            return self.selfplay_actions[idx - self.max_pregen]
        
    def add_selfplay_actions(self, actions):
        self.selfplay_actions = self.selfplay_actions + actions
        if len(self.selfplay_actions) > self.max_selfplay:
            self.selfplay_actions = self.selfplay_actions[len(self.selfplay_actions) - self.max_selfplay:]

dataset = ActionDataset(["data/Sar_pairs_4_100000_123456.pt"], 500000, 10000)

len(dataset.pregen_actions)


500000

In [7]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [13]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from architecture import *
from tqdm import tqdm

def loss_fn(pred_logits, true_tokens, pred_value, true_value, val_weight=1.0, device = 'cuda'):
    policy_loss = nn.functional.cross_entropy(pred_logits.reshape(-1, pred_logits.shape[-1]), true_tokens.flatten().type(torch.LongTensor).to(device))
    value_loss = (torch.abs(pred_value - true_value)).mean()
    return policy_loss + val_weight*value_loss

def loss_reporter(pred_logits, true_tokens, pred_value, true_value, val_weight=1.0, device = 'cuda'):
    policy_loss = nn.functional.cross_entropy(pred_logits.reshape(-1, pred_logits.shape[-1]), true_tokens.flatten().type(torch.LongTensor).to(device))
    value_loss = (torch.abs(pred_value - true_value)).mean()
    return policy_loss, value_loss

def train(model, dataset, epochs, batch_size = 1024, lr=0.001, device = 'cuda'):
    model.train()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        running_loss = 0.0
        rpol, rval = 0.0, 0.0
        for batch in tqdm(dataloader):
            optimizer.zero_grad()
            states, actions, values = batch
            states = states.to(device).float()
            actions = actions.to(device)
            values = values.to(device).float()
            pred_value, pred_logits = model(states, g=actions)
            loss = loss_fn(pred_logits, actions, pred_value, values, val_weight=.33)
            pol_loss, val_loss = loss_reporter(pred_logits, actions, pred_value, values, val_weight=.33)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            rpol += pol_loss
            rval += val_loss
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}, Policy Loss: {rpol/len(dataloader)}, Value Loss: {rval/len(dataloader)}')

In [14]:
alphaTensor184.to('cuda')   
train(alphaTensor184, dataset, 5, batch_size=1024, lr=0.05)

100%|██████████| 489/489 [07:08<00:00,  1.14it/s]


Epoch 1, Loss: 3.7409712142008216, Policy Loss: 2.682469606399536, Value Loss: 3.207577705383301


100%|██████████| 489/489 [07:19<00:00,  1.11it/s]


Epoch 2, Loss: 3.74065669670183, Policy Loss: 2.6823582649230957, Value Loss: 3.2069599628448486


100%|██████████| 489/489 [07:26<00:00,  1.10it/s]


Epoch 3, Loss: 3.7406905355629014, Policy Loss: 2.6824002265930176, Value Loss: 3.2069432735443115


  3%|▎         | 14/489 [00:04<02:25,  3.26it/s]


KeyboardInterrupt: 

In [9]:
## Self-play is meant to control the self-play loop
from architecture import *
from mcts import *
from utilities import *
from tensorgame import *
import torch
import numpy as np
from tqdm import tqdm

# Probably more convenient for state to be a tensor of ints
#       so that we don't risk floating point inaccuracies when
#       checking whether the state equals zero
canonical = torch.zeros(4, 4, 4, dtype=torch.long)
canonical[0, 0, 0] = 1
canonical[0, 1, 1] = 1
canonical[1, 2, 0] = 1
canonical[1, 3, 1] = 1
canonical[2, 0, 2] = 1
canonical[2, 1, 3] = 1
canonical[3, 2, 3] = 1
canonical[3, 3, 3] = 1

def self_play(model, S: int, canonical, n_plays, num_samples = 8, num_sim = 8, identifier=1, max_actions = 12,
              cob_entries = torch.tensor([-1, 0, 1]), cob_probs = torch.tensor([.05, .9, .05]), device='cuda'):
    model.eval()

    # Build a set of target tensors
    targets = [canonical] * n_plays
    bases_changes = [None] * n_plays
    for i, state in enumerate(targets):
        cob = change_of_basis(S, cob_entries, cob_probs)
        targets[i] = apply_COB(state, cob)
        bases_changes[i] = cob
    # Play the game using MCTS and model for each tensor

    # Not including cob in successful_trajectories anymore b/c we can just
    #       immediately perform the reverse change of basis (see below)
    successful_trajectories = [] # Tuples of (Initial State, [Actions], Final State)
    SAR_pairs = [] # Tuples of (State, Action, Reward)

    total_reward = 0

    for i, target in tqdm(enumerate(targets)):
        ## Need to expand this
        # Avoding separate root = TensorGame(target, max_actions)
        #     for clarity: We want to work with mcts.root, which is
        #     updated with mcts.search_and_play, rather than root.
        mcts = MCTS(TensorGame(target, max_actions), model, device=device)

        # Storing all rewards; see comment below
        rewards = []
        states = []
        actions = []
        for _ in range(max_actions):
            states.append(mcts.root.state)
            # search_and_play already calls search internally
            r, action = mcts.search_and_play(num_samples, num_sim)
            actions.append(action)
            rewards.append(r)

            if mcts.root.done():  # Already considers i == max_actions - 1
                break
        
        rewards[-1] += mcts.root.terminal_reward()
        
        # I think the value should be the sum of the suffix of the list of rewards
        #      rather than smearing the total reward equally over all actions.  For
        #      instance, I think the last station-action pair should have reward
        #      -1 + terminal rather than (-n + terminal)/n = -1 + terminal/n, where
        #      n = len(actions)
        SAR = []
        value = 0
        for (state, action, reward) in zip(reversed(states), reversed(actions), reversed(rewards)):
            value += reward
            # Maybe we could consider canonicalizing the action by sorting or smnth
            SAR.append((state, action, value))

        SAR_pairs += SAR

        total_reward += reward

        if mcts.root.is_zero():
            orig_actions = [action @ bases_changes[i] for action in actions]
            successful_trajectories.append((target, orig_actions))
    
    torch.save(successful_trajectories, f"data/successful_trajectories_{identifier}.pt")
    torch.save(SAR_pairs, f"data/SAR_pairs_sp_{identifier}.pt")

    return successful_trajectories, total_reward / n_plays


In [10]:
alphaTensor184.load_state_dict(torch.load("models/model_1_test.pt"))
self_play(alphaTensor184, 4, canonical, 100, identifier=1)

14it [18:07, 77.65s/it]


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
