# Model Training for Emotion Recognition

This notebook demonstrates how to train emotion recognition models using the implemented components.

**What we'll do:**
1. Load the dataset using our data loading pipeline
2. Create models (ResNet18, EfficientNet-B0)
3. Train with different loss functions (MSE, L1, KL Divergence)
4. Monitor training with TensorBoard
5. Save and load checkpoints

## 1. Setup and Imports

In [None]:
import sys
import torch
import torch.optim as optim
from pathlib import Path
import matplotlib.pyplot as plt

# Import our custom modules
from src.models import create_model
from src.losses import get_loss_function
from src.train import create_trainer
from src.utils import get_device, set_seed, count_parameters

# Import data loading from previous notebook
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()
print(f"Using device: {device}")

## 2. Load Dataset

We'll reuse the `EmotionDataset` class from the data loading notebook.

In [None]:
# Emotion labels
EMOTION_LABELS = [
    'neutral', 'happy', 'sad', 'surprised', 'fear', 'disgust', 'angry',
    'contempt', 'serene', 'contemplative', 'secure', 'untroubled', 'quiet'
]

class EmotionDataset(Dataset):
    """PyTorch Dataset for emotion recognition with probability distributions."""
    
    def __init__(self, images_dir, annots_dir, transform=None, target_transform=None):
        self.images_dir = Path(images_dir)
        self.annots_dir = Path(annots_dir)
        self.transform = transform
        self.target_transform = target_transform
        
        # Get all image files
        self.image_files = sorted(
            list(self.images_dir.glob('*.jpg')) + 
            list(self.images_dir.glob('*.png'))
        )
        
        self._verify_annotations()
        
    def _verify_annotations(self):
        """Verify all images have corresponding annotation files."""
        valid_files = []
        for img_path in self.image_files:
            annot_path = self.annots_dir / f"{img_path.stem}_prob_rank.txt"
            if annot_path.exists():
                valid_files.append(img_path)
        
        self.image_files = valid_files
        print(f"‚úì Verified {len(self.image_files)} samples in {self.images_dir.parent.name}")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Load probabilities (comma-separated)
        annot_path = self.annots_dir / f"{img_path.stem}_prob_rank.txt"
        with open(annot_path, 'r') as f:
            line = f.read().strip()
            probs = np.array([float(val) for val in line.split(',')], dtype=np.float32)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            probs = self.target_transform(probs)
        
        probs = torch.from_numpy(probs)
        
        return image, probs

In [None]:
# Dataset paths
DATASET_ROOT = Path('AffectNetFused')
TRAIN_IMAGES_DIR = DATASET_ROOT / 'train_set' / 'images'
TRAIN_ANNOTS_DIR = DATASET_ROOT / 'train_set' / 'annotations'
VAL_IMAGES_DIR = DATASET_ROOT / 'val_set' / 'images'
VAL_ANNOTS_DIR = DATASET_ROOT / 'val_set' / 'annotations'

# Image transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = EmotionDataset(TRAIN_IMAGES_DIR, TRAIN_ANNOTS_DIR, transform=train_transform)
val_dataset = EmotionDataset(VAL_IMAGES_DIR, VAL_ANNOTS_DIR, transform=val_transform)

# Create dataloaders
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

print(f"\n‚úì DataLoaders ready:")
print(f"  Training: {len(train_dataset)} samples, {len(train_loader)} batches")
print(f"  Validation: {len(val_dataset)} samples, {len(val_loader)} batches")

## 3. Create and Inspect Model

In [None]:
# Create a model
model = create_model('resnet18', pretrained=True, num_emotions=13)
print(f"\n‚úì Created model: {model}")

# Count parameters
params = count_parameters(model)
print(f"\nüìä Model parameters:")
print(f"  Total: {params['total']:,}")
print(f"  Trainable: {params['trainable']:,}")
print(f"  Non-trainable: {params['non_trainable']:,}")

# Test forward pass
dummy_input = torch.randn(2, 3, 224, 224)
model.eval()
with torch.no_grad():
    output = model(dummy_input)

print(f"\n‚úì Forward pass test:")
print(f"  Input shape: {dummy_input.shape}")
print(f"  Output shape: {output.shape}")
print(f"  Output sums: {output.sum(dim=1)} (should be ~1.0)")

## 4. Configure Training

Choose loss function, optimizer, and other hyperparameters.

