# 02 - Baseline CNN Training

This notebook trains a baseline CNN (LeNet-inspired) for facial emotion recognition.

## Contents
1. Setup and configuration
2. Load data
3. Build model
4. Train model
5. Evaluate results

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from src.models import BaselineCNN
from src.data import get_dataloaders
from src.utils.metrics import MetricsTracker, calculate_metrics
from src.utils.visualization import plot_training_history, plot_confusion_matrix

%matplotlib inline

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Configuration

In [None]:
# Hyperparameters
CONFIG = {
    'data_dir': '../data',
    'batch_size': 64,
    'num_workers': 4,
    'epochs': 50,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'dropout': 0.5,
    'num_classes': 7
}

## 2. Load Data

In [None]:
train_loader, val_loader, test_loader = get_dataloaders(
    data_dir=CONFIG['data_dir'],
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    augment=True
)

print(f'Train batches: {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')
print(f'Test batches: {len(test_loader)}')

## 3. Build Model

In [None]:
model = BaselineCNN(
    num_classes=CONFIG['num_classes'],
    dropout=CONFIG['dropout']
).to(device)

# Print model architecture
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'\nTotal parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

## 4. Train Model

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.1)

# Metrics tracker
tracker = MetricsTracker()

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc='Training'):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / total, 100. * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / total, 100. * correct / total

In [None]:
# Training loop
best_val_acc = 0

for epoch in range(1, CONFIG['epochs'] + 1):
    print(f'\nEpoch {epoch}/{CONFIG["epochs"]}')
    print('-' * 30)
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    
    lr = optimizer.param_groups[0]['lr']
    tracker.update(epoch, train_loss, train_acc, val_loss, val_acc, lr)
    
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), '../checkpoints/baseline_best.pth')
        print(f'New best model saved! Val Acc: {val_acc:.2f}%')

## 5. Results

In [None]:
# Plot training history
plot_training_history(tracker.get_history())

In [None]:
# Final evaluation on test set
model.load_state_dict(torch.load('../checkpoints/baseline_best.pth'))
test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f'\nTest Results:')
print(f'Test Loss: {test_loss:.4f}')
print(f'Test Accuracy: {test_acc:.2f}%')