# Topic 6: Convolutional Neural Networks (CNNs)

## Learning Objectives

By the end of this notebook, you will:
- Understand **why** CNNs were invented and what problem they solve
- Learn how convolutions exploit spatial structure in images
- Build CNNs from scratch using PyTorch
- Understand pooling, stride, padding, and their purposes
- Recognize CNN architectures and their evolution
- Connect CNNs to the broader deep learning ecosystem

## The Big Picture: Why CNNs?

### The Problem CNNs Solve

Imagine you want to build a classifier for 224x224 RGB images using a fully-connected neural network:

- **Input size**: 224 × 224 × 3 = 150,528 pixels
- **First hidden layer** (1,000 neurons): 150,528 × 1,000 = **150 million parameters**
- That's just the first layer!

**Problems**:
1. **Too many parameters**: Computationally expensive, prone to overfitting
2. **Ignores spatial structure**: Treats pixels independently, doesn't understand that nearby pixels are related
3. **Not translation invariant**: A cat in the top-left corner looks completely different from a cat in the bottom-right

### The CNN Solution

CNNs solve these problems by:
1. **Local connectivity**: Each neuron only looks at a small region (receptive field)
2. **Parameter sharing**: Same weights are reused across the entire image
3. **Hierarchy of features**: Early layers detect edges, later layers detect complex patterns

**Real-world impact**: CNNs power:
- Computer vision (object detection, segmentation, face recognition)
- Medical imaging (tumor detection, X-ray analysis)
- Self-driving cars (lane detection, obstacle recognition)
- Even parts of modern vision transformers (ViT uses CNN-like tokenization)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)

print(f"PyTorch version: {torch.__version__}")

## Understanding Convolutions: The Core Operation

### What is a Convolution?

A **convolution** is a mathematical operation that slides a small matrix (kernel/filter) over an image and computes element-wise products and sums.

**Why this works**:
- **Local patterns matter**: Edges, textures, and shapes are local phenomena
- **Translation invariance**: The same kernel detects the same feature anywhere in the image
- **Parameter efficiency**: A 3×3 kernel has only 9 parameters, but can process an entire image

### Convolution Step-by-Step

Let's see how a simple edge detection kernel works:

In [None]:
# Create a simple image (8x8) with a vertical edge
image = torch.zeros(1, 1, 8, 8)  # batch=1, channels=1, height=8, width=8
image[:, :, :, 4:] = 1.0  # Right half is white

# Vertical edge detection kernel
# Why this works: Detects transitions from dark to light (or vice versa)
kernel = torch.tensor([[
    [-1, 0, 1],
    [-2, 0, 2],
    [-1, 0, 1]
]], dtype=torch.float32).view(1, 1, 3, 3)

# Apply convolution
output = F.conv2d(image, kernel, padding=1)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(image[0, 0], cmap='gray')
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(kernel[0, 0], cmap='gray')
axes[1].set_title('Edge Detection Kernel')
axes[1].axis('off')

axes[2].imshow(output[0, 0].detach(), cmap='gray')
axes[2].set_title('Detected Edges')
axes[2].axis('off')

plt.tight_layout()
plt.show()

print("Notice how the edge is highlighted in the output!")

## Convolution Parameters: Padding, Stride, and Dilation

### 1. Padding

**What**: Adding extra pixels around the image border

**Why it's needed**:
- **Preserve spatial dimensions**: Without padding, each conv layer shrinks the image
- **Use edge information**: Edge pixels are processed fewer times without padding

**Common choices**:
- `padding=0`: Valid convolution (shrinks output)
- `padding=kernel_size//2`: Same convolution (preserves dimensions)

### 2. Stride

**What**: Number of pixels to skip when sliding the kernel

**Why it's used**:
- **Reduce spatial dimensions**: Stride > 1 downsamples the image
- **Computational efficiency**: Fewer operations than processing every pixel
- **Alternative to pooling**: Modern architectures sometimes use strided convs instead of pooling

### 3. Dilation

