<a href="https://colab.research.google.com/github/abhinaash-broski/ANAIS25/blob/main/anais_gdl_p2_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Authors: A. Lupidi (alisia.lupidi@cs.ox.ac.uk)

# ANAIS Geometric Deep Learning
## Practical 2: Transformers

*Welcome to our second practical* ðŸš€ \
This notebook will focus on Transformers.
In the following sections, we will:
- Part 0: a quick refresh on theory and set up
- Part 1: building a vanilla Transformer
- Part 2: working through the "Attention Is All You Need" paper
- Part 3: building a modern "competitive" Transformer (GPT-2)

*Your Task:*
Implement the `# Code this!` parts in Part 1!

References:
- Attention Is All You Need (A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. Gomez, L. Kaiser, I. Polosukhin)
- Neural Machine Translation by Jointly Learning to Align and Translate (D. Bahdanau, K. Cho, Y. Bengio)



## Part 0:

In [None]:
!pip install torch
!pip install numpy



In [None]:
import torch
import numpy as np
# 3. TASK: SEQUENCE COPYING
def generate_data(batch_size, seq_len, vocab_size):
    # Generates random sequences. Target is same as Source.
    data = torch.randint(2, vocab_size, (batch_size, seq_len)) # 0,1 reserved for PAD/SOS
    # Add Start-of-Sequence (SOS) token to target
    sos = torch.ones((batch_size, 1), dtype=torch.long)
    tgt_in = torch.cat([sos, data], dim=1)[:, :-1] # Shifted for teacher forcing
    return data, tgt_in, data

## Part 1: Implementing Transformer from "Attention Is All You Need"

Paper: https://arxiv.org/pdf/1706.03762

Core Ideas:
- Multi-head Self-Attention
- Positional Embeddings


This is the most basic implementation of a Transformer. We will see it work on the toy task of Sequence Copying. The model is given a sequence of random numbers and must learn to output the exact same sequence. This proves the model can move information from the Encoder to the Decoder and handle sequence order.

In [None]:
# 4. TRAINING PARAMETERS
VOCAB_SIZE = 20
D_MODEL = 64
NHEAD = 8
LAYERS = 3
FF_DIM = 128
BATCH_SIZE = 64
SEQ_LEN = 10
EPOCHS = 2000

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# 1. POSITIONAL ENCODING
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# 2. CUSTOM TRANSFORMER SUB-COMPONENTS
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.nhead = nhead
        self.d_k = d_model // nhead

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        Q = self.w_q(q).view(batch_size, -1, self.nhead, self.d_k).transpose(1, 2)
        K = self.w_k(k).view(batch_size, -1, self.nhead, self.d_k).transpose(1, 2)
        V = self.w_v(v).view(batch_size, -1, self.nhead, self.d_k).transpose(1, 2)

        scores =

        if mask is not None:
            # Broadmask mask to match (batch, nhead, seq_len, seq_len)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention = # Code this!
        out =       # Code this!
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.nhead * self.d_k)
        return self.fc_out(out)

class FeedForward(nn.Module):
    def __init__(self, d_model, dim_feedforward):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
    def forward(self, x):
        return self.net(x)

# 3. ENCODER AND DECODER LAYERS
class EncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super().__init__()
        self.self_attn = # Code this!
        self.ff =        # Code this!
        self.norm1 =     # Code this!
        self.norm2 =     # Code this!

    def forward(self, x, mask=None):
        x =             # Code this!
        x =             # Code this!
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super().__init__()
        self.self_attn =    # Code this!
        self.cross_attn =   # Code this!
        self.ff =           # Code this!
        self.norm1 =        # Code this!
        self.norm2 =        # Code this!
        self.norm3 =        # Code this!

    def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
        x =                 # Code this!
        x =                 # Code this!
        x =                 # Code this!
        return x

# 4. FULL CUSTOM TRANSFORMER MODEL
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward):
        super().__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_encoder =    # Code this!

        self.encoder_layers = # Code this!
        self.decoder_layers = # Code this!

        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.pos_encoder(self.embed(src) * math.sqrt(self.d_model))
        tgt = self.pos_encoder(self.embed(tgt) * math.sqrt(self.d_model))

        enc_out = src
        for layer in self.encoder_layers:
            enc_out = layer(enc_out, src_mask)

        dec_out = tgt
        for layer in self.decoder_layers:
            dec_out = layer(dec_out, enc_out, src_mask, tgt_mask)

        return self.fc_out(dec_out)

Train and Eval

- Teacher Forcing: During training, we give the model the correct previous tokens (`tgt_in`) so it learns faster.

- The Causal Mask: We hide the future tokens from the decoder to avoid looking ahead.

In [None]:
# 5. DATA GENERATION & UTILS
def generate_data(batch_size, seq_len, vocab_size):
    data = torch.randint(2, vocab_size, (batch_size, seq_len))
    sos = torch.ones((batch_size, 1), dtype=torch.long)
    tgt_in = torch.cat([sos, data], dim=1)[:, :-1]
    return data, tgt_in, data

