In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Vision Transformers from Scratch -- Vizuara

**We build the Vision Transformer (ViT) from first principles: patch embeddings, self-attention, and the full encoder -- all implemented manually before using PyTorch.**

In this notebook, we will:
1. Understand why CNNs struggle with global context
2. Implement patch embeddings from scratch
3. Build the self-attention mechanism step by step
4. Assemble a complete Vision Transformer

**Runtime:** Google Colab (GPU recommended, T4 is sufficient)
**Estimated time:** 60-75 minutes

## 1. Why Does This Matter?

In the previous notebook, we built a CNN that classifies images by sliding small filters across them. This works remarkably well, but it has a fundamental limitation: **each filter only sees a tiny local patch.**

Consider an image of a bird flying over the ocean. To correctly classify it, the network needs to understand the relationship between the bird (top of image) and the water (bottom of image). In a CNN, this information must pass through many layers before these distant regions can "communicate." In a Vision Transformer, every part of the image can attend to every other part in a single operation.

This is the key idea behind the Vision Transformer: **treat an image as a sequence of patches and let every patch attend to every other patch.** By the end of this notebook, you will have built one from scratch.

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
torch.manual_seed(42)

## 2. Building Intuition

Let us think about images differently. Instead of processing pixels with sliding filters, what if we:

1. **Cut the image into fixed-size patches** (like puzzle pieces)
2. **Flatten each patch** into a vector
3. **Treat each patch as a "word"** in a sentence
4. **Use the Transformer** (from NLP) to process this "sentence"

This is exactly the idea from the 2020 paper "An Image is Worth 16x16 Words." Let us see this concretely.

In [None]:
# Load a sample CIFAR-10 image
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
sample_img, sample_label = dataset[0]
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Image shape: {sample_img.shape}")
print(f"Label: {classes[sample_label]}")

# Visualize the image and its patches
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# Original image
axes[0].imshow(sample_img.permute(1, 2, 0))
axes[0].set_title(f'Original Image: {classes[sample_label]}')

# Image with patch grid overlay
axes[1].imshow(sample_img.permute(1, 2, 0))
patch_size = 4  # Using 4x4 patches for 32x32 images
for i in range(0, 32, patch_size):
    axes[1].axhline(y=i, color='red', linewidth=1)
    axes[1].axvline(x=i, color='red', linewidth=1)
axes[1].set_title(f'Divided into {(32//patch_size)**2} patches ({patch_size}x{patch_size})')

plt.tight_layout()
plt.show()

