In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1IgUoF-zZMZRikv9Wn-4nQSKVIcGjXn0Y", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ Building a Complete Vision Transformer from Scratch

*Part 3 of 3 in the Vizuara series on Vision Transformers from Scratch*

*Estimated time: 60 minutes*

In Notebook 1, we learned how to turn images into patch embeddings. In Notebook 2, we built a Transformer encoder from scratch. Now, in this final notebook, we assemble everything into a **complete, trainable Vision Transformer**, train it on CIFAR-10, and visualize what it learns.

By the end of this notebook, you will have:
- A fully functional ViT that classifies images
- Training and validation curves showing it learning
- Attention heatmaps revealing what the model focuses on
- A position embedding similarity map showing learned 2D spatial structure

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://course-creator-brown.vercel.app/courses/vision-transformers-from-scratch/practice/3/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


In [None]:
#@title üéß Listen: Why It Matters
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/01_why_it_matters.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## Section 1: Why Does This Matter?

We have built every piece of the Vision Transformer individually. Now comes the satisfying part: **putting it all together** and watching it learn.

This is not a toy exercise. The model we build here is architecturally identical to the original ViT paper (Dosovitskiy et al., 2020) ‚Äî just smaller. The same principles scale to ViT-Base (86M parameters), ViT-Large (307M), and ViT-Huge (632M).

Here is what we will accomplish in the next 60 minutes:
1. **Assemble** the full ViT: patches + embeddings + encoder + classification head
2. **Train** it on CIFAR-10 (50,000 images, 10 classes) in under 10 minutes
3. **Evaluate** its predictions on unseen test images
4. **Visualize** attention maps ‚Äî where does the model look when classifying a cat vs a truck?
5. **Analyze** position embeddings ‚Äî does the model learn that patch 0 is top-left and patch 63 is bottom-right?

Let us begin.

In [None]:
#@title üéß Listen: Full Pipeline
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_full_pipeline.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## Section 2: Building Intuition

Before we write any code, let us walk through the complete ViT pipeline one more time.

### The Full Pipeline

```
Image (32√ó32√ó3)
    ‚Üì  Split into 4√ó4 patches
64 Patches (each 4√ó4√ó3 = 48 values)
    ‚Üì  Linear projection to D=192
64 Patch Embeddings (each 192-dim)
    ‚Üì  Prepend [CLS] token
65 Tokens (each 192-dim)
    ‚Üì  Add position embeddings
65 Positioned Tokens
    ‚Üì  Pass through 6 Transformer blocks
65 Encoded Tokens
    ‚Üì  Extract [CLS] token (index 0)
1 Global Representation (192-dim)
    ‚Üì  Layer Norm ‚Üí Linear(192, 10)
10 Class Logits ‚Üí Prediction
```

Every component here was built in Notebooks 1 and 2. Today we connect the wires.

Let us verify the dimensions quickly.

