# Attention Ain't All, Recurent Transformer (ReTran), ACL 2025

# Setup

In [None]:
import random, math, torch, time, warnings, torch_optimizer, torchtune, datasets
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from matplotlib.ticker import ScalarFormatter, LogLocator

plt.rcParams.update({'font.size':20})
device = torch.device("cuda")
torch.set_printoptions(sci_mode=False)

def timer(func, *args, **kwargs):
    start_time = time.perf_counter()
    result = func(*args, **kwargs)
    end_time = time.perf_counter()
    
    elapsed_time = end_time - start_time
    return elapsed_time, result

# Dataset

In [None]:
#Vocab
PAD = '<pad>'
START = '<start>'
STOP = '<stop>'
SEP = '='
vocab = [START, SEP, STOP, PAD] + [str(i) for i in range(10)] + ['-', '+', '*', '(', ')']

#Translates from a string to a pytorch tensor using a vocab
def encodeVocab(string, pad_length):
    return torch.tensor([vocab.index(START)]
                        + [vocab.index(i) for i in string]
                        + [vocab.index(STOP)]
                        + [vocab.index(PAD)] * (pad_length - len(string)),
                        dtype = torch.long)

#Translates from a pytorch tensor to a string using a vocab
def decodeVocab(tensor):
    ans = ""
    for i in tensor:
        if vocab[i] == START or vocab[i] == PAD:
            continue
        if vocab[i] == STOP:
            break
        ans += vocab[i]
    return ans

In [None]:
#PROBLEM GENERATION
#Define rules for a context free grammar
rules = {}
def addRule(left_hand, right_hand, probability):
    if left_hand in rules:
        rules[left_hand].append((right_hand, probability))
    else:
        rules[left_hand] = [(right_hand, probability)]

#Create a probabilistic free grammar to generate math problems
addRule('EQ', ['VAL', 'OP', 'VAL'], 0.5)
addRule('EQ', ['(', 'VAL', 'OP', 'VAL', ')'], 0.5)
addRule('VAL', ['EQ'], 0.46)
addRule('VAL', ['NUM'], 0.54)
addRule('OP', ['+'], 0.35)
addRule('OP', ['-'], 0.35)
addRule('OP', ['*'], 0.3)
addRule('NUM', ['0'], 0.05)
addRule('NUM', ['NUMH', 'NUMT'], 0.95)
for i in range(1, 10):
    addRule('NUMH', [str(i)], 1.0 / 18)
    addRule('NUMH', ['-', str(i)], 1.0 / 18)

for i in range(10):
    addRule('NUMT', [str(i), 'NUMT'], 0.02)
addRule('NUMT', [], 0.8)

#Choose a random rule to expand
def selectRule(left_hand):
    selector = random.random()
    for i in rules[left_hand]:
        selector -= i[1]
        if(selector < 0):
            return i[0]
    raise Exception("Improper rule probabilities")

#Generate a problem using the context free grammar with certain bounds
def generateProblem(min_in_len, max_in_len, min_out_len, max_out_len):
    while(True):
        stack = ['EQ']
        index = 0
        while(index < len(stack) and len(stack) <= (max_in_len + 2)):
            if index > 0 and stack[index] == '-' :
                if stack[index - 1] == '-':
                    stack.pop(index)
                    stack[index - 1] = '+'
                elif stack[index - 1] == '+':
                    stack.pop(index)
                    stack[index - 1] = '-'
            if stack[index] in rules:
                stack = stack[:index] + selectRule(stack[index]) + stack[index + 1:]
            else:
                index += 1

        prob = ''.join(stack)
        if(len(prob) >= min_in_len and len(prob) <= max_in_len):
            soln = str(eval(prob))
            if(len(soln) >= min_out_len and len(soln) <= max_out_len):
                return prob, soln

In [None]:
#Decoder only arithmetic dataset
class ArithDataset(torch.utils.data.Dataset):
    def __init__(self, samples, seed=0, min_in_len=1, max_in_len=20, min_out_len=1, max_out_len=6):
        prob_check = set()
        self.combined = []
        self.problem_lens = []
        self.combined_lens = []
        
        random.seed(seed)
        while(len(prob_check) < samples):
            prob, soln = generateProblem(min_in_len=min_in_len,
                                          max_in_len=max_in_len,
                                          min_out_len=min_out_len,
                                          max_out_len=max_out_len)
            if prob not in prob_check:
                self.combined.append(encodeVocab(prob + SEP + soln, max_in_len + max_out_len + 1))
                self.problem_lens.append(len(prob) + 2)
                self.combined_lens.append(len(prob) + len(soln) + 3)
                prob_check.add(prob)
        
    def __len__(self):
        return len(self.combined)
    
    #Store both the problem length and total length
    #Problem length later used to trim loss
    def __getitem__(self, idx):
        return self.combined[idx], self.problem_lens[idx], self.combined_lens[idx]

