# Part 7: Convolutional Neural Networks (CNNs) 👁️

Welcome to the finale! 

Standard Neural Networks (MLPs) are bad at images. Why?
1. **Too many parameters**: A 1000x1000 image has 1M pixels. If the first hidden layer has 1000 neurons, that's 1 Billion weights!
2. **Ignore Structure**: They flatten the image, losing the idea that pixels next to each other are related.

**Solution**: CNNs. They look for features (edges, curves) regardless of where they are.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
%matplotlib inline

## 1. The Building Blocks

### Convolution (`nn.Conv2d`)
A small window (kernel) slides over the image, multiplying its values with the image pixels. It learns to detect features.

### Pooling (`nn.MaxPool2d`)
Shrinks the image by keeping only the most important (maximum) value in a region. Reduces computation.

In [None]:
# Example of a simple CNN block
conv_block = nn.Sequential(
    nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2)
)

# Fake image batch: 4 images, 3 channels (RGB), 32x32 pixels
fake_image = torch.randn(4, 3, 32, 32)
output = conv_block(fake_image)

# Output shape: 32x32 -> MaxPool(2) -> 16x16
print(f"Input Shape: {fake_image.shape}")
print(f"Output Shape: {output.shape}")

## 2. Building a Classifier

Let's build a network to classify 32x32 images (like CIFAR-10).

Architecture:
`ConvBlock -> ConvBlock -> Flatten -> Linear -> Output`

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Block 1: 32x32 -> 16x16
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        # Block 2: 16x16 -> 8x8
        self.block2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        # Flatten and Classify
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 8 * 8, 64), # 32 channels * 8 * 8 pixels
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.classifier(x)
        return x

model = SimpleCNN()
print(model)

## 3. Does it run?

Before training, always do a "dummy pass" to check shapes.

In [None]:
dummy_input = torch.randn(1, 3, 32, 32)
try:
    output = model(dummy_input)
    print("✅ Forward pass successful!")
    print(f"Output shape: {output.shape} (1 batch, 10 classes)")
except Exception as e:
    print(f"❌ Error: {e}")

## 4. Seeing What CNNs Learn 🔍

Let's train our CNN on MNIST and **visualize what features it learns**. This will make the abstract concept concrete!

In [None]:
# Need to adapt our model for grayscale MNIST (1 channel instead of 3)
class MNISTSimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Block 1: Input is now 1 channel (grayscale)
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),  # Changed from 3 to 1 input channel
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        # Block 2: Same as before
        self.block2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        # MNIST is 28x28, so after two MaxPool(2): 28/2/2 = 7x7
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.classifier(x)
        return x

# Create model
mnist_model = MNISTSimpleCNN()
print("Model created!")

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Download and prepare MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Simple normalization
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Quick training (1 epoch is enough to see learned features)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mnist_model.parameters(), lr=0.001)

print("Training for 1 epoch...")
mnist_model.train()
for batch_idx, (images, labels) in enumerate(train_loader):
    optimizer.zero_grad()
    outputs = mnist_model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    
    if batch_idx % 100 == 0:
        print(f"  Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

print("✅ Training complete!")

### 📊 Visualizing Learned Filters

The first convolutional layer has 16 filters. Let's see what patterns they learned to detect!

In [None]:
# Extract the learned filters from the first conv layer
filters = mnist_model.block1[0].weight.data.cpu()

# Plot all 16 filters in a 4x4 grid
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.suptitle('Learned Filters (First Conv Layer)', fontsize=16)

for i, ax in enumerate(axes.flat):
    # Each filter is shape (1, 3, 3) - we take the single channel
    filter_img = filters[i, 0, :, :]
    ax.imshow(filter_img, cmap='gray')
    ax.set_title(f'Filter {i+1}', fontsize=10)
    ax.axis('off')

plt.tight_layout()
plt.show()

print("Notice: Some filters detect edges (horizontal/vertical/diagonal),")
print("        others detect corners or curves. The network learned these automatically!")

### 🎯 Visualizing Feature Maps

Now let's see how these filters **activate** on a real MNIST digit!

In [None]:
# Get a test image
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_image, test_label = test_dataset[0]

# Forward pass through first conv layer only
mnist_model.eval()
with torch.no_grad():
    test_image_batch = test_image.unsqueeze(0)  # Add batch dimension
    feature_maps = mnist_model.block1[0](test_image_batch)  # Just the conv layer
    feature_maps = torch.relu(feature_maps)  # Apply ReLU

# Visualize
fig, axes = plt.subplots(3, 6, figsize=(12, 6))
fig.suptitle(f'Feature Maps for Digit "{test_label}"', fontsize=16)

# Show original image in first subplot
axes[0, 0].imshow(test_image.squeeze(), cmap='gray')
axes[0, 0].set_title('Original', fontsize=10)
axes[0, 0].axis('off')

# Show first 16 feature maps
for i in range(16):
    row = (i + 1) // 6
    col = (i + 1) % 6
    if row < 3:  # Only plot if we have space
        ax = axes[row, col]
        fmap = feature_maps[0, i, :, :].cpu()
        ax.imshow(fmap, cmap='viridis')
        ax.set_title(f'Filter {i+1}', fontsize=8)
        ax.axis('off')

# Hide unused subplots
axes[2, 5].axis('off')

plt.tight_layout()
plt.show()

print("Each feature map shows where that filter 'fires' (activates) on the image.")
print("Bright areas = strong activation. This is how CNNs 'see' the image!")

## 🧠 Summary & Next Steps

This concludes the **Zero to Hero** prerequisites pipeline!

**You have learned:**
1. **Tensors**: The math blocks.
2. **Autograd**: The learning engine.
3. **Training Loop**: The standard recipe.
4. **DataLoaders**: Handling data.
5. **CNNs**: Handling images — and you've **seen** what they learn!

**You are now ready to tackle the main NanoJEPA notebooks!** 🚀