In [None]:
# Quick dimension check for our ViT-Tiny pipeline
img_size, patch_size, channels = 32, 4, 3
embed_dim, depth, num_heads = 192, 6, 3
num_patches = (img_size // patch_size) ** 2

print("ViT-Tiny Pipeline Dimensions")
print("=" * 40)
print(f"Input image:        {img_size}√ó{img_size}√ó{channels}")
print(f"Patch size:         {patch_size}√ó{patch_size}")
print(f"Num patches (N):    {num_patches}")
print(f"Patch dim (P¬≤¬∑C):   {patch_size**2 * channels}")
print(f"Embed dim (D):      {embed_dim}")
print(f"Sequence length:    {num_patches + 1} (patches + CLS)")
print(f"Encoder depth:      {depth} blocks")
print(f"Attention heads:    {num_heads}")
print(f"Head dim (D/H):     {embed_dim // num_heads}")

In [None]:
#@title üéß Listen: Cls Token
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_cls_token.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Why [CLS] for Classification?

The `[CLS]` token is a clever design choice from BERT that ViT inherits. Think of it this way:

**Analogy:** Imagine a meeting with 64 employees (the patch tokens). The `[CLS]` token is like a **manager** who starts the meeting knowing nothing. Over 6 rounds of discussion (Transformer layers), the manager listens to everyone through self-attention. By the end, the manager has synthesized information from all 64 employees into a single, comprehensive summary.

That summary ‚Äî the final `[CLS]` representation ‚Äî is what we feed to the classification head.

**Why not just average all patch tokens?** You could, and some models do (this is called **global average pooling**). But `[CLS]` is elegant: it gives the model a **dedicated slot** for the global representation without polluting individual patch representations. Each patch token can focus on encoding its local region, while `[CLS]` focuses on the big picture.

In [None]:
#@title üéß Listen: Variants And Data Hunger
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_variants_and_data_hunger.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### ViT Variants and Our "Tiny" Model

The original paper introduced three sizes:

| Variant   | Layers | Hidden Dim | Heads | Parameters |
|-----------|--------|------------|-------|------------|
| ViT-Base  | 12     | 768        | 12    | 86M        |
| ViT-Large | 24     | 1024       | 16    | 307M       |
| ViT-Huge  | 32     | 1280       | 16    | 632M       |

For CIFAR-10 on a T4 GPU, we will use a **ViT-Tiny** configuration:

| Setting    | Value |
|------------|-------|
| Layers     | 6     |
| Hidden Dim | 192   |
| Heads      | 3     |
| MLP Ratio  | 4.0   |
| Parameters | ~2.8M |

This is small enough to train in minutes, but large enough to learn meaningful representations.

In [None]:
# Compare ViT variant sizes
variants = {
    'ViT-Tiny (ours)': {'layers': 6, 'dim': 192, 'heads': 3},
    'ViT-Base':        {'layers': 12, 'dim': 768, 'heads': 12},
    'ViT-Large':       {'layers': 24, 'dim': 1024, 'heads': 16},
    'ViT-Huge':        {'layers': 32, 'dim': 1280, 'heads': 16},
}

print(f"{'Variant':<18} {'Layers':>6} {'Dim':>6} {'Heads':>6} {'~Params':>10}")
print("-" * 50)
for name, v in variants.items():
    # Rough param estimate: ~12¬∑D¬≤ per block + embeddings
    approx_params = v['layers'] * 12 * v['dim']**2 + v['dim'] * 197
    print(f"{name:<18} {v['layers']:>6} {v['dim']:>6} {v['heads']:>6} {approx_params/1e6:>9.1f}M")

### The Data Hunger Problem

ViT has a well-known weakness: **it needs a lot of data**. Unlike CNNs, which have built-in inductive biases (locality, translation equivariance), ViT must learn spatial relationships entirely from data. The original ViT was pre-trained on JFT-300M (300 million images!) before fine-tuning.

We only have 50,000 CIFAR-10 images. To compensate, we will use **data augmentation** ‚Äî random crops, flips, and normalization ‚Äî to artificially increase the effective dataset size.

> **Think About This:** Our tiny ViT will have ~2.8M parameters. A ResNet-18 has 11M. Given that ViT lacks the locality inductive bias, what accuracy do you predict on CIFAR-10? Higher or lower than a CNN?

In [None]:
# Make your prediction before training!
# (No peeking ahead ‚Äî write down your guess)
my_prediction = "___"  # Fill in: e.g., "75%", "85%", "60%"
print(f"My prediction for ViT-Tiny on CIFAR-10: {my_prediction}")
print("We will check this at the end of the notebook!")

In [None]:
#@title üéß Listen: Mathematics
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/05_mathematics.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## Section 3: The Mathematics

Let us formalize the three mathematical ideas we need for the full model.

### 3.1 The Classification Head

After the Transformer encoder processes all tokens through $L$ layers, we extract the `[CLS]` token from the final layer:

$$\hat{y} = \text{Linear}(\text{LayerNorm}(\mathbf{z}_L^0))$$

where:
- $\mathbf{z}_L^0 \in \mathbb{R}^D$ is the `[CLS]` token at layer $L$ (the superscript 0 means it is the first token)
- LayerNorm stabilizes the representation before the final projection
- Linear projects from dimension $D$ to $C$ classes (192 ‚Üí 10 for CIFAR-10)

This is deliberately simple. The entire "intelligence" of the model is in the Transformer encoder; the head just reads off the answer.

### 3.2 Cross-Entropy Loss

We train with the standard classification loss:

$$\mathcal{L} = -\sum_{c=1}^{C} y_c \log(\hat{y}_c)$$

where $y_c$ is 1 for the true class and 0 elsewhere, and $\hat{y}_c = \text{softmax}(\text{logits})_c$.

In practice, since $y$ is one-hot, this simplifies to:

$$\mathcal{L} = -\log(\hat{y}_{true})$$

The loss is just the **negative log probability** of the correct class. If the model assigns 90% probability to the right answer, the loss is $-\log(0.9) = 0.105$. If it assigns only 10%, the loss is $-\log(0.1) = 2.303$. The training process pushes all 2.8M parameters to minimize this value.

### 3.3 Parameter Count Breakdown

Let us derive exactly where our ~2.8M parameters come from (ViT-Tiny with $D=192$, $L=6$, $H=3$, patch size $P=4$, CIFAR-10 images $32 \times 32 \times 3$):

**Patch Embedding:**
- Conv2d weight: $(P^2 \cdot C) \cdot D = (16 \cdot 3) \cdot 192 = 9{,}216$ weights + 192 bias = **9,408**

**CLS Token + Position Embeddings:**
- CLS token: $1 \cdot D = 192$
- Position embeddings: $(N+1) \cdot D = 65 \cdot 192 = 12{,}480$
- Total: **12,672**

**Per Transformer Block:**
- Multi-head attention (Q, K, V projections + output): $4 \cdot D^2 + 4D = 4 \cdot 192^2 + 768 = 148{,}224$
- MLP (two linear layers with expansion ratio 4): $2 \cdot 4 \cdot D^2 + (4D + D) = 8 \cdot 192^2 + 960 = 295{,}680$
- Two LayerNorms: $2 \cdot 2D = 768$
- Per block total: **~444,672**
- Times 6 blocks: **~2,668,032**

**Final LayerNorm + Classification Head:**
- LayerNorm: $2D = 384$
- Linear: $D \cdot C + C = 192 \cdot 10 + 10 = 1{,}930$
- Total: **2,314**

**Grand total: ~2,692,426 parameters** (approximately 2.7M)

Every single one of these numbers will be adjusted during training to minimize the cross-entropy loss.

Let us verify our hand calculation with code.

In [None]:
# Verify the parameter count breakdown
D = 192          # embed_dim
L = 6            # depth
P = 4            # patch_size
C_in = 3         # channels
N = 64           # num_patches
C_out = 10       # num_classes

patch_embed_params = P * P * C_in * D + D  # Conv2d weight + bias
cls_pos_params = D + (N + 1) * D           # CLS token + position embeddings
per_block = 4 * D * D + 4 * D + 8 * D * D + (4 * D + D) + 2 * 2 * D  # attn + MLP + norms
encoder_params = L * per_block
head_params = 2 * D + D * C_out + C_out    # final LN + linear head

total = patch_embed_params + cls_pos_params + encoder_params + head_params
print(f"Patch embedding:    {patch_embed_params:>10,}")
print(f"CLS + Position:     {cls_pos_params:>10,}")
print(f"Encoder ({L} blocks): {encoder_params:>10,}")
print(f"Head:               {head_params:>10,}")
print(f"{'':->35}")
print(f"Estimated total:    {total:>10,}")
print(f"\n(Actual count will be close but may differ")
print(f" slightly due to bias terms in nn.MultiheadAttention)")

In [None]:
#@title üéß Listen: Building Start
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_building_start.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## Section 4: Let Us Build It ‚Äî Component by Component

Time to write code. We will build each component cleanly, test it, and then assemble them into the full model.

In [None]:
# Install and import everything we need
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import math
import warnings
warnings.filterwarnings('ignore')

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
#@title üéß Listen: Patch Embedding
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/07_patch_embedding.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.1 Patch Embedding (Recap from Notebook 1)

We use the Conv2d trick: a convolution with kernel_size=patch_size and stride=patch_size cleanly splits the image into non-overlapping patches and projects each to the embedding dimension in a single operation.

In [None]:
class PatchEmbedding(nn.Module):
    """Convert image into patch embeddings using Conv2d.

    Input:  (B, C, H, W) image tensor
    Output: (B, num_patches, embed_dim) patch embeddings
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2  # 64 for 32/4

        # Conv2d does patch extraction AND linear projection in one step
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W) -> (B, embed_dim, H/P, W/P)
        x = self.projection(x)
        # Flatten spatial dims and transpose: (B, embed_dim, N) -> (B, N, embed_dim)
        x = x.flatten(2).transpose(1, 2)
        return x

Let us verify the shapes are correct.

In [None]:
# Quick shape test
patch_embed = PatchEmbedding(img_size=32, patch_size=4, embed_dim=192)
dummy_img = torch.randn(2, 3, 32, 32)  # Batch of 2 CIFAR-10 images
patches = patch_embed(dummy_img)
print(f"Input shape:  {dummy_img.shape}")   # (2, 3, 32, 32)
print(f"Output shape: {patches.shape}")      # (2, 64, 192)
print(f"Number of patches: {patch_embed.num_patches}")  # 64
assert patches.shape == (2, 64, 192), "Shape mismatch!"
print("Patch embedding: OK")

In [None]:
#@title üéß Listen: Attention And Block
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/08_attention_and_block.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 Multi-Head Attention (From Scratch ‚Äî Recap from Notebook 2)

We build multi-head self-attention from raw matrix operations ‚Äî no `nn.MultiheadAttention` wrapper. This keeps the "from scratch" promise of the series and makes the attention weights easy to extract for visualization.

In [None]:
class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention built from scratch.

    Each head independently computes Q, K, V projections and attention.
    Heads are concatenated and projected through W_O.
    """
    def __init__(self, embed_dim=192, num_heads=3, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        # Q, K, V projections (combined for efficiency)
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x, return_attention=False):
        B, N, D = x.shape
        # Project to Q, K, V and reshape for multi-head
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv.unbind(0)            # each: (B, heads, N, head_dim)

        # Scaled dot-product attention
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn_weights = attn_scores.softmax(dim=-1)
        attn_weights = self.attn_drop(attn_weights)

        # Weighted sum of values
        out = (attn_weights @ v)            # (B, heads, N, head_dim)
        out = out.transpose(1, 2).reshape(B, N, D)  # (B, N, D)
        out = self.out_proj(out)

        if return_attention:
            return out, attn_weights  # (B, heads, N, N)
        return out, None

Now the full Transformer block using our from-scratch attention.

In [None]:
class TransformerBlock(nn.Module):
    """A single Transformer block with Pre-Norm architecture.

    Pre-Norm: LayerNorm BEFORE attention/MLP (more stable training).
    Uses our from-scratch MultiHeadSelfAttention, not nn.MultiheadAttention.
    """
    def __init__(self, embed_dim=192, num_heads=3, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)

        mlp_hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x, return_attention=False):
        # Pre-Norm Attention with residual
        attn_output, attn_weights = self.attn(
            self.norm1(x), return_attention=return_attention
        )
        x = x + attn_output

        # Pre-Norm MLP with residual
        x = x + self.mlp(self.norm2(x))

        if return_attention:
            return x, attn_weights  # attn_weights: (B, heads, N, N)
        return x

