# Training GPT for Modular Arithmetic

In [2]:
import torch
from zeptogpt.model import SimpleGPT

[0, 832, 28, 325, 135, 385, 248, 357, 981, 192, 939]


In [61]:
# Vocabulary, encoder and decoder.

# Our modular arithmetic is in a world with only 0....MOD numbers
MOD = 10

stoi = {}
for i in range(MOD):
    stoi[str(i)] = i

# The mathematical operators we want to support
stoi['+'] = MOD

stoi['='] = MOD+1

# Special tokens
stoi['<'] = MOD+2
stoi['>'] = MOD+3

# Padding
stoi['.'] = MOD+4

vocab = list(stoi.keys())
vocab_size = len(stoi)

encode = lambda x: [stoi[s] for s in x]

itos = {v:k for k, v in stoi.items()}
decode = lambda x: ''.join([itos[i] for i in x])

print(decode(encode('<5+4+8=7>')))

<5+4+8=7>


In [406]:
from typing import Iterator
from torch.utils.data import IterableDataset, DataLoader

import random

class ModularArithmeticDataset(IterableDataset):
    def __init__(self, block_size, validation_size):
        super().__init__()
        self.block_size = block_size
        self.skip_exprs = [self.generate_expr() for _ in range(validation_size)]
        self.is_validation_mode = False
    
    def validation_mode(self):
        self.is_validation_mode = True
    
    def training_mode(self):
        self.is_validation_mode = False
    
    def generate_expr(self):
        num_terms = random.randint(1, self.block_size // 2 - 2)
        numbers = [random.randint(0, MOD-1) for _ in range(num_terms)]
        total = sum(numbers) % MOD
        expr = '<' + '+'.join(map(str, numbers)) + f'={total}' + '>'
        expr = expr + '.' * (self.block_size - len(expr))
        return expr
    
    def prepare_input(self, expr):
        return torch.tensor(encode(expr))

    def prepare_target(self, expr):
        expr = expr[1:] + '.'
        equal_pos = expr.index('=')
        expr = '.' * (equal_pos + 1) + expr[equal_pos + 1:]
        return torch.tensor(encode(expr))

    def __iter__(self):
        if self.is_validation_mode:
            for skip_expr in self.skip_exprs:
                yield self.prepare_input(skip_expr), self.prepare_target(skip_expr)
            return
        while True:
            expr = self.generate_expr()
            if expr in self.skip_exprs:
                continue
            yield self.prepare_input(expr), self.prepare_target(expr)
dataset = ModularArithmeticDataset(16, 0)
dataloader = DataLoader(dataset, 4)
for ix, (input, target) in enumerate(dataloader):
    print(input)
    print(target)
    break

tensor([[12,  7, 10,  4, 10,  8, 10,  0, 10,  8, 11,  7, 13, 14, 14, 14],
        [12,  9, 10,  7, 10,  5, 10,  4, 11,  5, 13, 14, 14, 14, 14, 14],
        [12,  1, 10,  0, 10,  3, 10,  1, 10,  1, 10,  0, 11,  6, 13, 14],
        [12,  2, 10,  7, 10,  9, 10,  6, 10,  3, 10,  4, 11,  1, 13, 14]])
tensor([[14, 14, 14, 14, 14, 14, 14, 14, 14, 14,  7, 13, 14, 14, 14, 14],
        [14, 14, 14, 14, 14, 14, 14, 14,  5, 13, 14, 14, 14, 14, 14, 14],
        [14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,  6, 13, 14, 14],
        [14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,  1, 13, 14, 14]])


In [407]:
from tqdm import tqdm

# Hyperparamters
num_iterations = 10000
report_every_n = 1000
eval_size = 100
block_size = 16
batch_size = 32
embed_dim = 16
num_heads = 8
num_decoder_layers = 4

model = SimpleGPT(vocab_size, embed_dim, block_size, num_heads, num_decoder_layers)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=stoi['.'])
optimizer = torch.optim.AdamW(model.parameters())

dataset = ModularArithmeticDataset(block_size, eval_size)
dataloader = DataLoader(dataset, batch_size)

@torch.no_grad
def estimate_loss():
    losses = torch.zeros(eval_size)
    model.eval()
    dataset.validation_mode()
    for i in range(eval_size):
        inputs, targets = next(iter(dataloader))
        logits = model(inputs)
        B, T, C = logits.shape
        loss = loss_fn(logits.view(B*T, C), targets.view(B*T))
        losses[i] = loss.item()
    model.train()
    dataset.training_mode()
    return losses.mean()

for i in tqdm(range(num_iterations)):
    optimizer.zero_grad()
    inputs, targets = next(iter(dataloader))
    logits = model(inputs)
    B, T, C = logits.shape
    loss = loss_fn(logits.view(B*T, C), targets.view(B*T))
    loss.backward()
    optimizer.step()
    if i % report_every_n == 0 or i == num_iterations - 1:
        print(f"Train Loss={loss.item()}, Test Loss={estimate_loss()}")

  0%|          | 4/10000 [00:01<36:27,  4.57it/s]  

Train Loss=3.7799081802368164, Test Loss=3.6419904232025146


 10%|█         | 1007/10000 [00:37<11:31, 13.01it/s]

Train Loss=1.0614112615585327, Test Loss=1.2852228879928589


 20%|██        | 2005/10000 [01:14<12:07, 10.99it/s]

Train Loss=1.124786615371704, Test Loss=1.388229250907898


 30%|███       | 3004/10000 [01:50<14:47,  7.88it/s]

Train Loss=1.0683794021606445, Test Loss=1.423374891281128


 40%|████      | 4003/10000 [02:43<15:23,  6.50it/s]

Train Loss=0.9192647337913513, Test Loss=1.3892085552215576


 50%|█████     | 5006/10000 [03:23<09:34,  8.69it/s]

Train Loss=0.7025317549705505, Test Loss=1.3223774433135986


 60%|██████    | 6005/10000 [04:00<05:53, 11.31it/s]

Train Loss=0.5251529812812805, Test Loss=1.1339017152786255


 70%|███████   | 7006/10000 [04:39<05:07,  9.72it/s]

Train Loss=0.4093388617038727, Test Loss=1.0924144983291626


 80%|████████  | 8006/10000 [05:19<03:33,  9.34it/s]

Train Loss=0.2749190926551819, Test Loss=1.002109169960022


 90%|█████████ | 9004/10000 [05:58<01:43,  9.62it/s]

Train Loss=0.2245638519525528, Test Loss=1.238279938697815


100%|██████████| 10000/10000 [06:38<00:00, 25.11it/s]

Train Loss=0.14825193583965302, Test Loss=1.3029029369354248





In [417]:
import torch.nn.functional as F

def generate(model, input):
    input = torch.tensor([encode(input)])
    while True:
        logits = model(input)[:,-1,:]
        prob = F.softmax(logits, dim=-1)
        pred = torch.multinomial(prob, num_samples=1)
        input = torch.cat((input, pred), dim=1)
        if pred.item() == stoi['>']:
            break
    return decode(input[0].tolist())


print(generate(model, "<1+1="))

<1+1=2>
