# Convolutional Autoencoder for MNIST

Train a convolutional autoencoder to compress and reconstruct MNIST images using a 12-dimensional latent space.

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.optim as optim
from model import ConvAutoencoder
from data import get_mnist_dataloaders
from utils import get_device, plot_reconstructions, plot_training_curves, visualize_latent_space
from tqdm import tqdm
import matplotlib.pyplot as plt

## Configuration

In [None]:
# Hyperparameters
LATENT_DIM = 12
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EPOCHS = 10

# Get device
device = get_device()
print(f"Using device: {device}")

## Data Loading

In [None]:
# Load MNIST data
train_loader, val_loader, test_loader = get_mnist_dataloaders(
    batch_size=BATCH_SIZE
)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

## Model Initialization

In [None]:
# Create model
model = ConvAutoencoder(latent_dim=LATENT_DIM)
model = model.to(device)

# Print model architecture
print("Model Architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## Training Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """
    Train the model for one epoch.
    
    Args:
        model: The autoencoder model
        dataloader: Training dataloader
        criterion: Loss function
        optimizer: Optimizer
        device: Device to train on
        
    Returns:
        Average loss for the epoch
    """
    model.train()
    total_loss = 0.0
    
    for images, targets in tqdm(dataloader, desc='Training', leave=False):
        images = images.to(device)
        targets = targets.to(device)
        
        # Forward pass
        reconstructed = model(images)
        loss = criterion(reconstructed, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)


def validate(model, dataloader, criterion, device):
    """
    Validate the model.
    
    Args:
        model: The autoencoder model
        dataloader: Validation dataloader
        criterion: Loss function
        device: Device to run validation on
        
    Returns:
        Average validation loss
    """
    model.eval()
    total_loss = 0.0
    
    with torch.no_grad():
        for images, targets in tqdm(dataloader, desc='Validation', leave=False):
            images = images.to(device)
            targets = targets.to(device)
            
            # Forward pass
            reconstructed = model(images)
            loss = criterion(reconstructed, targets)
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

## Training Loop

In [None]:
# Initialize optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

# Track losses
train_losses = []
val_losses = []

# Training loop
print("Starting training...\n")
for epoch in range(1, EPOCHS + 1):
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = validate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    
    # Print progress
    print(f"Epoch {epoch}/{EPOCHS}")
    print(f"  Train Loss: {train_loss:.6f}")
    print(f"  Val Loss:   {val_loss:.6f}")
    print()

print("Training complete!")

## Visualization

In [None]:
# Plot training curves
plot_training_curves(train_losses, val_losses)

In [None]:
# Plot reconstruction samples from test set
model.eval()
with torch.no_grad():
    test_images, _ = next(iter(test_loader))
    test_images = test_images.to(device)
    reconstructed = model(test_images)
    
plot_reconstructions(test_images, reconstructed, n_samples=10)

In [None]:
# Visualize latent space with PCA
visualize_latent_space(model, test_loader, device, method='pca', n_samples=5000)

## Save Checkpoint

In [None]:
# Save model checkpoint
checkpoint = {
    'epoch': EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'latent_dim': LATENT_DIM,
    'config': {
        'latent_dim': LATENT_DIM,
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'epochs': EPOCHS
    }
}

torch.save(checkpoint, 'autoencoder_checkpoint.pth')
print("Checkpoint saved to autoencoder_checkpoint.pth")