# Attention Ain't All, Recurent Transformers

# Setup

In [1]:
import random, math, torch, time, torchtune, datasets, sys, os, json
import torch.nn as nn
from torch.utils.data import DataLoader
from pysat.solvers import Solver
from fvcore.nn import FlopCountAnalysis

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_printoptions(sci_mode=False)

# Data

In [2]:
train_samples = 512 * 1000
valid_samples = train_samples // 5
test_samples = train_samples // 10

In [3]:
START = '<start>'
STOP = '<stop>'
PAD = '<pad>'
SEP = '='

# Translates from a string to a pytorch tensor using a vocab
def encode(string, vocab, pad_length):
    out = []
    while len(string) > 0:
        vocab_match = False
        for i in range(1, len(string) + 1):
            if string[:i] in vocab:
                out.append(vocab.index(string[:i]))
                string = string[i:]
                vocab_match = True
                break
        if not vocab_match:
            raise Exception("Encoding error:", string, vocab)
    out += [vocab.index(PAD)] * (pad_length - len(out))
    return torch.tensor(out, dtype=torch.long)

# Translates from a pytorch tensor to a string using a vocab
def decode(tensor, vocab):
    ans = ""
    for i in tensor:
        if vocab[i] == START or vocab[i] == PAD:
            continue
        if vocab[i] == STOP:
            break
        ans += vocab[i]
    return ans

In [4]:
# Arithmetic dataset generated using probablistic context free grammar
# Consists of integers, +, -, *, //, and %
class ArithDataset(torch.utils.data.Dataset):
    vocab = [START, SEP, STOP, PAD] + [str(i) for i in range(10)] + ['-', '+', '*',  '%', '//', '(', ')']

    # Probablistic context free grammar rules
    # Dictionary where key is the token to be expanded and
    # right side is list of tuples containing rules and associated probabilities
    rules = {
        'EQ': [(['VAL', 'OP', 'VAL'], 0.5), (['(', 'VAL', 'OP', 'VAL', ')'], 0.5)],
        'VAL': [(['EQ'], 0.45), (['NUM'], 0.55)],
        'OP': [(['-'], 0.2), (['+'], 0.2), (['*'], 0.2), (['%'], 0.2), (['//'], 0.2)],
        'NUM': [([str(i), 'NUMT'], 1.0/19) for i in range(1, 10)] + [(['-', str(i), 'NUMT'], 1.0/19) for i in range(1, 10)] + [(['0'], 1.0/19)],
        'NUMT': [([str(i), 'NUMT'], 0.2/10) for i in range(10)] + [([], 0.8)],
    }

    # Chooses a rule based on the probabilities
    def selectRule(left_hand):
        selector = random.random()
        for i in ArithDataset.rules[left_hand]:
            selector -= i[1]
            if(selector < 0):
                return i[0]
        raise Exception("Improper rule probabilities")

    # Initialize dataset with certain bounds
    # Bounds relate to the number of tokens, not characters
    # X relates to the arithmetic problem, Y the solution
    def __init__(self, samples, seed=0, min_x_len=5, max_x_len=30, min_y_len=1, max_y_len=8):
        self.min_x_len = min_x_len
        self.max_x_len = max_x_len
        self.min_y_len = min_y_len
        self.max_y_len = max_y_len
        
        dup_check = set()
        self.xy = []
        
        random.seed(seed)
        while (len(dup_check) < samples):
            x, y = self.generateProblem()
            if x not in dup_check:
                if len(dup_check) % 100 == 0:
                    sys.stdout.write(f"\r {len(dup_check) / samples * 100:.2f}% complete")
                    sys.stdout.flush()
                self.xy.append(encode(START + x + SEP + y + STOP, self.vocab, max_x_len + max_y_len))
                dup_check.add(x)
        print("\r100.00% complete")
        
        self.xy = torch.stack(self.xy)
        self.x_lens = (self.xy == self.vocab.index(SEP)).nonzero()[:, 1] + 1
        self.xy_lens = (self.xy == self.vocab.index(STOP)).nonzero()[:, 1] + 1

        # Trim the max length
        self.xy = self.xy[:, :torch.max(self.xy_lens)]

    # Generate a problem using the context free grammar within dataset bounds
    def generateProblem(self):
        while(True):
            stack = ['EQ']
            index = 0
            while (index < len(stack) and len(stack) <= (self.max_x_len - 2)): # Subtract 2 for start and sep token
                if stack[index] in ArithDataset.rules:
                    stack = stack[:index] + ArithDataset.selectRule(stack[index]) + stack[index + 1:]
                else:
                    index += 1

            if len(stack) > (self.max_x_len - 2) or len(stack) < (self.min_x_len - 2):
                continue
            
            try: # Catch division or modulus by 0
                x = ''.join(stack)
                y = str(eval(x))
                if (len(y) >= (self.min_y_len - 1) and len(y) <= (self.max_y_len - 1)): # Subtract 1 for stop token
                    return x, y
            
            except:
                pass

    def save(self, filename):
        torch.save({attr:getattr(self, attr) for attr in ['min_x_len', 'max_x_len', 'min_y_len', 'max_y_len', 'xy', 'x_lens', 'xy_lens']}, filename)

    @classmethod
    def load(cls, filename):
        obj = object.__new__(cls)
        obj_data = torch.load(filename)
        for attr in obj_data:
            setattr(obj, attr, obj_data[attr])
        return obj
    
    def __len__(self):
        return len(self.xy)
    
    def __getitem__(self, idx):
        return self.xy[idx], self.x_lens[idx], self.xy_lens[idx]

In [5]:
if os.path.exists("arith_dataset.pt"):
    print("Loading saved arithmetic dataset")
    arithData = ArithDataset.load("arith_dataset.pt")
else:
    print("Generating arithmetic dataset")
    arithData = ArithDataset(train_samples + valid_samples + test_samples)
    arithData.save('arith_dataset.pt')
arithSets = torch.utils.data.random_split(arithData, [train_samples, valid_samples, test_samples], generator=torch.Generator().manual_seed(0))
    
print("Arithmetic Problems:")
for i in range(20):
    print(decode(arithData[i][0], arithData.vocab))