In [None]:
#Create datasets
train_samples = 2048 * 1280
valid_samples = train_samples // 5
test_samples = train_samples // 10
dset_args = {"min_in_len":8, "max_in_len":30, "min_out_len":0, "max_out_len":5}

data = ArithDataset(train_samples + valid_samples + test_samples, **dset_args)
trainset, validset, testset = torch.utils.data.random_split(data, [train_samples, valid_samples, test_samples], generator=torch.Generator().manual_seed(42))

print("Problems:")
for i in range(20):
    print(decodeVocab(trainset[i][0]))


In [None]:
#Graph out problem lengths
prob_lens = {}
for problem in data:
    if problem[1] not in prob_lens:
        prob_lens[problem[1]] = 1
    else:
        prob_lens[problem[1]] += 1
for key in prob_lens.keys():
    prob_lens[key] /= len(data)
lens = sorted(list(prob_lens.keys()))

plt.figure(figsize=(10, 6))
plt.tick_params(axis='both', which='both', length=10, width=2)

plt.plot(range(min(lens) - 2, max(lens) - 1), [prob_lens[i] * 100 for i in lens], linewidth=3)
plt.title('Problem Length Distribution')
plt.ylabel('Percentage of Problems')
plt.ylim(0, 9)
plt.xticks(range(min(lens) - 2, max(lens) - 1, 2))
plt.xlabel('Problem Length')
plt.show()

# Models and Training

In [None]:
#Define basic model shells
class Basic(nn.Module):
    def __init__(self, model, d_model, vocab_len, reuse_embeddings = False, pad_idx=None, include_x_loss = False):
        super(Basic, self).__init__()
        self.model = model
        self.d_model = d_model
        self.vocab_len = vocab_len
        self.embedding = nn.Embedding(vocab_len, d_model)
        self.actor = nn.Linear(d_model, vocab_len, bias=False)
        if reuse_embeddings:
            self.actor.weight = self.embedding.weight
        self.include_x_loss = include_x_loss
        self.criteria = nn.CrossEntropyLoss(ignore_index=pad_idx)
    
    def forward(self, seq, seq_len):
        output, _ = self.model(self.embedding(seq), seq_len)
        return self.actor(output) / (self.d_model ** 0.5)
    
    def calcLoss(self, xy, x_len, xy_len):
        xy = xy.to(device)
        batch_size = xy.size(0)
        output = self(xy, xy_len)[:, :-1]
        if self.include_x_loss:
            guesses = torch.reshape(output, (batch_size * (xy.size(1) - 1), self.vocab_len))
            actual = torch.flatten(xy[:, 1:])
            return self.criteria(guesses, actual)
        max_y_len = torch.max(xy_len - x_len)
        ranges = torch.arange(max_y_len) + x_len.unsqueeze(1)
        actions = torch.reshape(output[torch.arange(batch_size).unsqueeze(1), ranges - 1], (batch_size * max_y_len, self.vocab_len))
        actual = torch.flatten(xy[torch.arange(batch_size).unsqueeze(1), ranges])
        return self.criteria(actions, actual)


In [None]:
#Define different schedulers, you can get different interesting results
#depending on scheduler used. We ended up using linear as it was the most reliable
def decay_scheduler(steps):
    warmup_steps = steps // 10
    decay = 0.05 ** (1 / (steps - warmup_steps))
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return current_step / warmup_steps
        return decay ** (current_step - warmup_steps)
    return lr_lambda

def cosine_scheduler(steps):
    warmup_steps = steps // 10
    min_lr = 0.05
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return current_step / warmup_steps
        return min_lr + (1 - min_lr) / 2 * (1 + math.cos((current_step - warmup_steps) * math.pi / (steps - warmup_steps)))
    return lr_lambda

def linear_scheduler(steps):
    warmup_steps = steps // 10
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return current_step / warmup_steps
        return 1.0 - ((current_step - warmup_steps) / (steps - warmup_steps))
    return lr_lambda


