# Chapter 12: Training Your Model

> "I have not failed. I've just found 10,000 ways that won't work."
> — **Thomas Edison**, Inventor

---

## What You'll Learn

- How next-token prediction teaches models to write
- The 5-step training recipe that powers all neural networks
- Why loss curves reveal if your model is learning
- How to prevent overfitting and know when to stop
- Saving checkpoints so you never lose progress
- The thrill of watching your model transform from gibberish to coherence

---

## Setup

First, let's install required packages:

In [None]:
# Install required packages
!pip install -q torch transformers tqdm

In [None]:
# ===== IMPORTS =====
import math
import urllib.request
from functools import partial
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# ===== REPRODUCIBILITY =====
def set_seed(seed=42):
    """Set all seeds for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

## 1. Model Components from Chapters 10-11

First, let's bring in the MiniGPT model we built in previous chapters.

In [None]:
# ===== MULTI-HEAD ATTENTION (from Chapter 10) =====

class MultiHeadAttention(nn.Module):
    """Efficient multi-head attention (batches all heads together)."""

    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch, seq, d_model = x.shape

        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch, seq, 3, self.num_heads, self.d_head)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_head)

        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        attn_output = attn_weights @ V
        attn_output = attn_output.transpose(1, 2).reshape(batch, seq, d_model)

        return self.out_proj(attn_output), attn_weights

print("MultiHeadAttention defined!")

In [None]:
# ===== FEEDFORWARD NETWORK (from Chapter 10) =====

class FeedForward(nn.Module):
    """Position-wise feedforward network."""

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

print("FeedForward defined!")

In [None]:
# ===== TRANSFORMER BLOCK (from Chapter 10) =====

class TransformerBlock(nn.Module):
    """Complete Transformer block (pre-norm style like GPT-2)."""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out, attn_weights = self.attn(self.ln1(x), mask)
        x = x + self.dropout(attn_out)
        ffn_out = self.ffn(self.ln2(x))
        x = x + self.dropout(ffn_out)
        return x, attn_weights

print("TransformerBlock defined!")

In [None]:
# ===== GPT CONFIG (from Chapter 11) =====

@dataclass
class GPTConfig:
    """Configuration for MiniGPT model."""
    vocab_size: int = 50257
    max_seq_len: int = 1024
    embed_dim: int = 768
    num_heads: int = 12
    num_layers: int = 12
    d_ff: int = 3072
    dropout: float = 0.1

    def __post_init__(self):
        assert self.embed_dim % self.num_heads == 0, \
            f"embed_dim ({self.embed_dim}) must be divisible by num_heads ({self.num_heads})"

print("GPTConfig defined!")

In [None]:
# ===== MINIGPT MODEL (from Chapter 11) =====

class MiniGPT(nn.Module):
    """A minimal GPT-style language model."""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        # Embeddings
        self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim)
        self.pos_embed = nn.Embedding(config.max_seq_len, config.embed_dim)
        self.dropout = nn.Dropout(config.dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=config.embed_dim,
                num_heads=config.num_heads,
                d_ff=config.d_ff,
                dropout=config.dropout
            )
            for _ in range(config.num_layers)
        ])

        # Final layer norm and LM head
        self.ln_f = nn.LayerNorm(config.embed_dim)
        self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)

        # Weight tying
        self.lm_head.weight = self.token_embed.weight

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.token_embed.weight, std=0.02)
        nn.init.normal_(self.pos_embed.weight, std=0.02)

    def forward(self, token_ids, return_attention=False):
        batch, seq = token_ids.shape
        device = token_ids.device

        tok_emb = self.token_embed(token_ids)
        positions = torch.arange(seq, device=device)
        pos_emb = self.pos_embed(positions)
        x = self.dropout(tok_emb + pos_emb)

        mask = torch.tril(torch.ones(seq, seq, device=device))

        attention_weights = []
        for block in self.blocks:
            x, attn = block(x, mask)
            if return_attention:
                attention_weights.append(attn)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        if return_attention:
            return logits, attention_weights
        return logits

print("MiniGPT class defined!")

## 2. Download Dataset

We'll use TinyShakespeare - small enough to train in minutes.

In [None]:
# Download TinyShakespeare
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(url, "shakespeare.txt")

with open("shakespeare.txt", "r") as f:
    text = f.read()

print(f"Dataset size: {len(text):,} characters")
print(f"\nSample:\n{text[:500]}")

## 3. Create Dataset and DataLoader

In [None]:
class TextDataset(Dataset):
    """Simple dataset that returns text chunks."""

    def __init__(self, text, chunk_size=256):
        self.chunks = []
        for i in range(0, len(text) - chunk_size, chunk_size):
            self.chunks.append(text[i:i + chunk_size])

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, idx):
        return self.chunks[idx]


def collate_fn(batch, tokenizer, max_length=128):
    """Tokenize and pad a batch of text strings."""
    encoded = tokenizer(
        batch,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    return encoded["input_ids"], encoded["attention_mask"]


# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Create dataset
dataset = TextDataset(text, chunk_size=256)
print(f"Number of chunks: {len(dataset)}")

In [None]:
# Create train/val split
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

# Create DataLoaders
collate = partial(collate_fn, tokenizer=tokenizer, max_length=128)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=collate
)

# Check one batch
input_ids, attention_mask = next(iter(train_loader))
print(f"Batch input_ids shape: {input_ids.shape}")

## 4. Create Model

In [None]:
# Small config for fast training
config = GPTConfig(
    vocab_size=50257,
    max_seq_len=128,
    embed_dim=256,
    num_heads=4,
    num_layers=4,
    d_ff=1024,
    dropout=0.1
)

model = MiniGPT(config).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 5. The "Before" State: Untrained Model

In [None]:
@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=30, temperature=1.0):
    """Generate text with temperature control."""
    model.eval()
    device = next(model.parameters()).device

    token_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    for _ in range(max_new_tokens):
        logits = model(token_ids)
        next_logits = logits[:, -1, :] / temperature
        probs = F.softmax(next_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        token_ids = torch.cat([token_ids, next_token], dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(token_ids[0])


# Generate from untrained model
print("BEFORE TRAINING (random weights):")
print("="*50)
prompt = "The king"
print(f"Prompt: '{prompt}'")
print(f"Output: {generate(model, tokenizer, prompt)}")
print("\n(Random gibberish - the model hasn't learned anything yet!)")

## 6. Training Setup

In [None]:
# Training hyperparameters
num_epochs = 3
learning_rate = 3e-4
warmup_steps = 100

total_steps = len(train_loader) * num_epochs
print(f"Total training steps: {total_steps}")

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

## 7. Training and Evaluation Functions

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, device, clip_norm=1.0):
    """Train for one epoch using the 5-step recipe."""
    model.train()
    total_loss = 0

    progress = tqdm(dataloader, desc="Training")
    for input_ids, attention_mask in progress:
        input_ids = input_ids.to(device)

        # Shift for language modeling
        inputs = input_ids[:, :-1]
        targets = input_ids[:, 1:]

        # ===== THE 5-STEP RECIPE =====
        optimizer.zero_grad(set_to_none=True)       # 1. Zero gradients
        logits = model(inputs)                       # 2. Forward pass
        loss = F.cross_entropy(                      # 3. Compute loss
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            ignore_index=tokenizer.pad_token_id
        )
        loss.backward()                              # 4. Backward pass
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
        optimizer.step()                             # 5. Update weights
        scheduler.step()

        total_loss += loss.item()
        progress.set_postfix(loss=f"{loss.item():.4f}")

    return total_loss / len(dataloader)


@torch.no_grad()
def evaluate(model, dataloader, device):
    """Evaluate model and compute perplexity."""
    model.eval()
    total_loss = 0
    total_tokens = 0

    for input_ids, attention_mask in dataloader:
        input_ids = input_ids.to(device)
        inputs = input_ids[:, :-1]
        targets = input_ids[:, 1:]

        logits = model(inputs)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            ignore_index=tokenizer.pad_token_id,
            reduction='sum'
        )

        mask = (targets != tokenizer.pad_token_id)
        total_loss += loss.item()
        total_tokens += mask.sum().item()

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity

In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss, path):
    """Save training checkpoint."""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, path)
    print(f"Checkpoint saved: {path}")

## 8. Training Loop

In [None]:
# Training!
train_losses = []
val_losses = []
best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f"\n{'='*50}")
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"{'='*50}")

    # Train
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
    train_losses.append(train_loss)

    # Evaluate
    val_loss, perplexity = evaluate(model, val_loader, device)
    val_losses.append(val_loss)

    print(f"\nTrain Loss: {train_loss:.4f}")
    print(f"Val Loss:   {val_loss:.4f}")
    print(f"Perplexity: {perplexity:.1f}")

    # Save best checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint(model, optimizer, scheduler, epoch,
                       train_loss, val_loss, "best_model.pt")

print("\nTraining complete!")
print(f"Best validation loss: {best_val_loss:.4f}")

## 9. Plot Loss Curves

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
epochs = range(1, len(train_losses) + 1)
plt.plot(epochs, train_losses, 'b-o', label='Train Loss')
plt.plot(epochs, val_losses, 'r-o', label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 10. The Payoff: Text Generation!

In [None]:
# Load best model
checkpoint = torch.load("best_model.pt", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from epoch {checkpoint['epoch'] + 1}")
print(f"Validation loss: {checkpoint['val_loss']:.4f}")

In [None]:
print("\n" + "="*60)
print("AFTER TRAINING (on Shakespeare):")
print("="*60)

prompts = [
    "The king",
    "To be or not to be",
    "Friends, Romans, countrymen",
    "All the world's a stage"
]

for prompt in prompts:
    output = generate(model, tokenizer, prompt, max_new_tokens=40, temperature=0.8)
    print(f"\nPrompt: '{prompt}'")
    print(f"Output: {output}")
    print("-"*40)

## 11. Before vs After Comparison

In [None]:
# Create fresh untrained model for comparison
set_seed(123)  # Different seed for different random weights
untrained_model = MiniGPT(config).to(device)
untrained_model.eval()

prompt = "The fair maiden"

print("="*60)
print(f"Prompt: '{prompt}'")
print("="*60)
print("\nBEFORE (untrained):")
print(generate(untrained_model, tokenizer, prompt, max_new_tokens=30))
print("\nAFTER (trained on Shakespeare):")
print(generate(model, tokenizer, prompt, max_new_tokens=30, temperature=0.8))
print("\nSame architecture. Same code. Training makes all the difference!")

## 12. Temperature Comparison

In [None]:
prompt = "The noble lord"

print(f"Prompt: '{prompt}'\n")
for temp in [0.5, 0.8, 1.0, 1.5]:
    output = generate(model, tokenizer, prompt, max_new_tokens=30, temperature=temp)
    print(f"Temperature {temp}:")
    print(f"  {output}")
    print()

## Summary

**What we built:**

1. **Label shifting** for next-token prediction
2. **DataLoaders** that efficiently batch and tokenize text
3. **The 5-step training recipe**: zero_grad → forward → loss → backward → step
4. **Evaluation** with validation loss and perplexity
5. **Checkpointing** to save and resume training
6. **Text generation** with temperature control

**Key concepts:**

- Training is a feedback control loop: measure error, adjust weights, repeat
- Cross-entropy loss penalizes confident wrong predictions more than uncertain ones
- Overfitting = memorizing training data (train loss ↓, val loss ↑)
- Perplexity measures "how surprised" the model is—lower is better
- Temperature controls generation diversity

**Next:** Chapter 13 will teach you to fine-tune this model for specific tasks!

## Exercises

### Exercise 1: Learning Rate Experiment

Try different learning rates and compare results.

In [None]:
# YOUR CODE HERE
# Train with lr=1e-5, lr=3e-4, lr=1e-2
# Compare the loss curves
# Which learning rate works best?

### Exercise 2: More Epochs

In [None]:
# YOUR CODE HERE
# Train for 5-10 epochs instead of 3
# Does the model keep improving?
# Do you see signs of overfitting?

### Exercise 3: Model Size Comparison

In [None]:
# YOUR CODE HERE
# Create a smaller model (2 layers, 128 dim)
# Create a larger model (6 layers, 384 dim)
# Compare training speed and final perplexity