**What**: Spacing between kernel elements (gaps in the filter)

**Why it's used**:
- **Larger receptive field**: See more context without adding parameters
- **Multi-scale processing**: Capture patterns at different scales

In [None]:
# Demonstrate different padding and stride effects
input_tensor = torch.randn(1, 1, 8, 8)

# Same padding (output size = input size)
conv_same = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
out_same = conv_same(input_tensor)

# Valid convolution (no padding)
conv_valid = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=0, bias=False)
out_valid = conv_valid(input_tensor)

# Strided convolution (downsampling)
conv_stride = nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1, bias=False)
out_stride = conv_stride(input_tensor)

print(f"Input shape: {input_tensor.shape}")
print(f"Same padding (padding=1, stride=1): {out_same.shape}")
print(f"Valid convolution (padding=0, stride=1): {out_valid.shape}")
print(f"Strided convolution (padding=1, stride=2): {out_stride.shape}")

print("\nNotice:")
print("- Same padding preserves dimensions: 8x8 -> 8x8")
print("- Valid conv shrinks: 8x8 -> 6x6 (lost 2 pixels per dimension)")
print("- Stride=2 downsamples: 8x8 -> 4x4 (halved dimensions)")

## Pooling Layers: Why Downsample?

### Purpose of Pooling

Pooling reduces spatial dimensions while preserving important features.

**Why pooling is needed**:
1. **Reduce computation**: Fewer pixels = fewer operations in later layers
2. **Increase receptive field**: Each neuron sees a larger portion of the original image
3. **Translation invariance**: Small shifts in input don't change output
4. **Prevent overfitting**: Acts as a form of regularization

### Types of Pooling

**Max Pooling**: Takes maximum value in each region
- **Why**: Preserves strongest activations (most prominent features)
- **Common choice**: 2×2 with stride 2 (halves dimensions)

**Average Pooling**: Takes average of values in each region
- **Why**: Smoother downsampling, often used in final layers

**Global Average Pooling**: Averages each feature map to a single value
- **Why**: Replaces fully-connected layers, reduces parameters

In [None]:
# Demonstrate pooling operations
feature_map = torch.randn(1, 1, 8, 8)

# Max pooling (2x2, stride 2)
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
out_max = max_pool(feature_map)

# Average pooling (2x2, stride 2)
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
out_avg = avg_pool(feature_map)

# Global average pooling
gap = nn.AdaptiveAvgPool2d((1, 1))
out_gap = gap(feature_map)

print(f"Original shape: {feature_map.shape}")
print(f"After max pooling: {out_max.shape}")
print(f"After average pooling: {out_avg.shape}")
print(f"After global average pooling: {out_gap.shape}")

print("\nGlobal Average Pooling (GAP) is especially useful:")
print("- Replaces flatten + fully-connected layers")
print("- Works with any input size")
print("- Reduces overfitting (fewer parameters)")

## Building a Simple CNN from Scratch

Let's build a classic CNN architecture for CIFAR-10 classification.

### Architecture Design Principles

**Why this structure**:
1. **Progressive feature extraction**: Early layers detect simple patterns (edges), later layers detect complex patterns (objects)
2. **Increase channels, decrease spatial dimensions**: Trade resolution for feature richness
3. **Batch normalization**: Stabilizes training, allows higher learning rates
4. **Dropout**: Prevents overfitting by randomly dropping neurons

**Common pattern**:
```
Conv -> BatchNorm -> ReLU -> Pool -> 
Conv -> BatchNorm -> ReLU -> Pool -> 
... -> 
Global Average Pool -> Classifier
```

