**task:** train a GPT to do addition of two numbers, i.e. a+b=c. You may find it helpful to predict the digits of c in reverse order, as the typical addition algorithm (that you're hoping it learns) would proceed right to left too. You may want to modify the data loader to simply serve random problems and skip the generation of train.bin, val.bin. You may want to mask out the loss at the input positions of a+b that just specify the problem using y=-1 in the targets (see CrossEntropyLoss ignore_index).

## two approaches: seq-to-seq and autoregressive prediction
1. predicting the entire output (sum) from the inputs (13+35 for ex) --> memorizing the mapping from an input format to an output
2. predicting the next number after shifting the window to the right

In [1]:
import random
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt

In [2]:
nums = [str(i) for i in range(10)] + list("+=;")
print(len(nums))
''.join(nums)

13


'0123456789+=;'

In [3]:
vocab_size = len(nums)
torch.manual_seed(1337)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
eval_iters = 200
eval_interval = 100
train_ratio = 0.8
max_digits = 2

block_size = 120
batch_size = 8
learning_rate = 1e-3
max_iters = 20000
dataset_size = 10000

n_embd = 32
dropout = 0.2
n_head = 4
n_layer = 4

cpu


In [4]:
stoi = { ch:i for i,ch in enumerate(nums) }
itos = { i:ch for i,ch in enumerate(nums) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

print(encode("123;+0"))
print(decode(encode("123;+0")))

[1, 2, 3, 12, 10, 0]
123;+0


In [5]:
samples = ""

for i in range(dataset_size):
    a = random.randint(10**(max_digits-1), 10**max_digits - 1)
    b = random.randint(10**(max_digits-1), 10**max_digits - 1)

    a_padded = str(a).zfill(max_digits)
    b_padded = str(b).zfill(max_digits)

    sol = str(a + b).zfill(max_digits + 1)[::-1] # pad extra values and reverse solution
    new = a_padded + "+" + b_padded + "=" + sol + ";"
    samples += new

print(samples[:50])
equations = encode(samples)

72+84=651;78+37=511;65+35=001;69+36=501;13+78=190;


only want to look at target values where the respective x sample has '=' and a complete input block (eg. 12+24) (which has length = `max_digits*2 + 1`). else, we'd be training with equation having incomplete terms.

In [6]:
split_index = int(train_ratio * len(equations))
train_set = equations[:split_index]
test_val = equations[split_index:]

half = len(test_val) // 2
test_set = test_val[:half]
val_set = test_val[half:]

In [7]:
def get_batch(data):
  data = torch.tensor(data)
  ix = torch.randint(0, len(data) - block_size, (batch_size,))
  x = torch.stack([data[i : i + block_size] for i in ix])
  y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])

  # Mask y based on the position of ';' and '=' in x
  for i in range(batch_size):
    eq_indices = [i for i, val in enumerate(x[i]) if val == encode('=')[0]]
    semicolon_indices = [i for i, val in enumerate(x[i]) if val == encode(';')[0]]
    # Initialize mask with -1
    mask = torch.ones_like(y[i]) * -1

    # Unmask y for segments between ';' and '='
    for j in range(len(eq_indices)):
        if j < len(semicolon_indices):
            if eq_indices[j] < semicolon_indices[j]:
                start = eq_indices[j] - 1
                end = semicolon_indices[j] + 1
            else:
                start = semicolon_indices[j] - 1
                end = eq_indices[j] + 1
            mask[start:end] = y[i, start:end]  # Unmask the segment between ';' and '='

    # Unmask everything after the last equal sign
    if eq_indices[-1] > semicolon_indices[-1]:
        mask[eq_indices[-1] - 1 :] = y[i, eq_indices[-1] - 1 :]

    y[i] = mask

  x, y = x.to(device), y.to(device)
  return x, y

In [20]:
def get_batch2(data):
    data = torch.tensor(data)
    ix = torch.randint(0, len(data) - block_size, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])

    eq_token = encode('=')[0]
    semicolon_token = encode(';')[0]

    # Mask y based on the position of ';' and '=' in x
    for i in range(batch_size):
        # Initialize mask with -1
        mask = torch.ones_like(y[i]) * -1

        equation_start = 0
        for j in range(block_size):
            token = x[i, j].item()

            # If semicolon or end of sequence is encountered, process the equation
            if token == semicolon_token or j == block_size - 1:
                equation = x[i, equation_start : j + 1].tolist()
                equation_start = j + 1  # Update start for the next equation

                # Find the '=' position in the equation
                if eq_token in equation:
                    eq_idx = equation.index(eq_token)
                    solution_start = eq_idx + 1

                    # Unmask all digits after '=' in y
                    print(decode(equation))
                    for k in range(solution_start, len(equation)):
                      if encode('0')[0] <= equation[k] <= encode('9')[0]:  # Check if token is a digit
                          print(equation_start, equation[k])
                          mask[equation_start - len(equation) + k] = y[i, equation_start - len(equation) + k]

        # Apply the updated mask to y
        y[i] = mask

    x, y = x.to(device), y.to(device)

    return x, y