In [None]:
#Train function
#LR Scheduling is done on a batch basis instead of epoch
#LAMB optimizer
#Revert to best model
def train(model, lr, epochs=10, scheduler=linear_scheduler):
    train_iter = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True)
    valid_iter = DataLoader(validset, batch_size=batch_size, pin_memory=True)
    optim = torch_optimizer.Lamb(model.parameters(), lr=lr)
    lambdalr = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=scheduler(epochs * len(train_iter)))
    losses = []
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch_idx, batch in enumerate(train_iter):
            print('=' * (math.floor((batch_idx / len(train_iter)) * 40) - math.floor(((batch_idx - 1) / len(train_iter)) * 40)), end = "")
            optim.zero_grad()
            loss = model.calcLoss(*batch)
            train_loss += loss.item()
            loss.backward()
            optim.step()
            lambdalr.step()
        print(">\nEpoch {} train loss:\t{}".format(epoch, train_loss / len(train_iter)))

        model.eval()
        with torch.no_grad():
            valid_loss = 0
            for batch in valid_iter:
                loss = model.calcLoss(*batch)
                valid_loss += loss.item()
        print("Epoch {} valid loss:\t{}".format(epoch, valid_loss / len(valid_iter)))
        
        losses.append(valid_loss / len(valid_iter))
        if losses[-1] == min(losses):
            torch.save(model.state_dict(), 'best_model.pth')
    model.load_state_dict(torch.load('best_model.pth', weights_only=True))
    return losses

In [None]:
#Testing
def solve(model, x, x_len, max_out_len=6):
    model.eval()
    with torch.no_grad():
        xy = x.detach().clone().to(device)
        xy_len = x_len.detach().clone() - 1
        for i in range(max_out_len - 1):
            xy_len += 1
            output = model(xy, xy_len)
            xy[torch.arange(xy.size(0)), xy_len] = output.argmax(2)[torch.arange(x.size(0)), xy_len - 1]
        return xy

def test(model):
    test_iter = DataLoader(testset, batch_size=batch_size, pin_memory=True)
    num_correct = 0
    for batch in test_iter:
        guesses = [decodeVocab(xy) for xy in solve(model, batch[0], batch[1], torch.max(batch[2] - batch[1]))]
        actual = [decodeVocab(xy) for xy in batch[0]]
        num_correct += sum([guesses[j] == actual[j] for j in range(batch_size)])
    return num_correct / len(test_iter) / batch_size

In [None]:
#Basic feed forward for GPT, uses GELU activation
class FeedFwd(nn.Module):
    def __init__(self, dims, dropout=0.1, activ=nn.GELU()):
        super(FeedFwd, self).__init__()
        layers = [nn.Linear(dims[i], dims[i+1]) for i in range(len(dims) - 1)]
        for i in layers:
            nn.init.kaiming_uniform_(i.weight, mode='fan_in', nonlinearity='relu')
            nn.init.zeros_(i.bias)
        self.lays = nn.ModuleList(layers)
        self.activ = activ
        self.drop = nn.Dropout(dropout)
        
    def forward(self, x):
        for layer in self.lays[:-1]:
            x = self.drop(self.activ(layer(x)))
        return self.lays[-1](x)


In [None]:
#Define own LSTM in order to better manage projection and initialization
class LSTM(nn.Module):
    def __init__(self, d_in, d_long, d_short, init_kai = True):
        super(LSTM, self).__init__()
        self.d_in = d_in
        self.d_long = d_long
        self.d_short = d_short
        
        self.long0 = nn.Parameter(torch.zeros(d_long))
        self.short0 = nn.Parameter(torch.zeros(d_short))
        self.i_short = nn.Linear(d_short, d_long, bias=False)
        self.i_in = nn.Linear(d_in, d_long)
        self.f_short = nn.Linear(d_short, d_long, bias=False)
        self.f_in = nn.Linear(d_in, d_long)
        self.g_short = nn.Linear(d_short, d_long, bias=False)
        self.g_in = nn.Linear(d_in, d_long)
        self.o_short = nn.Linear(d_short, d_long, bias=False)
        self.o_in = nn.Linear(d_in, d_long)
        self.proj = nn.Linear(d_long, d_short)
        self.out = nn.Linear(d_long, d_short)

        nn.init.zeros_(self.i_in.bias)

        nn.init.ones_(self.f_in.bias)
        for bias in [self.i_in.bias, self.g_in.bias, self.o_in.bias, self.proj.bias, self.out.bias]:
            nn.init.zeros_(bias)

        for weight in [self.i_short.weight, self.f_short.weight, self.g_short.weight, self.o_short.weight, self.proj.weight]:
            nn.init.orthogonal_(weight)
            
        if init_kai:
            for weight in [self.i_in.weight, self.f_in.weight, self.g_in.weight, self.o_in.weight, self.out.weight]:
                nn.init.kaiming_uniform_(weight)
        else:
            for weight in [self.i_in.weight, self.f_in.weight, self.g_in.weight, self.o_in.weight, self.out.weight]:
                nn.init.normal_(weight, mean=0.0, std=0.02)
            
    def forward(self, x, x_len, hiddens = None):
        batch_size = x.size(0)
        if hiddens == None:
            shorts = self.short0.unsqueeze(0).repeat(batch_size, 1)
            longs = self.long0.unsqueeze(0).repeat(batch_size, 1)
        else:
            shorts = hiddens[0]
            longs = hiddens[1]
            
        sq_len = x.size(1)
        mask = (torch.arange(sq_len).unsqueeze(0) < x_len.unsqueeze(1)).unsqueeze(-1).to(x.device)
        out = torch.zeros(batch_size, sq_len, self.d_short, device=x.device)
        for sq_idx in range(sq_len):
            i = torch.sigmoid(self.i_short(shorts) + self.i_in(x[:, sq_idx]))
            f = torch.sigmoid(self.f_short(shorts) + self.f_in(x[:, sq_idx]))
            g = torch.tanh(self.g_short(shorts) + self.g_in(x[:, sq_idx]))
            o = torch.sigmoid(self.o_short(shorts) + self.o_in(x[:, sq_idx]))
            longs = (f * longs + i * g) * mask[:, sq_idx] + longs * ~mask[:, sq_idx]
            shorts = self.proj(o * longs) * mask[:, sq_idx] + shorts * ~mask[:, sq_idx]
            out[:, sq_idx] = self.out(o * longs)
        return out, (shorts, longs)

