In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Loss Functions and Backpropagation: How GPT Learns

*Part 3 of the Vizuara series on Building a GPT-Style Model from Scratch*
*Estimated time: 50 minutes*

## 1. Why Does This Matter?

We have built the forward pass of a GPT model -- text goes in, probability distributions come out. But how does the model actually *learn*? How does it go from random predictions to generating coherent text?

The answer lies in two ideas that are deceptively simple yet profoundly powerful:

1. **The loss function** measures how wrong the model is. It assigns a single number to the quality of the model's predictions -- lower is better.
2. **Backpropagation** computes the gradient of this loss with respect to every parameter in the model, telling each parameter exactly how to change to make the predictions better.

These two ideas, combined in a loop, are the entire learning algorithm. Every AI model you have ever heard of -- GPT-4, DALL-E, AlphaFold -- learns through this same loop: predict, measure error, compute gradients, update weights. In this notebook, we will build both the loss function and the backpropagation process from first principles, with full numerical examples.

By the end, you will understand exactly what `loss.backward()` does under the hood.

## 2. Building Intuition

### The Mountain Analogy

Imagine you are standing on a mountain in thick fog. You cannot see the valley below, but you want to get down. What do you do? You feel the slope of the ground beneath your feet and take a small step in the downhill direction. Then you feel the slope again and take another step. Repeat thousands of times, and you reach the valley.

This is gradient descent. The "mountain" is the loss landscape -- a surface where the height represents how wrong the model is. The "slope" is the gradient -- the direction of steepest ascent. By moving in the opposite direction (downhill), we reduce the loss.

### Why Negative Log?

The cross-entropy loss uses the negative log of the predicted probability. Why not just use the probability directly? Consider two scenarios:
- The model assigns 90% to the correct answer. Probability-based loss: 0.10. Log-based loss: 0.105.
- The model assigns 1% to the correct answer. Probability-based loss: 0.99. Log-based loss: 4.605.

The log-based loss penalizes confident wrong answers *much* more severely. Assigning 1% to the correct answer is catastrophic -- the loss is 46x higher than the 90% case. This strong penalty creates a powerful incentive for the model to avoid being confidently wrong, which is exactly the behavior we want.

### The Chain Rule -- One Sentence

Backpropagation is the chain rule of calculus, applied systematically from the loss backwards through every layer of the network. If the loss depends on A, which depends on B, which depends on C, then the gradient of the loss with respect to C is: (gradient of loss w.r.t. A) times (gradient of A w.r.t. B) times (gradient of B w.r.t. C). That is it. The rest is bookkeeping.

### Think About This

When you make a prediction and get it wrong, not every part of your reasoning contributed equally to the error. Some parts were on the right track; others led you astray. Backpropagation solves exactly this attribution problem -- it tells each parameter in the network precisely how much it contributed to the error, and in which direction it should change.

## 3. The Mathematics

### Cross-Entropy Loss

For a single prediction, the cross-entropy loss is:

$$\mathcal{L} = -\log P(\text{correct token})$$

Computationally: look up the probability the model assigned to the correct answer, take the natural log, and negate it. If the model assigned probability 0.7, the loss is $-\log(0.7) = 0.357$. If it assigned 0.01, the loss is $-\log(0.01) = 4.605$.

For a full sequence of $N$ tokens:

$$\mathcal{L}_{\text{total}} = -\frac{1}{N}\sum_{i=1}^{N} \log P_\theta(t_i \mid t_1, \ldots, t_{i-1})$$

This says: for each position, measure how well the model predicted the actual next token, and average across all positions.

### The Gradient of Cross-Entropy

The gradient of cross-entropy loss with respect to the logits has a beautifully simple form:

$$\frac{\partial \mathcal{L}}{\partial \text{logit}_i} = p_i - \mathbb{1}_{i=c}$$

where $p_i$ is the softmax probability for token $i$, and $c$ is the index of the correct token. Computationally: the gradient for every token is just its predicted probability. For the correct token, subtract 1. This means: push the correct token's logit higher (its gradient is negative), push all others lower (their gradients are positive).

