In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# CNN vs. Vision Transformer: A Hands-On Comparison -- Vizuara

**We train both a CNN and a ViT on the same dataset and directly compare their behavior, accuracy, computational cost, and what they learn.**

In this notebook, we will:
1. Train both architectures under identical conditions
2. Compare accuracy, parameter count, and training speed
3. Analyze how data size affects each architecture differently
4. Visualize the differences in learned representations

**Runtime:** Google Colab (GPU, T4 sufficient)
**Estimated time:** 45-60 minutes

## 1. Why Does This Matter?

CNNs and Vision Transformers are the two dominant paradigms for visual understanding. Choosing between them is one of the most important architectural decisions in modern computer vision. But the choice is not always obvious -- it depends on your dataset size, computational budget, and the type of visual reasoning your task requires.

In this notebook, we will run a controlled experiment: **same data, same training schedule, same evaluation** -- different architecture. By the end, you will have a clear intuition for when to use each.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
torch.manual_seed(42)

## 2. Building Intuition

The key difference between CNNs and ViTs can be summarized in one sentence: **CNNs have strong built-in assumptions about images; ViTs learn everything from data.**

CNNs assume:
- **Locality**: nearby pixels are more related than distant ones
- **Translation equivariance**: a cat looks the same regardless of where it appears

These assumptions are incredibly useful -- they act as a form of regularization. But they also limit what the network can learn.

ViTs make almost no assumptions. They must learn locality, translation equivariance, and spatial structure entirely from the training data. This gives them more flexibility but requires more data.

Let us set up both architectures.

In [None]:
# Define both architectures

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(), nn.Linear(128*4*4, 256), nn.ReLU(), nn.Linear(256, num_classes))

    def forward(self, x):
        return self.classifier(self.features(x))


class SimpleViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, embed_dim=128,
                 num_heads=4, num_layers=4, num_classes=10):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, embed_dim, patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches+1, embed_dim)*0.02)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4,
            activation='gelu', batch_first=True, norm_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        x = torch.cat([self.cls_token.expand(B, -1, -1), x], dim=1) + self.pos_embed
        x = self.norm(self.transformer(x))
        return self.head(x[:, 0])


# Instantiate both
cnn = SimpleCNN().to(device)
vit = SimpleViT().to(device)

cnn_params = sum(p.numel() for p in cnn.parameters())
vit_params = sum(p.numel() for p in vit.parameters())

print(f"SimpleCNN: {cnn_params:,} parameters")
print(f"SimpleViT: {vit_params:,} parameters")
print(f"ViT/CNN ratio: {vit_params/cnn_params:.1f}x")

## 3. The Mathematics

The key mathematical difference:

**CNN**: Each output neuron is a function of a **local** region:
$$y_{i,j} = \sum_{m,n} W_{m,n} \cdot x_{i+m, j+n}$$

**ViT**: Each output token is a function of **all** input tokens:
$$y_i = \sum_j \alpha_{ij} \cdot V_j, \quad \alpha_{ij} = \frac{\exp(Q_i \cdot K_j / \sqrt{d})}{\sum_k \exp(Q_i \cdot K_k / \sqrt{d})}$$

Let us plug in complexity numbers:
- For a 32x32 image with 3x3 convolutions: each output pixel depends on **9 values** (local)
- For a 32x32 image with 4x4 patches (64 patches): each patch attends to **64 other patches** (global)

This global attention is both the ViT's greatest strength and its computational cost.

## 4. Let's Build It -- Component by Component: The Training Framework

In [None]:
# Unified training function for fair comparison
def train_model(model, trainloader, testloader, epochs=20, lr=3e-4):
    """Train a model and return metrics."""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)

    history = {'train_loss': [], 'test_acc': [], 'epoch_time': []}

    for epoch in range(epochs):
        model.train()
        start_time = time.time()
        running_loss, correct, total = 0.0, 0, 0

        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        epoch_time = time.time() - start_time
        train_loss = running_loss / len(trainloader)
        train_acc = 100. * correct / total

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(device), labels.to(device)
                _, predicted = model(images).max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        test_acc = 100. * correct / total

        history['train_loss'].append(train_loss)
        history['test_acc'].append(test_acc)
        history['epoch_time'].append(epoch_time)

        if (epoch + 1) % 5 == 0:
            print(f'  Epoch {epoch+1}/{epochs}: Loss={train_loss:.3f}, '
                  f'Test={test_acc:.1f}%, Time={epoch_time:.1f}s')

    return history

