
# Mini Transformer Training: Cross-Entropy, Perplexity, and Attention

This notebook accompanies **Module 2: Transformer Math & Training Objectives**.

**What you'll do:**
- Load a tiny text dataset from Hugging Face
- Tokenize with GPT-2 tokenizer
- Train a small Transformer (2 layers, 2 heads) on next-token prediction
- Plot the **loss curve**
- Compute **perplexity**



## 0. Setup & Installs
If you're on Google Colab, run the following to install dependencies. Then **Restart Runtime** if prompted.


In [None]:

# If running on Colab, uncomment to install/upgrade dependencies:
# !pip -q install datasets transformers torch matplotlib


## 1. Imports & Reproducibility

In [None]:

import math, random, os, sys, time
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from datasets import load_dataset
from transformers import AutoTokenizer

# Reproducibility
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Torch:", torch.__version__, "| CUDA:", torch.cuda.is_available(), "| Device:", device)



## 2. Load a Tiny Text Dataset

We'll use **WikiText-2 (raw)** and only a small subset to keep things fast.


In [None]:

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
print(dataset[0]["text"][:200].replace("\n", " "))
print("Num samples:", len(dataset))


## 3. Tokenize with GPT-2

In [None]:

tokenizer = AutoTokenizer.from_pretrained("gpt2")
# GPT-2 doesn't have a pad token by default; set eos_token as pad for batching simplicity
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

MAX_LEN = 64

def tokenize_fn(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=MAX_LEN)

tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names)
input_ids = torch.tensor(tokenized["input_ids"])[:256]  # small subset
attn_mask = torch.tensor(tokenized["attention_mask"])[:256]
print("input_ids shape:", tuple(input_ids.shape), "| attention_mask shape:", tuple(attn_mask.shape))
vocab_size = tokenizer.vocab_size


## 4. Create Mini Batches

In [None]:

BATCH_SIZE = 16

def iterate_batches(x, m, batch_size=BATCH_SIZE):
    for i in range(0, x.size(0), batch_size):
        yield x[i:i+batch_size], m[i:i+batch_size]

print("Example first batch sizes:", next(iterate_batches(input_ids, attn_mask))[0].shape)


## 5. Define a Tiny Transformer (with positional embeddings)

In [None]:

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=2, nlayers=2, max_len=MAX_LEN, pad_id=0):
        super().__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(max_len, d_model)
        layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(layer, num_layers=nlayers)
        self.decoder = nn.Linear(d_model, vocab_size)
        self.pad_id = pad_id

    def causal_mask(self, T, device):
        # True = mask out (prevent attending to future)
        mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
        return mask

    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(T, device=x.device)
        h = self.embed(x) + self.pos(pos)
        # causal mask ensures token t cannot attend to future tokens > t
        mask = self.causal_mask(T, x.device)
        out = self.encoder(h, mask=mask)
        logits = self.decoder(out)
        return logits, out  # return logits and hidden states


## 6. Train the Model and Plot Loss Curve

In [None]:

from torch.optim import AdamW

model = TinyTransformer(vocab_size=vocab_size).to(device)
optimizer = AdamW(model.parameters(), lr=2e-3)
EPOCHS = 5

losses = []

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    n_tokens = 0
    for xb, mb in iterate_batches(input_ids, attn_mask):
        xb = xb.to(device)
        logits, _ = model(xb)
        # Next-token prediction: shift targets by 1
        targets = xb[:, 1:].contiguous()
        logits_flat = logits[:, :-1, :].contiguous().view(-1, vocab_size)
        targets_flat = targets.view(-1)

        loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=tokenizer.pad_token_id)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += float(loss.item()) * targets_flat.numel()
        n_tokens += (targets_flat != tokenizer.pad_token_id).sum().item()

    avg_loss = total_loss / max(n_tokens, 1)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{EPOCHS} - Avg token loss: {avg_loss:.4f}")

plt.figure(figsize=(6,3))
plt.plot(losses, marker='o')
plt.title("Training Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Average Token Loss")
plt.show()


## 7. Evaluate Perplexity

In [None]:

model.eval()
with torch.no_grad():
    xb = input_ids[:32].to(device)
    logits, _ = model(xb)
    targets = xb[:, 1:].contiguous()
    logits_flat = logits[:, :-1, :].contiguous().view(-1, vocab_size)
    targets_flat = targets.view(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=tokenizer.pad_token_id)
    ppl = torch.exp(loss)
    print(f"Validation loss: {loss:.4f} | Perplexity: {ppl:.2f}")