Loading saved arithmetic dataset
Arithmetic Problems:
(8*(-9+-3//-7))=-72
-9*-9*5=405
(((1//4+-7)-(3%-7))//-3)=1
(((-14//-8)//8+-8)*-3)=24
(-9--1)=-8
-91-79-(-7-0--8)=-171
98*-1=-98
5-72=-67
7+-4=3
(19+8*8-(-2//6)//-79)=83
(3%((-9//(-4//8))+17%7))=3
(6+2+-1//9%-9+-434)-9*9=-508
(4*-1)=-4
(-4--3)=-1
-7//-7=1
(1//23--73+(30--895))=998
2*7=14
0%(-7+9)=0
3%40=3
(19--9)=28


In [6]:
# Integer to hexadecimal dataset
class HexDataset(torch.utils.data.Dataset):
    vocab = [START, SEP, STOP, PAD] + [str(i) for i in range(10)] + ['a', 'b', 'c', 'd', 'e', 'f']

    # Initialize dataset with certain bounds
    # We generate numbers on an exponential scale to get variety of lengths and difficulties
    def __init__(self, samples, seed=0, min_exp=0, max_exp=10):
        self.min_exp = min_exp
        self.max_exp = max_exp
        
        dup_check = set()
        self.xy = []

        max_len = len(str(int(10 ** max_exp))) + len(hex(int(10 ** max_exp))[2:]) + 3 # Crop out 0x, add 3 for start, sep, stop
        
        random.seed(seed)
        while (len(dup_check) < samples):
            x = int(10 ** ((max_exp - min_exp) * random.random() + min_exp))
            if x not in dup_check:
                if len(dup_check) % 100 == 0:
                    sys.stdout.write(f"\r {len(dup_check) / samples * 100:.2f}% complete")
                    sys.stdout.flush()
                self.xy.append(encode(START + str(x) + SEP + hex(x)[2:] + STOP, self.vocab, max_len))
                dup_check.add(x)
        print("\r100.00% complete")
        
        self.xy = torch.stack(self.xy)
        self.x_lens = (self.xy == self.vocab.index(SEP)).nonzero()[:, 1] + 1
        self.xy_lens = (self.xy == self.vocab.index(STOP)).nonzero()[:, 1] + 1

        # Trim the max length
        self.xy = self.xy[:, :torch.max(self.xy_lens)]

    def save(self, filename):
        torch.save({attr:getattr(self, attr) for attr in ['min_exp', 'max_exp', 'xy', 'x_lens', 'xy_lens']}, filename)

    @classmethod
    def load(cls, filename):
        obj = object.__new__(cls)
        obj_data = torch.load(filename)
        for attr in obj_data:
            setattr(obj, attr, obj_data[attr])
        return obj
        
    def __len__(self):
        return len(self.xy)
    
    def __getitem__(self, idx):
        return self.xy[idx], self.x_lens[idx], self.xy_lens[idx]

In [7]:
if os.path.exists("hex_dataset.pt"):
    print("Loading saved hex dataset")
    hexData = HexDataset.load("hex_dataset.pt")
else:
    print("Generating hex dataset")
    hexData = HexDataset(train_samples + valid_samples + test_samples)
    hexData.save("hex_dataset.pt")
hexSets = torch.utils.data.random_split(hexData, [train_samples, valid_samples, test_samples], generator=torch.Generator().manual_seed(0))
    
print("Hex Problems:")
for i in range(20):
    print(decode(hexData[i][0], hexData.vocab))

Loading saved hex dataset
Hex Problems:
278111223=1093a3f7
37979044=24383a4
16058=3eba
388=184
129642=1fa6a
11203=2bc3
68862992=41ac410
1079=437
58340=e3e4
682056=a6848
1205393518=47d8d86e
111395=1b323
658=292
36144485=2278565
1526475=174acb
319=13f
1251591369=4a99c4c9
6727516326=190fdc0a6
126523838=78a99be
1051137437=3ea7159d


In [8]:
# Generate a randomized cnf 3sat problem with specific count of clauses and variables
def generate3Sat(n_clauses, n_vars):
    # Sets to remove duplicate clauses
    clauses = set()
    while len(clauses) < n_clauses:
        clauses.add(frozenset([i if random.random() > 0.5 else -i for i in random.sample(range(1, n_vars + 1), 3)]))
    # Check every variable is included at least once
    for var in range(1, n_vars + 1):
        if sum([var in clause for clause in clauses]) == 0:
            return generate3Sat(n_clauses, n_vars)
    return tuple(tuple(random.sample(list(clause), 3)) for clause in random.sample(list(clauses), n_clauses))

In [9]:
# Dataset of 3-sat problems where the model predicts wether or not the problem is sat
class SatSolveDataset(torch.utils.data.Dataset):
    def __init__(self, samples, seed=0, n_vars=5):
        if n_vars < 4:
            raise Exception("Must have more than 4 vars")
        self.n_vars = n_vars
        # Spaces after numbers to break ambiguity when n_vars > 9
        self.vocab = [START, SEP, STOP, PAD] + [f'{i} ' for i in range(1, n_vars + 1)] + [f'{-i} ' for i in range(1, n_vars + 1)] + [', ', 'sat', 'unsat']

        dup_check = set()
        self.xy = []

        # Max vars = num vars * max clauses per var * tokens per clause + sep, soln, stop (stop is trivial, but used in solve function later)
        max_len = n_vars * 5 * 4 + 3
        while (len(dup_check) < samples):
            # The hardest satisfiability problems occur when there are ~ 4.26 clauses per variable
            # to create consistently hard problems with some variation, we generate problems where there are
            # between 4 and 5 clauses per variable
            n_clauses = random.randint(n_vars * 4, n_vars * 5)
            x = generate3Sat(n_clauses, n_vars)
            if x not in dup_check:
                if len(dup_check) % 100 == 0:
                    sys.stdout.write(f"\r {len(dup_check) / samples * 100:.2f}% complete")
                    sys.stdout.flush()
                with Solver(name='g3') as solver:
                    for clause in x:
                        solver.add_clause(clause)
                    satisfiable = solver.solve()
                if satisfiable:
                    self.xy.append(encode(START + ', '.join(''.join(f'{var} ' for var in clause) for clause in x) + SEP + 'sat' + STOP, self.vocab, max_len))
                else:
                    self.xy.append(encode(START + ', '.join(''.join(f'{var} ' for var in clause) for clause in x) + SEP + 'unsat' + STOP, self.vocab, max_len))
                dup_check.add(x)
        print("\r100.00% complete")

        self.xy = torch.stack(self.xy)
        self.x_lens = (self.xy == self.vocab.index(SEP)).nonzero()[:, 1] + 1
        self.xy_lens = (self.xy == self.vocab.index(STOP)).nonzero()[:, 1] + 1

        # Trim the max length
        self.xy = self.xy[:, :torch.max(self.xy_lens)]

    def save(self, filename):
        torch.save({attr:getattr(self, attr) for attr in ['n_vars', 'vocab', 'xy', 'x_lens', 'xy_lens']}, filename)

    @classmethod
    def load(cls, filename):
        obj = object.__new__(cls)
        obj_data = torch.load(filename)
        for attr in obj_data:
            setattr(obj, attr, obj_data[attr])
        return obj
                    
    def __len__(self):
        return len(self.xy)
    
    def __getitem__(self, idx):
        return self.xy[idx], self.x_lens[idx], self.xy_lens[idx]

In [10]:
if os.path.exists("sat_solve_dataset.pt"):
    print("Loading saved sat solve dataset")
    satSolveData = SatSolveDataset.load("sat_solve_dataset.pt")
else:
    print("Generating sat solve dataset")
    satSolveData = SatSolveDataset(train_samples + valid_samples + test_samples)
    satSolveData.save('sat_solve_dataset.pt')
satSolveSets = torch.utils.data.random_split(satSolveData, [train_samples, valid_samples, test_samples], generator=torch.Generator().manual_seed(0))

print("Satisfiability Solver Problems:")
for i in range(5):
    print(decode(satSolveData[i][0], satSolveData.vocab))

Loading saved sat solve dataset
Satisfiability Solver Problems:
-5 -2 -1 , 3 -2 4 , 1 -2 4 , -4 1 3 , -5 1 -3 , -2 -1 5 , 4 -3 1 , -4 5 2 , 1 5 2 , -4 -5 3 , 2 -1 5 , -3 -2 4 , -5 -1 -4 , 2 4 -1 , -4 5 -2 , -1 -2 4 , 4 -5 -3 , 5 -3 4 , 5 2 4 , 5 -1 3 , 1 -2 5 , -3 -1 5 =sat
2 -1 -5 , -1 5 3 , -4 -1 5 , -5 -2 -1 , -1 4 2 , -5 2 1 , 5 -2 3 , 4 -2 -3 , -1 4 -3 , 2 4 -5 , 2 1 3 , 2 1 -3 , 5 2 3 , -1 2 3 , -2 -4 1 , 2 4 3 , -3 -2 5 , 4 -2 3 , -1 -2 5 , 3 -1 -4 , 3 -5 4 =unsat
3 -1 4 , 4 1 -3 , 4 -1 -5 , -1 -5 -3 , 5 3 1 , -2 4 -3 , -5 2 -1 , -1 5 -3 , 3 4 1 , -2 -1 4 , 1 -4 -2 , -4 -5 -3 , -4 2 1 , -2 1 -5 , -2 -3 -5 , -3 -1 4 , 2 -5 -3 , 2 -1 3 , 1 5 2 , 5 -4 3 , -2 5 -3 , -2 -5 -4 , -1 -4 -2 =unsat
5 2 1 , -2 -5 -3 , 4 -3 1 , -5 2 1 , 1 -3 -5 , 4 3 2 , -2 3 -4 , -1 -2 5 , -2 5 -4 , 4 5 1 , -5 3 2 , -2 1 -4 , 5 -3 -2 , 2 1 3 , -1 -3 -2 , -1 2 3 , 3 4 -5 , 5 4 -3 , 5 -3 -4 , -2 4 5 , -5 2 -3 , 5 2 -1 , 5 2 4 , 5 -1 -3 , -1 -2 3 =unsat
4 5 3 , 1 2 -4 , 1 3 2 , -5 -3 2 , -1 -3 2 , -4 5 1 , 2 

In [11]:
# Dataset of 3-sat problems where the model predicts the one sat solution
class SingleSatDataset(torch.utils.data.Dataset):
    def __init__(self, samples, seed=0, n_vars=6):
        if n_vars < 4:
            raise Exception("Must have more than 4 vars")
        self.n_vars = n_vars
        # Spaces after numbers to break ambiguity when n_vars > 9
        self.vocab = [START, SEP, STOP, PAD] + [f'{i} ' for i in range(1, n_vars + 1)] + [f'{-i} ' for i in range(1, n_vars + 1)] + [', ', 'True', 'False']

        dup_check = set()
        self.xy = []

        # Max vars = num vars * max clauses per var * tokens per clause + sep, soln, stop
        max_len = n_vars * 5 * 4 + n_vars + 2
        while (len(dup_check) < samples):
            # The hardest satisfiability problems occur when there are ~ 4.26 clauses per variable
            # to create consistently hard problems with some variation, we generate problems where there are
            # between 4 and 5 clauses per variable
            n_clauses = random.randint(n_vars * 4, n_vars * 5)
            x = generate3Sat(n_clauses, n_vars)
            if x not in dup_check:
                with Solver(name='g3') as solver:
                    for clause in x:
                        solver.add_clause(clause)
                    if not solver.solve():
                        continue
                    model = solver.get_model()
                    blocking_clause = [-lit for lit in model]
                    solver.add_clause(blocking_clause)
                    if solver.solve():
                        continue
                    self.xy.append(encode(START + ', '.join(''.join(f'{var} ' for var in clause) for clause in x) + SEP + ''.join([str(i>0) for i in model]) + STOP, self.vocab, max_len))
                if len(dup_check) % 100 == 0:
                    sys.stdout.write(f"\r {len(dup_check) / samples * 100:.2f}% complete")
                    sys.stdout.flush()
                dup_check.add(x)
        print("100.00% complete")

        self.xy = torch.stack(self.xy)
        self.x_lens = (self.xy == self.vocab.index(SEP)).nonzero()[:, 1] + 1
        self.xy_lens = (self.xy == self.vocab.index(STOP)).nonzero()[:, 1] + 1

        # Trim the max length
        self.xy = self.xy[:, :torch.max(self.xy_lens)]

    def save(self, filename):
        torch.save({attr:getattr(self, attr) for attr in ['n_vars', 'vocab', 'xy', 'x_lens', 'xy_lens']}, filename)

    @classmethod
    def load(cls, filename):
        obj = object.__new__(cls)
        obj_data = torch.load(filename)
        for attr in obj_data:
            setattr(obj, attr, obj_data[attr])
        return obj
                    
    def __len__(self):
        return len(self.xy)
    
    def __getitem__(self, idx):
        return self.xy[idx], self.x_lens[idx], self.xy_lens[idx]

In [12]:
if os.path.exists("single_sat_dataset.pt"):
    print("Loading saved single sat dataset")
    singleSatData = SingleSatDataset.load("single_sat_dataset.pt")
else:
    print("Generating single sat dataset")
    singleSatData = SingleSatDataset(train_samples + valid_samples + test_samples)
    singleSatData.save("single_sat_dataset.pt")
singleSatSets = torch.utils.data.random_split(singleSatData, [train_samples, valid_samples, test_samples], generator=torch.Generator().manual_seed(0))

print("Single Satisfiable Problems:")
for i in range(5):
    print(decode(singleSatData[i][0], singleSatData.vocab))

Loading saved single sat dataset
Single Satisfiable Problems:
4 3 6 , -4 -3 -2 , 3 -6 -4 , -1 2 -4 , -1 2 3 , 2 6 3 , 2 4 1 , -2 4 -1 , -6 3 5 , 3 1 -5 , 1 -3 -6 , -2 6 -1 , 4 -1 -6 , -2 1 6 , 6 -3 -4 , 5 -3 2 , -2 -6 -4 , 1 6 4 , -3 -2 -1 , -2 -3 -5 , 1 -6 3 , -1 -6 3 , -5 -3 1 , -3 -6 -4 =TrueFalseTrueFalseTrueFalse
4 1 5 , 6 2 3 , 4 -6 -3 , 3 -5 -1 , -4 -2 -3 , 5 -4 -6 , -6 -1 4 , 4 6 5 , 2 6 5 , 4 -3 -5 , 5 3 1 , 6 -3 -4 , -4 -6 2 , -4 -3 5 , -5 6 -4 , -2 5 -1 , 2 -6 4 , -6 5 -2 , -1 2 3 , -6 4 3 , -5 1 2 , 1 -5 4 , -2 6 3 , 6 1 -2 =FalseTrueFalseTrueTrueTrue
1 -6 -4 , 2 -5 4 , -5 -2 -1 , -1 6 -3 , -3 1 -4 , -4 2 -1 , 1 5 4 , -3 6 1 , 4 -6 -5 , -6 -5 1 , -4 2 1 , -3 5 2 , 2 -4 -6 , 3 2 6 , 4 3 -2 , -1 2 -6 , 4 1 2 , -2 -6 -5 , 5 6 -2 , -3 -5 -2 , -2 -1 -3 , -1 2 5 , 3 -2 1 , -6 -1 -3 =TrueTrueFalseTrueFalseTrue
-2 -6 -3 , 6 5 -2 , -2 5 -3 , 4 1 -3 , 3 2 4 , -6 5 -1 , -5 -2 -6 , 5 -1 4 , 4 5 -3 , -1 -2 4 , -5 2 -6 , -6 -5 -3 , 1 2 -6 , 5 -1 2 , -5 4 -6 , 6 -4 -5 , -4 -6 -2 , -2 5 -6

# Models

In [13]:
# Basic feed forward, uses GELU activation and GPT-2 initialization
class FeedFwd(nn.Module):
    def __init__(self, dims, dropout=0.1, activ=nn.GELU()):
        super(FeedFwd, self).__init__()
        layers = [nn.Linear(dims[i], dims[i+1]) for i in range(len(dims) - 1)]
        for i in layers:
            torch.nn.init.normal_(i.weight, mean=0.0, std=0.02)
            nn.init.zeros_(i.bias)
        self.lays = nn.ModuleList(layers)
        self.activ = activ
        self.drop = nn.Dropout(dropout)
        
    def forward(self, x):
        for layer in self.lays[:-1]:
            x = self.drop(self.activ(layer(x)))
        return self.lays[-1](x)


In [14]:
# Long short-term memory network with projection
# Utilizes orthogonal initialization for recurrent weights and GPT-2 initialization for others
class LSTM(nn.Module):
    def __init__(self, d_in, d_long, d_short):
        super(LSTM, self).__init__()
        self.d_in = d_in
        self.d_long = d_long
        self.d_short = d_short
        
        self.long0 = nn.Parameter(torch.zeros(d_long))
        self.short0 = nn.Parameter(torch.zeros(d_short))
        self.i_short = nn.Linear(d_short, d_long, bias=False)
        self.i_in = nn.Linear(d_in, d_long)
        self.f_short = nn.Linear(d_short, d_long, bias=False)
        self.f_in = nn.Linear(d_in, d_long)
        self.g_short = nn.Linear(d_short, d_long, bias=False)
        self.g_in = nn.Linear(d_in, d_long)
        self.o_short = nn.Linear(d_short, d_long, bias=False)
        self.o_in = nn.Linear(d_in, d_long)
        self.proj = nn.Linear(d_long, d_short)
        self.out = nn.Linear(d_long, d_short)

        nn.init.ones_(self.f_in.bias)
        for bias in [self.i_in.bias, self.g_in.bias, self.o_in.bias, self.proj.bias, self.out.bias]:
            nn.init.zeros_(bias)

        for weight in [self.i_short.weight, self.f_short.weight, self.g_short.weight, self.o_short.weight, self.proj.weight]:
            nn.init.orthogonal_(weight)

        for weight in [self.i_in.weight, self.f_in.weight, self.g_in.weight, self.o_in.weight, self.out.weight]:
            nn.init.normal_(weight, mean=0.0, std=0.02)
            
    def forward(self, x, x_len, hiddens = None):
        batch_size = x.size(0)
        sq_len = x.size(1)
        if hiddens == None:
            shorts = self.short0.unsqueeze(0).repeat(batch_size, 1)
            longs = self.long0.unsqueeze(0).repeat(batch_size, 1)
        else:
            shorts = hiddens[0]
            longs = hiddens[1]
            
        mask = (torch.arange(sq_len, device=x.device).unsqueeze(0) < x_len.unsqueeze(1)).unsqueeze(-1)
        out = torch.zeros(batch_size, sq_len, self.d_short, device=x.device)
        for sq_idx in range(sq_len):
            i = torch.sigmoid(self.i_short(shorts) + self.i_in(x[:, sq_idx]))
            f = torch.sigmoid(self.f_short(shorts) + self.f_in(x[:, sq_idx]))
            g = torch.tanh(self.g_short(shorts) + self.g_in(x[:, sq_idx]))
            o = torch.sigmoid(self.o_short(shorts) + self.o_in(x[:, sq_idx]))
            longs = (f * longs + i * g) * mask[:, sq_idx] + longs * ~mask[:, sq_idx]
            shorts = self.proj(o * longs) * mask[:, sq_idx] + shorts * ~mask[:, sq_idx]
            out[:, sq_idx] = self.out(o * longs)
        return out, (shorts, longs)

In [15]:
# Multi-head self attention with RoPE defined in a manner that allows both parallel and auto-regressive computation
float_min = torch.finfo(torch.float32).min
class MHSA(nn.Module):
    def __init__(self, d_model, d_sa, n_head, dropout = 0.1):
        super(MHSA, self).__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.d_sa = d_sa
        self.d_key = d_sa // n_head
        self.dropout = nn.Dropout(dropout)
        self.rope = torchtune.modules.RotaryPositionalEmbeddings(dim=self.d_key)
        
        self.q = nn.Linear(d_model, d_sa)
        self.k = nn.Linear(d_model, d_sa)
        self.v = nn.Linear(d_model, d_sa)
        self.sa_lin = nn.Linear(d_sa, d_model)
        for lin in [self.k, self.q, self.v, self.sa_lin]:
            nn.init.normal_(lin.weight, mean=0.0, std=0.02)
            nn.init.zeros_(lin.bias)

    def forward(self, x, ks, vs, mask=None, position=None):
        # If we are receiving one single element in a sequence rather than a whole sequence,
        # we unsqueeze it at the beginning and resqueeze it at the end to make the tensor shapes work out
        is_sequence = (len(x.size()) == 3)
        if not is_sequence:
            x = x.unsqueeze(1)
        batch_size = x.size(0)
        sq_len = x.size(1)
        qs = self.rope(self.q(x).view(batch_size, sq_len, self.n_head, self.d_key), input_pos=position).transpose(1, 2)
        dots = torch.matmul(qs, ks.transpose(-1, -2)) / math.sqrt(self.d_key)
        if mask != None:
            dots[~(mask.unsqueeze(1).repeat(1, self.n_head, 1, 1))] = float_min
        attn_weight = self.dropout(nn.Softmax(dim = -1)(dots))
        attns = torch.matmul(attn_weight, vs).transpose(1, 2).contiguous().view(batch_size, sq_len, self.d_sa)
        if not is_sequence:
            attns = attns.squeeze(1)
        return self.sa_lin(attns)

    # Cacheing keys and values is extremely important for autoregressive generation
    # This saves significant computation for recurrent architectures
    def cache_kvs(self, x, ks=None, vs=None, position=None):
        if len(x.size()) == 2:
            x = x.unsqueeze(1)
        batch_size = x.size(0)
        
        if ks == None:
            sq_len = x.size(1)
            ks = self.rope(self.k(x).view(batch_size, sq_len, self.n_head, self.d_key)).transpose(1, 2)
            vs = self.v(x).view(batch_size, sq_len, self.n_head, self.d_key).transpose(1, 2)
            return ks, vs
        else:
            if x.size(1) != 1:
                raise Exception("Expected only one sequence element input at a time with auto-regressive kv caching")
            sq_len = ks.size(2)
            new_ks = self.rope(self.k(x).view(batch_size, 1, self.n_head, self.d_key), input_pos=position).transpose(1, 2)
            new_vs = self.v(x).view(batch_size, 1, self.n_head, self.d_key).transpose(1, 2)
            return torch.cat([ks, new_ks], 2), torch.cat([vs, new_vs], 2)
        

In [16]:
# Define base model shell for embeddings, output linear, and calculating loss
class Base(nn.Module):
    def __init__(self, model, vocab, reuse_embeddings = False, include_x_loss = False):
        super(Base, self).__init__()
        self.model = model
        self.vocab = vocab
        self.d_model = model.d_model
        self.vocab_len = len(vocab)
        self.embedding = nn.Embedding(self.vocab_len, self.d_model)
        self.actor = nn.Linear(self.d_model, self.vocab_len, bias=False)
        if reuse_embeddings:
            self.actor.weight = self.embedding.weight
        self.include_x_loss = include_x_loss
        self.criteria = nn.CrossEntropyLoss(ignore_index=vocab.index(PAD))

        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.actor.weight, mean=0.0, std=0.02)
    
    def forward(self, x, x_len, hiddens=None):
        if len(x.size()) != 2:
            raise Exception("Input must be of shape (batch_size, sq_len), received input shape", x.shape()) 
        output, hiddens = self.model(self.embedding(x), x_len, hiddens)
        return self.actor(output) / (self.d_model ** 0.5), hiddens
    
    def calcLoss(self, xy, x_len, xy_len):
        batch_size = xy.size(0)
        sq_len = xy.size(1)
        output, _ = self(xy, xy_len)

        if self.include_x_loss:
            selected = (torch.arange(sq_len, device=xy.device).unsqueeze(0) < xy_len.unsqueeze(1))[:, 1:]
        else:
            selected = ((torch.arange(sq_len, device=xy.device).unsqueeze(0) < xy_len.unsqueeze(1)) & (torch.arange(sq_len, device=xy.device).unsqueeze(0) >= x_len.unsqueeze(1)))[:, 1:]
            
        guesses = output[:, :-1][selected]
        actual = xy[:, 1:][selected]

        return self.criteria(guesses, actual)


In [17]:
# Decoder only transformer architecture, as commonly used in generative pretrained transformers 
class DecTrans(nn.Module):
    def __init__(self, d_model, d_sa, d_ffwd, n_head, n_lay, activ=nn.GELU(), dropout = 0.1):
        super(DecTrans, self).__init__()
        self.d_model = d_model
        self.d_sa = d_sa
        self.d_ffwd = d_ffwd
        self.n_head = n_head
        self.n_lay = n_lay
        self.dropout = nn.Dropout(dropout)
        
        self.sas = nn.ParameterList([MHSA(d_model, d_sa, n_head, dropout=dropout) for _ in range(n_lay)])
        self.ffwds = nn.ParameterList([FeedFwd([d_model, d_ffwd, d_model], dropout=dropout, activ=activ) for _ in range(n_lay)])
        self.sa_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])
        self.ffwd_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])

        for norm in self.sa_norms:
            nn.init.ones_(norm.weight)
            nn.init.zeros_(norm.bias)

        for norm in self.ffwd_norms:
            nn.init.ones_(norm.weight)
            nn.init.zeros_(norm.bias)
        
    def forward(self, x, x_len, hiddens=None):
        batch_size = x.size(0)
        sq_len = x.size(1)
        if hiddens == None:
            ks = [None for _ in range(self.n_lay)]
            vs = [None for _ in range(self.n_lay)]
            src_mask = (torch.arange(sq_len, device=x.device).unsqueeze(0) < x_len.unsqueeze(1))
            mask = src_mask.unsqueeze(1).repeat(1, sq_len, 1) & torch.tril(torch.ones(batch_size, sq_len, sq_len, dtype=torch.bool, device=x.device))
            position = None
        else:
            #TODO, allow sq_len > 1 during generation
            if sq_len != 1:
                raise Exception("Must only enter one sequence element at a time during autoregressive generation")
            ks, vs, src_mask = hiddens
            src_mask = torch.cat([src_mask, (torch.arange(sq_len, device=x.device).unsqueeze(0) < x_len.unsqueeze(1))], dim=1)
            # We don't need a causality mask when we are autoregressively generating, and when there are hiddens we must be autoregressively generating
            mask = src_mask.unsqueeze(1)
            position = mask.sum(dim=2) - 1
            
        for layer in range(self.n_lay):
            ks[layer], vs[layer] = self.sas[layer].cache_kvs(x, ks[layer], vs[layer], position)
            x = self.sa_norms[layer](x + self.dropout(self.sas[layer](x, ks[layer], vs[layer], mask=mask, position=position)))
            x = self.ffwd_norms[layer](x + self.dropout(self.ffwds[layer](x)))
        return x, (ks, vs, src_mask)

In [18]:
# Multiple layer LSTM, uses residual connections and layer normalization, like the decoder transformer, but has no self attention
class MultiLayLSTM(nn.Module):
    def __init__(self, d_model, d_lstm, n_lay, dropout = 0.1):
        super(MultiLayLSTM, self).__init__()
        self.d_model = d_model
        self.d_lstm = d_lstm
        self.n_lay = n_lay
        self.dropout = nn.Dropout(dropout)
        
        self.lstms = nn.ParameterList([LSTM(d_model, d_lstm, d_model) for _ in range(n_lay)])
        self.norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])
    
    def forward(self, x, x_len, hiddens=None):
        batch_size = x.size(0)
        if hiddens == None:
            hiddens = [None for _ in range(self.n_lay)]

        for layer in range(self.n_lay):
            lstm_out, hiddens[layer] = self.lstms[layer](x, x_len, hiddens[layer])
            x = self.norms[layer](x + self.dropout(lstm_out))
            
        return x, hiddens

In [19]:
# Long short term memory transformer, same as decoder transformer except substituting the feed forward network with an LSTM
class LSTMTrans(nn.Module):
    def __init__(self, d_model, d_sa, d_lstm, n_head, n_lay, activ=nn.GELU(), dropout = 0.1):
        super(LSTMTrans, self).__init__()
        self.d_model = d_model
        self.d_sa = d_sa
        self.d_lstm = d_lstm
        self.n_head = n_head
        self.n_lay = n_lay
        self.dropout = nn.Dropout(dropout)
        
        self.sas = nn.ParameterList([MHSA(d_model, d_sa, n_head, dropout=dropout) for _ in range(n_lay)])
        self.lstms = nn.ParameterList([LSTM(d_model, d_lstm, d_model) for _ in range(n_lay)])
        self.sa_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])
        self.lstm_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])

        for norm in self.sa_norms:
            nn.init.ones_(norm.weight)
            nn.init.zeros_(norm.bias)

        for norm in self.lstm_norms:
            nn.init.ones_(norm.weight)
            nn.init.zeros_(norm.bias)
        
    def forward(self, x, x_len, hiddens=None):
        batch_size = x.size(0)
        sq_len = x.size(1)
        if hiddens == None:
            ks = [None for _ in range(self.n_lay)]
            vs = [None for _ in range(self.n_lay)]
            src_mask = (torch.arange(sq_len, device=x.device).unsqueeze(0) < x_len.unsqueeze(1))
            mask = src_mask.unsqueeze(1).repeat(1, sq_len, 1) & torch.tril(torch.ones(batch_size, sq_len, sq_len, dtype=torch.bool, device=x.device))
            position = None
            lstm_hiddens = [None for _ in range(self.n_lay)]
        else:
            #TODO, allow sq_len > 1 during generation
            if sq_len != 1:
                raise Exception("Must only enter one sequence element at a time during autoregressive generation")
            ks, vs, src_mask, lstm_hiddens = hiddens
            src_mask = torch.cat([src_mask, (torch.arange(sq_len, device=x.device).unsqueeze(0) < x_len.unsqueeze(1))], dim=1)
            # We don't need a causality mask when we are autoregressively generating, and when there are hiddens we must be autoregressively generating
            mask = src_mask.unsqueeze(1)
            position = mask.sum(dim=2) - 1
            
        for layer in range(self.n_lay):
            ks[layer], vs[layer] = self.sas[layer].cache_kvs(x, ks[layer], vs[layer], position)
            x = self.sa_norms[layer](x + self.dropout(self.sas[layer](x, ks[layer], vs[layer], mask=mask, position=position)))
            lstm_out, lstm_hiddens[layer] = self.lstms[layer](x, x_len, lstm_hiddens[layer])
            x = self.lstm_norms[layer](x + self.dropout(lstm_out))
        return x, (ks, vs, src_mask, lstm_hiddens)

In [20]:
# Transformer with recurrent self attention, same as decoder transformer except derives keys and values from output rather than input
class RSATrans(nn.Module):
    def __init__(self, d_model, d_sa, d_ffwd, n_head, n_lay, activ=nn.GELU(), dropout = 0.1):
        super(RSATrans, self).__init__()
        self.d_model = d_model
        self.d_ffwd = d_ffwd
        self.d_sa = d_sa
        self.n_head = n_head
        self.n_lay = n_lay
        self.dropout = nn.Dropout(dropout)
        
        self.ffwds = nn.ParameterList([FeedFwd([d_model, d_ffwd, d_model], dropout=dropout, activ=activ) for _ in range(n_lay)])
        self.sas = nn.ParameterList([MHSA(d_model, d_sa, n_head, dropout=dropout) for _ in range (n_lay)])
        
        self.sa_h0s = nn.ParameterList([nn.Parameter(torch.zeros(d_model)) for _ in range(n_lay)])
        self.sa_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])
        self.ffwd_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])

        for h0 in self.sa_h0s:
            nn.init.normal_(h0, mean=0.0, std=0.02)
        
        for norm in self.sa_norms:
            nn.init.ones_(norm.weight)
            nn.init.zeros_(norm.bias)

        for norm in self.ffwd_norms:
            nn.init.ones_(norm.weight)
            nn.init.zeros_(norm.bias)
        
    def forward(self, x, x_len, hiddens=None):
        batch_size = x.size(0)
        sq_len = x.size(1)
        out = torch.zeros((batch_size, sq_len, self.d_model), device = x.device)
        
        if hiddens == None:
            ks = [None for _ in range(self.n_lay)]
            vs = [None for _ in range(self.n_lay)]
            for layer in range(self.n_lay):
                ks[layer], vs[layer] = self.sas[layer].cache_kvs(self.sa_h0s[layer].unsqueeze(0).repeat(batch_size, 1))
            # When no hiddens, our source mask must be one larger in the sq_len dimension in order to accomodate the h0
            src_mask = (torch.arange(sq_len + 1, device=x.device).unsqueeze(0) < (x_len + 1).unsqueeze(1))
        else:
            ks, vs, src_mask, lstm_hiddens = hiddens
            src_mask = torch.cat([src_mask, (torch.arange(sq_len, device=x.device).unsqueeze(0) < x_len.unsqueeze(1))], dim=1)
        
        for sq_idx in range(sq_len):
            position = src_mask[:, :(sq_idx - sq_len)].sum(dim=1)
            curr = x[:, sq_idx]
            for layer in range(self.n_lay):
                curr = self.sa_norms[layer](curr + self.dropout(self.sas[layer](curr, ks[layer], vs[layer], mask=src_mask[:, :(sq_idx - sq_len)].unsqueeze(1), position=position)))
                curr = self.ffwd_norms[layer](curr + self.dropout(self.ffwds[layer](curr)))
                ks[layer], vs[layer] = self.sas[layer].cache_kvs(curr, ks[layer], vs[layer], position)
            out[:, sq_idx] = curr
        
        return out, (ks, vs, src_mask)

