# Lesson 4: Data Augmentation with CIFAR-10

This notebook demonstrates data augmentation techniques using PyTorch and the CIFAR-10 dataset.

**Learning objectives:**
- Understand why data augmentation improves model generalization
- Apply common augmentation transforms to images
- Compare model performance with and without augmentation


In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from fastai.vision.all import *
from fastai.callback.all import *

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
set_seed(42, reproducible=True)

print(f'PyTorch version: {torch.__version__}')
print(f'Using device: {"cuda" if torch.cuda.is_available() else "cpu"}')


## 1. Load CIFAR-10 Dataset

CIFAR-10 consists of 60,000 32×32 color images in 10 classes:
- airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck


In [None]:
# Download CIFAR-10 dataset
path = untar_data(URLs.CIFAR)
print(f'Dataset downloaded to: {path}')

# CIFAR-10 has 10 classes
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(f'Classes: {classes}')
print(f'Total images: 60,000 (50,000 train + 10,000 test)')


## 2. Visualize Original Images


In [None]:
# Create a basic DataBlock to visualize original images
dblock_viz = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(train_name='train', valid_name='test'),
    item_tfms=Resize(32)
)

dls_viz = dblock_viz.dataloaders(path, bs=16)

# Show a batch of images
print('Sample images from CIFAR-10:')
dls_viz.show_batch(max_n=16, nrows=2, figsize=(14, 6))


## 3. Common Data Augmentation Techniques

Augmentation artificially expands the training set by applying random transformations:
- **RandomHorizontalFlip**: Flips images left-right
- **RandomCrop**: Crops random patches from images
- **ColorJitter**: Randomly changes brightness, contrast, saturation
- **RandomRotation**: Rotates images by random angles


In [None]:
# Create augmentation pipeline using FastAI
# aug_transforms provides common augmentations for computer vision
aug_tfms = aug_transforms(
    size=32,
    do_flip=True,          # Random horizontal flip
    flip_vert=False,       # Don't flip vertically (objects don't appear upside down)
    max_rotate=15.0,       # Random rotation up to 15 degrees
    max_lighting=0.3,      # Brightness/contrast adjustments
    max_warp=0.2,          # Perspective warping
    p_affine=0.75,         # Probability of applying geometric transforms
    p_lighting=0.75        # Probability of applying lighting transforms
)

print('Augmentation transforms configured:')
print('- Random horizontal flips')
print('- Random crops with padding')
print('- Random rotations (±15°)')
print('- Brightness and contrast adjustments')
print('- Perspective warping')


## 4. Compare Original vs Augmented Images


In [None]:
# Create a DataBlock with augmentation to visualize the effects
dblock_aug_viz = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(train_name='train', valid_name='test'),
    item_tfms=Resize(32),
    batch_tfms=[*aug_tfms, Normalize.from_stats(*cifar_stats)]
)

dls_aug_viz = dblock_aug_viz.dataloaders(path, bs=16)

# Visualize augmented images
print('Same images with augmentation applied:')
dls_aug_viz.show_batch(max_n=16, nrows=2, figsize=(14, 6))

print('\nNotice the variations:')
print('- Some images are flipped horizontally')
print('- Images have different brightness/contrast')
print('- Slight rotations and perspective changes')


## 5. Prepare Data with FastAI

FastAI provides a cleaner API for data loading and augmentation.


In [None]:
# Download and extract CIFAR-10 to a standard location
path = untar_data(URLs.CIFAR)
print(f'Data path: {path}')

# Create DataBlock WITHOUT augmentation (baseline)
dblock_basic = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(train_name='train', valid_name='test'),
    item_tfms=Resize(32),
    batch_tfms=[Normalize.from_stats(*cifar_stats)]
)

dls_basic = dblock_basic.dataloaders(path, bs=128)
print(f'Training batches: {len(dls_basic.train)}')
print(f'Validation batches: {len(dls_basic.valid)}')
print(f'Classes: {dls_basic.vocab}')


## 6. Visualize with FastAI


In [None]:
# FastAI's built-in visualization
dls_basic.show_batch(max_n=16, nrows=2, figsize=(12, 6))


## 7. Define Inception-based Model

We'll use a pretrained Inception (Xception) architecture, adapted for CIFAR-10.


In [None]:
# Create learner with Inception (xresnet) architecture
# xresnet is FastAI's improved version of ResNet with Inception-like improvements
learn_basic = vision_learner(
    dls_basic, 
    xresnet18,  # Inception-inspired architecture
    metrics=[accuracy, error_rate],
    loss_func=CrossEntropyLossFlat()
)