### The Chain Rule

For a composite function $y = f(g(x))$:

$$\frac{dy}{dx} = \frac{dy}{dg} \cdot \frac{dg}{dx}$$

In a neural network with many layers, this chain extends through every layer from the loss back to the first parameter. The beauty is that each layer only needs to know its own local gradient and the gradient coming from the layer above.

### Gradient Descent Update

$$\theta_{\text{new}} = \theta_{\text{old}} - \eta \cdot \frac{\partial \mathcal{L}}{\partial \theta}$$

This says: adjust each parameter by a small step (controlled by learning rate $\eta$) in the direction that decreases the loss. If the gradient is positive, the parameter is contributing to increasing the loss, so we decrease it. If negative, we increase it.

## 4. Let's Build It -- Component by Component

### 4.1 Cross-Entropy Loss from Scratch

Let us implement cross-entropy loss manually, step by step, to see exactly what PyTorch's `F.cross_entropy` does under the hood.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def cross_entropy_from_scratch(logits, targets):
    """
    Compute cross-entropy loss manually.

    Args:
        logits: (N, C) raw model outputs (before softmax)
        targets: (N,) integer class labels

    Returns:
        scalar loss value
    """
    # Step 1: Softmax to get probabilities
    # Subtract max for numerical stability (log-sum-exp trick)
    logits_shifted = logits - logits.max(dim=-1, keepdim=True).values
    exp_logits = torch.exp(logits_shifted)
    probs = exp_logits / exp_logits.sum(dim=-1, keepdim=True)

    # Step 2: Extract probability of correct class for each sample
    N = logits.shape[0]
    correct_probs = probs[torch.arange(N), targets]

    # Step 3: Negative log
    loss = -torch.log(correct_probs + 1e-10)  # epsilon for numerical safety

    # Step 4: Average over all samples
    return loss.mean()

# Test: compare with PyTorch
logits = torch.randn(5, 10)  # 5 predictions, 10 classes
targets = torch.randint(0, 10, (5,))

our_loss = cross_entropy_from_scratch(logits, targets)
pytorch_loss = F.cross_entropy(logits, targets)

print(f"Our implementation:    {our_loss.item():.6f}")
print(f"PyTorch F.cross_entropy: {pytorch_loss.item():.6f}")
print(f"Difference: {abs(our_loss.item() - pytorch_loss.item()):.8f}")

In [None]:
# Visualize the negative log function
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: -log(p) curve
p = np.linspace(0.01, 1.0, 200)
loss = -np.log(p)

axes[0].plot(p, loss, 'b-', linewidth=2)
axes[0].axhline(y=0, color='gray', linestyle='-', alpha=0.3)

# Mark key points
points = [(0.01, -np.log(0.01), 'Terrible\nprediction'),
          (0.5, -np.log(0.5), 'Coin flip'),
          (1.0, -np.log(1.0), 'Perfect')]
for px, py, label in points:
    axes[0].plot(px, py, 'ro', markersize=8)
    axes[0].annotate(label, (px, py), textcoords="offset points",
                     xytext=(10, 10), fontsize=10)

axes[0].set_xlabel('P(correct token)', fontsize=12)
axes[0].set_ylabel('Loss = -log(P)', fontsize=12)
axes[0].set_title('Cross-Entropy Loss: Why -log?', fontsize=14)
axes[0].grid(True, alpha=0.3)

# Plot 2: Full example with a 4-token sequence
token_names = ['The', 'cat', 'sat', 'on']
pred_probs = [0.3, 0.7, 0.1]  # P(next correct token)
token_losses = [-np.log(p) for p in pred_probs]
avg_loss = np.mean(token_losses)

x_pos = range(len(pred_probs))
bars = axes[1].bar(x_pos, token_losses, color=['#ff7f0e', '#2ca02c', '#d62728'],
                   alpha=0.8, edgecolor='black')
axes[1].axhline(y=avg_loss, color='blue', linestyle='--',
                label=f'Average loss = {avg_loss:.3f}')