In [21]:
# Recurrent Transformer architecture, utilizes LSTMs and recurrent self attention
class ReTrans(nn.Module):
    def __init__(self, d_model, d_sa, d_lstm, n_head, n_lay, dropout = 0.1):
        super(ReTrans, self).__init__()
        self.d_model = d_model
        self.d_lstm = d_lstm
        self.d_sa = d_sa
        self.n_head = n_head
        self.n_lay = n_lay
        self.dropout = nn.Dropout(dropout)
        
        self.lstms = nn.ParameterList([LSTM(d_model, d_lstm, d_model) for _ in range(n_lay)])
        self.sas = nn.ParameterList([MHSA(d_model, d_sa, n_head, dropout=dropout) for _ in range (n_lay)])
        
        self.sa_h0s = nn.ParameterList([nn.Parameter(torch.zeros(d_model)) for _ in range(n_lay)])
        self.sa_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])
        self.lstm_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])

        for h0 in self.sa_h0s:
            nn.init.normal_(h0, mean=0.0, std=0.02)
        
        for norm in self.sa_norms:
            nn.init.ones_(norm.weight)
            nn.init.zeros_(norm.bias)

        for norm in self.lstm_norms:
            nn.init.ones_(norm.weight)
            nn.init.zeros_(norm.bias)
        
    def forward(self, x, x_len, hiddens=None):
        batch_size = x.size(0)
        sq_len = x.size(1)
        out = torch.zeros((batch_size, sq_len, self.d_model), device = x.device)
        
        if hiddens == None:
            ks = [None for _ in range(self.n_lay)]
            vs = [None for _ in range(self.n_lay)]
            lstm_hiddens = [None for _ in range(self.n_lay)]
            for layer in range(self.n_lay):
                ks[layer], vs[layer] = self.sas[layer].cache_kvs(self.sa_h0s[layer].unsqueeze(0).repeat(batch_size, 1))
            # When no hiddens, our source mask must be one larger in the sq_len dimension in order to accomodate the h0
            src_mask = (torch.arange(sq_len + 1, device=x.device).unsqueeze(0) < (x_len + 1).unsqueeze(1))
        else:
            ks, vs, src_mask, lstm_hiddens = hiddens
            src_mask = torch.cat([src_mask, (torch.arange(sq_len, device=x.device).unsqueeze(0) < x_len.unsqueeze(1))], dim=1)
        
        for sq_idx in range(sq_len):
            position = src_mask[:, :(sq_idx - sq_len)].sum(dim=1)
            curr = x[:, sq_idx]
            for layer in range(self.n_lay):
                curr = self.sa_norms[layer](curr + self.dropout(self.sas[layer](curr, ks[layer], vs[layer], mask=src_mask[:, :(sq_idx - sq_len)].unsqueeze(1), position=position)))
                lstm_out, lstm_hiddens[layer] = self.lstms[layer](curr.unsqueeze(1), src_mask[:, (sq_idx - sq_len)], lstm_hiddens[layer])
                curr = self.lstm_norms[layer](curr + self.dropout(lstm_out.squeeze(1)))
                ks[layer], vs[layer] = self.sas[layer].cache_kvs(curr, ks[layer], vs[layer], position)
            out[:, sq_idx] = curr
        
        return out, (ks, vs, src_mask, lstm_hiddens)