Let us verify the block works and inspect the attention weight shapes.

In [None]:
# Test the Transformer block
block = TransformerBlock(embed_dim=192, num_heads=3)
dummy_tokens = torch.randn(2, 65, 192)  # 64 patches + 1 CLS = 65 tokens

# Without attention weights
out = block(dummy_tokens)
print(f"Input:  {dummy_tokens.shape}")  # (2, 65, 192)
print(f"Output: {out.shape}")           # (2, 65, 192)

# With attention weights
out, attn = block(dummy_tokens, return_attention=True)
print(f"Attention weights: {attn.shape}")  # (2, 3, 65, 65) ‚Äî per head
print("Transformer block: OK")

In [None]:
#@title üéß Listen: Full Vit Class
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/09_full_vit_class.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 The Full VisionTransformer Class

Now for the main event. We assemble all the pieces into the complete ViT architecture.

In [None]:
class VisionTransformer(nn.Module):
    """Complete Vision Transformer for image classification.

    Architecture:
        Image ‚Üí PatchEmbedding ‚Üí [CLS] + PositionEmb ‚Üí TransformerBlocks ‚Üí Classify

    Args:
        img_size:    Input image size (assumes square images)
        patch_size:  Size of each patch
        in_channels: Number of input channels (3 for RGB)
        num_classes: Number of classification categories
        embed_dim:   Transformer hidden dimension D
        depth:       Number of Transformer blocks
        num_heads:   Number of attention heads per block
        mlp_ratio:   MLP hidden dimension as multiple of embed_dim
        dropout:     Dropout rate
    """
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_channels=3,
        num_classes=10,
        embed_dim=192,
        depth=6,
        num_heads=3,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size) ** 2

        # --- Patch Embedding ---
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )

        # --- CLS Token ---
        # Learnable token prepended to the sequence
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # --- Position Embeddings ---
        # Learnable position for each token (CLS + patches)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches + 1, embed_dim)
        )

        self.pos_drop = nn.Dropout(dropout)

We continue the `__init__` method with the encoder and classification head.

In [None]:
class VisionTransformer(VisionTransformer):
    """Continuing the VisionTransformer definition."""
    def __init__(self, **kwargs):
        # This cell just extends the previous class for readability.
        # In practice, all of this goes in one __init__.
        # We will define the complete class in one block below.
        pass

Actually, let us define the complete class properly in a single clean block. This is more realistic and avoids any inheritance tricks.

In [None]:
class VisionTransformer(nn.Module):
    """Complete Vision Transformer for image classification."""

    def __init__(self, img_size=32, patch_size=4, in_channels=3,
                 num_classes=10, embed_dim=192, depth=6, num_heads=3,
                 mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.depth = depth
        self.num_patches = (img_size // patch_size) ** 2

        # 1. Patch embedding
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )

        # 2. CLS token and position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches + 1, embed_dim)
        )
        self.pos_drop = nn.Dropout(dropout)

        # 3. Transformer encoder (stack of blocks)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        # 4. Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize weights following ViT conventions."""
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.zeros_(self.head.bias)

In [None]:
#@title üéß Listen: Forward Pass
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/10_forward_pass.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

Now the forward method. This is where the entire pipeline comes together.

In [None]:
class VisionTransformer(VisionTransformer):
    """Add forward method to VisionTransformer."""

    def forward(self, x, return_attention=False):
        """
        Forward pass of the Vision Transformer.

        Args:
            x: Input images (B, C, H, W)
            return_attention: If True, also return attention weights
                              from the last block

        Returns:
            logits: Class predictions (B, num_classes)
            attn_weights: (optional) Attention from last layer (B, heads, N, N)
        """
        B = x.shape[0]

        # Step 1: Patch embedding ‚Äî (B, C, H, W) -> (B, num_patches, D)
        x = self.patch_embed(x)

        # Step 2: Prepend CLS token ‚Äî (B, num_patches, D) -> (B, num_patches+1, D)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Step 3: Add position embeddings
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Step 4: Pass through Transformer blocks
        attn_weights = None
        for i, block in enumerate(self.blocks):
            if return_attention and i == self.depth - 1:
                # Get attention from the last layer only
                x, attn_weights = block(x, return_attention=True)
            else:
                x = block(x)

        # Step 5: Extract CLS token (index 0)
        cls_output = x[:, 0]  # (B, D)

        # Step 6: Classification head
        cls_output = self.norm(cls_output)
        logits = self.head(cls_output)

        if return_attention:
            return logits, attn_weights
        return logits