for i, (prob, loss_val) in enumerate(zip(pred_probs, token_losses)):
    axes[1].text(i, loss_val + 0.05, f'P={prob}\nL={loss_val:.3f}',
                ha='center', fontsize=10)

tick_labels = [f'After "{token_names[i]}"\npredict "{token_names[i+1]}"'
               for i in range(len(pred_probs))]
axes[1].set_xticks(x_pos)
axes[1].set_xticklabels(tick_labels, fontsize=9)
axes[1].set_ylabel('Loss', fontsize=12)
axes[1].set_title('Loss Across a Sequence: "The cat sat on"', fontsize=14)
axes[1].legend(fontsize=11)

plt.tight_layout()
plt.show()

### 4.2 The Gradient of Cross-Entropy (Manual)

In [None]:
def cross_entropy_gradient(logits, targets):
    """
    Compute the gradient of cross-entropy loss w.r.t. logits -- manually.

    The gradient has a beautiful form:
        dL/d(logit_i) = softmax(logit_i) - 1{i == target}

    Args:
        logits: (N, C) raw model outputs
        targets: (N,) integer class labels

    Returns:
        grad: (N, C) gradient matrix
    """
    # Softmax
    probs = F.softmax(logits, dim=-1)

    # Start with softmax probabilities
    grad = probs.clone()

    # Subtract 1 at the correct class position
    N = logits.shape[0]
    grad[torch.arange(N), targets] -= 1.0

    # Average over batch
    grad = grad / N

    return grad

# Compare with PyTorch autograd
logits = torch.randn(3, 5, requires_grad=True)
targets = torch.tensor([2, 0, 4])

# PyTorch gradient
loss = F.cross_entropy(logits, targets)
loss.backward()
pytorch_grad = logits.grad.clone()

# Our gradient
our_grad = cross_entropy_gradient(logits.detach(), targets)

print("PyTorch autograd gradient:")
print(pytorch_grad.numpy().round(4))
print("\nOur manual gradient:")
print(our_grad.numpy().round(4))
print(f"\nMax difference: {(pytorch_grad - our_grad).abs().max().item():.8f}")

In [None]:
# Visualize what the gradient means
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

logits_example = torch.tensor([[2.0, 1.0, 0.5, -1.0, 0.0]])
target = torch.tensor([0])  # Correct answer is token 0

probs = F.softmax(logits_example, dim=-1)[0]
grad = cross_entropy_gradient(logits_example, target)[0]

# Plot 1: Logits
axes[0].bar(range(5), logits_example[0].numpy(), color='steelblue', edgecolor='black')
axes[0].set_title('Logits (Raw Scores)', fontsize=13)
axes[0].set_xlabel('Token ID')
axes[0].set_ylabel('Logit Value')

# Plot 2: Probabilities (after softmax)
colors = ['green' if i == 0 else 'salmon' for i in range(5)]
axes[1].bar(range(5), probs.numpy(), color=colors, edgecolor='black')
axes[1].set_title('Probabilities (After Softmax)', fontsize=13)
axes[1].set_xlabel('Token ID')
axes[1].set_ylabel('Probability')
axes[1].annotate('Correct\ntoken', xy=(0, probs[0].item()), fontsize=10,
                 xytext=(0.5, probs[0].item() + 0.05))

# Plot 3: Gradients
grad_colors = ['green' if g < 0 else 'red' for g in grad.numpy()]
axes[2].bar(range(5), grad.numpy(), color=grad_colors, edgecolor='black')
axes[2].axhline(y=0, color='black', linewidth=0.5)
axes[2].set_title('Gradients (Direction of Update)', fontsize=13)
axes[2].set_xlabel('Token ID')
axes[2].set_ylabel('dL/d(logit)')
axes[2].annotate('Push UP\n(increase prob)', xy=(0, grad[0].item()),
                 xytext=(0.8, grad[0].item() - 0.08), fontsize=9,
                 arrowprops=dict(arrowstyle='->', color='green'))