In [None]:
# Prepare data
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(f"Training: {len(trainset)}, Test: {len(testset)}")

## 5. Your Turn

### TODO 1: Predict Which Architecture Wins

Before running the experiment, write down your predictions:

In [None]:
# TODO: Fill in your predictions before running the training!
predictions = {
    "which_has_higher_accuracy": "___",  # "CNN" or "ViT"
    "which_trains_faster": "___",         # "CNN" or "ViT"
    "accuracy_gap_percent": "___",        # e.g., "5" for a 5% gap
    "reasoning": "___"
}

# Print your predictions
for key, value in predictions.items():
    print(f"  {key}: {value}")
print("\nNow let us see if you are right!")

## 6. Putting It All Together

In [None]:
# Train CNN
print("=" * 50)
print("Training CNN...")
print("=" * 50)
cnn = SimpleCNN().to(device)
cnn_history = train_model(cnn, trainloader, testloader, epochs=20)

print("\n" + "=" * 50)
print("Training ViT...")
print("=" * 50)
vit = SimpleViT().to(device)
vit_history = train_model(vit, trainloader, testloader, epochs=20)

## 7. Training and Results

In [None]:
# Visualization checkpoint 1: Side-by-side comparison
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss curves
axes[0].plot(cnn_history['train_loss'], 'b-', label='CNN', linewidth=2)
axes[0].plot(vit_history['train_loss'], 'r-', label='ViT', linewidth=2)
axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss'); axes[0].legend(); axes[0].grid(alpha=0.3)

# Accuracy curves
axes[1].plot(cnn_history['test_acc'], 'b-', label='CNN', linewidth=2)
axes[1].plot(vit_history['test_acc'], 'r-', label='ViT', linewidth=2)
axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Test Accuracy (%)')
axes[1].set_title('Test Accuracy'); axes[1].legend(); axes[1].grid(alpha=0.3)

# Time per epoch
axes[2].bar(['CNN', 'ViT'],
            [np.mean(cnn_history['epoch_time']), np.mean(vit_history['epoch_time'])],
            color=['#4a90d9', '#e74c3c'])
axes[2].set_ylabel('Seconds per Epoch')
axes[2].set_title('Training Speed')

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"  CNN: {max(cnn_history['test_acc']):.1f}% accuracy, "
      f"{np.mean(cnn_history['epoch_time']):.1f}s/epoch")
print(f"  ViT: {max(vit_history['test_acc']):.1f}% accuracy, "
      f"{np.mean(vit_history['epoch_time']):.1f}s/epoch")

In [None]:
# Visualization checkpoint 2: Data efficiency experiment
# Train both models on subsets of increasing size

subset_sizes = [1000, 5000, 10000, 25000, 50000]
cnn_accs_by_size = []
vit_accs_by_size = []

print("Data Efficiency Experiment")
print("-" * 40)

for size in subset_sizes:
    # Create subset
    indices = torch.randperm(len(trainset))[:size]
    subset = torch.utils.data.Subset(trainset, indices)
    subloader = torch.utils.data.DataLoader(subset, batch_size=128, shuffle=True, num_workers=2)

    # Train CNN
    cnn_sub = SimpleCNN().to(device)
    h = train_model(cnn_sub, subloader, testloader, epochs=10)
    cnn_accs_by_size.append(max(h['test_acc']))

    # Train ViT
    vit_sub = SimpleViT().to(device)
    h = train_model(vit_sub, subloader, testloader, epochs=10)
    vit_accs_by_size.append(max(h['test_acc']))

    print(f"  {size:,} samples: CNN={cnn_accs_by_size[-1]:.1f}%, "
          f"ViT={vit_accs_by_size[-1]:.1f}%")