In [None]:
#Define own multi-head self attention.
#Necessary to work nice with recurrence + autoregressive caching
float_min = torch.finfo(torch.float32).min
class MHSA(nn.Module):
    def __init__(self, d_model, d_sa, n_head, dropout = 0.1, rope=False):
        super(MHSA, self).__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.d_sa = d_sa
        self.d_key = d_sa // n_head
        self.q = nn.Linear(d_model, d_sa)
        self.k = nn.Linear(d_model, d_sa)
        self.v = nn.Linear(d_model, d_sa)
        for lin in [self.k, self.q, self.v]:
            nn.init.zeros_(lin.bias)
        
        self.dropout = nn.Dropout(dropout)
        self.sa_lin = nn.Linear(d_sa, d_model)
        if rope:
            self.rope = torchtune.modules.RotaryPositionalEmbeddings(dim=self.d_key)
        else:
            self.rope = None
    
    def parallel(self, src, mem, mask=None):
        batch_size = src.size(0)
        src_sq_len = src.size(1)
        mem_sq_len = mem.size(1)
        if self.rope != None:
            ks = self.rope(self.k(mem).view(batch_size, mem_sq_len, self.n_head, self.d_key)).transpose(1, 2)
            qs = self.rope(self.q(src).view(batch_size, src_sq_len, self.n_head, self.d_key)).transpose(1, 2)
        else:
            ks = self.k(mem).view(batch_size, mem_sq_len, self.n_head, self.d_key).transpose(1, 2)
            qs = self.q(src).view(batch_size, src_sq_len, self.n_head, self.d_key).transpose(1, 2)
        vs = self.v(mem).view(batch_size, mem_sq_len, self.n_head, self.d_key).transpose(1, 2)
        
        dots = torch.matmul(qs, ks.transpose(-1, -2)) / math.sqrt(self.d_key)
        if mask != None:
            dots[~(mask.unsqueeze(1).repeat(1, self.n_head, 1, 1))] = float_min
        attn_weight = self.dropout(nn.Softmax(dim = -1)(dots))
        attns = torch.matmul(attn_weight, vs).transpose(1, 2).contiguous().view(batch_size, src_sq_len, self.d_sa)
        return self.sa_lin(attns)
        
    def autoreg(self, incoming, ks, vs, mask=None):
        batch_size = incoming.size(0)
        if self.rope != None:
            qs = self.rope(self.q(incoming).view(batch_size, 1, self.n_head, self.d_key), input_pos=torch.tensor(len(ks)).unsqueeze(0).repeat(batch_size, 1)).squeeze(1)
        else:
            qs = self.q(incoming).view(batch_size, self.n_head, self.d_key)
        dots = torch.sum(qs * torch.stack(ks), dim = -1) / math.sqrt(self.d_key)
        if mask != None:
            dots[~mask] = float_min
        attn_weight = self.dropout(nn.Softmax(dim = -3)(dots))
        return self.sa_lin(torch.sum(attn_weight.unsqueeze(-1) * torch.stack(vs), dim = (0)).view(batch_size, self.d_sa))
    
    def reg_kvs(self, incoming, ks, vs):
        if self.rope != None:
            ks.append(self.rope(self.k(incoming).view(incoming.size(0), 1, self.n_head, self.d_key), input_pos=torch.tensor(len(ks)).unsqueeze(0).repeat(incoming.size(0), 1)).squeeze(1))
        else:
            ks.append(self.k(incoming).view(incoming.size(0), self.n_head, self.d_key))
        vs.append(self.v(incoming).view(incoming.size(0), self.n_head, self.d_key))
        