def get_causal_mask(size):
    # Generates a binary mask where 1 is allowed and 0 is masked
    mask = torch.triu(torch.ones(size, size), diagonal=1) == 0
    return mask.float()

# 6. TRAINING SETUP
VOCAB_SIZE, D_MODEL, NHEAD, LAYERS, FF_DIM = 20, 64, 8, 3, 128
BATCH_SIZE, SEQ_LEN = 32, 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerModel(VOCAB_SIZE, D_MODEL, NHEAD, LAYERS, FF_DIM).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss()

# 7. TRAINING LOOP
print(f"Training Custom Transformer on {device}...")
model.train()
for epoch in range(2000):
    src, tgt_in, tgt_expected = generate_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
    src, tgt_in, tgt_expected = src.to(device), tgt_in.to(device), tgt_expected.to(device)

    # Generate Causal Mask for Decoder
    tgt_mask = get_causal_mask(tgt_in.size(1)).to(device)

    optimizer.zero_grad()
    output = model(src, tgt_in, tgt_mask=tgt_mask)

    loss = criterion(output.view(-1, VOCAB_SIZE), tgt_expected.view(-1))
    loss.backward()
    optimizer.step()

    if epoch % 50 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# 8. INFERENCE (DEMO)
model.eval()
with torch.no_grad():
    test_src, _, _ = generate_data(1, SEQ_LEN, VOCAB_SIZE)
    test_src = test_src.to(device)
    input_tgt = torch.tensor([[1]]).to(device) # Start with SOS

    for _ in range(SEQ_LEN):
        tgt_mask = get_causal_mask(input_tgt.size(1)).to(device)
        out = model(test_src, input_tgt, tgt_mask=tgt_mask)
        next_token = out[:, -1].argmax(dim=-1).unsqueeze(1)
        input_tgt = torch.cat([input_tgt, next_token], dim=1)

print("\n--- RESULTS ---")
print(f"Source: {test_src[0].cpu().numpy()}")
print(f"Model:  {input_tgt[0, 1:].cpu().numpy()}")

# Part 2: Building a modern transformer GPT-2
"Let's build GPT: from scratch, in code, spelled out", https://www.youtube.com/watch?v=kCc8FmEb1nY by A. Karphathy

Takes ~10 mins

Innovations: to Encoder / Decoder to Autoregressive
- Decoder-Only: We remove the Encoder and Cross-Attention. We want a single st


- Pre-LayerNorm: We apply LayerNorm before the attention and feedforward layers (this is the key GPT-2 stability improvement, keeps gradients well behaved).

- Causal Masking: We use a triangular mask to ensure the model only looks at previous characters.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import requests

# --- 1. DATA PREPARATION ---
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = requests.get(url).text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]

def get_batch(split, batch_size, block_size):
    data_split = train_data if split == 'train' else data[n:]
    ix = torch.randint(len(data_split) - block_size, (batch_size,))
    x = torch.stack([data_split[i:i+block_size] for i in ix])
    y = torch.stack([data_split[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, n_embd, block_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        return self.proj(torch.cat([h(x) for h in self.heads], dim=-1))

class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embd // n_head, n_embd, block_size)
        self.ffwd = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd))
        self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))   # Pre-LN
        x = x + self.ffwd(self.ln2(x)) # Pre-LN
        return x

class GPTModel(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(self.ln_f(x))
        loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1)) if targets is not None else None
        return logits, loss

    def generate(self, idx, max_new_tokens, block_size):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            probs = F.softmax(logits[:, -1, :], dim=-1)
            idx = torch.cat((idx, torch.multinomial(probs, num_samples=1)), dim=1)
        return idx



# --- 3. TRAINING ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_embd, n_head, n_layer, block_size = 128, 4, 4, 64
model = GPTModel(vocab_size, n_embd, n_head, n_layer, block_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

def generate_sample(model, title="Sample"):
    model.eval()
    # Start with a newline/zero token as context
    context = torch.zeros((1, 1), dtype=torch.long, device=device)
    generated = decode(model.generate(context, 100, block_size)[0].tolist())
    print(f"\n{'='*20} {title} {'='*20}")
    print(generated)
    print(f"{'='*50}\n")
    model.train()

# 1. Show the model state before it has learned anything
print("System: Initializing... capturing baseline (Epoch 0 / Random Weights)")
generate_sample(model, title="PRE-TRAINING (GIBBERISH)")

print("Starting training. Watch the output transition from gibberish to structure...")

for epoch in range(10):
    total_loss = 0
    steps_per_epoch = 500 # Increased steps slightly to see more improvement

    for step in range(steps_per_epoch):
        xb, yb = get_batch('train', 32, block_size)

        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Optional: Print a mini-progress bar every 100 steps
        if step % 100 == 0:
            print(f"Epoch {epoch+1} | Step {step}/{steps_per_epoch} | Current Loss: {loss.item():.4f}")

    # 2. Show progress after each epoch
    avg_loss = total_loss / steps_per_epoch
    generate_sample(model, title=f"END OF EPOCH {epoch+1} (Avg Loss: {avg_loss:.4f})")

print("Training Complete!")