num_patches = (32 // patch_size) ** 2
print(f"\nWith patch_size={patch_size}: {num_patches} patches")
print(f"Each patch: {patch_size}x{patch_size}x3 = {patch_size*patch_size*3} values")

## 3. The Mathematics

### Patch Embedding

Given an image of size $H \times W$ with $C$ channels, and patch size $P$:

1. Number of patches: $N = \frac{H \times W}{P^2}$
2. Each patch is flattened: $\mathbf{x}_i^{\text{patch}} \in \mathbb{R}^{P^2 C}$
3. Linear projection: $\mathbf{z}_i = \mathbf{E} \cdot \mathbf{x}_i^{\text{patch}} + \mathbf{e}_i^{\text{pos}}$

where $\mathbf{E} \in \mathbb{R}^{D \times (P^2 C)}$ is the projection matrix and $\mathbf{e}_i^{\text{pos}}$ is the positional embedding.

Let us plug in some numbers for our CIFAR-10 images:
- Image: $32 \times 32 \times 3$, Patch size: $P = 4$
- Each patch: $4 \times 4 \times 3 = 48$ values
- Number of patches: $\frac{32 \times 32}{4^2} = 64$ patches
- Projection to dimension $D = 64$

So the image becomes a sequence of 64 tokens, each of dimension 64.

### Self-Attention

The attention mechanism computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

where Q, K, V are linear projections of the input. The scaling by $\sqrt{d_k}$ prevents the dot products from becoming too large.

Let us implement each component.

## 4. Let's Build It -- Component by Component

### Component 1: Patch Embedding (from scratch)

In [None]:
def extract_patches(image, patch_size):
    """
    Extract non-overlapping patches from an image.

    Args:
        image: tensor of shape (C, H, W)
        patch_size: int (P)

    Returns:
        patches: tensor of shape (N, P*P*C) where N = (H*W) / P^2
    """
    C, H, W = image.shape
    assert H % patch_size == 0 and W % patch_size == 0

    # Reshape image into grid of patches
    # (C, H, W) -> (C, H/P, P, W/P, P)
    patches = image.reshape(C, H // patch_size, patch_size, W // patch_size, patch_size)
    # Rearrange: (H/P, W/P, C, P, P) -> flatten last 3 dims
    patches = patches.permute(1, 3, 0, 2, 4).contiguous()
    patches = patches.reshape(-1, C * patch_size * patch_size)

    return patches

# Extract patches from our sample image
patch_size = 4
patches = extract_patches(sample_img, patch_size)
print(f"Image shape: {sample_img.shape}")
print(f"Patches shape: {patches.shape}")
print(f"  -> {patches.shape[0]} patches, each with {patches.shape[1]} values")

# Visualize the first 16 patches
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
fig.suptitle(f'First 16 patches (out of {patches.shape[0]})', fontsize=14)
for i in range(16):
    ax = axes[i // 8, i % 8]
    patch = patches[i].reshape(3, patch_size, patch_size).permute(1, 2, 0).numpy()
    ax.imshow(patch)
    ax.set_title(f'P{i}', fontsize=9)
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Now implement the full patch embedding with linear projection

class PatchEmbeddingManual(nn.Module):
    """Patch embedding implemented step by step."""

    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=64):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        patch_dim = in_channels * patch_size * patch_size

        # Linear projection: (P*P*C) -> D
        self.projection = nn.Linear(patch_dim, embed_dim)

        # Learnable positional embeddings
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim) * 0.02)

        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        B = x.shape[0]

        # Step 1: Extract patches
        patches = []
        for img in x:
            patches.append(extract_patches(img, self.patch_size))
        patches = torch.stack(patches)  # (B, N, P*P*C)

        # Step 2: Linear projection
        embeddings = self.projection(patches)  # (B, N, D)

        # Step 3: Prepend [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        embeddings = torch.cat([cls_tokens, embeddings], dim=1)  # (B, N+1, D)

        # Step 4: Add positional embeddings
        embeddings = embeddings + self.pos_embedding  # (B, N+1, D)

        return embeddings

# Test it
patch_embed = PatchEmbeddingManual(img_size=32, patch_size=4, embed_dim=64)
test_batch = sample_img.unsqueeze(0)  # (1, 3, 32, 32)
output = patch_embed(test_batch)
print(f"Input: {test_batch.shape}")
print(f"Output: {output.shape}")
print(f"  -> {output.shape[1]} tokens (1 CLS + {output.shape[1]-1} patches)")
print(f"  -> each token is {output.shape[2]}-dimensional")

### Component 2: Self-Attention (from scratch)

Now let us implement self-attention step by step. This is the core mechanism that allows every patch to communicate with every other patch.

In [None]:
class SelfAttentionManual(nn.Module):
    """Self-attention implemented step by step for pedagogical clarity."""

    def __init__(self, embed_dim=64, num_heads=4):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # d_k

        # Linear projections for Q, K, V
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, return_attention=False):
        B, N, D = x.shape
        H = self.num_heads
        d_k = self.head_dim

        # Step 1: Project to Q, K, V
        Q = self.W_q(x)  # (B, N, D)
        K = self.W_k(x)
        V = self.W_v(x)

        # Step 2: Reshape for multi-head: (B, N, D) -> (B, H, N, d_k)
        Q = Q.reshape(B, N, H, d_k).transpose(1, 2)
        K = K.reshape(B, N, H, d_k).transpose(1, 2)
        V = V.reshape(B, N, H, d_k).transpose(1, 2)

        # Step 3: Compute attention scores: QK^T / sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
        # scores shape: (B, H, N, N)

        # Step 4: Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        # attention_weights shape: (B, H, N, N)

        # Step 5: Weighted sum of values
        output = torch.matmul(attention_weights, V)
        # output shape: (B, H, N, d_k)

        # Step 6: Reshape back and project
        output = output.transpose(1, 2).reshape(B, N, D)
        output = self.W_o(output)

        if return_attention:
            return output, attention_weights
        return output