# Training + Testing

In [22]:
# Define different schedulers, we only end up using linear,
# But decay and cosine are included in case 
def decay_scheduler(steps):
    warmup_steps = steps // 10
    decay = 0.05 ** (1 / (steps - warmup_steps))
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return current_step / warmup_steps
        return decay ** (current_step - warmup_steps)
    return lr_lambda

def cosine_scheduler(steps):
    warmup_steps = steps // 10
    min_lr = 0.05
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return current_step / warmup_steps
        return min_lr + (1 - min_lr) / 2 * (1 + math.cos((current_step - warmup_steps) * math.pi / (steps - warmup_steps)))
    return lr_lambda

def linear_scheduler(steps):
    warmup_steps = steps // 10
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return current_step / warmup_steps
        return 1.0 - ((current_step - warmup_steps) / (steps - warmup_steps))
    return lr_lambda


In [23]:
# Train function, learning rate scheduling is done on a batch basis instead of epoch, revert to best model
def train(model, trainset, validset, lr=0.001, batch_size=512, epochs=10, optimizer=torch.optim.AdamW, scheduler=cosine_scheduler):
    train_iter = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True)
    valid_iter = DataLoader(validset, batch_size=batch_size, pin_memory=True)
    optim = optimizer(model.parameters(), lr=lr)
    lambdalr = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=scheduler(epochs * len(train_iter)))
    train_losses = []
    valid_losses = []
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch_idx, batch in enumerate(train_iter):
            xy, x_len, xy_len = batch
            print('=' * (math.floor((batch_idx / len(train_iter)) * 40) - math.floor(((batch_idx - 1) / len(train_iter)) * 40)), end = "")
            optim.zero_grad()
            loss = model.calcLoss(xy.to(device), x_len.to(device), xy_len.to(device))
            train_loss += loss.item()
            loss.backward()
            optim.step()
            lambdalr.step()
        print(">\nEpoch {} train loss:\t{}".format(epoch, train_loss / len(train_iter)))
        train_losses.append(train_loss / len(train_iter))

        model.eval()
        with torch.no_grad():
            valid_loss = 0
            for batch in valid_iter:
                xy, x_len, xy_len = batch
                loss = model.calcLoss(xy.to(device), x_len.to(device), xy_len.to(device))
                valid_loss += loss.item()
        print("Epoch {} valid loss:\t{}".format(epoch, valid_loss / len(valid_iter)))
        valid_losses.append(valid_loss / len(valid_iter))
        
        if valid_losses[-1] == min(valid_losses):
            torch.save(model.state_dict(), 'best_model.pth')
    model.load_state_dict(torch.load('best_model.pth'))
    os.remove('best_model.pth')
    return train_losses, valid_losses