In [None]:
# Configuration
EXPERIMENT_NAME = 'resnet18_mse'  # Change this for different experiments
NUM_EPOCHS = 10
LEARNING_RATE = 0.001
LOSS_TYPE = 'mse'  # Options: 'mse', 'l1', 'kl', 'ce', 'js'

print(f"\n‚öôÔ∏è Training Configuration:")
print(f"  Experiment: {EXPERIMENT_NAME}")
print(f"  Model: ResNet18")
print(f"  Loss: {LOSS_TYPE.upper()}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Batch Size: {BATCH_SIZE}")

In [None]:
# Create fresh model for training
model = create_model('resnet18', pretrained=True, num_emotions=13)

# Create loss function
loss_fn = get_loss_function(LOSS_TYPE)
print(f"\n‚úì Loss function: {loss_fn}")

# Create optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
print(f"‚úì Optimizer: Adam (lr={LEARNING_RATE})")

# Learning rate scheduler (optional)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
print(f"‚úì Scheduler: ReduceLROnPlateau")

## 5. Create Trainer and Start Training

In [None]:
# Create trainer
trainer = create_trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    device=device,
    checkpoint_dir=f'checkpoints/{EXPERIMENT_NAME}',
    log_dir=f'runs/{EXPERIMENT_NAME}',
    metrics=['mse', 'kl', 'tvd']  # Metrics to track
)

print(f"\n‚úì Trainer created!")
print(f"  Checkpoints will be saved to: checkpoints/{EXPERIMENT_NAME}")
print(f"  TensorBoard logs: runs/{EXPERIMENT_NAME}")
print(f"\nTo view training in TensorBoard, run:")
print(f"  tensorboard --logdir=runs")

In [None]:
# Start training!
history = trainer.train(
    num_epochs=NUM_EPOCHS,
    scheduler=scheduler,
    early_stopping_patience=5
)

# Save final model
trainer.save_final_checkpoint()

## 6. Visualize Training Results

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Metrics
if history['val_metrics']:
    metric_name = list(history['val_metrics'][0].keys())[0]
    metric_values = [m[metric_name] for m in history['val_metrics']]
    axes[1].plot(metric_values, label=metric_name.upper(), linewidth=2, color='green')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Metric Value')
    axes[1].set_title(f'Validation {metric_name.upper()}')
    axes[1].legend()
    axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f'checkpoints/{EXPERIMENT_NAME}/training_curves.png', dpi=150)
plt.show()

print(f"\n‚úì Training curves saved to: checkpoints/{EXPERIMENT_NAME}/training_curves.png")

## 7. Quick Predictions Visualization

Let's see how the model performs on some validation samples.

In [None]:
from src.utils import visualize_predictions

# Get a batch from validation set
images, targets = next(iter(val_loader))
images = images.to(device)
targets = targets.to(device)

# Make predictions
model.eval()
with torch.no_grad():
    predictions = model(images)

# Visualize
fig = visualize_predictions(
    images[:4], 
    predictions[:4], 
    targets[:4],
    num_samples=4,
    save_path=f'checkpoints/{EXPERIMENT_NAME}/sample_predictions.png'
)
plt.show()

## 8. Next Steps

**To run more experiments:**

1. **Try different loss functions:**
   - Set `LOSS_TYPE = 'kl'` for KL Divergence
   - Set `LOSS_TYPE = 'l1'` for L1 loss
   - Set `LOSS_TYPE = 'ce'` for Cross-Entropy

2. **Try different models:**
   - `create_model('resnet50', pretrained=True)`
   - `create_model('efficientnet_b0', pretrained=True)`

3. **Adjust hyperparameters:**
   - Learning rate: `LEARNING_RATE = 0.0001`
   - Batch size: `BATCH_SIZE = 64`
   - More epochs: `NUM_EPOCHS = 20`

4. **View TensorBoard:**
   ```bash
   tensorboard --logdir=runs
   ```
   Then open http://localhost:6006 in your browser

5. **Comprehensive evaluation:**
   - See `03_evaluation.ipynb` for detailed model comparison
   - See `04_visualization.ipynb` for qualitative analysis

## Summary

‚úì Loaded dataset with emotion probability distributions  
‚úì Created and trained emotion recognition model  
‚úì Tracked metrics with TensorBoard  
‚úì Saved checkpoints for best model  
‚úì Visualized predictions vs ground truth  

The trained model is saved in `checkpoints/{EXPERIMENT_NAME}/best_model.pth`