# Test it with our patch embeddings
attn = SelfAttentionManual(embed_dim=64, num_heads=4)
embeddings = patch_embed(test_batch)
attn_output, attn_weights = attn(embeddings, return_attention=True)

print(f"Input embeddings: {embeddings.shape}")
print(f"Attention output: {attn_output.shape}")
print(f"Attention weights: {attn_weights.shape}")
print(f"  -> {attn_weights.shape[1]} heads, each producing a {attn_weights.shape[2]}x{attn_weights.shape[3]} attention map")

In [None]:
# Visualization checkpoint 1: Attention patterns
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.suptitle('Attention Weights from 4 Heads (CLS token attending to all patches)', fontsize=13)

for head in range(4):
    # Get CLS token's attention to all patches (row 0 of attention matrix)
    cls_attn = attn_weights[0, head, 0, 1:].detach().numpy()  # Exclude CLS-to-CLS
    # Reshape to spatial grid
    grid_size = int(np.sqrt(len(cls_attn)))
    attn_map = cls_attn.reshape(grid_size, grid_size)

    axes[head].imshow(attn_map, cmap='hot')
    axes[head].set_title(f'Head {head+1}')
    axes[head].axis('off')

plt.tight_layout()
plt.show()
print("Each head attends to different parts of the image -- this is the power of multi-head attention.")

## 5. Your Turn

### TODO 1: Compute Attention Manually for 3 Patches

Given these Q, K, V matrices for 3 patches (4-dimensional each), compute the attention output step by step.

In [None]:
def compute_attention_manual(Q, K, V):
    """
    Compute scaled dot-product attention manually.

    Args:
        Q: tensor of shape (N, d_k) -- queries
        K: tensor of shape (N, d_k) -- keys
        V: tensor of shape (N, d_k) -- values

    Returns:
        output: tensor of shape (N, d_k) -- attention output
        weights: tensor of shape (N, N) -- attention weights

    Steps:
        1. Compute QK^T (matrix multiplication)
        2. Scale by 1/sqrt(d_k)
        3. Apply softmax row-wise
        4. Multiply by V
    """
    d_k = Q.shape[-1]

    # TODO: Step 1 - Compute raw scores
    # scores = ...

    # TODO: Step 2 - Scale
    # scores = scores / ...

    # TODO: Step 3 - Softmax
    # weights = F.softmax(scores, dim=...)

    # TODO: Step 4 - Weighted sum
    # output = torch.matmul(weights, V)

    # return output, weights
    pass

# Test with known values
Q_test = torch.tensor([[1., 0., 1., 0.],
                        [0., 1., 0., 1.],
                        [1., 1., 0., 0.]])
K_test = torch.tensor([[1., 1., 0., 0.],
                        [0., 0., 1., 1.],
                        [1., 0., 1., 0.]])
V_test = torch.tensor([[1., 0., 0., 1.],
                        [0., 1., 1., 0.],
                        [1., 1., 0., 0.]])

# Uncomment when ready:
# output, weights = compute_attention_manual(Q_test, K_test, V_test)
# print("Attention weights:")
# print(weights.numpy().round(3))
# print("\nOutput:")
# print(output.numpy().round(3))
# print("\nExpected row 0 weights: ~[0.24, 0.24, 0.52] (patch 1 attends most to patch 3)")

### TODO 2: Add Layer Normalization and MLP to Create a Transformer Block

Complete the Transformer encoder block below. Each block has:
- LayerNorm -> Multi-Head Self-Attention -> Residual
- LayerNorm -> MLP (2 linear layers with GELU) -> Residual