In [None]:
#Recurrent Transformer
class ReTran(nn.Module):
    def __init__(self, ffwd_dim, d_sa, d_model, n_head, n_lay, dropout = 0.1):
        super(ReTran, self).__init__()
        self.d_model = d_model
        self.n_lay = n_lay
        self.lstms = nn.ParameterList([LSTM(d_model, ffwd_dim, d_model, init_kai=False) for _ in range(n_lay)])
        self.sas = nn.ParameterList([MHSA(d_model, d_sa, n_head, rope=True, dropout=dropout) for _ in range (n_lay)])
        for sa in self.sas:
            sa.apply(_init_gpt2)
        self.sa_h0s = nn.ParameterList([nn.Parameter(torch.zeros(d_model)) for _ in range(n_lay)])
        self.sa_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])
        self.ffwd_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, x_len):
        batch_size = x.size(0)
        sq_len = x.size(1)
        pad_mask = torch.arange(sq_len).unsqueeze(0) < x_len.unsqueeze(1)
        lstm_hiddens = [None for _ in range(self.n_lay)]
        out = torch.zeros((batch_size, sq_len, self.d_model), device = x.device)
        ks = [[] for _ in range(self.n_lay)]
        vs = [[] for _ in range(self.n_lay)]
        
        for layer in range(self.n_lay):
            self.sas[layer].reg_kvs(self.sa_h0s[layer].unsqueeze(0).repeat(batch_size, 1), ks[layer], vs[layer])
        
        for sq_idx in range(sq_len):
            curr = x[:, sq_idx]
            for layer in range(self.n_lay):
                curr = self.sa_norms[layer](self.dropout(self.sas[layer].autoreg(curr, ks[layer], vs[layer])) + curr)
                lstm_out, lstm_hiddens[layer] = self.lstms[layer](curr.unsqueeze(1), pad_mask[:, sq_idx], lstm_hiddens[layer])
                curr = self.ffwd_norms[layer](self.dropout(lstm_out.squeeze(1)) + curr)
                self.sas[layer].reg_kvs(curr, ks[layer], vs[layer])
            out[:, sq_idx] = curr
        
        return out, (out, lstm_hiddens, ks, vs)

In [None]:
#GPT definition          
class GPT(nn.Module):
    def __init__(self, d_ffwd, d_sa, d_model, n_head, n_lay, activ=nn.GELU(), dropout = 0.1):
        super(GPT, self).__init__()
        self.d_model = d_model
        self.n_lay = n_lay
        self.sas = nn.ParameterList([MHSA(d_model, d_sa, n_head, dropout=dropout, rope=True) for _ in range(n_lay)])
        self.ffwds = nn.ParameterList([FeedFwd([d_model, d_ffwd, d_model], dropout=dropout, activ=activ) for _ in range(n_lay)])
        self.sa_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])
        self.ffwd_norms = nn.ParameterList([nn.LayerNorm(d_model) for _ in range(n_lay)])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, x_len):
        batch_size = x.size(0)
        sq_len = x.size(1)
        src_mask = (torch.arange(sq_len).unsqueeze(0) < x_len.unsqueeze(1)).to(x.device)
        mask = src_mask.unsqueeze(1).repeat(1, sq_len, 1) & torch.tril(torch.ones(sq_len, sq_len).unsqueeze(0).repeat(batch_size, 1, 1)).bool().to(x.device)
        for layer in range(self.n_lay):
            x = self.sa_norms[layer](x + self.dropout(self.sas[layer].parallel(x, x, mask=mask)))
            x = self.ffwd_norms[layer](x + self.dropout(self.ffwds[layer](x)))
        return x, (x, src_mask)

#Uses initialization from GPT-2, centered around 0 with small std
def _init_gpt2(module):
    if isinstance(module, nn.Linear):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    elif isinstance(module, nn.LayerNorm):
        if module.weight is not None:
            torch.nn.init.ones_(module.weight)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
            
def init_sa_lin(model):
    for pn, p in model.named_parameters():
        if pn.endswith('sa_lin.weight'):
            torch.nn.init.normal_(p, mean = 0.0, std=0.02/math.sqrt(2 * model.n_lay))

In [None]:
class MultiLayLSTM(nn.Module):
    def __init__(self, ffwd_dim, d_model, n_lay, dropout = 0.1):
        super(MultiLayLSTM, self).__init__()
        self.n_lay = n_lay
        self.d_model = d_model
        self.lstms = nn.ParameterList([LSTM(d_model, ffwd_dim, d_model) for _ in range(n_lay)])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, x_len):
        batch_size = x.size(0)
        sq_len = x.size(1)
        pad_mask = torch.arange(sq_len).unsqueeze(0) < x_len.unsqueeze(1)
        lstm_hiddens = [None for _ in range(self.n_lay)]
        out = torch.zeros((batch_size, sq_len, self.d_model), device = x.device)

        for sq_idx in range(sq_len):
            curr = x[:, sq_idx]
            for layer in range(self.n_lay):
                lstm_out, lstm_hiddens[layer] = self.lstms[layer](curr.unsqueeze(1), pad_mask[:, sq_idx], lstm_hiddens[layer])
                if layer != self.n_lay - 1:
                    curr = self.dropout(lstm_out.squeeze(1))
                else:
                    curr = lstm_out.squeeze(1)
            out[:, sq_idx] = curr

        return (out, lstm_hiddens)