In [24]:
# Tests what proportion of the test set that the model predicts in one shot
@torch.no_grad()
def test(model, testset, batch_size=512):
    test_iter = DataLoader(testset, batch_size=batch_size, pin_memory=True)
    num_correct = 0
    model.eval()
    for batch in test_iter:
        xy, x_len, xy_len = batch
        xy = xy.to(device)
        x_len = x_len.to(device)
        xy_len = xy_len.to(device)
        out, _ = model(xy, xy_len)
        guesses = torch.argmax(out[:, :-1], dim=-1)
        answer_selector = ((torch.arange(xy.size(1), device=device).unsqueeze(0) < xy_len.unsqueeze(1)) & (torch.arange(xy.size(1), device=device).unsqueeze(0) >= x_len.unsqueeze(1))).to(device)[:, 1:]
        correct_token = (guesses == xy[:, 1:]) & answer_selector
        correct_answer = (correct_token.sum(dim=1) == (xy_len - x_len))
        num_correct += torch.sum(correct_answer).item()
    return num_correct / len(testset)

In [25]:
# Generate the output for an input
@torch.no_grad()
def solve(model, x, x_len, vocab, max_y_len):
    batch_size = x.size(0)
    stop_idx = vocab.index(STOP)
    y = torch.zeros(batch_size, max_y_len, dtype=torch.long, device=x.device)
    y_len = torch.full((batch_size,), max_y_len, device=x.device)
    model.eval()
    out, hiddens = model(x, x_len)
    y[:, 0] = torch.argmax(out[torch.arange(batch_size, device=x.device), x_len - 1], dim=1)
    for i in range(max_y_len - 1):
        y_len[(y[:, i] == stop_idx) & (y_len == max_y_len)] = i
        out, hiddens = model(y[:, i].unsqueeze(1), (y_len == max_y_len), hiddens)
        y[:, i + 1] = torch.argmax(out[:, 0], dim=1)
    return y, y_len + 1
        

