# Day 14: The Annotated Transformer

> **Interactive notebook for understanding the Transformer architecture**

This notebook walks through:
1. Attention mechanism visualization
2. Multi-head attention
3. Positional encoding
4. Full model training on copy task
5. Inference and decoding

In [None]:
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

print(f"PyTorch version: {torch.__version__}")

## 1. Scaled Dot-Product Attention

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    """Scaled Dot-Product Attention"""
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

# Demo
torch.manual_seed(42)
Q = torch.randn(1, 4, 8)  # (batch, seq, d_k)
K = torch.randn(1, 4, 8)
V = torch.randn(1, 4, 8)

output, attn_weights = attention(Q, K, V)
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

In [None]:
# Visualize attention weights
plt.figure(figsize=(6, 5))
sns.heatmap(attn_weights[0].detach().numpy(), annot=True, fmt='.2f', cmap='Blues')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Attention Weights')
plt.show()

## 2. Causal (Subsequent) Mask

For autoregressive decoding, position $i$ can only attend to positions $\leq i$.

In [None]:
def subsequent_mask(size):
    """Mask out subsequent positions."""
    mask = torch.triu(torch.ones(1, size, size), diagonal=1) == 0
    return mask

mask = subsequent_mask(6)
plt.figure(figsize=(5, 4))
sns.heatmap(mask[0].int().numpy(), annot=True, cmap='Greens', cbar=False)
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Causal Mask (1 = can attend)')
plt.show()

In [None]:
# Attention with mask
output_masked, attn_masked = attention(Q, K, V, mask=subsequent_mask(4))

plt.figure(figsize=(6, 5))
sns.heatmap(attn_masked[0].detach().numpy(), annot=True, fmt='.2f', cmap='Blues')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Masked Attention Weights')
plt.show()

## 3. Positional Encoding

Since attention has no notion of position, we add positional information:

$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$

In [None]:
def positional_encoding(max_len, d_model):
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

pe = positional_encoding(100, 64)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap
ax = axes[0]
im = ax.imshow(pe.T, aspect='auto', cmap='RdBu')
ax.set_xlabel('Position')
ax.set_ylabel('Dimension')
ax.set_title('Positional Encoding')
plt.colorbar(im, ax=ax)

# Line plot
ax = axes[1]
for dim in [0, 1, 2, 3, 10, 20]:
    ax.plot(pe[:, dim].numpy(), label=f'dim {dim}')
ax.set_xlabel('Position')
ax.set_ylabel('Value')
ax.set_title('PE Values by Dimension')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.show()

## 4. Noam Learning Rate Schedule

In [None]:
def rate(step, d_model, warmup):
    if step == 0:
        step = 1
    return d_model ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))

steps = list(range(20000))

plt.figure(figsize=(10, 5))
for warmup in [400, 2000, 4000, 8000]:
    rates = [rate(s, 512, warmup) for s in steps]
    plt.plot(steps, rates, label=f'warmup={warmup}')

plt.xlabel('Training Step')
plt.ylabel('Learning Rate')
plt.title('Noam Learning Rate Schedule (d_model=512)')
plt.legend()
plt.grid(True)
plt.show()

## 5. Label Smoothing

In [None]:
def label_smoothing_demo(vocab_size=10, true_class=3, smoothing=0.1):
    dist = np.ones(vocab_size) * (smoothing / (vocab_size - 1))
    dist[true_class] = 1.0 - smoothing
    return dist

fig, axes = plt.subplots(1, 4, figsize=(14, 4))
for ax, smooth in zip(axes, [0.0, 0.1, 0.2, 0.3]):
    dist = label_smoothing_demo(smoothing=smooth)
    colors = ['steelblue'] * 10
    colors[3] = 'coral'
    ax.bar(range(10), dist, color=colors)
    ax.set_xlabel('Token Index')
    ax.set_ylabel('Probability')
    ax.set_title(f'Smoothing = {smooth}')
    ax.set_ylim(0, 1.1)

plt.suptitle('Label Smoothing Effect', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Training Demo: Copy Task

Train a small Transformer to copy input sequences.

In [None]:
# Import full implementation
from implementation import make_model, Batch, subsequent_mask, greedy_decode

# Create small model
VOCAB = 11
model = make_model(VOCAB, VOCAB, N=2, d_model=64, d_ff=128, h=2)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

losses = []
model.train()

for epoch in range(20):
    epoch_loss = 0
    for _ in range(20):
        # Generate batch
        data = torch.randint(1, VOCAB, (32, 10))
        data[:, 0] = 1
        batch = Batch(data.clone(), data.clone())
        
        optimizer.zero_grad()
        out = model(batch.src, batch.tgt, batch.src_mask, batch.tgt_mask)
        loss = criterion(out.reshape(-1, VOCAB), batch.tgt_y.reshape(-1))
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    losses.append(epoch_loss / 20)
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch + 1}: Loss = {losses[-1]:.4f}")

In [None]:
# Plot training curve
plt.figure(figsize=(8, 4))
plt.plot(losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.show()

In [None]:
# Test
model.eval()
test_src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
test_mask = torch.ones(1, 1, 10)

with torch.no_grad():
    output = greedy_decode(model, test_src, test_mask, max_len=10, start_symbol=1)

print(f"Input:  {test_src[0].tolist()}")
print(f"Output: {output[0].tolist()}")

match = (test_src[0] == output[0][:10]).sum().item()
print(f"Match:  {match}/10 ({100*match/10:.0f}%)")

## 7. Visualize Learned Attention

Let's see what the trained model is attending to.

In [None]:
# Get attention from encoder
model.eval()
_ = model.encode(test_src, test_mask)

# Plot attention from first encoder layer
enc_attn = model.encoder.layers[0].self_attn.attn[0]  # (heads, seq, seq)

fig, axes = plt.subplots(1, enc_attn.size(0), figsize=(12, 4))
for head in range(enc_attn.size(0)):
    ax = axes[head]
    sns.heatmap(enc_attn[head].detach().numpy(), ax=ax, cmap='Blues', cbar=False,
                xticklabels=test_src[0].tolist(), yticklabels=test_src[0].tolist())
    ax.set_title(f'Head {head}')

plt.suptitle('Encoder Self-Attention (Layer 0)', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

In this notebook we:
1. ✅ Implemented and visualized **scaled dot-product attention**
2. ✅ Created and understood **causal masking**
3. ✅ Visualized **positional encoding** patterns
4. ✅ Explored the **Noam learning rate schedule**
5. ✅ Understood **label smoothing**
6. ✅ Trained a Transformer on the **copy task**
7. ✅ Visualized **learned attention patterns**

**Next:** Try the exercises in `exercises/` or explore the full implementation in `implementation.py`!