# Experiments

In [None]:
#Large batch size due to LAMB
batch_size = 2048
dropout = 0.1

In [None]:
retran_params = []
retran_losses = []
retran_acc = []
retran_time = []

In [None]:
#Small retran
d_model = 16
d_ffwd = 64
d_sa = 16
n_lay = 3
n_head = 8
basic_retran = Basic(ReTran(d_ffwd, d_sa, d_model, n_head, n_lay, dropout=dropout), d_model, len(vocab), False, vocab.index(PAD)).to(device)
init_sa_lin(basic_retran.model)
retran_params.append(sum(p.numel() for p in basic_retran.parameters()))
print("Total small retran params: ", retran_params[-1])
basic_retran_time, basic_retran_valid_losses = timer(train, basic_retran, lr=0.008)
retran_time.append(basic_retran_time)
retran_losses.append(min(basic_retran_valid_losses))
retran_acc.append(test(basic_retran))
print("Proportion small retran test case correct: ", retran_acc[-1])
print("Time", retran_time[-1])

In [None]:
#Medium retran
d_model = 32
d_ffwd = 128
d_sa = 32
n_lay = 4
n_head = 8
basic_retran = Basic(ReTran(d_ffwd, d_sa, d_model, n_head, n_lay, dropout=dropout), d_model, len(vocab), False, vocab.index(PAD)).to(device)
init_sa_lin(basic_retran.model)
retran_params.append(sum(p.numel() for p in basic_retran.parameters()))
print("Total medium retran params: ", retran_params[-1])
basic_retran_time, basic_retran_valid_losses = timer(train, basic_retran, lr=0.008)
retran_time.append(basic_retran_time)
retran_losses.append(min(basic_retran_valid_losses))
retran_acc.append(test(basic_retran))
print("Proportion medium retran test case correct: ", retran_acc[-1])
print("Time", retran_time[-1])

In [None]:
#Large retran
d_model = 64
d_ffwd = 256
d_sa = 64
n_lay = 6
n_head = 8
basic_retran = Basic(ReTran(d_ffwd, d_sa, d_model, n_head, n_lay, dropout=dropout), d_model, len(vocab), False, vocab.index(PAD)).to(device)
init_sa_lin(basic_retran.model)
retran_params.append(sum(p.numel() for p in basic_retran.parameters()))
print("Total large retran params: ", retran_params[-1])
basic_retran_time, basic_retran_valid_losses = timer(train, basic_retran, lr=0.008)
retran_time.append(basic_retran_time)
retran_losses.append(min(basic_retran_valid_losses))
retran_acc.append(test(basic_retran))
print("Proportion large retran test case correct: ", retran_acc[-1])
print("Time", retran_time[-1])

In [None]:
gpt_params = []
gpt_losses = []
gpt_acc = []
gpt_time = []

In [None]:
#Small GPT
d_model = 32
d_ffwd = 128
d_sa = 32
n_lay = 3
n_head = 8
basic_gpt = Basic(GPT(d_ffwd, d_sa, d_model, n_head, n_lay, dropout=dropout), d_model, len(vocab), False, vocab.index(PAD)).to(device)
basic_gpt.apply(_init_gpt2)
init_sa_lin(basic_gpt.model)
gpt_params.append(sum(p.numel() for p in basic_gpt.parameters()))
print("Total small GPT params: ", gpt_params[-1])
#We performed a grid search to find optimal lr for each architecture.
basic_gpt_time, basic_gpt_valid_losses = timer(train, basic_gpt, lr=0.01)
gpt_time.append(basic_gpt_time)
gpt_losses.append(min(basic_gpt_valid_losses))
gpt_acc.append(test(basic_gpt))
print("Proportion small GPT test case correct: ", gpt_acc[-1])
print("Time", gpt_time[-1])

