# Lesson 12: Capstone Ablation Harness

Research engineering is about running controlled experiments with clean, reusable code and careful measurement.

In this notebook we:
- build a single GPT-like model with toggles (norm, position, FFN)
- keep the data split fixed and seeds controlled
- run short ablations and collect results in a table
- pick the best config and do a quick inference demo

The goal is not SOTA accuracy; it is a small, trustworthy experiment loop.


In [None]:
import math
import time
import random
import re
from dataclasses import dataclass, asdict
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd


def set_seed(seed):
    # Reproducibility across Python, NumPy, and PyTorch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
print("Torch:", torch.__version__)

## Dataset and tokenizer

We prefer `wikitext-2-raw-v1` and fall back to a tiny dataset if the download is blocked.
The tokenizer is intentionally simple (word and punctuation based) so you can see every step.


In [None]:
def load_text_dataset():
    from datasets import load_dataset
    ds = load_dataset("wikitext", "wikitext-2-raw-v1")
    name = "wikitext-2-raw-v1"
    train_text = "\n".join(ds["train"]["text"])
    val_text = "\n".join(ds["validation"]["text"]) if "validation" in ds else ""
    return train_text, val_text, name

train_text, val_text, dataset_name = load_text_dataset()
print("Dataset:", dataset_name)
print("Train chars:", len(train_text), "Val chars:", len(val_text))
print("Sample:", train_text[:200].replace("\n", " "))

### Tokenizer and vocab

We use a basic regex tokenizer and a capped vocabulary. Unknown tokens map to `<unk>`.
This is much simpler than BPE, but it keeps the pipeline transparent.


In [None]:
TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE)  # tokenization regex pattern


def basic_tokenize(text):
    return TOKEN_PATTERN.findall(text)


def build_vocab(tokens, vocab_size, min_freq=1):
    counts = Counter(tokens)
    specials = ["<pad>", "<unk>", "<bos>", "<eos>"]
    vocab_tokens = specials + [tok for tok, freq in counts.most_common() if freq >= min_freq]
    vocab_tokens = vocab_tokens[:vocab_size]
    stoi = {tok: i for i, tok in enumerate(vocab_tokens)}
    itos = {i: tok for tok, i in stoi.items()}
    return stoi, itos


def encode(text, stoi):
    tokens = basic_tokenize(text)
    return [stoi.get(tok, stoi["<unk>"]) for tok in tokens]


def decode(ids, itos):
    tokens = [itos.get(i, "<unk>") for i in ids]
    text = " ".join(tokens)
    for p in [".", ",", "!", "?", ":", ";", ")", "]", "'"]:
        text = text.replace(" " + p, p)
    for p in ["(", "["]:
        text = text.replace(p + " ", p)
    return text


VOCAB_SIZE = 8000  # vocab size
train_tokens = basic_tokenize(train_text)
stoi, itos = build_vocab(train_tokens, VOCAB_SIZE)
VOCAB_SIZE = len(stoi)  # vocab size

train_ids = torch.tensor(encode(train_text, stoi), dtype=torch.long)
val_ids = torch.tensor(encode(val_text, stoi), dtype=torch.long)

print("Vocab size:", VOCAB_SIZE)
print("Train tokens:", train_ids.numel(), "Val tokens:", val_ids.numel())
print("Token sample:", train_tokens[:20])
print("Decode sample:", decode(train_ids[:30].tolist(), itos))

## Config system

A small dataclass captures hyperparameters and ablation toggles.
We keep this separate so it is easy to log, compare, and reuse.


In [None]:
@dataclass
class Config:
    name: str = "baseline"
    seed: int = 1337
    vocab_size: int = VOCAB_SIZE
    block_size: int = 128
    batch_size: int = 32
    max_steps: int = 300
    eval_interval: int = 100
    eval_iters: int = 50
    lr: float = 3e-4
    weight_decay: float = 0.1
    grad_clip: float = 1.0

    d_model: int = 128
    n_heads: int = 4
    n_layers: int = 2
    ffn_mult: float = 4.0
    dropout: float = 0.1

    norm_type: str = "layernorm"  # or "rmsnorm"
    positional: str = "learned"   # or "rope"
    ffn_type: str = "gelu"        # or "swiglu"
    rope_base: int = 10000


def print_config(cfg):
    print("Config:", cfg.name)
    for k, v in asdict(cfg).items():
        print(f"  {k}: {v}")

## Model with toggles

We build a small GPT-like decoder-only Transformer.
All three toggles are wired into this single implementation.