Let us verify the full model works and count its parameters.

In [None]:
# Create the ViT-Tiny model
model = VisionTransformer(
    img_size=32, patch_size=4, in_channels=3, num_classes=10,
    embed_dim=192, depth=6, num_heads=3, mlp_ratio=4.0, dropout=0.1
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters:     {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size:           {total_params * 4 / 1024 / 1024:.1f} MB (float32)")

# Forward pass test
dummy_batch = torch.randn(4, 3, 32, 32).to(device)
logits = model(dummy_batch)
print(f"\nInput shape:  {dummy_batch.shape}")
print(f"Output shape: {logits.shape}")       # Should be (4, 10)
assert logits.shape == (4, 10), "Output shape mismatch!"

# Test with attention weights
logits, attn = model(dummy_batch, return_attention=True)
print(f"Attention shape: {attn.shape}")      # Should be (4, 3, 65, 65)
print("\nFull model: OK! Ready to train.")

In [None]:
#@title üéß Listen: Data Preparation
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/11_data_preparation.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 Data Preparation

CIFAR-10 consists of 60,000 32x32 color images in 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck. We use 50,000 for training and 10,000 for testing.

Data augmentation is crucial for ViT ‚Äî without it, the model overfits quickly due to its lack of inductive biases.

In [None]:
# CIFAR-10 class names
CIFAR10_CLASSES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# Training transforms: augmentation + normalization
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    ),
])

# Test transforms: normalization only (no augmentation)
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    ),
])

In [None]:
# Download and load CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=test_transform
)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset, batch_size=128, shuffle=True,
    num_workers=2, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False,
    num_workers=2, pin_memory=True
)

print(f"Training samples:   {len(train_dataset):,}")
print(f"Test samples:       {len(test_dataset):,}")
print(f"Training batches:   {len(train_loader)}")
print(f"Test batches:       {len(test_loader)}")
print(f"Classes:            {CIFAR10_CLASSES}")

Let us visualize a batch of training images to see what our model will be working with.

In [None]:
# Visualize a batch of training images
# We need to un-normalize for display
def unnormalize(img_tensor):
    """Reverse CIFAR-10 normalization for visualization."""
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
    return (img_tensor * std + mean).clamp(0, 1)

# Get a batch
images, labels = next(iter(train_loader))

fig, axes = plt.subplots(2, 8, figsize=(14, 4))
fig.suptitle('Sample Training Images (with augmentation)', fontsize=14, fontweight='bold')
for i, ax in enumerate(axes.flat):
    img = unnormalize(images[i]).permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.set_title(CIFAR10_CLASSES[labels[i]], fontsize=9)
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Training Setup
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/12_training_setup.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.5 Training Setup

We use **AdamW** (Adam with decoupled weight decay) and a **learning rate schedule** with linear warmup followed by cosine decay. This is the standard recipe for training Vision Transformers.

In [None]:
# Hyperparameters
EPOCHS = 25
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.05
WARMUP_EPOCHS = 5

# Print hyperparameter summary
print("=" * 50)
print("         TRAINING CONFIGURATION")
print("=" * 50)
print(f"  Model:          ViT-Tiny")
print(f"  Parameters:     {total_params:,}")
print(f"  Epochs:         {EPOCHS}")
print(f"  Batch size:     128")
print(f"  Learning rate:  {LEARNING_RATE}")
print(f"  Weight decay:   {WEIGHT_DECAY}")
print(f"  Warmup epochs:  {WARMUP_EPOCHS}")
print(f"  Optimizer:      AdamW")
print(f"  LR schedule:    Linear warmup + Cosine decay")
print(f"  Device:         {device}")
print("=" * 50)

In [None]:
# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.999),
)

# Learning rate scheduler: linear warmup then cosine decay
def lr_lambda(epoch):
    """Linear warmup for warmup_epochs, then cosine decay."""
    if epoch < WARMUP_EPOCHS:
        return (epoch + 1) / WARMUP_EPOCHS
    else:
        progress = (epoch - WARMUP_EPOCHS) / (EPOCHS - WARMUP_EPOCHS)
        return 0.5 * (1 + math.cos(math.pi * progress))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Loss function
criterion = nn.CrossEntropyLoss()