In [None]:
class SimpleCNN(nn.Module):
    """
    A simple but effective CNN for CIFAR-10.
    
    Architecture:
    - 3 convolutional blocks (conv -> bn -> relu -> pool)
    - Each block doubles the channels and halves spatial dimensions
    - Global average pooling to reduce parameters
    - Single fully-connected layer for classification
    """
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        
        # Why these channel sizes? 
        # Start small (32), gradually increase (64, 128) to capture complexity
        
        # Block 1: 32x32x3 -> 16x16x32
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)  # Stabilizes training
        self.pool1 = nn.MaxPool2d(2, 2)  # Downsample by 2
        
        # Block 2: 16x16x32 -> 8x8x64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        # Block 3: 8x8x64 -> 4x4x128
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        # Global average pooling: 4x4x128 -> 1x1x128
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        
        # Classifier
        self.dropout = nn.Dropout(0.5)  # Regularization
        self.fc = nn.Linear(128, num_classes)
    
    def forward(self, x):
        # Block 1
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        
        # Block 2
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        
        # Block 3
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        
        # Global pooling and classification
        x = self.gap(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.dropout(x)
        x = self.fc(x)
        
        return x

# Create model and inspect
model = SimpleCNN(num_classes=10)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
dummy_input = torch.randn(4, 3, 32, 32)  # Batch of 4 CIFAR-10 images
output = model(dummy_input)
print(f"\nInput shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print("Output represents logits for 10 classes")

## Training the CNN on CIFAR-10

Let's train our CNN on a real dataset to see it in action.

In [None]:
# Data preparation
# Why these transforms?
# - Normalization: Helps optimization by centering data around 0
# - Data augmentation (train only): Improves generalization

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Augmentation
    transforms.RandomCrop(32, padding=4),  # Augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Training samples: {len(trainset)}")
print(f"Test samples: {len(testset)}")
print(f"Classes: {classes}")

In [None]:
# Training function
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100. * correct / total

# Evaluation function
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100. * correct / total

# Setup training
device = torch.device('cpu')  # Use CPU for this demo
model = SimpleCNN(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train for a few epochs (use more for better results)
num_epochs = 3

print("Training started...\n")
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, trainloader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, testloader, criterion, device)
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")

print("\nTraining complete!")
print(f"Final test accuracy: {test_acc:.2f}%")

## Visualizing What CNNs Learn

Let's visualize the learned filters to understand what features the CNN detects.

In [None]:
# Visualize first layer filters
def visualize_filters(model):
    # Get first conv layer weights
    weights = model.conv1.weight.data.cpu()
    
    # Normalize to [0, 1] for visualization
    weights = (weights - weights.min()) / (weights.max() - weights.min())
    
    # Plot first 16 filters
    fig, axes = plt.subplots(4, 8, figsize=(12, 6))
    for idx, ax in enumerate(axes.flat):
        if idx < weights.shape[0]:
            # Convert RGB channels to grayscale for visualization
            filter_img = weights[idx].mean(0)
            ax.imshow(filter_img, cmap='viridis')
        ax.axis('off')
    
    plt.suptitle('First Layer Filters (32 total, showing first 32)')
    plt.tight_layout()
    plt.show()

visualize_filters(model)

print("These filters learned to detect basic patterns:")
print("- Edges at different orientations")
print("- Color gradients")
print("- Simple textures")
print("\nDeeper layers combine these to detect more complex features!")

## Modern CNN Architectures: Evolution and Design Patterns

### Historical Evolution

**1. LeNet-5 (1998)**: First successful CNN
- **Why important**: Proved CNNs work for digit recognition
- **Architecture**: Conv -> Pool -> Conv -> Pool -> FC

**2. AlexNet (2012)**: ImageNet breakthrough
- **Why revolutionary**: Won ImageNet by huge margin, revived deep learning
- **Innovations**: ReLU, Dropout, GPU training, data augmentation

**3. VGG (2014)**: Simplicity and depth
- **Why influential**: Showed deeper is better (16-19 layers)
- **Design principle**: Small 3×3 filters, consistent architecture

**4. ResNet (2015)**: Residual connections
- **Why critical**: Solved vanishing gradients, enabled 100+ layer networks
- **Innovation**: Skip connections (x + F(x))

**5. EfficientNet (2019)**: Compound scaling
- **Why modern**: Balanced depth, width, and resolution
- **Innovation**: Neural architecture search (NAS)

### Key Design Patterns

**Residual Blocks** (ResNet, ResNeXt):
```python
# Skip connection allows gradients to flow directly
out = F.relu(x + conv_block(x))
```

**Inception Modules** (GoogLeNet):
- Multiple filter sizes in parallel
- Captures multi-scale features

**Depthwise Separable Convolutions** (MobileNet):
- Factorize convolution into depthwise + pointwise
- Dramatically fewer parameters

**Squeeze-and-Excitation** (SENet):
- Channel-wise attention mechanism
- Reweights feature maps by importance

In [None]:
# Example: Simple Residual Block
class ResidualBlock(nn.Module):
    """
    Basic residual block with skip connection.
    
    Why skip connections work:
    - Gradient flows directly through skip connection
    - Network can learn identity function (do nothing) if needed
    - Easier to optimize very deep networks
    """
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # If dimensions change, use 1x1 conv to match
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = self.shortcut(x)
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # Add skip connection BEFORE activation
        out += identity
        out = F.relu(out)
        
        return out

# Test residual block
res_block = ResidualBlock(64, 128, stride=2)
test_input = torch.randn(1, 64, 32, 32)
test_output = res_block(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")
print("\nNotice: Channels doubled (64->128) and spatial dimensions halved (32->16)")
print("The skip connection was automatically adjusted to match dimensions!")

## Connections to Modern Deep Learning

### CNNs in the Age of Transformers

**Why CNNs are still relevant**:

1. **Vision Transformers (ViT)**: Use CNN-like patch embeddings
   - Input image is divided into patches (like conv kernels)
   - Some architectures combine CNNs with transformers

2. **Hybrid architectures**: 
   - ConvNeXt (2022): Modernized ResNet matches ViT performance
   - Combines CNN efficiency with transformer-inspired design

3. **Efficiency**: CNNs are faster and more parameter-efficient for many vision tasks

4. **Inductive biases**: CNNs have built-in assumptions (locality, translation invariance) that help on images

### Key Concepts to Remember

**Why CNNs work**:
- Exploit spatial structure through local connectivity
- Parameter sharing reduces overfitting
- Hierarchical feature learning (edges -> textures -> objects)

**When to use CNNs**:
- Image classification, detection, segmentation
- Any spatial data (medical images, satellite imagery)
- When efficiency matters (mobile devices)

**Connection to transformers**:
- Both build hierarchical representations
- Transformers replace local connectivity with attention
- Hybrid approaches combine strengths of both

## Mini Exercises

Test your understanding with these exercises.

### Exercise 1: Calculate Output Dimensions

Given an input of size 64×64 and a conv layer with:
- kernel_size=5
- stride=2
- padding=2

What is the output size?

Formula: output_size = (input_size + 2×padding - kernel_size) / stride + 1

In [None]:
# YOUR CODE HERE
# Calculate the output size


# SOLUTION (run to reveal)
def show_solution_1():
    input_size = 64
    kernel_size = 5
    stride = 2
    padding = 2
    
    output_size = (input_size + 2*padding - kernel_size) // stride + 1
    print(f"Output size: {output_size}")
    print(f"\nCalculation: (64 + 2*2 - 5) / 2 + 1 = 63/2 + 1 = 32")
    
    # Verify with PyTorch
    conv = nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2)
    x = torch.randn(1, 3, 64, 64)
    out = conv(x)
    print(f"PyTorch verification: {out.shape}")

# Uncomment to see solution:
# show_solution_1()

### Exercise 2: Design a Deeper CNN

Modify `SimpleCNN` to have 4 convolutional blocks instead of 3. Make sure to:
- Follow the pattern of doubling channels each block
- Use batch normalization
- Maintain the same final classifier structure

In [None]:
# YOUR CODE HERE
class DeeperCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeeperCNN, self).__init__()
        # Add your code here
        pass
    
    def forward(self, x):
        # Add your code here
        pass


# SOLUTION (run to reveal)
def show_solution_2():
    class DeeperCNN(nn.Module):
        def __init__(self, num_classes=10):
            super(DeeperCNN, self).__init__()
            
            # Block 1: 32x32 -> 16x16, channels: 3 -> 32
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
            self.bn1 = nn.BatchNorm2d(32)
            self.pool1 = nn.MaxPool2d(2, 2)
            
            # Block 2: 16x16 -> 8x8, channels: 32 -> 64
            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
            self.bn2 = nn.BatchNorm2d(64)
            self.pool2 = nn.MaxPool2d(2, 2)
            
            # Block 3: 8x8 -> 4x4, channels: 64 -> 128
            self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            self.bn3 = nn.BatchNorm2d(128)
            self.pool3 = nn.MaxPool2d(2, 2)
            
            # Block 4: 4x4 -> 2x2, channels: 128 -> 256
            self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
            self.bn4 = nn.BatchNorm2d(256)
            self.pool4 = nn.MaxPool2d(2, 2)
            
            # Global average pooling and classifier
            self.gap = nn.AdaptiveAvgPool2d((1, 1))
            self.dropout = nn.Dropout(0.5)
            self.fc = nn.Linear(256, num_classes)
        
        def forward(self, x):
            x = self.pool1(F.relu(self.bn1(self.conv1(x))))
            x = self.pool2(F.relu(self.bn2(self.conv2(x))))
            x = self.pool3(F.relu(self.bn3(self.conv3(x))))
            x = self.pool4(F.relu(self.bn4(self.conv4(x))))
            
            x = self.gap(x)
            x = x.view(x.size(0), -1)
            x = self.dropout(x)
            x = self.fc(x)
            return x
    
    model = DeeperCNN()
    print(model)
    print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

# Uncomment to see solution:
# show_solution_2()

### Exercise 3: Implement Depthwise Separable Convolution

Depthwise separable convolution is used in MobileNet for efficiency. It consists of:
1. **Depthwise conv**: Apply separate filters to each input channel
2. **Pointwise conv**: 1×1 conv to combine channels

Why is this more efficient? Count the parameters!

In [None]:
# YOUR CODE HERE
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(DepthwiseSeparableConv, self).__init__()
        # Add your code here
        pass
    
    def forward(self, x):
        # Add your code here
        pass


# SOLUTION (run to reveal)
def show_solution_3():
    class DepthwiseSeparableConv(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
            super(DepthwiseSeparableConv, self).__init__()
            
            # Depthwise: Apply 3x3 filter to EACH input channel separately
            # groups=in_channels means each input channel gets its own filter
            self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
                                       stride=stride, padding=padding, groups=in_channels,
                                       bias=False)
            self.bn1 = nn.BatchNorm2d(in_channels)
            
            # Pointwise: 1x1 conv to combine channels
            self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                                       bias=False)
            self.bn2 = nn.BatchNorm2d(out_channels)
        
        def forward(self, x):
            x = F.relu(self.bn1(self.depthwise(x)))
            x = F.relu(self.bn2(self.pointwise(x)))
            return x
    
    # Compare parameter count
    in_ch, out_ch = 64, 128
    
    # Standard conv
    standard_conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
    standard_params = sum(p.numel() for p in standard_conv.parameters())
    
    # Depthwise separable conv
    dw_conv = DepthwiseSeparableConv(in_ch, out_ch)
    dw_params = sum(p.numel() for p in dw_conv.parameters())
    
    print(f"Standard Conv2d parameters: {standard_params:,}")
    print(f"Depthwise Separable parameters: {dw_params:,}")
    print(f"\nReduction: {standard_params / dw_params:.2f}x fewer parameters!")
    print("\nWhy this matters: Mobile devices have limited compute/memory")