PyTorch note: each `nn.Module` stores parameters, and `forward` describes how tensors flow.
We use shapes `(B, T, C)` for batch, time (sequence length), and channels (hidden size).


In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # x: (B, T, C)
        norm = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(norm + self.eps)
        return self.weight * x


def make_norm(norm_type, dim):
    if norm_type == "layernorm":
        return nn.LayerNorm(dim)
    if norm_type == "rmsnorm":
        return RMSNorm(dim)
    raise ValueError(f"Unknown norm_type: {norm_type}")


def build_rope_cache(seq_len, head_dim, base=10000, device=None):
    # RoPE uses pairs of dimensions as complex numbers
    theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
    positions = torch.arange(seq_len, device=device).float()
    freqs = torch.einsum("i,j->ij", positions, theta)
    cos = torch.cos(freqs)
    sin = torch.sin(freqs)
    return cos, sin


def apply_rope(x, cos, sin):
    # x: (B, n_heads, T, head_dim)
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]
    cos = cos[: x.size(-2), :].unsqueeze(0).unsqueeze(0)
    sin = sin[: x.size(-2), :].unsqueeze(0).unsqueeze(0)
    rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
    return rotated


class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, block_size, dropout, use_rope, rope_base=10000):
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError("d_model must be divisible by n_heads")
        head_dim = d_model // n_heads
        if use_rope and head_dim % 2 != 0:
            raise ValueError("head_dim must be even for RoPE")

        self.n_heads = n_heads
        self.head_dim = head_dim
        self.use_rope = use_rope
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

        # Causal mask prevents attention from seeing future tokens
        mask = torch.tril(torch.ones(block_size, block_size))
        self.register_buffer("mask", mask.view(1, 1, block_size, block_size))

        if use_rope:
            cos, sin = build_rope_cache(block_size, head_dim, base=rope_base, device=torch.device("cpu"))
            self.register_buffer("cos_cached", cos)
            self.register_buffer("sin_cached", sin)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        q = q.transpose(1, 2)  # (B, n_heads, T, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        if self.use_rope:
            q = apply_rope(q, self.cos_cached, self.sin_cached)
            k = apply_rope(k, self.cos_cached, self.sin_cached)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)

        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.proj(y)
        y = self.dropout(y)
        return y


class MLP(nn.Module):
    def __init__(self, d_model, ffn_mult, dropout, ffn_type):
        super().__init__()
        hidden_dim = int(d_model * ffn_mult)
        self.ffn_type = ffn_type
        if ffn_type == "gelu":
            self.fc1 = nn.Linear(d_model, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, d_model)
        elif ffn_type == "swiglu":
            self.fc1 = nn.Linear(d_model, 2 * hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, d_model)
        else:
            raise ValueError(f"Unknown ffn_type: {ffn_type}")
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        if self.ffn_type == "gelu":
            x = self.fc1(x)
            x = F.gelu(x)
            x = self.fc2(x)
        else:
            x = self.fc1(x)
            x1, x2 = x.chunk(2, dim=-1)
            x = F.silu(x1) * x2
            x = self.fc2(x)
        return self.dropout(x)


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm1 = make_norm(config.norm_type, config.d_model)
        self.attn = CausalSelfAttention(
            config.d_model,
            config.n_heads,
            config.block_size,
            config.dropout,
            use_rope=config.positional == "rope",
            rope_base=config.rope_base,
        )
        self.norm2 = make_norm(config.norm_type, config.d_model)
        self.mlp = MLP(config.d_model, config.ffn_mult, config.dropout, config.ffn_type)

    def forward(self, x):
        # Pre-norm Transformer block
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = None
        if config.positional == "learned":
            self.pos_emb = nn.Embedding(config.block_size, config.d_model)
        self.drop = nn.Dropout(config.dropout)
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)])
        self.norm_f = make_norm(config.norm_type, config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self._printed_shapes = False

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        if isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        if T > self.config.block_size:
            raise ValueError("Sequence length exceeds block_size")

        tok = self.tok_emb(idx)
        if self.pos_emb is not None:
            pos = torch.arange(T, device=idx.device)
            pos = self.pos_emb(pos)[None, :, :]
            x = tok + pos
        else:
            x = tok
        x = self.drop(x)

        for block in self.blocks:
            x = block(x)
        x = self.norm_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            # Flatten for cross-entropy: (B*T, vocab)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        if not self._printed_shapes:
            print("Input idx:", idx.shape)
            print("Token emb:", tok.shape)
            print("Logits:", logits.shape)
            self._printed_shapes = True

        return logits, loss

## Training and evaluation

We use short runs with a fixed random seed and a consistent train/val split.
`get_batch` creates random subsequences from the token stream.


In [None]:
def get_batch(data, config, generator):
    # data is a 1D tensor of token ids
    max_start = data.numel() - config.block_size - 1
    ix = torch.randint(0, max_start, (config.batch_size,), generator=generator)
    x = torch.stack([data[i : i + config.block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + 1 + config.block_size] for i in ix])
    return x.to(device), y.to(device)


def estimate_loss(model, data, config):
    # Run a few batches in eval mode and average the loss
    model.eval()
    losses = []
    eval_gen = torch.Generator().manual_seed(config.seed + 1234)
    with torch.no_grad():
        for _ in range(config.eval_iters):
            xb, yb = get_batch(data, config, eval_gen)
            _, loss = model(xb, yb)
            losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)