plt.suptitle('Cross-Entropy Gradient: "Push correct up, push others down"',
             fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

### 4.3 The Chain Rule in Action

Let us trace backpropagation through a tiny network to see exactly how the chain rule works.

In [None]:
# A minimal 2-layer network to trace gradients manually
class TinyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Parameter(torch.tensor([[0.5, -0.3],
                                              [0.2, 0.8]]))
        self.w2 = nn.Parameter(torch.tensor([[0.4],
                                              [-0.6]]))

    def forward(self, x):
        h = torch.relu(x @ self.w1)       # Hidden layer
        out = h @ self.w2                   # Output
        return out

net = TinyNet()
x = torch.tensor([[1.0, 2.0]])
target = torch.tensor([[1.0]])

# Forward pass (manual trace)
h_pre = x @ net.w1                          # Before ReLU
h = torch.relu(h_pre)                       # After ReLU
out = h @ net.w2                            # Final output
loss = (out - target) ** 2                  # MSE loss

print("=== FORWARD PASS ===")
print(f"Input x:         {x.numpy()}")
print(f"W1:\n{net.w1.detach().numpy()}")
print(f"h (before ReLU): {h_pre.detach().numpy()}")
print(f"h (after ReLU):  {h.detach().numpy()}")
print(f"W2:\n{net.w2.detach().numpy()}")
print(f"Output:          {out.item():.4f}")
print(f"Target:          {target.item():.4f}")
print(f"Loss (MSE):      {loss.item():.4f}")

In [None]:
# Now trace the backward pass manually using the chain rule
print("\n=== BACKWARD PASS (Chain Rule) ===")

# dL/d(out) = 2 * (out - target)
dL_dout = 2 * (out - target)
print(f"dL/d(out) = 2*(out - target) = {dL_dout.item():.4f}")

# dL/d(W2) = h^T @ dL_dout  (chain rule through matmul)
dL_dW2 = h.t() @ dL_dout
print(f"dL/d(W2) = h^T @ dL_dout = {dL_dW2.detach().numpy().round(4)}")

# dL/d(h) = dL_dout @ W2^T
dL_dh = dL_dout @ net.w2.t()
print(f"dL/d(h)  = dL_dout @ W2^T = {dL_dh.detach().numpy().round(4)}")

# dL/d(h_pre) = dL/d(h) * ReLU_derivative
# ReLU derivative: 1 if h_pre > 0, else 0
relu_mask = (h_pre > 0).float()
dL_dh_pre = dL_dh * relu_mask
print(f"ReLU mask: {relu_mask.detach().numpy()}")
print(f"dL/d(h_pre) = {dL_dh_pre.detach().numpy().round(4)}")

# dL/d(W1) = x^T @ dL_dh_pre
dL_dW1 = x.t() @ dL_dh_pre
print(f"dL/d(W1) = x^T @ dL_dh_pre =\n{dL_dW1.detach().numpy().round(4)}")

# Compare with PyTorch autograd
loss_auto = ((net(x) - target) ** 2).sum()
loss_auto.backward()
print(f"\n=== PyTorch Autograd (should match) ===")
print(f"dL/d(W1) =\n{net.w1.grad.numpy().round(4)}")
print(f"dL/d(W2) =\n{net.w2.grad.numpy().round(4)}")

### 4.4 Visualizing Gradient Flow Through Residual Connections

In [None]:
# Demonstrate why residual connections help gradient flow

def simulate_gradient_flow(n_layers, use_residual=True):
    """Simulate gradient magnitude through N layers."""
    gradient = 1.0
    gradients = [gradient]

    for _ in range(n_layers):
        # Simulate a layer transformation (random scaling)
        layer_grad = np.random.uniform(0.3, 0.9)

        if use_residual:
            # Residual: gradient = 1 + layer_grad (always >= 1)
            gradient = gradient * (1 + layer_grad)
        else:
            # No residual: gradient = layer_grad (can shrink to 0)
            gradient = gradient * layer_grad

        gradients.append(gradient)

    return gradients

# Run many simulations
np.random.seed(42)
n_sims = 100
n_layers = 12

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Without residual connections
for _ in range(n_sims):
    grads = simulate_gradient_flow(n_layers, use_residual=False)
    axes[0].plot(grads, alpha=0.1, color='red')

avg_no_res = np.mean([simulate_gradient_flow(n_layers, False) for _ in range(1000)], axis=0)
axes[0].plot(avg_no_res, 'r-', linewidth=3, label='Average')
axes[0].set_xlabel('Layer (from output to input)')
axes[0].set_ylabel('Gradient Magnitude')
axes[0].set_title('WITHOUT Residual Connections\n(Vanishing Gradient Problem)')
axes[0].set_yscale('log')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# With residual connections
for _ in range(n_sims):
    grads = simulate_gradient_flow(n_layers, use_residual=True)
    axes[1].plot(grads, alpha=0.1, color='green')

avg_res = np.mean([simulate_gradient_flow(n_layers, True) for _ in range(1000)], axis=0)
axes[1].plot(avg_res, 'g-', linewidth=3, label='Average')
axes[1].set_xlabel('Layer (from output to input)')
axes[1].set_ylabel('Gradient Magnitude')
axes[1].set_title('WITH Residual Connections\n(Gradient Highway)')
axes[1].set_yscale('log')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle('Gradient Flow Through 12 Transformer Layers', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()
print("Without residual connections, gradients vanish exponentially.")
print("With residual connections, gradients are preserved -- training works!")

## 5. Your Turn

### TODO 1: Implement the Full Training Step

In [None]:
def training_step(model, x, y, optimizer, vocab_size=256):
    """
    Perform one complete training step: forward pass, loss, backward, update.

    Args:
        model: GPT model instance
        x: input token IDs, shape (batch_size, seq_len)
        y: target token IDs, shape (batch_size, seq_len)
        optimizer: torch optimizer
        vocab_size: vocabulary size for cross-entropy

    Returns:
        loss_value: the scalar loss for this step

    Steps:
        1. Forward pass: get logits from the model
        2. Reshape logits to (batch_size * seq_len, vocab_size)
        3. Reshape targets to (batch_size * seq_len,)
        4. Compute cross-entropy loss
        5. Zero the gradients (optimizer.zero_grad())
        6. Backward pass (loss.backward())
        7. Update weights (optimizer.step())
        8. Return the loss value as a float
    """
    # ============ TODO ============
    # Implement all 8 steps above.
    # The reshape is needed because PyTorch's cross_entropy
    # expects (N, C) logits and (N,) targets.
    # ==============================
    loss_value = None  # YOUR CODE HERE
    return loss_value

In [None]:
# Verification -- build a small model and run one training step
import torch.nn.functional as F_ver

class MiniGPT(nn.Module):
    def __init__(self, vocab_size=256, d_model=32, n_heads=2, n_layers=2, max_seq_len=64):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(max_seq_len, d_model)
        self.ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        x = self.tok(idx) + self.pos(torch.arange(T, device=idx.device))
        x = self.ln(x)
        return self.head(x)

mini = MiniGPT()
opt = torch.optim.Adam(mini.parameters(), lr=1e-3)
x_test = torch.randint(0, 256, (4, 16))
y_test = torch.randint(0, 256, (4, 16))

# Get loss before training
with torch.no_grad():
    logits_before = mini(x_test)
    loss_before = F_ver.cross_entropy(logits_before.view(-1, 256), y_test.view(-1))

# Run one training step
loss_val = training_step(mini, x_test, y_test, opt)
assert loss_val is not None, "training_step returned None"
assert isinstance(loss_val, float), "training_step should return a float"

# Loss should have decreased (or at least the gradient should be non-zero)
print(f"Loss before: {loss_before.item():.4f}")
print(f"Loss after 1 step: {loss_val:.4f}")
print("Training step works correctly!")

### TODO 2: Gradient Magnitude Analysis

In [None]:
def analyze_gradients(model, x, y, vocab_size=256):
    """
    Analyze gradient magnitudes across different parts of the model.

    Run a forward+backward pass and report the mean, max, and std
    of gradient magnitudes for each named parameter group:
    - Embedding layers (token_emb, pos_emb)
    - Attention weights (qkv, proj)
    - FFN weights
    - Output head
    - Layer norm parameters

    This reveals which parts of the model are learning fastest/slowest.

    Args:
        model: GPT model instance
        x: input token IDs
        y: target token IDs
        vocab_size: vocabulary size

    Returns:
        dict mapping parameter_group_name -> {'mean': float, 'max': float, 'std': float}

    Steps:
        1. Forward pass
        2. Compute loss
        3. Backward pass
        4. Iterate over model.named_parameters()
        5. Group parameters by name pattern (contains 'emb', 'qkv', etc.)
        6. For each group, compute mean/max/std of gradient magnitudes
        7. Print a formatted table
        8. Create a bar chart of mean gradient magnitudes by group
    """
    # ============ TODO ============
    # Implement the gradient analysis.
    # Hint: Use param.grad.abs() to get gradient magnitudes.
    # Remember to call optimizer.zero_grad() first to clear old gradients.
    # ==============================
    pass

## 6. Putting It All Together

In [None]:
# Complete training pipeline with loss tracking and analysis

class CharTokenizer:
    def __init__(self):
        self.vocab_size = 256
    def encode(self, text):
        return [ord(ch) for ch in text]
    def decode(self, ids):
        return ''.join(chr(i) for i in ids)

class SimpleGPT(nn.Module):
    """Minimal GPT for training demonstration."""
    def __init__(self, vocab_size=256, d_model=64, n_heads=4, n_layers=4, max_seq_len=128):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, n_heads, dim_feedforward=4*d_model,
                                       dropout=0.0, batch_first=True, norm_first=True)
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.max_seq_len = max_seq_len

    def forward(self, idx):
        B, T = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
        mask = torch.triu(torch.ones(T, T, device=idx.device), diagonal=1).bool()
        for block in self.blocks:
            x = block(x, src_mask=mask, is_causal=True)
        x = self.ln_f(x)
        return self.head(x)

tokenizer = CharTokenizer()

# Training data
text = open('/dev/stdin', 'r').read() if False else """
All that glitters is not gold Often have you heard that told
Many a man his life hath sold But my outside to behold
Gilded tombs do worms enfold Had you been as wise as bold
Young in limbs in judgment old Your answer had not been inscrolled
""" * 50

data = torch.tensor(tokenizer.encode(text))
print(f"Training data: {len(data)} characters")
print(f"Unique characters: {len(set(data.tolist()))}")

## 7. Training and Results

In [None]:
# Train with detailed tracking
model = SimpleGPT()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

seq_len = 64
batch_size = 16
n_steps = 1500

losses = []
learning_rates = []
grad_norms = []

for step in range(n_steps):
    # Sample batch
    idx = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
    x = torch.stack([data[i:i+seq_len] for i in idx])
    y = torch.stack([data[i+1:i+seq_len+1] for i in idx])

    # Forward
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, 256), y.view(-1))

    # Backward
    optimizer.zero_grad()
    loss.backward()

    # Track gradient norm
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            total_norm += p.grad.norm().item() ** 2
    grad_norms.append(total_norm ** 0.5)

    optimizer.step()

    losses.append(loss.item())
    if step % 300 == 0:
        print(f"Step {step:4d}: loss = {loss.item():.4f}, grad_norm = {grad_norms[-1]:.4f}")

