### Transformers
...

---

### Example 1: learning modulo funcion / groking
...

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random

# ----------------------------
# Parameters
# ----------------------------
# Vocabulary
VOCAB_SIZE = 10 + 1
SEP_TOKEN = VOCAB_SIZE - 1
MAX_LEN = 3                 # [a, b, SEP]
# Model
D_MODEL = 64
NUM_HEADS = 4
NUM_LAYERS = 2
D_FF = 128
# Hyperparameters
BATCH_SIZE = 128
EPOCHS = 300
LR = 3e-4
DEVICE = "cpu"

# ----------------------------
# Dataset
# ----------------------------
def generate_batch(batch_size):
    a = torch.randint(0, VOCAB_SIZE - 1, (batch_size,))
    b = torch.randint(1, VOCAB_SIZE - 1, (batch_size,))
    y = a % b
    x = torch.stack([a, b, torch.full_like(a, SEP_TOKEN)], dim=1)
    return x.to(DEVICE), y.to(DEVICE)

# ----------------------------
# Multi-Head Self Attention
# ----------------------------
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.shape

        qkv = self.qkv(x)
        qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = torch.softmax(scores, dim=-1)
        out = attn @ v

        out = out.transpose(1, 2).reshape(B, T, C)
        return self.out(out)

# ----------------------------
# Transformer Block
# ----------------------------
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadSelfAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

# ----------------------------
# Full Transformer Model
# ----------------------------
class TransformerModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_emb = nn.Embedding(VOCAB_SIZE, D_MODEL)
        self.pos_emb = nn.Embedding(MAX_LEN, D_MODEL)

        self.layers = nn.ModuleList([
            TransformerBlock(D_MODEL, NUM_HEADS, D_FF)
            for _ in range(NUM_LAYERS)
        ])

        self.norm = nn.LayerNorm(D_MODEL)
        self.head = nn.Linear(D_MODEL, VOCAB_SIZE)

    def forward(self, x):
        B, T = x.shape

        positions = torch.arange(T, device=DEVICE)
        x = self.token_emb(x) + self.pos_emb(positions)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)

        # Use first token representation for classification
        return self.head(x[:, 0, :])

# ----------------------------
# Training
# ----------------------------
model = TransformerModel().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    x, y = generate_batch(BATCH_SIZE)

    logits = model(x)
    loss = F.cross_entropy(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# ----------------------------
# Evaluation
# ----------------------------
model.eval()
correct = 0
total = 1000

with torch.no_grad():
    x, y = generate_batch(total)
    preds = model(x).argmax(dim=-1)
    correct = (preds == y).sum().item()

print(f"\nAccuracy: {correct/total:.4f}")