In [None]:
#Medium GPT
d_model = 64
d_ffwd = 256
d_sa = 64
n_lay = 4
n_head = 8
basic_gpt = Basic(GPT(d_ffwd, d_sa, d_model, n_head, n_lay, dropout=dropout), d_model, len(vocab), False, vocab.index(PAD)).to(device)
basic_gpt.apply(_init_gpt2)
init_sa_lin(basic_gpt.model)
gpt_params.append(sum(p.numel() for p in basic_gpt.parameters()))
print("Total medium GPT params: ", gpt_params[-1])
basic_gpt_time, basic_gpt_valid_losses = timer(train, basic_gpt, lr=0.01)
gpt_time.append(basic_gpt_time)
gpt_losses.append(min(basic_gpt_valid_losses))
gpt_acc.append(test(basic_gpt))
print("Proportion medium GPT test case correct: ", gpt_acc[-1])
print("Time", gpt_time[-1])

In [None]:
#Large GPT
d_model = 128
d_ffwd = 512
d_sa = 128
n_lay = 6
n_head = 8
basic_gpt = Basic(GPT(d_ffwd, d_sa, d_model, n_head, n_lay, dropout=dropout), d_model, len(vocab), False, vocab.index(PAD)).to(device)
basic_gpt.apply(_init_gpt2)
init_sa_lin(basic_gpt.model)
gpt_params.append(sum(p.numel() for p in basic_gpt.parameters()))
print("Total large GPT params: ", gpt_params[-1])
basic_gpt_time, basic_gpt_valid_losses = timer(train, basic_gpt, lr=0.01)
gpt_time.append(basic_gpt_time)
gpt_losses.append(min(basic_gpt_valid_losses))
gpt_acc.append(test(basic_gpt))
print("Proportion large GPT test case correct: ", gpt_acc[-1])
print("Time", gpt_time[-1])

In [None]:
lstm_params = []
lstm_losses = []
lstm_acc = []
lstm_time = []

In [None]:
#Small LSTM
d_model = 16
d_ffwd = 64
n_lay = 3
basic_lstm = Basic(MultiLayLSTM(d_ffwd, d_model, n_lay, dropout=dropout), d_model, len(vocab), False, vocab.index(PAD)).to(device)
lstm_params.append(sum(p.numel() for p in basic_lstm.parameters()))
print("Total small lstm params: ", lstm_params[-1])
basic_lstm_time, basic_lstm_valid_losses = timer(train, basic_lstm, lr=0.006)
lstm_losses.append(min(basic_lstm_valid_losses))
lstm_time.append(basic_lstm_time)
lstm_acc.append(test(basic_lstm))
print("Proportion small lstm test case correct: ", lstm_acc[-1])
print("Time", lstm_time[-1])

In [None]:
#Medium LSTM
d_model = 32
d_ffwd = 128
n_lay = 4
basic_lstm = Basic(MultiLayLSTM(d_ffwd, d_model, n_lay, dropout=dropout), d_model, len(vocab), False, vocab.index(PAD)).to(device)
lstm_params.append(sum(p.numel() for p in basic_lstm.parameters()))
print("Total medium lstm params: ", lstm_params[-1])
basic_lstm_time, basic_lstm_valid_losses = timer(train, basic_lstm, lr=0.006)
lstm_losses.append(min(basic_lstm_valid_losses))
lstm_time.append(basic_lstm_time)
lstm_acc.append(test(basic_lstm))
print("Proportion medium lstm test case correct: ", lstm_acc[-1])
print("Time", lstm_time[-1])

In [None]:
#Large LSTM
d_model = 64
d_ffwd = 256
n_lay = 6
basic_lstm = Basic(MultiLayLSTM(d_ffwd, d_model, n_lay, dropout=dropout), d_model, len(vocab), False, vocab.index(PAD)).to(device)
lstm_params.append(sum(p.numel() for p in basic_lstm.parameters()))
print("Total large lstm params: ", lstm_params[-1])
basic_lstm_time, basic_lstm_valid_losses = timer(train, basic_lstm, lr=0.006)
lstm_losses.append(min(basic_lstm_valid_losses))
lstm_time.append(basic_lstm_time)
lstm_acc.append(test(basic_lstm))
print("Proportion large lstm test case correct: ", lstm_acc[-1])
print("Time", lstm_time[-1])

# Results

In [None]:
#We used the commented code to test GPT-4o performance
#You can either trust us and leave it commented, or try it out for yourself
#Running the commented code requires properly setting up an OpenAI API account at https://platform.openai.com/
gpto_perf = 55.68

_ = """
from openai import OpenAI
from pydantic import BaseModel

client = OpenAI()
class MathAnswer(BaseModel):
    answer: int

gpto_total = 10000
gpto_correct = 0
for prob in range(gpto_total):
    xy = decodeVocab(testset[prob][0])
    split = xy.index('=') + 1
    x = xy[:split]
    y = xy[split:]
    completion = client.beta.chat.completions.parse(
        model="gpt-4o",
        temperature = 0.0,
        messages=[
            {"role": "system", "content": "Answer the math question."},
            {"role": "user", "content": str(x)},
        ],
        response_format=MathAnswer,
    )
    ans = completion.choices[0].message.parsed.answer
    if y == str(ans):
        gpto_correct += 1
    if prob % 100 == 0:
        print("Through", prob, "problems,", gpto_correct / (prob + 1) * 100, "% correct")
print("Final", gpto_correct / gpto_total * 100, "% correct")
gpto_perf = gpto_correct / gpto_total * 100
"""