print(f'Model: xresnet18 (Inception-inspired architecture)')
print(f'Total parameters: {sum(p.numel() for p in learn_basic.model.parameters()):,}')


## 8. Train Without Augmentation (Baseline)


In [None]:
print('Training WITHOUT augmentation...')
learn_basic.fine_tune(10, base_lr=3e-3, freeze_epochs=2)


## 9. Create DataLoader WITH Augmentation


In [None]:
# Create DataBlock WITH augmentation using FastAI's aug_transforms
dblock_augmented = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(train_name='train', valid_name='test'),
    item_tfms=Resize(32),
    batch_tfms=[
        *aug_transforms(
            size=32,
            do_flip=True,
            flip_vert=False,
            max_rotate=15.0,
            max_lighting=0.3,
            max_warp=0.2,
            p_affine=0.75,
            p_lighting=0.75
        ),
        Normalize.from_stats(*cifar_stats)
    ]
)

dls_augmented = dblock_augmented.dataloaders(path, bs=128)

# Visualize augmented samples
print('Augmented training samples:')
dls_augmented.show_batch(max_n=16, nrows=2, figsize=(12, 6))


## 10. Train WITH Augmentation


In [None]:
print('Training WITH augmentation...')
learn_augmented = vision_learner(
    dls_augmented, 
    xresnet18,
    metrics=[accuracy, error_rate],
    loss_func=CrossEntropyLossFlat()
)

learn_augmented.fine_tune(10, base_lr=3e-3, freeze_epochs=2)


## 11. Compare Results

FastAI automatically tracks training history. Let's visualize and compare.


In [None]:
# Plot training history for both models
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Extract losses
basic_losses = [x['train_loss'] for x in learn_basic.recorder.values]
aug_losses = [x['train_loss'] for x in learn_augmented.recorder.values]

# Extract accuracies (convert to percentage)
basic_acc = [x['accuracy'] * 100 for x in learn_basic.recorder.values]
aug_acc = [x['accuracy'] * 100 for x in learn_augmented.recorder.values]

# Plot training loss
ax1.plot(basic_losses, label='Without Augmentation', marker='o', linewidth=2)
ax1.plot(aug_losses, label='With Augmentation', marker='s', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Training Loss', fontsize=12)
ax1.set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot validation accuracy
ax2.plot(basic_acc, label='Without Augmentation', marker='o', linewidth=2)
ax2.plot(aug_acc, label='With Augmentation', marker='s', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Validation Accuracy (%)', fontsize=12)
ax2.set_title('Validation Accuracy Comparison', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f'\nFinal Results:')
print(f'Without Augmentation: {basic_acc[-1]:.2f}% accuracy')
print(f'With Augmentation: {aug_acc[-1]:.2f}% accuracy')
print(f'Improvement: {aug_acc[-1] - basic_acc[-1]:.2f}%')


## 12. Visualize Predictions

Let's see how each model performs on test images.


In [None]:
# Show predictions from augmented model
learn_augmented.show_results(max_n=12, figsize=(14, 10))


## 13. Confusion Matrix

Analyze which classes are most commonly confused.


In [None]:
# Get interpretation for augmented model
interp = ClassificationInterpretation.from_learner(learn_augmented)

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(10, 10))
interp.plot_confusion_matrix(figsize=(10, 10))
plt.show()

# Show most confused classes
print('\nMost confused pairs:')
interp.most_confused(min_val=50)


## 14. Understanding the Inception Architecture

The xresnet architecture includes Inception-inspired improvements:
- **Depthwise separable convolutions**: More efficient than standard convolutions
- **Residual connections**: Help with gradient flow
- **Batch normalization**: Stabilizes training
- **Global average pooling**: Reduces parameters vs fully connected layers


In [None]:
# Inspect model architecture
print('Model Architecture Summary:')
print(learn_augmented.model)
print(f'\nTotal trainable parameters: {sum(p.numel() for p in learn_augmented.model.parameters() if p.requires_grad):,}')


## Key Takeaways

1. **Data augmentation** artificially expands the training set without collecting new data
2. **Common techniques**: flips, crops, rotations, color/lighting adjustments, warping
3. **Benefits**: Better generalization, reduced overfitting, improved test accuracy
4. **Inception architectures**: Use efficient depthwise separable convolutions and residual connections
5. **FastAI advantages**: Simplified API, automatic mixed precision, learning rate scheduling
6. **Best practice**: Apply augmentation only to training data, not validation/test data
7. **Transfer learning**: Fine-tuning pretrained models significantly improves performance on small datasets