print(f"\nFinal loss: {losses[-1]:.4f}")
print(f"Random baseline: {np.log(256):.4f}")
print(f"Improvement: {np.log(256) - losses[-1]:.4f}")

In [None]:
# Comprehensive training analysis
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curve
axes[0,0].plot(losses, alpha=0.3, color='blue')
window = 50
smoothed = [np.mean(losses[max(0,i-window):i+1]) for i in range(len(losses))]
axes[0,0].plot(smoothed, 'b-', linewidth=2, label='Smoothed')
axes[0,0].axhline(y=np.log(256), color='red', linestyle='--', label=f'Random: {np.log(256):.2f}')
axes[0,0].set_xlabel('Step')
axes[0,0].set_ylabel('Loss')
axes[0,0].set_title('Training Loss')
axes[0,0].legend()
axes[0,0].grid(True, alpha=0.3)

# Gradient norms over time
axes[0,1].plot(grad_norms, alpha=0.3, color='orange')
smoothed_gn = [np.mean(grad_norms[max(0,i-window):i+1]) for i in range(len(grad_norms))]
axes[0,1].plot(smoothed_gn, color='orange', linewidth=2)
axes[0,1].set_xlabel('Step')
axes[0,1].set_ylabel('Gradient L2 Norm')
axes[0,1].set_title('Gradient Norms Over Training')
axes[0,1].grid(True, alpha=0.3)