# Experiments

In [26]:
def experiment(datasets, model_type, model_args, train_args={}, test_args={}):
    trainset, validset, testset = datasets
    experimental_results = {}
    if str(device) == "cpu":
        experimental_results["device"] = "cpu"
    else:
        experimental_results["device"] = torch.cuda.get_device_name(torch.cuda.current_device())
    experimental_results["model class"] = model_type.__name__
    experimental_results["dataset class"] = trainset.dataset.__class__.__name__
    experimental_results["dataset sizes"] = (len(trainset), len(validset), len(testset))
    model = Base(model_type(**model_args), trainset.dataset.vocab).to(device)
    experimental_results["model args"] = model_args
    experimental_results["model size"] = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
    print("Parameter count", experimental_results["model size"])
    flop_xy, flop_x_len, flop_xy_len = next(iter(DataLoader(testset, batch_size=64)))
    # Note this does not catch all operations, but it catches matmul, which is by far the dominant operator
    flops = FlopCountAnalysis(model, (flop_xy.to(device), flop_xy_len.to(device)))
    experimental_results["flops per input"] = flops.total() // 64
    print("Flops per input", experimental_results["flops per input"])
    experimental_results["train args"] = train_args
    start_time = time.perf_counter()
    train_loss, valid_loss = train(model, trainset, validset, **train_args)
    end_time = time.perf_counter()
    experimental_results["train loss"] = train_loss
    experimental_results["valid loss"] = valid_loss
    experimental_results["train time"] = end_time - start_time
    print("Train time", experimental_results["train time"])
    experimental_results["test args"] = test_args
    experimental_results["test accuracy"] = test(model, testset, **test_args)
    print("Test accuracy", experimental_results["test accuracy"])
    return experimental_results
    