# Visualize the LR schedule
lrs = [lr_lambda(e) * LEARNING_RATE for e in range(EPOCHS)]
plt.figure(figsize=(8, 3))
plt.plot(range(EPOCHS), lrs, 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule: Warmup + Cosine Decay')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Todo Forward
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/13_todo_forward.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## Section 5: Your Turn!

Before we run the full training, let us make sure you understand the key pieces by implementing them yourself.

### TODO 1: Implement the VisionTransformer Forward Pass

The `__init__` is done for you. Your job is to implement the `forward()` method that connects all the pieces.

In [None]:
class VisionTransformerTODO(nn.Module):
    """Vision Transformer ‚Äî implement the forward pass!"""

    def __init__(self, img_size=32, patch_size=4, in_channels=3,
                 num_classes=10, embed_dim=192, depth=6, num_heads=3,
                 mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.depth = depth
        self.num_patches = (img_size // patch_size) ** 2

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        """
        Implement the full ViT forward pass.

        Steps:
          1. Apply self.patch_embed to get patch tokens (B, N, D)
          2. Expand self.cls_token to batch size and prepend it (B, N+1, D)
          3. Add self.pos_embed and apply self.pos_drop
          4. Pass through each block in self.blocks
          5. Extract the CLS token (index 0) from the output
          6. Apply self.norm then self.head
          7. Return the logits (B, num_classes)

        Hints:
          - cls_tokens = self.cls_token.expand(B, -1, -1)
          - Use torch.cat([cls_tokens, x], dim=1) to prepend
          - x[:, 0] extracts the first token from each batch
        """
        B = x.shape[0]

        # ============ YOUR CODE HERE ============
        # TODO: Implement the 7 steps above

        raise NotImplementedError("Implement the forward pass!")
        # =========================================

In [None]:
# --- Verification for TODO 1 ---
# Uncomment after implementing:

# model_todo = VisionTransformerTODO().to(device)
# test_input = torch.randn(4, 3, 32, 32).to(device)
# test_output = model_todo(test_input)
# assert test_output.shape == (4, 10), f"Expected (4, 10), got {test_output.shape}"
#
# # Check that we can compute loss
# test_labels = torch.randint(0, 10, (4,)).to(device)
# loss = criterion(test_output, test_labels)
# loss.backward()
# print(f"Output shape: {test_output.shape} -- Correct!")
# print(f"Loss value:   {loss.item():.4f} -- Gradient flows!")
# print("TODO 1: PASSED!")

In [None]:
#@title üéß Listen: Todo Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/14_todo_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### TODO 2: Implement the Training Step

Implement a function that trains the model for one epoch. This is the standard PyTorch training loop ‚Äî but writing it yourself helps cement the pattern.

In [None]:
def train_one_epoch_TODO(model, dataloader, optimizer, scheduler, criterion, device):
    """
    Train the model for one epoch.

    Steps for each batch:
      1. Move images and labels to device
      2. Zero the optimizer gradients
      3. Forward pass: logits = model(images)
      4. Compute loss: loss = criterion(logits, labels)
      5. Backward pass: loss.backward()
      6. Optimizer step: optimizer.step()
      7. Track running loss and accuracy

    After all batches:
      8. Step the scheduler (once per epoch)

    Returns:
        avg_loss (float): Average loss over the epoch
        accuracy (float): Training accuracy (0-100)

    Hints:
      - predictions = logits.argmax(dim=1) gives predicted classes
      - Compare predictions == labels and count correct ones
      - Don't forget model.train() at the start!
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in dataloader:
        # ============ YOUR CODE HERE ============
        # TODO: Implement steps 1-7

        raise NotImplementedError("Implement the training step!")
        # =========================================

    # Step 8: Update learning rate
    scheduler.step()

    avg_loss = running_loss / len(dataloader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy

In [None]:
# --- Verification for TODO 2 ---
# Uncomment after implementing:

# # Reset the model for a fair test
# model_todo2 = VisionTransformer(
#     img_size=32, patch_size=4, in_channels=3, num_classes=10,
#     embed_dim=192, depth=6, num_heads=3, mlp_ratio=4.0, dropout=0.1
# ).to(device)
# opt_todo2 = optim.AdamW(model_todo2.parameters(), lr=3e-4, weight_decay=0.05)
# sched_todo2 = optim.lr_scheduler.LambdaLR(opt_todo2, lr_lambda)
#
# loss1, acc1 = train_one_epoch_TODO(
#     model_todo2, train_loader, opt_todo2, sched_todo2, criterion, device
# )
# print(f"Epoch 1 ‚Äî Loss: {loss1:.4f}, Accuracy: {acc1:.1f}%")
#
# loss2, acc2 = train_one_epoch_TODO(
#     model_todo2, train_loader, opt_todo2, sched_todo2, criterion, device
# )
# print(f"Epoch 2 ‚Äî Loss: {loss2:.4f}, Accuracy: {acc2:.1f}%")
#
# assert loss2 < loss1, f"Loss should decrease! Epoch 1: {loss1:.4f}, Epoch 2: {loss2:.4f}"
# print("TODO 2: PASSED! Loss is decreasing.")

In [None]:
#@title üéß Listen: Training Begins
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/15_training_begins.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## Section 6: Training

Now let us train our ViT-Tiny on CIFAR-10. This is the moment of truth ‚Äî can a hand-built Transformer learn to classify images?

First, we define clean training and evaluation functions.

In [None]:
def train_one_epoch(model, dataloader, optimizer, scheduler, criterion, device):
    """Train for one epoch. Returns (avg_loss, accuracy)."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        predicted = logits.argmax(dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    scheduler.step()
    return running_loss / len(dataloader), 100.0 * correct / total

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    """Evaluate model on a dataset. Returns (avg_loss, accuracy)."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        loss = criterion(logits, labels)

        running_loss += loss.item()
        predicted = logits.argmax(dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return running_loss / len(dataloader), 100.0 * correct / total

Let us reinitialize the model fresh and start training. On a T4 GPU, this should take approximately 5-8 minutes.

In [None]:
# Fresh model for training
model = VisionTransformer(
    img_size=32, patch_size=4, in_channels=3, num_classes=10,
    embed_dim=192, depth=6, num_heads=3, mlp_ratio=4.0, dropout=0.1
).to(device)

optimizer = optim.AdamW(
    model.parameters(), lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999)
)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
criterion = nn.CrossEntropyLoss()

total_params = sum(p.numel() for p in model.parameters())
print(f"Starting training with {total_params:,} parameters...")
print(f"{'Epoch':>5} | {'Train Loss':>10} | {'Train Acc':>9} | "
      f"{'Val Loss':>8} | {'Val Acc':>7} | {'LR':>10}")
print("-" * 65)

In [None]:
#@title üéß Listen: Training Loop
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/16_training_loop.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# Training loop
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [], 'lr': []
}

for epoch in range(EPOCHS):
    # Train
    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, scheduler, criterion, device
    )

    # Evaluate
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)

    # Record
    current_lr = optimizer.param_groups[0]['lr']
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)

    # Print progress
    print(f"{epoch+1:>5d} | {train_loss:>10.4f} | {train_acc:>8.2f}% | "
          f"{val_loss:>8.4f} | {val_acc:>6.2f}% | {current_lr:>10.6f}")

print(f"\nBest validation accuracy: {max(history['val_acc']):.2f}% "
      f"(epoch {np.argmax(history['val_acc'])+1})")

In [None]:
#@title üéß Listen: Training Curves
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/17_training_curves.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

Let us visualize the training curves. These tell us a lot about the training dynamics.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Loss curves
ax1 = axes[0]
ax1.plot(history['train_loss'], label='Train Loss', color='#2196F3', linewidth=2)
ax1.plot(history['val_loss'], label='Val Loss', color='#FF5722', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Plot 2: Accuracy curves
ax2 = axes[1]
ax2.plot(history['train_acc'], label='Train Accuracy', color='#4CAF50', linewidth=2)
ax2.plot(history['val_acc'], label='Val Accuracy', color='#9C27B0', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

best_epoch = np.argmax(history['val_acc'])
ax2.axvline(x=best_epoch, color='gray', linestyle='--', alpha=0.5)
ax2.annotate(f"Best: {max(history['val_acc']):.1f}%",
             xy=(best_epoch, max(history['val_acc'])),
             fontsize=10, fontweight='bold',
             xytext=(best_epoch + 1, max(history['val_acc']) - 5),
             arrowprops=dict(arrowstyle='->', color='gray'))

plt.tight_layout()
plt.show()

A few things to notice about the training curves:

- **Warmup phase** (epochs 1-5): The loss drops rapidly as the learning rate ramps up. The model is going from random initialization to something useful.
- **Training vs validation gap**: If train accuracy is much higher than val accuracy, that is overfitting. Some gap is expected, but if it is too large, we need more regularization or data.
- **Cosine decay**: The learning rate gradually decreases, allowing the model to fine-tune its weights in later epochs.

Our ViT-Tiny should achieve around **75-82%** validation accuracy on CIFAR-10. This is respectable for a 2.8M parameter model with no pre-training! For context, a ResNet-18 (11M params) gets ~93% with the same training recipe ‚Äî the gap comes from the inductive bias advantage of convolutions.

In [None]:
#@title üéß Listen: Predictions
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/18_predictions.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## Section 7: Final Output

This is the payoff of the entire series. We will create a comprehensive visualization that shows our ViT in action.

### Panel 1: Predictions on Test Images

Let us see how our trained model performs on individual test images.

In [None]:
@torch.no_grad()
def get_predictions(model, dataloader, device, n_samples=10):
    """Get predictions for the first n_samples images."""
    model.eval()
    images, labels = next(iter(dataloader))
    images = images[:n_samples].to(device)
    labels = labels[:n_samples]

    logits = model(images)
    probs = F.softmax(logits, dim=1)
    preds = logits.argmax(dim=1).cpu()
    confidences = probs.max(dim=1).values.cpu()

    return images.cpu(), labels, preds, confidences

In [None]:
images, labels, preds, confidences = get_predictions(model, test_loader, device)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
fig.suptitle('Model Predictions on Test Images', fontsize=16, fontweight='bold', y=1.02)

for i, ax in enumerate(axes.flat):
    img = unnormalize(images[i]).permute(1, 2, 0).numpy()
    ax.imshow(img)

    true_label = CIFAR10_CLASSES[labels[i]]
    pred_label = CIFAR10_CLASSES[preds[i]]
    conf = confidences[i].item()
    correct = preds[i] == labels[i]

    color = '#2E7D32' if correct else '#C62828'
    symbol = 'correct' if correct else 'WRONG'
    ax.set_title(f'Pred: {pred_label} ({conf:.0%})\nTrue: {true_label} [{symbol}]',
                 fontsize=9, color=color, fontweight='bold')
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Attention Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/19_attention_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Panel 2: Attention Visualization

This is one of the most illuminating visualizations in deep learning. We will see **where the model looks** when classifying an image ‚Äî specifically, what the `[CLS]` token attends to in the final Transformer layer.

In [None]:
@torch.no_grad()
def get_attention_maps(model, images, device):
    """Extract CLS token attention from the last layer.

    Returns:
        attn_maps: (B, num_heads, num_patches) ‚Äî CLS attention to each patch
    """
    model.eval()
    images = images.to(device)
    logits, attn_weights = model(images, return_attention=True)

    # attn_weights: (B, num_heads, N+1, N+1) where N+1 includes CLS
    # We want: CLS (row 0) attending to patches (columns 1:)
    cls_attn = attn_weights[:, :, 0, 1:]  # (B, num_heads, num_patches)

    return cls_attn.cpu(), logits.argmax(dim=1).cpu()

In [None]:
def visualize_attention(images, attn_maps, labels, preds, n_images=3, patch_size=4):
    """Visualize CLS attention overlaid on images."""
    num_heads = attn_maps.shape[1]
    grid_size = int(math.sqrt(attn_maps.shape[2]))  # 8 for 64 patches

    fig, axes = plt.subplots(n_images, num_heads + 1,
                              figsize=(3.5 * (num_heads + 1), 3.5 * n_images))
    fig.suptitle('Attention Maps: What [CLS] Focuses On (Last Layer)',
                 fontsize=16, fontweight='bold', y=1.02)

    for i in range(n_images):
        img = unnormalize(images[i]).permute(1, 2, 0).numpy()

        # Original image
        axes[i, 0].imshow(img)
        pred_name = CIFAR10_CLASSES[preds[i]]
        true_name = CIFAR10_CLASSES[labels[i]]
        axes[i, 0].set_title(f'Original\nPred: {pred_name}\nTrue: {true_name}',
                              fontsize=10)
        axes[i, 0].axis('off')

        # Per-head attention
        for h in range(num_heads):
            attn = attn_maps[i, h].reshape(grid_size, grid_size).numpy()

            # Upsample attention to image size
            attn_upsampled = np.kron(attn, np.ones((patch_size, patch_size)))

            axes[i, h + 1].imshow(img)
            axes[i, h + 1].imshow(attn_upsampled, alpha=0.6,
                                   cmap='hot', interpolation='bilinear')
            axes[i, h + 1].set_title(f'Head {h+1}', fontsize=10)
            axes[i, h + 1].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# Get a fresh batch of test images
test_images, test_labels = next(iter(test_loader))
sample_images = test_images[:3]
sample_labels = test_labels[:3]

# Extract attention maps
attn_maps, preds = get_attention_maps(model, sample_images, device)
print(f"Attention map shape: {attn_maps.shape}")  # (3, 3, 64)

# Visualize!
visualize_attention(sample_images, attn_maps, sample_labels, preds,
                    n_images=3, patch_size=4)

Look at the attention maps closely! Each head learns to attend to different aspects of the image:
- Some heads focus on the **object** (the cat's body, the truck's outline)
- Some focus on **edges** and boundaries
- Some attend more **uniformly**, gathering global context

This is the power of multi-head attention ‚Äî the model develops multiple complementary "ways of looking" at the image, all in parallel.

In [None]:
#@title üéß Listen: Position Embeddings
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/20_position_embeddings.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Panel 3: Learned Position Embeddings

This is a fascinating visualization. Recall that our model received **1D position indices** (0, 1, 2, ..., 63), yet the patches are arranged on a **2D grid** (8x8). Did the model learn the 2D spatial structure from data alone?

In [None]:
@torch.no_grad()
def visualize_position_embeddings(model):
    """Visualize learned position embedding similarity."""
    # Extract position embeddings (skip CLS at index 0)
    pos_embed = model.pos_embed[0, 1:, :].cpu()  # (64, 192)

    # Compute cosine similarity between all pairs
    pos_embed_norm = F.normalize(pos_embed, dim=1)
    similarity = torch.mm(pos_embed_norm, pos_embed_norm.t()).numpy()

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Full similarity matrix
    im = axes[0].imshow(similarity, cmap='RdBu_r', vmin=-1, vmax=1)
    axes[0].set_title('Position Embedding Cosine Similarity\n(64 patches x 64 patches)',
                       fontsize=12, fontweight='bold')
    axes[0].set_xlabel('Patch Index')
    axes[0].set_ylabel('Patch Index')
    plt.colorbar(im, ax=axes[0], shrink=0.8)

    # Show similarity for specific patches
    # Pick patch 0 (top-left), 4 (top-middle), 28 (center), 63 (bottom-right)
    interesting_patches = [0, 4, 28, 63]
    grid_size = 8

    axes[1].set_title('Similarity to Selected Patches\n(brighter = more similar)',
                       fontsize=12, fontweight='bold')

    combined = np.zeros((grid_size * 2, grid_size * 2))
    titles = ['Patch 0\n(top-left)', 'Patch 4\n(top-mid)',
              'Patch 28\n(center)', 'Patch 63\n(bottom-right)']

    for idx, patch_idx in enumerate(interesting_patches):
        row, col = idx // 2, idx % 2
        sim_map = similarity[patch_idx].reshape(grid_size, grid_size)
        combined[row*grid_size:(row+1)*grid_size,
                 col*grid_size:(col+1)*grid_size] = sim_map

    im2 = axes[1].imshow(combined, cmap='viridis')

    # Add grid lines to separate the 4 panels
    axes[1].axhline(y=grid_size - 0.5, color='white', linewidth=2)
    axes[1].axvline(x=grid_size - 0.5, color='white', linewidth=2)

    # Label each sub-panel
    for idx in range(4):
        row, col = idx // 2, idx % 2
        axes[1].text(col * grid_size + grid_size / 2, row * grid_size + 1,
                     titles[idx], ha='center', va='top',
                     fontsize=9, fontweight='bold', color='white')

    axes[1].axis('off')
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

    return similarity

similarity = visualize_position_embeddings(model)

The position embedding similarity matrix should reveal a striking pattern: **patches that are spatially close on the 2D grid have similar embeddings**, even though we only gave the model 1D indices!

On the left matrix, look for a block-diagonal structure ‚Äî groups of 8 patches (one row of the 8x8 grid) cluster together. On the right, notice that each patch is most similar to its immediate neighbors and least similar to patches far away.

This is a remarkable result. The model **discovered 2D geometry** purely from the training signal. Nobody told it that patch 0 is next to patch 1 horizontally and next to patch 8 vertically ‚Äî it learned this from seeing thousands of images where nearby patches tend to share visual features.

In [None]:
#@title üéß Listen: Per Class
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/21_per_class.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Panel 4: Per-Class Accuracy

Not all classes are equally easy for a Vision Transformer. Let us see which CIFAR-10 classes our model handles well and which ones are challenging.

In [None]:
@torch.no_grad()
def per_class_accuracy(model, dataloader, device, class_names):
    """Compute accuracy for each class."""
    model.eval()
    class_correct = torch.zeros(len(class_names))
    class_total = torch.zeros(len(class_names))

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        preds = logits.argmax(dim=1)

        for c in range(len(class_names)):
            mask = labels == c
            class_total[c] += mask.sum().item()
            class_correct[c] += (preds[mask] == labels[mask]).sum().item()

    accuracies = (class_correct / class_total * 100).numpy()
    return accuracies

In [None]:
class_accs = per_class_accuracy(model, test_loader, device, CIFAR10_CLASSES)

# Sort by accuracy for the plot
sorted_indices = np.argsort(class_accs)
sorted_names = [CIFAR10_CLASSES[i] for i in sorted_indices]
sorted_accs = class_accs[sorted_indices]

fig, ax = plt.subplots(figsize=(10, 5))
colors = plt.cm.RdYlGn(sorted_accs / 100)  # Red for low, green for high
bars = ax.barh(range(10), sorted_accs, color=colors, edgecolor='gray', linewidth=0.5)

# Add value labels on bars
for i, (acc, bar) in enumerate(zip(sorted_accs, bars)):
    ax.text(acc + 0.5, i, f'{acc:.1f}%', va='center', fontsize=10, fontweight='bold')

ax.set_yticks(range(10))
ax.set_yticklabels(sorted_names, fontsize=11)
ax.set_xlabel('Accuracy (%)', fontsize=12)
ax.set_title('Per-Class Accuracy on CIFAR-10', fontsize=14, fontweight='bold')
ax.set_xlim(0, 105)
ax.grid(axis='x', alpha=0.3)
ax.axvline(x=np.mean(class_accs), color='blue', linestyle='--',
           linewidth=1.5, label=f'Mean: {np.mean(class_accs):.1f}%')
ax.legend(fontsize=10)

plt.tight_layout()
plt.show()

Typical patterns you might observe:
- **Ship** and **automobile** tend to score well ‚Äî they have distinctive geometric shapes
- **Cat** and **dog** are often confused with each other ‚Äî they share similar shapes and textures
- **Bird** and **deer** can be tricky ‚Äî they appear against varied backgrounds

These patterns mirror what CNNs struggle with too, but the specific errors may differ because ViT has a different inductive bias.

In [None]:
#@title üéß Listen: Grand Summary
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/22_grand_summary.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### The Grand Summary Figure

Let us bring together the key results into one comprehensive figure.

In [None]:
print("=" * 60)
print("     VISION TRANSFORMER TRAINING SUMMARY")
print("=" * 60)
print(f"  Architecture:       ViT-Tiny (D=192, L=6, H=3)")
print(f"  Parameters:         {total_params:,}")
print(f"  Training epochs:    {EPOCHS}")
print(f"  Final train acc:    {history['train_acc'][-1]:.2f}%")
print(f"  Final val acc:      {history['val_acc'][-1]:.2f}%")
print(f"  Best val acc:       {max(history['val_acc']):.2f}% (epoch {np.argmax(history['val_acc'])+1})")
print(f"  Mean per-class acc: {np.mean(class_accs):.2f}%")
print(f"  Best class:         {CIFAR10_CLASSES[np.argmax(class_accs)]} ({np.max(class_accs):.1f}%)")
print(f"  Worst class:        {CIFAR10_CLASSES[np.argmin(class_accs)]} ({np.min(class_accs):.1f}%)")
print("=" * 60)

In [None]:
print("\n" + "=" * 70)
print("  You built, trained, and analyzed a Vision Transformer from scratch!")
print("=" * 70)
print("\n  In this 3-notebook series, you have:")
print("    Notebook 1: Converted images into patch embeddings")
print("    Notebook 2: Built a Transformer encoder with multi-head attention")
print("    Notebook 3: Assembled the full ViT, trained it, and visualized")
print("                what it learned")
print("\n  Key takeaways:")
print(f"    - ViT-Tiny with {total_params:,} params achieves ~{max(history['val_acc']):.0f}% on CIFAR-10")
print("    - Attention heads learn to focus on different image regions")
print("    - Position embeddings discover 2D spatial structure from 1D inputs")
print("    - ViT needs more data than CNNs but scales better with compute")
print()

In [None]:
#@title üéß Listen: Reflection
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/23_reflection.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## Section 8: Reflection and Next Steps

### Reflection Questions

Take a moment to think about what we have built and observed.

**1. The Accuracy Gap**

Our ViT-Tiny achieved ~78% on CIFAR-10. A ResNet-18 gets ~93% with similar training. Why the gap?

The key insight: **ViT has no built-in knowledge about images**. A CNN knows that nearby pixels are related (locality) and that features should work regardless of position (translation equivariance). ViT must learn all of this from data alone. With only 50,000 training images, it simply does not have enough examples to fully learn these spatial priors.

What would help ViT catch up?
- **More data**: The original ViT was pre-trained on 300M images. With enough data, ViT actually surpasses CNNs.
- **Data-efficient tricks**: DeiT showed that with strong augmentation, regularization, and knowledge distillation, ViT can be competitive on ImageNet even without massive pre-training.
- **Hybrid architectures**: Models like CoAtNet combine convolutional stems with Transformer encoders to get the best of both worlds.

**2. Attention Patterns**

Look back at the attention visualization. Different heads attend to different things ‚Äî this is called **attention head specialization**. You might see:
- Heads that attend to edges and boundaries
- Heads that focus on the central object
- Heads that capture global context more uniformly

This diversity is not explicitly programmed ‚Äî it emerges from training. The model finds it useful to look at images in multiple complementary ways, and the multi-head mechanism provides the capacity for this.

**3. Position Embeddings Learning 2D Structure**

Perhaps the most surprising result: position embeddings given only 1D indices (0, 1, 2, ..., 63) learned to encode 2D spatial relationships. How?

During training, the model sees that patch 0 (top-left) and patch 1 (one step right) consistently share visual features, while patch 0 and patch 63 (bottom-right) rarely do. Over thousands of gradient updates, the position embeddings adjust to reflect these statistical patterns ‚Äî which happen to encode 2D geometry. The model does not "know" about 2D; it just learns which positions tend to have similar content.

In [None]:
#@title üéß Listen: Challenges
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/24_challenges.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Optional Challenges

If you want to go deeper, try these extensions:

**Challenge 1: Global Average Pooling vs [CLS] Token**

Replace the [CLS] token approach with global average pooling:

In [None]:
# Instead of: cls_output = x[:, 0]
# Try:        cls_output = x[:, 1:].mean(dim=1)  # Average all patch tokens

Does accuracy change? Does training behave differently? For small models, GAP is sometimes slightly better because it uses information from all tokens directly.

**Challenge 2: Patch Size Ablation**

Try different patch sizes and compare accuracy vs training speed:

In [None]:
# Patch size 2: 256 patches ‚Äî more detail, slower training
# Patch size 4: 64 patches  ‚Äî our default
# Patch size 8: 16 patches  ‚Äî very coarse, fast training

Smaller patches give the model finer-grained information but increase the sequence length (and thus computation) quadratically. There is a sweet spot for each image resolution.

**Challenge 3: DeiT-style Knowledge Distillation**

Train a ResNet-18 teacher, then add a distillation token to ViT that learns to match the teacher's predictions. This can boost ViT accuracy significantly on small datasets. The DeiT paper showed this brings ViT to within 1% of CNNs on ImageNet without any external data.

**Challenge 4: Scale Up to ViT-Small**

In [None]:
# ViT-Small: embed_dim=384, depth=8, num_heads=6
# ~22M parameters ‚Äî much larger, needs longer training
model_small = VisionTransformer(
    embed_dim=384, depth=8, num_heads=6, mlp_ratio=4.0, dropout=0.1
)

Compare the accuracy improvement vs the increase in training time.

In [None]:
# Quick comparison: how model size scales with config
configs = {
    'ViT-Tiny (ours)': (192, 6, 3),
    'ViT-Small':       (384, 8, 6),
    'ViT-Base':        (768, 12, 12),
}
for name, (d, l, h) in configs.items():
    p = l * (4*d*d + 8*d*d) + d*197 + 48*d  # rough estimate
    print(f"{name:<18}: D={d}, L={l}, H={h}, ~{p/1e6:.1f}M params")

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/25_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### The Bigger Picture: ViT and Its Descendants

The Vision Transformer opened the floodgates for Transformer-based vision models. Here are the key descendants:

| Model | Year | Key Innovation |
|-------|------|---------------|
| **DeiT** | 2021 | Data-efficient training + distillation token |
| **Swin Transformer** | 2021 | Hierarchical + shifted windows for efficiency |
| **BEiT** | 2021 | BERT-style pre-training for vision (masked image modeling) |
| **MAE** | 2022 | Masked Autoencoder ‚Äî mask 75% of patches and reconstruct |
| **DINO / DINOv2** | 2021-23 | Self-supervised learning discovers objects without labels |
| **EVA** | 2023 | Scaled ViT to 1B+ params with improved training |

Each of these builds on the core ViT architecture we implemented today. The patch embedding, position embedding, Transformer encoder, and classification head ‚Äî these components remain fundamentally the same, even as the training methodology and scale evolve.

### The Unifying Vision

The most profound contribution of ViT was not architectural ‚Äî it was **philosophical**. Before ViT, vision and language were separate worlds with different architectures (CNNs vs Transformers). ViT showed that a single architecture, the Transformer, can handle both modalities.

This unification led directly to:
- **CLIP** (connecting vision and language in a shared space)
- **GPT-4V / Gemini** (multimodal models that see and read)
- **Stable Diffusion** (Transformers generating images from text)

The journey from "An Image is Worth 16x16 Words" to today's multimodal AI is a straight line. And you just built the starting point from scratch.

### What We Built Across Three Notebooks

| Notebook | Topic | Key Concept |
|----------|-------|-------------|
| **1** | Input Pipeline | Images are just matrices of patches; Conv2d is an efficient patching trick |
| **2** | Transformer Encoder | Self-attention lets every patch see every other patch; multi-head gives diversity |
| **3** | Complete ViT | Assembly + training reveals attention patterns and learned spatial structure |

**The core insight of this series**: A Vision Transformer is simpler than you might think. It is a patch embedding, position embedding, a stack of attention + MLP blocks, and a linear head. Nothing more. The magic comes from scale and data.

In [None]:
print("Congratulations! You have built, trained, and analyzed a Vision")
print("Transformer entirely from scratch.")
print()
print("You now understand not just WHAT a ViT does, but WHY each component")
print("exists and HOW they work together. This foundation will serve you")
print("well as you explore more advanced architectures.")
print()
print("The complete code from this series can be adapted for any image")
print("classification task. Try it on your own datasets!")