In [None]:
class TransformerBlock(nn.Module):
    """
    A single Transformer encoder block.

    Architecture:
        x' = MSA(LN(x)) + x       (attention with residual)
        z  = MLP(LN(x')) + x'      (feedforward with residual)

    Hint: MLP has two linear layers with GELU activation.
          Hidden dim is typically 4x the embed_dim.
    """
    def __init__(self, embed_dim=64, num_heads=4, mlp_ratio=4):
        super().__init__()
        # TODO: Define components
        # self.norm1 = nn.LayerNorm(embed_dim)
        # self.attn = SelfAttentionManual(embed_dim, num_heads)
        # self.norm2 = nn.LayerNorm(embed_dim)
        # self.mlp = nn.Sequential(
        #     nn.Linear(embed_dim, embed_dim * mlp_ratio),
        #     nn.GELU(),
        #     nn.Linear(embed_dim * mlp_ratio, embed_dim),
        # )
        pass

    def forward(self, x):
        # TODO: Implement forward with residual connections
        # x = self.attn(self.norm1(x)) + x    # Attention + residual
        # x = self.mlp(self.norm2(x)) + x      # MLP + residual
        # return x
        pass

# Verification:
# block = TransformerBlock(embed_dim=64, num_heads=4)
# test_input = torch.randn(2, 65, 64)  # (batch=2, tokens=65, dim=64)
# test_output = block(test_input)
# assert test_output.shape == test_input.shape, "Shape mismatch!"
# print(f"Input: {test_input.shape} -> Output: {test_output.shape}")
# print("Transformer block working correctly!")

## 6. Putting It All Together

Let us assemble the complete Vision Transformer.

In [None]:
class TransformerBlockFull(nn.Module):
    """Complete Transformer encoder block."""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = SelfAttentionManual(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(embed_dim * mlp_ratio, embed_dim),
        )

    def forward(self, x):
        x = self.attn(self.norm1(x)) + x
        x = self.mlp(self.norm2(x)) + x
        return x


class SimpleViT(nn.Module):
    """Complete Vision Transformer."""

    def __init__(self, img_size=32, patch_size=4, in_channels=3,
                 embed_dim=128, num_heads=4, num_layers=4, num_classes=10):
        super().__init__()

        self.num_patches = (img_size // patch_size) ** 2
        patch_dim = in_channels * patch_size * patch_size

        # Patch embedding (using Conv2d for efficiency)
        self.patch_embed = nn.Conv2d(in_channels, embed_dim,
                                      kernel_size=patch_size, stride=patch_size)

        # CLS token and positional embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.randn(1, self.num_patches + 1, embed_dim) * 0.02)

        # Transformer encoder blocks
        self.blocks = nn.ModuleList([
            TransformerBlockFull(embed_dim, num_heads) for _ in range(num_layers)
        ])

        # Final normalization and classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x, return_attention=False):
        B = x.shape[0]

        # Patch embedding
        x = self.patch_embed(x)       # (B, D, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)  # (B, N, D)

        # Prepend CLS token
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)

        # Add positional embeddings
        x = x + self.pos_embed

        # Transformer blocks
        for block in self.blocks:
            x = block(x)

        # Classification
        x = self.norm(x)
        cls_output = x[:, 0]  # CLS token
        return self.head(cls_output)


# Create the model
vit = SimpleViT(img_size=32, patch_size=4, embed_dim=128,
                num_heads=4, num_layers=4, num_classes=10).to(device)

total_params = sum(p.numel() for p in vit.parameters())
print(f"SimpleViT: {total_params:,} parameters")
print(f"\nArchitecture components:")
print(f"  Patch embedding: Conv2d(3, 128, kernel_size=4, stride=4)")
print(f"  Patches: {vit.num_patches}")
print(f"  Transformer blocks: 4")
print(f"  Embed dim: 128, Heads: 4")

## 7. Training and Results

In [None]:
# Load CIFAR-10
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Training
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(vit.parameters(), lr=3e-4, weight_decay=0.05)

train_losses, test_accs = [], []
NUM_EPOCHS = 20