In [27]:
def write_experiment(experiment_dict, name=None):
    if not os.path.exists("experiments"):
        os.makedirs("experiments")
    if name == None:
        name = time.ctime()
    try:
        with open(f"experiments/{time.perf_counter() if name == None else name}.json", "w") as f:
            json.dump(experiment_dict, f, indent=4, default=lambda x: x.__name__)
        print(f"Experiment {name} saved")
    except:
        print(f"Failure saving experiment {name}")

In [28]:
def read_experiments():
    experiments = []
    if not os.path.exists("experiments"):
        print("No experiments directory")
    else:
        for file in os.listdir("experiments"):
            try:
                with open(f"experiments/{file}", "r") as f:
                    experiments.append(json.load(f))
            except:
                print("Error loading json experiment:", file)
    return experiments

In [31]:
# Example
write_experiment(experiment(arithSets, ReTrans, \
                            model_args={"d_model":64, "d_sa":64, "d_lstm":256, "n_head":8, "n_lay":4}, \
                            train_args={"lr":0.003, "batch_size":512, "epochs":10}, \
                            test_args={"batch_size":512}))

Parameter count 729088


criteria


Flops per input 27954560
Epoch 0 train loss:	1.7162065160274507
Epoch 0 valid loss:	0.949542740881443
Epoch 1 train loss:	0.8025395053625107
Epoch 1 valid loss:	0.6668105074763298
Epoch 2 train loss:	0.6167759650945663
Epoch 2 valid loss:	0.5410297358036041
Epoch 3 train loss:	0.5311480776071549
Epoch 3 valid loss:	0.484632733464241
Epoch 4 train loss:	0.4708712833225727
Epoch 4 valid loss:	0.4404362453520298
Epoch 5 train loss:	0.42551880738139153
Epoch 5 valid loss:	0.4018432502448559
Epoch 6 train loss:	0.38846165645122527
Epoch 6 valid loss:	0.367625647932291
Epoch 7 train loss:	0.3517845411002636
Epoch 7 valid loss:	0.3433664847910404
Epoch 8 train loss:	0.32676662078499796
Epoch 8 valid loss:	0.3300132678449154
Epoch 9 train loss:	0.31128363019227984
Epoch 9 valid loss:	0.3250736986845732
Train time 8664.736939481998
Test accuracy 0.71091796875
Experiment Sat Jul  5 07:18:33 2025 saved