# Plot data efficiency
plt.figure(figsize=(8, 5))
plt.plot(subset_sizes, cnn_accs_by_size, 'bo-', label='CNN', linewidth=2, markersize=8)
plt.plot(subset_sizes, vit_accs_by_size, 'rs-', label='ViT', linewidth=2, markersize=8)
plt.xlabel('Training Set Size')
plt.ylabel('Test Accuracy (%)')
plt.title('Data Efficiency: CNN vs. ViT')
plt.legend(fontsize=12)
plt.grid(alpha=0.3)
plt.xscale('log')
plt.show()

print("\nKey insight: CNNs perform better with small datasets due to inductive biases.")
print("ViTs catch up and eventually surpass CNNs as data increases.")

### TODO 2: Analyze the Results

Answer these questions based on the experiments above:

In [None]:
# TODO: Fill in your analysis after running the experiments
analysis = {
    "which_model_won_with_full_data": "___",
    "which_model_won_with_1000_samples": "___",
    "at_what_dataset_size_did_vit_catch_up": "___",
    "why_does_cnn_perform_better_with_small_data": "___",
    "main_takeaway": "___"
}

for key, value in analysis.items():
    print(f"  {key}: {value}")

## 8. Final Output

Let us create a comprehensive summary visualization.

In [None]:
# Visualization checkpoint 3: Comprehensive comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('CNN vs. Vision Transformer: Complete Comparison', fontsize=16, fontweight='bold')

# 1. Parameter comparison
cnn_p = sum(p.numel() for p in SimpleCNN().parameters())
vit_p = sum(p.numel() for p in SimpleViT().parameters())
axes[0, 0].bar(['CNN', 'ViT'], [cnn_p/1e6, vit_p/1e6], color=['#4a90d9', '#e74c3c'])
axes[0, 0].set_ylabel('Parameters (Millions)')
axes[0, 0].set_title('Model Size')

# 2. Accuracy comparison
axes[0, 1].bar(['CNN', 'ViT'],
               [max(cnn_history['test_acc']), max(vit_history['test_acc'])],
               color=['#4a90d9', '#e74c3c'])
axes[0, 1].set_ylabel('Test Accuracy (%)')
axes[0, 1].set_title('Best Accuracy (Full Dataset)')
axes[0, 1].set_ylim(60, 95)

# 3. Training speed
axes[1, 0].bar(['CNN', 'ViT'],
               [np.mean(cnn_history['epoch_time']),
                np.mean(vit_history['epoch_time'])],
               color=['#4a90d9', '#e74c3c'])
axes[1, 0].set_ylabel('Seconds per Epoch')
axes[1, 0].set_title('Training Speed')

# 4. Data efficiency
axes[1, 1].plot(subset_sizes, cnn_accs_by_size, 'bo-', label='CNN', linewidth=2)
axes[1, 1].plot(subset_sizes, vit_accs_by_size, 'rs-', label='ViT', linewidth=2)
axes[1, 1].set_xlabel('Dataset Size')
axes[1, 1].set_ylabel('Accuracy (%)')
axes[1, 1].set_title('Data Efficiency')
axes[1, 1].legend()
axes[1, 1].set_xscale('log')
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("\nSummary:")
print(f"  CNN: {cnn_p:,} params, {max(cnn_history['test_acc']):.1f}% accuracy, "
      f"{np.mean(cnn_history['epoch_time']):.1f}s/epoch")
print(f"  ViT: {vit_p:,} params, {max(vit_history['test_acc']):.1f}% accuracy, "
      f"{np.mean(vit_history['epoch_time']):.1f}s/epoch")

## 9. Reflection and Next Steps

**What we learned:**
- CNNs and ViTs have different strengths that depend on the data regime
- With limited data, CNN's inductive biases give it a significant advantage
- With more data, ViT's flexibility allows it to learn richer representations
- Hybrid approaches that combine both are often the best practical choice

**Reflection questions:**
- If you were building a production image classifier with only 5,000 labeled images, which architecture would you choose and why?
- The original ViT paper needed 300M images to beat CNNs on ImageNet. DeiT later achieved competitive results with just ImageNet (1.2M). What changed?
- Modern architectures like ConvNeXt bring ViT design principles back to CNNs. Does this blur the distinction between the two paradigms?

**Next:** In the next notebook, we will dive deeper into attention visualization and explore what Vision Transformers actually learn about images.