def train_model(config, train_data, val_data):
    print_config(config)
    set_seed(config.seed)
    model = GPT(config).to(device)
    num_params = sum(p.numel() for p in model.parameters())
    print("Parameters:", num_params)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    train_gen = torch.Generator().manual_seed(config.seed)

    start = time.perf_counter()
    for step in range(1, config.max_steps + 1):
        xb, yb = get_batch(train_data, config, train_gen)
        _, loss = model(xb, yb)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()  # compute gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        optimizer.step()  # update weights

        if step == 1 or step % config.eval_interval == 0:
            val_loss = estimate_loss(model, val_data, config)
            print(f"step {step:4d} train_loss {loss.item():.4f} val_loss {val_loss:.4f}")

    runtime = time.perf_counter() - start
    final_val_loss = estimate_loss(model, val_data, config)
    return model, final_val_loss, runtime

## Ablation runs

We run three short experiments to see how the toggles behave.
Feel free to add more configs or increase `max_steps` later.


In [None]:
configs = [
    Config(
        name="LN + learned + GELU",
        norm_type="layernorm",
        positional="learned",
        ffn_type="gelu",
    ),
    Config(
        name="RMS + learned + SwiGLU",
        norm_type="rmsnorm",
        positional="learned",
        ffn_type="swiglu",
    ),
    Config(
        name="RMS + RoPE + SwiGLU",
        norm_type="rmsnorm",
        positional="rope",
        ffn_type="swiglu",
    ),
]

results = []
models = []
for i, cfg in enumerate(configs):
    print("\n--- Running", cfg.name, "---")
    model, val_loss, runtime = train_model(cfg, train_ids, val_ids)
    ppl = math.exp(val_loss)
    params = sum(p.numel() for p in model.parameters())
    results.append({
        "config_id": i,
        "name": cfg.name,
        "norm": cfg.norm_type,
        "positional": cfg.positional,
        "ffn": cfg.ffn_type,
        "val_loss": val_loss,
        "perplexity": ppl,
        "runtime_sec": runtime,
        "params": params,
    })
    models.append(model)

results_df = pd.DataFrame(results).sort_values("val_loss")
print(results_df.to_string(index=False))
results_df

## Inference test

We pick the best config by validation loss and generate a short sample.


In [None]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=50):
    model.eval()
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.config.block_size :]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature
        if top_k is not None:
            k = min(top_k, logits.size(-1))
            v, _ = torch.topk(logits, k)
            logits[logits < v[:, [-1]]] = -float("inf")
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)
    return idx


best_id = int(results_df.iloc[0]["config_id"])
best_config = configs[best_id]
best_model = models[best_id]
print("Best config:", best_config.name)

set_seed(best_config.seed)
prompt = "The meaning of life is"
prompt_ids = torch.tensor([encode(prompt, stoi)], dtype=torch.long).to(device)
output_ids = generate(best_model, prompt_ids, max_new_tokens=40, temperature=1.0, top_k=50)
print(decode(output_ids[0].tolist(), itos))

## Scaling notes

- Larger sweeps: wrap the config list in a grid search and log to a CSV or parquet file.
- Better logging: integrate Weights & Biases or TensorBoard for metrics and plots.
- More compute: move to multi-GPU with `torch.distributed` or use gradient accumulation.
- Reliability: add checkpoints, save model states, and log the dataset hash.


## Exercises

1) Add a fourth run that changes only the positional encoding and compare.
2) Increase `max_steps` to 1000 and see if the ranking changes.
3) Try a character-level tokenizer and compare perplexity.
4) Replace the attention with PyTorch's `scaled_dot_product_attention` and time it.