# Uncomment to see solution:
# show_solution_3()

## Comprehensive Exercise: Build a Modern CNN

Build a CNN that combines modern techniques:
1. Residual connections (like ResNet)
2. Batch normalization
3. Global average pooling
4. Data augmentation

Train it on CIFAR-10 and achieve >75% test accuracy.

In [None]:
# YOUR CODE HERE
# Build and train your modern CNN


# SOLUTION (run to reveal)
def show_comprehensive_solution():
    class ModernCNN(nn.Module):
        def __init__(self, num_classes=10):
            super(ModernCNN, self).__init__()
            
            # Initial conv
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            
            # Residual blocks
            self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
            self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
            self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
            
            # Global average pooling and classifier
            self.gap = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(256, num_classes)
        
        def _make_layer(self, in_channels, out_channels, num_blocks, stride):
            layers = []
            # First block may downsample
            layers.append(ResidualBlock(in_channels, out_channels, stride))
            # Remaining blocks maintain dimensions
            for _ in range(1, num_blocks):
                layers.append(ResidualBlock(out_channels, out_channels, stride=1))
            return nn.Sequential(*layers)
        
        def forward(self, x):
            x = F.relu(self.bn1(self.conv1(x)))
            
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            
            x = self.gap(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            return x
    
    # Training code
    device = torch.device('cpu')
    model = ModernCNN(num_classes=10).to(device)
    
    print("Modern CNN Architecture:")
    print(model)
    print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Setup training (would need more epochs for >75% accuracy)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    
    print("\nThis architecture combines:")
    print("✓ Residual connections for deep training")
    print("✓ Batch normalization for stable optimization")
    print("✓ Global average pooling for fewer parameters")
    print("✓ Cosine annealing scheduler for better convergence")

# Uncomment to see solution:
# show_comprehensive_solution()

## Key Takeaways

### Core Concepts

1. **CNNs exploit spatial structure**:
   - Local connectivity reduces parameters
   - Parameter sharing enables translation invariance
   - Hierarchical features (edges → objects)

2. **Key components and their purposes**:
   - **Convolution**: Feature extraction with parameter sharing
   - **Pooling**: Downsampling and translation invariance
   - **Batch Norm**: Stable training and faster convergence
   - **Residual connections**: Enable very deep networks

3. **Design principles**:
   - Increase channels while decreasing spatial dimensions
   - Use small filters (3×3) with multiple layers
   - Add skip connections for depth
   - Use global average pooling to reduce overfitting

4. **Modern relevance**:
   - Still state-of-the-art for many vision tasks
   - More efficient than transformers on smaller datasets
   - Hybrid CNN-transformer architectures are emerging
   - Principles apply to any spatial data

### Connection to Broader Ecosystem

**CNNs → Vision Transformers**:
- ViT uses patch embeddings (similar to conv kernels)
- Some architectures combine CNNs and attention

**CNNs → Modern Architectures**:
- Residual connections appear everywhere (Transformers use them too)
- Batch normalization evolved into Layer Norm
- Attention mechanisms can be seen as dynamic convolution

### What's Next?

You've mastered CNNs! Next topics:
- **Attention mechanisms**: The foundation of transformers
- **Positional encodings**: How to inject sequence information
- **Transformer architecture**: The revolution in deep learning

CNNs taught us about hierarchical feature learning and parameter sharing—concepts that carry forward into transformers and beyond!

## Further Reading

### Classic Papers
1. **LeCun et al. (1998)**: "Gradient-Based Learning Applied to Document Recognition" (LeNet)
2. **Krizhevsky et al. (2012)**: "ImageNet Classification with Deep Convolutional Neural Networks" (AlexNet)
3. **He et al. (2015)**: "Deep Residual Learning for Image Recognition" (ResNet)
4. **Tan & Le (2019)**: "EfficientNet: Rethinking Model Scaling for CNNs"

### Modern Perspectives
5. **Liu et al. (2022)**: "A ConvNet for the 2020s" (ConvNeXt)
6. **Dosovitskiy et al. (2020)**: "An Image is Worth 16x16 Words" (Vision Transformer)

### Resources
- **CS231n**: Stanford's CNN course (http://cs231n.stanford.edu/)
- **PyTorch Vision Models**: torchvision.models for pre-trained architectures
- **Papers With Code**: Compare architectures and benchmarks