In [None]:
plt.figure(figsize=(10, 6))
plt.tick_params(axis='both', which='both', length=10, width=2)

plt.plot(retran_params, [acc * 100 for acc in retran_acc], label='ReTran', marker='x', markersize=15, linewidth=3)
plt.plot(lstm_params, [acc * 100 for acc in lstm_acc], label="LSTM", marker='x', markersize=15, linewidth=3)
plt.plot(gpt_params, [acc * 100 for acc in gpt_acc], label='GPT', marker='x', markersize=15, linewidth=3)
plt.plot([min(retran_params + lstm_params + gpt_params), max(retran_params + lstm_params + gpt_params)], [gpto_perf, gpto_perf], label='GPT-4o', linestyle=':', linewidth=3)

plt.title('Test Accuracy over Parameter Count')

plt.xscale('log')
plt.ylim(0, 100)
plt.ylabel('Test Accuracy (Percent)')
plt.xlabel('Parameter Count')

plt.legend()
plt.show()

In [None]:
#Break out test accuracy by problem length
def test_lens(model, dataset):
    test_iter = DataLoader(dataset, batch_size=batch_size, pin_memory=True)
    num_correct = {}
    total = {}
    for batch in test_iter:
        guesses = [decodeVocab(xy) for xy in solve(model, batch[0], batch[1], torch.max(batch[2] - batch[1]))]
        actual = [decodeVocab(xy) for xy in batch[0]]
        for j in range(batch_size):
            if batch[1][j].item() not in total.keys():
                total[batch[1][j].item()] = 0
                num_correct[batch[1][j].item()] = 0
            if guesses[j] == actual[j]:
                num_correct[batch[1][j].item()] += 1
            total[batch[1][j].item()] += 1
    for key in num_correct.keys():
        num_correct[key] /= total[key]
    return num_correct

In [None]:
#Tests the last trained model of each architecture.
#If you run the notebook linearly, this will be the largest from each architecture
#If desired, run cells in a different order to compare different sized models
gpt_test_lens = test_lens(basic_gpt, testset)
retran_test_lens = test_lens(basic_retran, testset)
lstm_test_lens = test_lens(basic_lstm, testset)

In [None]:
#Generate out of distribution data
long_data = ArithDataset(2048 * 64, min_in_len=31, max_in_len=40, min_out_len=0, max_out_len=5)

In [None]:
#Add in out of distribution data
gpt_test_lens |= test_lens(basic_gpt, long_data)
retran_test_lens |= test_lens(basic_retran, long_data)
lstm_test_lens |= test_lens(basic_lstm, long_data)
lens = sorted(list(gpt_test_lens.keys()))

In [None]:
plt.figure(figsize=(10, 6))
plt.tick_params(axis='both', which='both', length=10, width=2)

plt.plot([l-2 for l in lens], [retran_test_lens[l] * 100 for l in lens], label='Large ReTran', linewidth=3)
plt.plot([l-2 for l in lens], [lstm_test_lens[l] * 100 for l in lens], label='Large LSTM', linewidth=3)
plt.plot([l-2 for l in lens], [gpt_test_lens[l] * 100 for l in lens], label='Large GPT', linewidth=3)

plt.title('Test Accuracy over Problem Lengths')
plt.ylabel('Test Accuracy (Percent)')
plt.xlabel('Problem length')
plt.xticks(range(min(lens) - 2, max(lens) - 1, 2))
plt.axvspan(30.5, max(lens)-2, color='gray', alpha=0.3, label='Out of distribution')

plt.ylim(0, 100)
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.tick_params(axis='both', which='both', length=10, width=2)
plt.plot(retran_params, retran_losses, label='ReTran', marker='x', markersize=15, linewidth=3)
plt.plot(lstm_params, lstm_losses, label="LSTM", marker='x', markersize=15, linewidth=3)
plt.plot(gpt_params, gpt_losses, label='GPT', marker='x', markersize=15, linewidth=3)

plt.title('Final Validation Losses over Parameter Count')

plt.yscale('log')
plt.xscale('log')

formatter = ScalarFormatter()
formatter.set_scientific(False)
plt.gca().yaxis.set_major_formatter(formatter)
plt.gca().yaxis.set_minor_formatter(formatter)


plt.ylabel('Final Validation Loss')
plt.xlabel('Parameter Count')
plt.legend()
plt.show()

Copyright © 2025 anonymous

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE