# Q1: Training ResNet-18 on CIFAR-100

**Objectives:**
- Train ResNet-18 classifier on CIFAR-100
- Visualize training curves
- Save checkpoint for later use
- Generate training GIF for report

## Setup

In [None]:
# Mount Google Drive (for saving checkpoints)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repository 
import os
if not os.path.exists('/content/OOD-Detection-Project---CSC_5IA23'):
    !git clone https://github.com/DiegoFleury/OOD-Detection-Project---CSC_5IA23.git
%cd /content/OOD-Detection-Project---CSC_5IA23

In [None]:
# Install dependencies
!pip install -q torch torchvision matplotlib seaborn scikit-learn pyyaml imageio tqdm

In [None]:
# Imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml

from src.models import ResNet18
from src.data import get_cifar100_loaders
from src.utils import Trainer, plot_training_curves, create_training_gif, plot_final_metrics

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Load config
with open('configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration:")
print(yaml.dump(config, default_flow_style=False))

## 1. Load Data

In [None]:
print("Loading CIFAR-100 dataset...")

train_loader, val_loader, test_loader = get_cifar100_loaders(
    data_dir=config['data']['data_dir'],
    batch_size=config['training']['batch_size'],
    num_workers=config['data']['num_workers'],
    augment=config['data']['augment'],
    val_split=config['training']['val_split']
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Visualize sample batch
images, labels = next(iter(train_loader))

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
    img = images[i].permute(1, 2, 0).numpy()
    # Denormalize
    img = img * np.array([0.2675, 0.2565, 0.2761]) + np.array([0.5071, 0.4867, 0.4408])
    img = np.clip(img, 0, 1)
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(f"Class {labels[i].item()}", fontsize=8)

plt.tight_layout()

sample_batch_path = os.path.join(config['paths']['figures'], 'training', 'sample_batch.png')
plt.savefig(sample_batch_path, dpi=150, bbox_inches='tight')
plt.show()

## 2. Create Model

In [None]:
print("Creating ResNet-18 model...")

model = ResNet18(num_classes=config['model']['num_classes'])

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
x_test = torch.randn(2, 3, 32, 32)
with torch.no_grad():
    out = model(x_test)
print(f"Output shape: {out.shape}")

## 3. Train Model

In [None]:
import glob
import re

checkpoint_dir = config['paths']['checkpoints']
checkpoints = glob.glob(os.path.join(checkpoint_dir, 'resnet18_cifar100_*.pth'))

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    lr=config['training']['lr'],
    momentum=config['training']['momentum'],
    weight_decay=config['training']['weight_decay'],
    device=device
)

if checkpoints:
    # Extract epoch number from checkpoint filename
    def get_epoch_num(path):
        match = re.search(r'epoch(\d+)', path)
        if match:
            return int(match.group(1))
        return 0
    
    # Get checkpoint with highest epoch number
    latest = max(checkpoints, key=get_epoch_num)
    epoch_num = get_epoch_num(latest)
    
    print(f"Resuming from: {os.path.basename(latest)} (epoch {epoch_num})")
    trainer.load_checkpoint(latest)
    
    epochs_done = len(trainer.history['train_loss'])
    epochs_left = config['training']['epochs'] - epochs_done
    print(f"Epoch {epochs_done}/{config['training']['epochs']} | Best val: {trainer.best_val_acc:.2f}%")
else:
    print("Starting fresh training")
    epochs_left = config['training']['epochs']

if epochs_left > 0:
    history = trainer.train(
        epochs=epochs_left,
        save_dir=checkpoint_dir,
        early_stopping_patience=config['training']['early_stopping_patience'],
        checkpoint_frequency=config['training']['checkpoint_frequency']
    )
else:
    print("Already trained!")
    history = trainer.history

print("\n" + "="*50)
print("Training finished!")
print("="*50)

## 4. Visualize Results

In [None]:
# Plot training curves
curves_path = os.path.join(config['paths']['figures'], 'training', 'training_curves.png')
plot_training_curves(history, save_path=curves_path)

In [None]:
# Create animated GIF (for README)
gif_path = os.path.join(config['paths']['gifs'], 'training_curves.gif')
create_training_gif(history, save_path=gif_path, fps=10)

## 5. Print Final Summary

In [None]:
print("\n" + "=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)

print(f"\nFinal Train Accuracy: {history['train_acc'][-1]:.2f}%")
print(f"Final Val Accuracy: {history['val_acc'][-1]:.2f}%")
print(f"Final Test Accuracy: {history['test_acc'][-1]:.2f}%")

best_val_acc = max(history['val_acc'])
best_epoch = history['val_acc'].index(best_val_acc) + 1
print(f"\nBest Val Accuracy: {best_val_acc:.2f}% (Epoch {best_epoch})")

print(f"\nTotal Epochs: {len(history['train_loss'])}")
print(f"\nCheckpoint saved at: checkpoints/resnet18_cifar100_best.pth")
print(f"Figures saved in: results/figures/training/")
print("\n" + "=" * 60)

## 6. Commit Results to GitHub

In [None]:
!git add results/figures/training/
!git add results/figures/gifs/training_curves.gif
!git commit -m "Add Q1 training results: curves and GIF"
!git push

print("Results committed to GitHub!")