# Loss distribution: early vs late
axes[1,0].hist(losses[:100], bins=30, alpha=0.5, label='First 100 steps', color='red')
axes[1,0].hist(losses[-100:], bins=30, alpha=0.5, label='Last 100 steps', color='green')
axes[1,0].set_xlabel('Loss')
axes[1,0].set_ylabel('Count')
axes[1,0].set_title('Loss Distribution: Early vs Late Training')
axes[1,0].legend()

# Per-parameter gradient analysis (final step)
param_grads = {}
for name, param in model.named_parameters():
    if param.grad is not None:
        category = name.split('.')[0]
        if category not in param_grads:
            param_grads[category] = []
        param_grads[category].append(param.grad.abs().mean().item())

categories = list(param_grads.keys())
means = [np.mean(v) for v in param_grads.values()]
axes[1,1].barh(categories, means, color='steelblue', edgecolor='black')
axes[1,1].set_xlabel('Mean |Gradient|')
axes[1,1].set_title('Gradient Magnitudes by Component (Final Step)')

plt.tight_layout()
plt.show()

## 8. Final Output

In [None]:
# Generate text and show the learning in action
def generate(model, tokenizer, prompt, max_new=100, temperature=0.8):
    model.eval()
    ids = torch.tensor([tokenizer.encode(prompt)])
    for _ in range(max_new):
        if ids.shape[1] > 128:
            ids = ids[:, -128:]
        with torch.no_grad():
            logits = model(ids)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, 1)
        ids = torch.cat([ids, next_id], dim=1)
    return tokenizer.decode(ids[0].tolist())