for epoch in range(NUM_EPOCHS):
    vit.train()
    running_loss = 0.0
    correct, total = 0, 0

    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        outputs = vit(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss = running_loss / len(trainloader)
    train_acc = 100. * correct / total
    train_losses.append(train_loss)

    vit.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = vit(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    test_acc = 100. * correct / total
    test_accs.append(test_acc)

    print(f'Epoch {epoch+1}/{NUM_EPOCHS}: Loss={train_loss:.3f}, '
          f'Train={train_acc:.1f}%, Test={test_acc:.1f}%')

print(f'\nBest test accuracy: {max(test_accs):.1f}%')

In [None]:
# Visualization checkpoint 2: Training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(train_losses, 'b-', linewidth=2)
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss')
ax1.grid(alpha=0.3)
ax2.plot(test_accs, 'r-', linewidth=2)
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy (%)'); ax2.set_title('Test Accuracy')
ax2.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 8. Final Output

Let us visualize the attention maps -- these show which patches the model focuses on when making predictions. We will look at the CLS token's attention to all patches in the final layer.

In [None]:
# Visualization checkpoint 3: Attention maps on real images
vit.eval()
fig, axes = plt.subplots(3, 5, figsize=(18, 10))
fig.suptitle('ViT Attention Maps: Where the Model Looks', fontsize=15)

for row in range(3):
    idx = row * 100  # Pick different images
    img, label = testset[idx]
    img_gpu = img.unsqueeze(0).to(device)

    # Get attention from the last block
    with torch.no_grad():
        x = vit.patch_embed(img_gpu).flatten(2).transpose(1, 2)
        cls = vit.cls_token.expand(1, -1, -1)
        x = torch.cat([cls, x], dim=1) + vit.pos_embed
        for block in vit.blocks[:-1]:
            x = block(x)
        # Get attention from last block
        attn_out, attn_w = vit.blocks[-1].attn(vit.blocks[-1].norm1(x), return_attention=True)

    # Show original image
    img_show = img.permute(1, 2, 0).numpy()
    img_show = (img_show - img_show.min()) / (img_show.max() - img_show.min())
    axes[row, 0].imshow(img_show)
    axes[row, 0].set_title(f'{classes[label]}')
    axes[row, 0].axis('off')

    # Show attention from each head (CLS to patches)
    for head in range(4):
        cls_attn = attn_w[0, head, 0, 1:].cpu().numpy()
        grid_size = int(np.sqrt(len(cls_attn)))
        attn_map = cls_attn.reshape(grid_size, grid_size)
        # Upsample attention map to image size
        attn_map_resized = np.kron(attn_map, np.ones((32//grid_size, 32//grid_size)))

        axes[row, head+1].imshow(img_show)
        axes[row, head+1].imshow(attn_map_resized, cmap='hot', alpha=0.5)
        axes[row, head+1].set_title(f'Head {head+1}')
        axes[row, head+1].axis('off')

plt.tight_layout()
plt.show()

print("Each attention head learns to focus on different parts of the image.")
print("Some heads attend to the object; others attend to background or edges.")

## 9. Reflection and Next Steps

**What we built:** We implemented the entire Vision Transformer from scratch:
- Patch extraction and embedding
- Self-attention with multi-head attention
- Transformer encoder blocks with residual connections
- Full ViT model trained on CIFAR-10

**Key takeaways:**
1. Vision Transformers treat images as sequences of patches -- just like words in a sentence
2. Self-attention lets every patch attend to every other patch, capturing global relationships
3. Multi-head attention allows the model to attend to different aspects simultaneously
4. ViTs need sufficient data to match CNN performance (they lack built-in inductive biases)

**Reflection questions:**
- What happens if you double the patch size from 4 to 8? How does this affect the number of tokens and the accuracy?
- Why does the ViT need positional embeddings while a CNN does not?
- If you remove the CLS token and instead average all patch outputs, how does the accuracy change?

**Next:** In the next notebook, we will directly compare CNNs and ViTs on the same task, exploring their different strengths and trade-offs.