# Day 3: Training CNNs on CIFAR-10
## CV Bootcamp 2024

Train your first CNN from scratch and learn to debug training issues!

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Prepare CIFAR-10 Dataset

CIFAR-10: 60,000 32x32 color images in 10 classes

In [None]:
# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load datasets
train_dataset = datasets.CIFAR10(root='./data', train=True,
                                download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False,
                               download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')

# Classes
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
# Visualize some samples
def show_images(images, labels, classes, num=8):
    fig, axes = plt.subplots(1, num, figsize=(15, 2))
    for i in range(num):
        img = images[i].numpy().transpose(1, 2, 0)
        img = img * 0.5 + 0.5  # Denormalize
        axes[i].imshow(img)
        axes[i].set_title(classes[labels[i]])
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

# Get a batch
dataiter = iter(train_loader)
images, labels = next(dataiter)
show_images(images, labels, classes)

## 2. Define CNN Model

In [None]:
class CIFAR10CNN(nn.Module):
    def __init__(self):
        super(CIFAR10CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = CIFAR10CNN().to(device)
print(model)
print(f'\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}')

## 3. Sanity Check: Overfit One Batch

**Before training on full dataset, verify the model can learn by overfitting one batch!**

This tests:
- Model architecture is correct
- Gradients flow properly
- Loss function works
- Optimizer is configured

In [None]:
# Take one batch
test_images, test_labels = next(iter(train_loader))
test_images, test_labels = test_images.to(device), test_labels.to(device)

# Simple model for sanity check
sanity_model = CIFAR10CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(sanity_model.parameters(), lr=0.001)

print("Sanity Check: Overfitting one batch...")
print("Loss should decrease to near 0\n")

for i in range(100):
    outputs = sanity_model(test_images)
    loss = criterion(outputs, test_labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if i % 20 == 0:
        _, predicted = torch.max(outputs, 1)
        acc = (predicted == test_labels).sum().item() / test_labels.size(0)
        print(f'Iteration {i:3d}: Loss = {loss.item():.4f}, Acc = {acc*100:.1f}%')

print("\nâœ“ Sanity check passed! Model can learn." if loss.item() < 0.1 else "âœ— Something wrong - loss should be near 0")

## 4. Training Loop

In [None]:
# Re-initialize model for actual training
model = CIFAR10CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
train_losses = []
train_accs = []

print("Starting training...\n")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # Forward
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if (i + 1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total
    train_losses.append(epoch_loss)
    train_accs.append(epoch_acc)
    
    print(f'Epoch {epoch+1} Summary: Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%\n')

## 5. Training Diagnostics

### How to Read Your Training Curves

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(train_losses, marker='o')
ax1.set_title('Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True, alpha=0.3)

ax2.plot(train_accs, marker='o', color='green')
ax2.set_title('Training Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Diagnose training
print("\nðŸ“Š Training Diagnostics:")
print("=" * 50)

if train_losses[-1] < train_losses[0] * 0.5:
    print("âœ“ Loss is decreasing - Good!")
else:
    print("âš  Loss not decreasing enough - Try:")
    print("  - Lower learning rate")
    print("  - Check data normalization")

if train_accs[-1] > 60:
    print("âœ“ Accuracy is improving - Good!")
else:
    print("âš  Low accuracy - Try:")
    print("  - Train more epochs")
    print("  - Increase model capacity")

if all(train_losses[i] > train_losses[i+1] for i in range(len(train_losses)-1)):
    print("âœ“ Smooth decrease - Training is stable")
else:
    print("âš  Unstable training - Consider:")
    print("  - Reduce learning rate")
    print("  - Use learning rate scheduling")

## 6. Evaluation on Test Set

In [None]:
model.eval()
correct = 0
total = 0
class_correct = [0] * 10
class_total = [0] * 10

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Per-class accuracy
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

test_accuracy = 100 * correct / total
print(f'Overall Test Accuracy: {test_accuracy:.2f}%\n')

print('Per-class Accuracy:')
for i in range(10):
    acc = 100 * class_correct[i] / class_total[i]
    print(f'{classes[i]:10s}: {acc:.1f}%')

## 7. Visualize Predictions

In [None]:
# Get a batch of test images
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

# Make predictions
outputs = model(images)
_, predicted = torch.max(outputs, 1)

# Show results
images = images.cpu()
labels = labels.cpu()
predicted = predicted.cpu()

fig, axes = plt.subplots(2, 4, figsize=(15, 8))
axes = axes.ravel()

for i in range(8):
    img = images[i].numpy().transpose(1, 2, 0)
    img = img * 0.5 + 0.5  # Denormalize
    
    axes[i].imshow(img)
    
    true_label = classes[labels[i]]
    pred_label = classes[predicted[i]]
    
    color = 'green' if labels[i] == predicted[i] else 'red'
    axes[i].set_title(f'True: {true_label}\nPred: {pred_label}', color=color)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 8. Common Problems & Solutions

### Problem: Loss is NaN
```python
# Solutions:
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Reduce LR
# Check for inf/nan in data
torch.isnan(images).any()  # Should be False
```

### Problem: Loss not decreasing
```python
# Check learning rate
for param_group in optimizer.param_groups:
    print(param_group['lr'])  # Should be 0.001-0.0001

# Check gradients
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f'{name}: {param.grad.norm():.4f}')
```

### Problem: Training acc high, test acc low (Overfitting)
```python
# Add data augmentation
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Increase dropout
self.dropout = nn.Dropout(0.7)  # Was 0.5
```

## 9. Save Your Model

In [None]:
# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'test_accuracy': test_accuracy,
    'epoch': num_epochs,
}, 'cifar10_cnn.pth')

print("Model saved to 'cifar10_cnn.pth'")

# Load model later
# checkpoint = torch.load('cifar10_cnn.pth')
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

## Summary

You've learned:
- âœ“ Preparing CIFAR-10 dataset
- âœ“ Building CNN architecture
- âœ“ Sanity checking before full training
- âœ“ Training loop implementation
- âœ“ Model evaluation
- âœ“ Visualizing training progress
- âœ“ Debugging common issues
- âœ“ Saving and loading models

**Congratulations! You trained your first CNN!**

**Next:** Transfer Learning for even better results with less data!