print("=" * 60)
print("  GPT TRAINED WITH CROSS-ENTROPY + BACKPROPAGATION")
print("=" * 60)

for prompt in ["All that ", "Young in ", "Gilded "]:
    text = generate(model, tokenizer, prompt, max_new=80)
    print(f"\nPrompt: '{prompt}'")
    print(f"Output: {text}")
    print("-" * 50)

# Summary statistics
print(f"\nTraining Summary:")
print(f"  Starting loss:  {losses[0]:.4f} (random)")
print(f"  Final loss:     {losses[-1]:.4f}")
print(f"  Improvement:    {((losses[0] - losses[-1]) / losses[0] * 100):.1f}%")
print(f"  Total steps:    {n_steps}")
print(f"  Parameters:     {sum(p.numel() for p in model.parameters()):,}")

print("\nYou have trained a GPT model using cross-entropy loss")
print("and backpropagation -- the exact same algorithm used to train GPT-4!")

## 9. Reflection and Next Steps

### Reflection Questions
1. Cross-entropy loss treats all wrong answers equally -- predicting "cat" when the answer is "car" gets the same penalty as predicting "cat" when the answer is "refrigerator." Is this a problem? How might you address it?
2. We used a fixed learning rate. What would happen if the learning rate is too large? Too small? How does AdamW's adaptive learning rate help?
3. The gradient of softmax cross-entropy has the elegant form $p_i - \mathbb{1}_{i=c}$. Why is this gradient zero-sum (sums to zero across all classes)? What does this mean for the weight updates?

### Optional Challenges
1. Implement gradient clipping (cap the gradient norm to a maximum value). Train with and without clipping -- does it affect stability for your small model?
2. Implement a learning rate warmup schedule: start with a tiny learning rate and linearly increase it over the first 100 steps. Does this improve early training?
3. Track the cross-entropy loss separately for each character position in the sequence. Does the model learn to predict some positions better than others? Which ones, and why?