In [24]:
data = encode("2+61=321;27+69=690;2")
block_size = 10
batch_size = 1

# Generate batch
a, b = get_batch2(data)

print(decode(a[0][:20].tolist()))  # Decoded input sequence
print(b[0][:20])                  # Masked target sequence

2+61=321;2
tensor([-1, -1,  1, 11, -1,  2,  1, 12, -1, -1])


In [None]:
@torch.no_grad()
def estimate_loss():
  out = {}
  m.eval()

  losses = torch.zeros(eval_iters)
  for k in range(eval_iters):
      X, y = get_batch(train_set)
      logits, loss = m(X, y)
      losses[k] = loss.item()
  out['train'] = losses.mean()

  losses = torch.zeros(eval_iters)
  for k in range(eval_iters):
      X, y = get_batch(val_set)
      logits, loss = m(X, y)
      losses[k] = loss.item()
  out['val'] = losses.mean()

  m.train()

  return out

In [None]:
class CausalSelfAttention(nn.Module):
  """ multiple heads of self-attention in parallel """

  def __init__(self, num_heads, head_size):
      super().__init__()
      self.proj = nn.Linear(n_embd, n_embd)
      self.dropout = nn.Dropout(dropout)

      self.key = nn.Linear(n_embd, head_size*n_head, bias=False)
      self.query = nn.Linear(n_embd, head_size*n_head, bias=False)
      self.value = nn.Linear(n_embd, head_size*n_head, bias=False)
      self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
      self.head_size = head_size

  def forward(self, x):
      B,T,C = x.shape

      k = self.key(x).view(B, n_head, T, C // n_head)
      q = self.query(x).view(B, n_head, T, C // n_head)

      wei = q @ k.transpose(-2, -1) * C**-0.5
      wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')).unsqueeze(0).unsqueeze(0) # wei.shape (1, 1, T, T)
      wei = F.softmax(wei, dim=-1)
      wei = self.dropout(wei)

      v = self.value(x).view(B, n_head, T, C // n_head)
      out = wei @ v # (1, 1, T, T) @ (B, nH, T, hS) ---> (B, nH, T, hS)
      out = out.transpose(1, 2).contiguous().view(B, T, C)
      out = self.dropout(self.proj(out))

      return out

class FeedForward(nn.Module):
  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd, 4 * n_embd),
        nn.ReLU(),
        nn.Linear(4 * n_embd, n_embd),
        nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = CausalSelfAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [None]:
class BigramLanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd, device=device)
    self.position_embedding_table = nn.Embedding(block_size, n_embd, device=device)
    self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
    self.ln_f = nn.LayerNorm(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx) # (B, T, C)
    pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
    x = tok_emb + pos_emb
    x = self.blocks(x) # (B,T,C)
    x = self.ln_f(x) # (B,T,C)
    logits = self.lm_head(x) # (B,T,vocab_size)

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets, ignore_index=-1)
    return logits, loss

  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]
        logits, loss = self(idx_cond)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1) # (B, C)
        idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

model = BigramLanguageModel()
m = model.to(device)

In [None]:
try:
  m.load_state_dict(torch.load("model.pt", map_location=torch.device(device)))
  print("Model loaded successfully!")

except FileNotFoundError:
  print("Model file not found. Starting with a new model.")
  optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

  losses = []

  for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets w/o backprop (hence no_grad)
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch(train_set)
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
  torch.save(m.state_dict(), "model.pt")

In [None]:
idx_start = encode("57+24=") # 57 + 24 = 81 = reverse(180)
idx = torch.tensor([idx_start], dtype=torch.long, device=device)

generated = m.generate(idx, max_new_tokens=20)
print("Generated Output:", decode(generated[0].tolist()))

Generated Output: 0+1;62=390;24+90;77=4


working gpt adder (not mine) - https://colab.research.google.com/drive/1AQ0a8lomUsMkZ2QoUwIPyYKz1PMyu1VS