# Model Training and Evaluation

## EE4745 Neural Networks Final Project

This notebook demonstrates the complete training pipeline for sports image classification models.

### Objectives:
- Model architecture comparison and selection
- Training process demonstration with real-time monitoring
- Hyperparameter experimentation
- Model evaluation and performance analysis
- Training visualization and convergence analysis

---

## 1. Setup and Configuration

Import necessary libraries and set up the training environment.

In [None]:
import os
import sys
import time
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.metrics import precision_recall_fscore_support

import warnings
warnings.filterwarnings('ignore')

# Add src directory to path
sys.path.append('../src')

# Import custom modules
from dataset.sports_dataset import SportsDataset, get_dataloaders
from models.simple_cnn import SimpleCNN, create_simple_cnn
from models.resnet_small import ResNetSmall, create_resnet_small
from training.trainer import Trainer
from training.utils import set_seed, get_device, count_parameters, format_time

# Set style and configuration
plt.style.use('default')
sns.set_palette('viridis')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

# Set seed for reproducibility
set_seed(42)

print("Training Environment Setup")
print("=" * 30)
print(f"PyTorch version: {torch.__version__}")
device = get_device()
print(f"Device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
print(f"MPS available: {torch.backends.mps.is_available()}")

## 2. Data Loading and Preparation

Load and prepare the dataset for training with optimal configurations.

In [None]:
# Dataset configuration
DATA_DIR = '../data'
IMAGE_SIZE = 32  # Start with 32x32 for faster training
BATCH_SIZE = 32
NUM_WORKERS = 2

# Load datasets
print("Loading datasets...")
train_loader, val_loader, num_classes = get_dataloaders(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    num_workers=NUM_WORKERS
)

print(f"\nDataset Information:")
print(f"  Number of classes: {num_classes}")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Training samples: {len(train_loader.dataset):,}")
print(f"  Validation samples: {len(val_loader.dataset):,}")
print(f"  Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"  Batch size: {BATCH_SIZE}")

# Class names
class_names = SportsDataset.CLASSES
print(f"\nClasses: {class_names}")

# Test data loading
print("\nTesting data loading...")
start_time = time.time()
for batch_idx, (images, labels) in enumerate(train_loader):
    if batch_idx == 0:
        print(f"  Sample batch shape: {images.shape}")
        print(f"  Label batch shape: {labels.shape}")
        print(f"  Image range: [{images.min():.3f}, {images.max():.3f}]")
        print(f"  Sample labels: {labels[:8].tolist()}")
        break
load_time = time.time() - start_time
print(f"  First batch load time: {load_time:.3f}s")

## 3. Model Architecture Comparison

Compare different model architectures and their characteristics.

In [None]:
def analyze_model_architecture(model, model_name, input_size=32):
    """Analyze and display model architecture information"""
    
    print(f"\n{model_name} Architecture Analysis:")
    print("=" * 40)
    
    # Parameter count
    total_params, trainable_params = count_parameters(model)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Model size estimation
    param_size = total_params * 4 / (1024 * 1024)  # Assuming float32
    print(f"Estimated model size: {param_size:.2f} MB")
    
    # Test forward pass
    model.eval()
    dummy_input = torch.randn(1, 3, input_size, input_size)
    
    with torch.no_grad():
        start_time = time.time()
        output = model(dummy_input)
        inference_time = time.time() - start_time
    
    print(f"Output shape: {output.shape}")
    print(f"Inference time: {inference_time*1000:.2f}ms")
    
    # Architecture summary
    print(f"\nArchitecture Summary:")
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # Only leaf modules
            params = sum(p.numel() for p in module.parameters())
            if params > 0:
                print(f"  {name:30}: {str(module):50} | {params:,} params")

# Create and analyze models
print("Model Architecture Comparison")
print("=" * 50)

# SimpleCNN
simple_cnn = create_simple_cnn(num_classes=num_classes, input_size=IMAGE_SIZE)
analyze_model_architecture(simple_cnn, "SimpleCNN", IMAGE_SIZE)

# ResNetSmall
resnet_small = create_resnet_small(num_classes=num_classes, input_size=IMAGE_SIZE)
analyze_model_architecture(resnet_small, "ResNetSmall", IMAGE_SIZE)

# Comparison table
models_comparison = {
    'Model': ['SimpleCNN', 'ResNetSmall'],
    'Parameters': [count_parameters(simple_cnn)[0], count_parameters(resnet_small)[0]],
    'Size (MB)': [count_parameters(simple_cnn)[0] * 4 / (1024 * 1024), 
                  count_parameters(resnet_small)[0] * 4 / (1024 * 1024)]
}

comparison_df = pd.DataFrame(models_comparison)
comparison_df['Parameters'] = comparison_df['Parameters'].apply(lambda x: f"{x:,}")
comparison_df['Size (MB)'] = comparison_df['Size (MB)'].apply(lambda x: f"{x:.2f}")

print("\n" + "=" * 50)
print("MODEL COMPARISON TABLE")
print("=" * 50)
print(comparison_df.to_string(index=False))

## 4. Training Configuration

Define training configurations for different experiments.

In [None]:
# Base training configuration
def get_base_config(model_name, experiment_name=None):
    """Get base training configuration"""
    
    if experiment_name is None:
        experiment_name = f"{model_name.lower()}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    config = {
        'model_name': model_name,
        'experiment_name': experiment_name,
        'device': str(device),
        'epochs': 50,
        'learning_rate': 1e-3,
        'weight_decay': 1e-4,
        'optimizer': 'adam',
        'scheduler': 'cosine',
        'eta_min': 1e-6,
        'patience': 10,
        'min_delta': 0.001,
        'use_tensorboard': True,
        'log_dir': '../logs',
        'checkpoint_dir': '../checkpoints',
        'save_best_only': True
    }
    
    return config

# Training configurations for different experiments
training_configs = {
    'simple_cnn_baseline': {
        **get_base_config('SimpleCNN', 'simple_cnn_baseline'),
        'epochs': 30,
        'learning_rate': 1e-3
    },
    'resnet_baseline': {
        **get_base_config('ResNetSmall', 'resnet_baseline'),
        'epochs': 40,
        'learning_rate': 1e-3
    },
    'simple_cnn_tuned': {
        **get_base_config('SimpleCNN', 'simple_cnn_tuned'),
        'epochs': 40,
        'learning_rate': 5e-4,
        'weight_decay': 5e-4,
        'scheduler': 'step',
        'step_size': 15,
        'gamma': 0.1
    }
}

print("Training Configurations:")
print("=" * 30)
for config_name, config in training_configs.items():
    print(f"\n{config_name}:")
    for key, value in config.items():
        if key not in ['log_dir', 'checkpoint_dir', 'use_tensorboard', 'save_best_only']:
            print(f"  {key}: {value}")

# Create directories
os.makedirs('../logs', exist_ok=True)
os.makedirs('../checkpoints', exist_ok=True)
print("\nCreated necessary directories for logging and checkpoints.")

## 5. Training Experiment 1: SimpleCNN Baseline

Train a SimpleCNN model as our baseline.

In [None]:
def train_model(model, train_loader, val_loader, config, verbose=True):
    """Train a model with the given configuration"""
    
    print(f"\nStarting training: {config['experiment_name']}")
    print("=" * 50)
    
    # Create trainer
    trainer = Trainer(model, train_loader, val_loader, config)
    
    # Train the model
    start_time = time.time()
    history = trainer.train()
    total_time = time.time() - start_time
    
    print(f"\nTraining completed in {format_time(total_time)}")
    
    return trainer, history

# Experiment 1: SimpleCNN Baseline
print("EXPERIMENT 1: SimpleCNN Baseline")
print("=" * 40)

# Create model
model_simple = create_simple_cnn(num_classes=num_classes, input_size=IMAGE_SIZE)
config_simple = training_configs['simple_cnn_baseline']

# Display training configuration
print("\nTraining Configuration:")
for key, value in config_simple.items():
    if key not in ['log_dir', 'checkpoint_dir', 'use_tensorboard', 'save_best_only']:
        print(f"  {key}: {value}")

# Train model
trainer_simple, history_simple = train_model(model_simple, train_loader, val_loader, config_simple)

# Save results
results_simple = {
    'config': config_simple,
    'history': history_simple,
    'final_train_acc': history_simple['train_acc'][-1],
    'final_val_acc': history_simple['val_acc'][-1],
    'best_val_acc': max(history_simple['val_acc']),
    'final_train_loss': history_simple['train_loss'][-1],
    'final_val_loss': history_simple['val_loss'][-1],
    'min_val_loss': min(history_simple['val_loss'])
}

print(f"\nSimpleCNN Results:")
print(f"  Best validation accuracy: {results_simple['best_val_acc']:.2f}%")
print(f"  Final validation accuracy: {results_simple['final_val_acc']:.2f}%")
print(f"  Minimum validation loss: {results_simple['min_val_loss']:.4f}")

## 6. Training Experiment 2: ResNetSmall

Train a ResNetSmall model to compare with the SimpleCNN baseline.

In [None]:
# Experiment 2: ResNetSmall Baseline
print("\nEXPERIMENT 2: ResNetSmall Baseline")
print("=" * 40)

# Create model
model_resnet = create_resnet_small(num_classes=num_classes, input_size=IMAGE_SIZE)
config_resnet = training_configs['resnet_baseline']

# Display training configuration
print("\nTraining Configuration:")
for key, value in config_resnet.items():
    if key not in ['log_dir', 'checkpoint_dir', 'use_tensorboard', 'save_best_only']:
        print(f"  {key}: {value}")

# Train model
trainer_resnet, history_resnet = train_model(model_resnet, train_loader, val_loader, config_resnet)

# Save results
results_resnet = {
    'config': config_resnet,
    'history': history_resnet,
    'final_train_acc': history_resnet['train_acc'][-1],
    'final_val_acc': history_resnet['val_acc'][-1],
    'best_val_acc': max(history_resnet['val_acc']),
    'final_train_loss': history_resnet['train_loss'][-1],
    'final_val_loss': history_resnet['val_loss'][-1],
    'min_val_loss': min(history_resnet['val_loss'])
}

print(f"\nResNetSmall Results:")
print(f"  Best validation accuracy: {results_resnet['best_val_acc']:.2f}%")
print(f"  Final validation accuracy: {results_resnet['final_val_acc']:.2f}%")
print(f"  Minimum validation loss: {results_resnet['min_val_loss']:.4f}")

## 7. Training Visualization and Analysis

Visualize training curves and analyze model performance.

In [None]:
def plot_training_curves(experiments_dict, title="Training Curves Comparison"):
    """Plot training curves for multiple experiments"""
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(title, fontsize=16, fontweight='bold')
    
    colors = ['blue', 'red', 'green', 'orange', 'purple']
    
    # Plot 1: Training Loss
    for i, (name, results) in enumerate(experiments_dict.items()):
        epochs = range(1, len(results['history']['train_loss']) + 1)
        axes[0, 0].plot(epochs, results['history']['train_loss'], 
                       color=colors[i], label=f'{name}', linewidth=2)
    axes[0, 0].set_title('Training Loss', fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Validation Loss
    for i, (name, results) in enumerate(experiments_dict.items()):
        epochs = range(1, len(results['history']['val_loss']) + 1)
        axes[0, 1].plot(epochs, results['history']['val_loss'], 
                       color=colors[i], label=f'{name}', linewidth=2)
    axes[0, 1].set_title('Validation Loss', fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: Training Accuracy
    for i, (name, results) in enumerate(experiments_dict.items()):
        epochs = range(1, len(results['history']['train_acc']) + 1)
        axes[1, 0].plot(epochs, results['history']['train_acc'], 
                       color=colors[i], label=f'{name}', linewidth=2)
    axes[1, 0].set_title('Training Accuracy', fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 4: Validation Accuracy
    for i, (name, results) in enumerate(experiments_dict.items()):
        epochs = range(1, len(results['history']['val_acc']) + 1)
        axes[1, 1].plot(epochs, results['history']['val_acc'], 
                       color=colors[i], label=f'{name}', linewidth=2)
    axes[1, 1].set_title('Validation Accuracy', fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy (%)')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Collect all experiment results
all_experiments = {
    'SimpleCNN': results_simple,
    'ResNetSmall': results_resnet
}

# Plot training curves
plot_training_curves(all_experiments, "Model Training Comparison")

# Learning rate analysis
def plot_learning_curves_detailed(history, title):
    """Plot detailed learning curves for a single model"""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'{title} - Detailed Training Analysis', fontsize=14, fontweight='bold')
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss curves
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Training', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation', linewidth=2)
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy curves
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Training', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Validation', linewidth=2)
    axes[0, 1].set_title('Accuracy Curves')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Loss smoothing (moving average)
    window = 3
    if len(history['train_loss']) >= window:
        train_loss_smooth = pd.Series(history['train_loss']).rolling(window).mean()
        val_loss_smooth = pd.Series(history['val_loss']).rolling(window).mean()
        
        axes[1, 0].plot(epochs, train_loss_smooth, 'b-', label='Training (smoothed)', linewidth=2)
        axes[1, 0].plot(epochs, val_loss_smooth, 'r-', label='Validation (smoothed)', linewidth=2)
    else:
        axes[1, 0].plot(epochs, history['train_loss'], 'b-', label='Training', linewidth=2)
        axes[1, 0].plot(epochs, history['val_loss'], 'r-', label='Validation', linewidth=2)
    
    axes[1, 0].set_title('Smoothed Loss Curves')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Training-Validation gap analysis
    acc_gap = [t - v for t, v in zip(history['train_acc'], history['val_acc'])]
    loss_gap = [v - t for t, v in zip(history['train_loss'], history['val_loss'])]
    
    ax2 = axes[1, 1]
    ax2.plot(epochs, acc_gap, 'g-', linewidth=2, label='Accuracy Gap')
    ax2.set_ylabel('Accuracy Gap (Train - Val)', color='g')
    ax2.tick_params(axis='y', labelcolor='g')
    ax2.grid(True, alpha=0.3)
    
    ax2_twin = ax2.twinx()
    ax2_twin.plot(epochs, loss_gap, 'orange', linewidth=2, label='Loss Gap')
    ax2_twin.set_ylabel('Loss Gap (Val - Train)', color='orange')
    ax2_twin.tick_params(axis='y', labelcolor='orange')
    
    ax2.set_title('Overfitting Analysis')
    ax2.set_xlabel('Epoch')
    
    plt.tight_layout()
    plt.show()

# Plot detailed curves for each model
plot_learning_curves_detailed(history_simple, "SimpleCNN")
plot_learning_curves_detailed(history_resnet, "ResNetSmall")

## 8. Model Evaluation and Testing

Evaluate trained models and generate detailed performance reports.

In [None]:
def evaluate_model(model, dataloader, device, class_names):
    """Comprehensive model evaluation"""
    
    model.eval()
    all_predictions = []
    all_labels = []
    all_probs = []
    
    print("Evaluating model...")
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluation"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)
            predictions = outputs.argmax(dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probabilities.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision, recall, f1, support = precision_recall_fscore_support(
        all_labels, all_predictions, average=None
    )
    
    # Classification report
    class_report = classification_report(
        all_labels, all_predictions, 
        target_names=class_names, 
        output_dict=True
    )
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    
    return {
        'predictions': all_predictions,
        'labels': all_labels,
        'probabilities': np.array(all_probs),
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'support': support,
        'classification_report': class_report,
        'confusion_matrix': cm
    }

def plot_confusion_matrix(cm, class_names, title, normalize=False):
    """Plot confusion matrix"""
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm = np.nan_to_num(cm)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', 
                cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title(title, fontweight='bold')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

def plot_classification_report(class_report, class_names, title):
    """Visualize classification report as bar plots"""
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle(title, fontsize=14, fontweight='bold')
    
    metrics = ['precision', 'recall', 'f1-score']
    
    for i, metric in enumerate(metrics):
        values = [class_report[cls][metric] for cls in class_names]
        
        bars = axes[i].bar(range(len(class_names)), values, alpha=0.8)
        axes[i].set_title(f'{metric.capitalize()}', fontweight='bold')
        axes[i].set_xlabel('Classes')
        axes[i].set_ylabel(metric.capitalize())
        axes[i].set_xticks(range(len(class_names)))
        axes[i].set_xticklabels(class_names, rotation=45, ha='right')
        axes[i].set_ylim(0, 1.0)
        axes[i].grid(True, alpha=0.3)
        
        # Add value labels on bars
        for j, bar in enumerate(bars):
            height = bar.get_height()
            axes[i].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{values[j]:.3f}', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    plt.show()

# Evaluate SimpleCNN
print("\nEVALUATING SIMPLECNN MODEL")
print("=" * 40)
simple_eval = evaluate_model(trainer_simple.get_model(), val_loader, device, class_names)

print(f"SimpleCNN Validation Results:")
print(f"  Accuracy: {simple_eval['accuracy']:.4f}")
print(f"  Macro F1: {simple_eval['classification_report']['macro avg']['f1-score']:.4f}")
print(f"  Weighted F1: {simple_eval['classification_report']['weighted avg']['f1-score']:.4f}")

# Evaluate ResNetSmall
print("\nEVALUATING RESNETSMALL MODEL")
print("=" * 40)
resnet_eval = evaluate_model(trainer_resnet.get_model(), val_loader, device, class_names)

print(f"ResNetSmall Validation Results:")
print(f"  Accuracy: {resnet_eval['accuracy']:.4f}")
print(f"  Macro F1: {resnet_eval['classification_report']['macro avg']['f1-score']:.4f}")
print(f"  Weighted F1: {resnet_eval['classification_report']['weighted avg']['f1-score']:.4f}")

# Plot confusion matrices
plot_confusion_matrix(simple_eval['confusion_matrix'], class_names, 
                     "SimpleCNN - Confusion Matrix")
plot_confusion_matrix(simple_eval['confusion_matrix'], class_names, 
                     "SimpleCNN - Normalized Confusion Matrix", normalize=True)

plot_confusion_matrix(resnet_eval['confusion_matrix'], class_names, 
                     "ResNetSmall - Confusion Matrix")
plot_confusion_matrix(resnet_eval['confusion_matrix'], class_names, 
                     "ResNetSmall - Normalized Confusion Matrix", normalize=True)

# Plot classification reports
plot_classification_report(simple_eval['classification_report'], class_names, 
                          "SimpleCNN - Per-Class Performance")
plot_classification_report(resnet_eval['classification_report'], class_names, 
                          "ResNetSmall - Per-Class Performance")

## 9. Hyperparameter Tuning Experiment

Demonstrate hyperparameter tuning with the SimpleCNN model.

In [None]:
# Experiment 3: SimpleCNN with tuned hyperparameters
print("EXPERIMENT 3: SimpleCNN Hyperparameter Tuning")
print("=" * 50)

# Create new model with same architecture
model_tuned = create_simple_cnn(num_classes=num_classes, input_size=IMAGE_SIZE)
config_tuned = training_configs['simple_cnn_tuned']

print("\nTuned Configuration Changes:")
print(f"  Learning rate: {config_tuned['learning_rate']} (vs {config_simple['learning_rate']})")
print(f"  Weight decay: {config_tuned['weight_decay']} (vs {config_simple['weight_decay']})")
print(f"  Scheduler: {config_tuned['scheduler']} (vs {config_simple['scheduler']})")
print(f"  Epochs: {config_tuned['epochs']} (vs {config_simple['epochs']})")

# Train tuned model
trainer_tuned, history_tuned = train_model(model_tuned, train_loader, val_loader, config_tuned)

# Save results
results_tuned = {
    'config': config_tuned,
    'history': history_tuned,
    'final_train_acc': history_tuned['train_acc'][-1],
    'final_val_acc': history_tuned['val_acc'][-1],
    'best_val_acc': max(history_tuned['val_acc']),
    'final_train_loss': history_tuned['train_loss'][-1],
    'final_val_loss': history_tuned['val_loss'][-1],
    'min_val_loss': min(history_tuned['val_loss'])
}

print(f"\nTuned SimpleCNN Results:")
print(f"  Best validation accuracy: {results_tuned['best_val_acc']:.2f}%")
print(f"  Final validation accuracy: {results_tuned['final_val_acc']:.2f}%")
print(f"  Minimum validation loss: {results_tuned['min_val_loss']:.4f}")

# Add to experiments
all_experiments['SimpleCNN_Tuned'] = results_tuned

# Compare all three models
plot_training_curves(all_experiments, "Complete Model Comparison")

# Evaluate tuned model
print("\nEVALUATING TUNED SIMPLECNN MODEL")
print("=" * 40)
tuned_eval = evaluate_model(trainer_tuned.get_model(), val_loader, device, class_names)

print(f"Tuned SimpleCNN Validation Results:")
print(f"  Accuracy: {tuned_eval['accuracy']:.4f}")
print(f"  Macro F1: {tuned_eval['classification_report']['macro avg']['f1-score']:.4f}")
print(f"  Weighted F1: {tuned_eval['classification_report']['weighted avg']['f1-score']:.4f}")

## 10. Model Comparison and Analysis

Comprehensive comparison of all trained models.

In [None]:
# Create comprehensive comparison
def create_model_comparison_table(experiments, evaluations):
    """Create a comprehensive comparison table"""
    
    comparison_data = []
    
    model_names = ['SimpleCNN', 'ResNetSmall', 'SimpleCNN_Tuned']
    eval_names = [simple_eval, resnet_eval, tuned_eval]
    
    for i, (model_name, eval_result) in enumerate(zip(model_names, eval_names)):
        if model_name in experiments:
            exp_data = experiments[model_name]
            
            row = {
                'Model': model_name,
                'Best Val Acc (%)': f"{exp_data['best_val_acc']:.2f}",
                'Final Val Acc (%)': f"{exp_data['final_val_acc']:.2f}",
                'Min Val Loss': f"{exp_data['min_val_loss']:.4f}",
                'Test Accuracy': f"{eval_result['accuracy']:.4f}",
                'Macro F1': f"{eval_result['classification_report']['macro avg']['f1-score']:.4f}",
                'Weighted F1': f"{eval_result['classification_report']['weighted avg']['f1-score']:.4f}",
                'Epochs Trained': len(exp_data['history']['train_loss'])
            }
            comparison_data.append(row)
    
    return pd.DataFrame(comparison_data)

# Create comparison table
comparison_df = create_model_comparison_table(
    all_experiments, 
    [simple_eval, resnet_eval, tuned_eval]
)

print("\n" + "="*80)
print("COMPREHENSIVE MODEL COMPARISON")
print("="*80)
print(comparison_df.to_string(index=False))

# Performance improvement analysis
print("\n" + "="*50)
print("PERFORMANCE IMPROVEMENT ANALYSIS")
print("="*50)

baseline_acc = results_simple['best_val_acc']
resnet_acc = results_resnet['best_val_acc']
tuned_acc = results_tuned['best_val_acc']

print(f"\nBaseline SimpleCNN: {baseline_acc:.2f}%")
print(f"ResNetSmall vs Baseline: {resnet_acc:.2f}% ({resnet_acc - baseline_acc:+.2f}%)")
print(f"Tuned SimpleCNN vs Baseline: {tuned_acc:.2f}% ({tuned_acc - baseline_acc:+.2f}%)")

if resnet_acc > baseline_acc:
    print(f"\n‚úÖ ResNetSmall shows improvement over baseline")
else:
    print(f"\n‚ö†Ô∏è ResNetSmall does not improve over baseline")

if tuned_acc > baseline_acc:
    print(f"‚úÖ Hyperparameter tuning shows improvement")
else:
    print(f"‚ö†Ô∏è Hyperparameter tuning does not improve performance")

# Best performing model
best_model = max(all_experiments.keys(), 
                key=lambda x: all_experiments[x]['best_val_acc'])
best_acc = all_experiments[best_model]['best_val_acc']

print(f"\nüèÜ Best performing model: {best_model} ({best_acc:.2f}%)")

## 11. Learning Rate and Optimizer Analysis

Analyze the effect of different learning rates and optimizers.

In [None]:
def quick_training_experiment(model_fn, config_base, modifications, experiment_name):
    """Run a quick training experiment with modifications"""
    
    # Create modified config
    config = config_base.copy()
    config.update(modifications)
    config['experiment_name'] = experiment_name
    config['epochs'] = 15  # Shorter training for quick experiments
    
    # Create model
    model = model_fn(num_classes=num_classes, input_size=IMAGE_SIZE)
    
    print(f"\nQuick experiment: {experiment_name}")
    print(f"Modifications: {modifications}")
    
    # Train
    trainer, history = train_model(model, train_loader, val_loader, config)
    
    return {
        'config': config,
        'history': history,
        'best_val_acc': max(history['val_acc']),
        'final_val_acc': history['val_acc'][-1]
    }

print("LEARNING RATE AND OPTIMIZER EXPERIMENTS")
print("=" * 50)

# Base configuration for quick experiments
quick_config = {
    'model_name': 'SimpleCNN_Quick',
    'device': str(device),
    'epochs': 15,
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'optimizer': 'adam',
    'scheduler': 'none',
    'patience': 10,
    'min_delta': 0.001,
    'use_tensorboard': False,
    'log_dir': '../logs',
    'checkpoint_dir': '../checkpoints'
}

# Experiment with different learning rates
lr_experiments = {}
learning_rates = [1e-4, 5e-4, 1e-3, 5e-3, 1e-2]

print("\nTesting different learning rates...")
for lr in learning_rates:
    exp_name = f"lr_{lr:.0e}"
    modifications = {'learning_rate': lr}
    
    try:
        result = quick_training_experiment(
            create_simple_cnn, quick_config, modifications, exp_name
        )
        lr_experiments[exp_name] = result
        print(f"  LR {lr:.0e}: Best Val Acc = {result['best_val_acc']:.2f}%")
    except Exception as e:
        print(f"  LR {lr:.0e}: Failed - {e}")

# Experiment with different optimizers
optimizer_experiments = {}
optimizers = ['adam', 'sgd']

print("\nTesting different optimizers...")
for opt in optimizers:
    exp_name = f"opt_{opt}"
    modifications = {'optimizer': opt, 'learning_rate': 1e-3}
    if opt == 'sgd':
        modifications['momentum'] = 0.9
    
    try:
        result = quick_training_experiment(
            create_simple_cnn, quick_config, modifications, exp_name
        )
        optimizer_experiments[exp_name] = result
        print(f"  {opt.upper()}: Best Val Acc = {result['best_val_acc']:.2f}%")
    except Exception as e:
        print(f"  {opt.upper()}: Failed - {e}")

# Visualize results
if lr_experiments:
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Learning rate comparison
    lr_names = list(lr_experiments.keys())
    lr_accs = [lr_experiments[name]['best_val_acc'] for name in lr_names]
    lr_values = [float(name.split('_')[1]) for name in lr_names]
    
    axes[0].semilogx(lr_values, lr_accs, 'bo-', linewidth=2, markersize=8)
    axes[0].set_title('Learning Rate vs Performance', fontweight='bold')
    axes[0].set_xlabel('Learning Rate')
    axes[0].set_ylabel('Best Validation Accuracy (%)')
    axes[0].grid(True, alpha=0.3)
    
    # Annotate points
    for lr, acc in zip(lr_values, lr_accs):
        axes[0].annotate(f'{acc:.1f}%', (lr, acc), 
                        textcoords="offset points", xytext=(0,10), ha='center')
    
    # Training curves for best LR
    best_lr_exp = max(lr_experiments.values(), key=lambda x: x['best_val_acc'])
    epochs = range(1, len(best_lr_exp['history']['val_acc']) + 1)
    
    axes[1].plot(epochs, best_lr_exp['history']['train_acc'], 'b-', label='Training')
    axes[1].plot(epochs, best_lr_exp['history']['val_acc'], 'r-', label='Validation')
    axes[1].set_title('Best Learning Rate - Training Curves', fontweight='bold')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Find best learning rate
    best_lr_name = max(lr_experiments.keys(), 
                       key=lambda x: lr_experiments[x]['best_val_acc'])
    best_lr_value = float(best_lr_name.split('_')[1])
    best_lr_acc = lr_experiments[best_lr_name]['best_val_acc']
    
    print(f"\nüéØ Optimal learning rate: {best_lr_value:.0e} (Val Acc: {best_lr_acc:.2f}%)")

# Summary of hyperparameter experiments
print("\n" + "="*60)
print("HYPERPARAMETER EXPERIMENT SUMMARY")
print("="*60)

if lr_experiments:
    print("\nLearning Rate Results:")
    for name, result in lr_experiments.items():
        lr_val = float(name.split('_')[1])
        print(f"  {lr_val:.0e}: {result['best_val_acc']:.2f}%")

if optimizer_experiments:
    print("\nOptimizer Results:")
    for name, result in optimizer_experiments.items():
        opt_name = name.split('_')[1].upper()
        print(f"  {opt_name}: {result['best_val_acc']:.2f}%")

## 12. Training Summary and Conclusions

Summarize all experiments and provide training recommendations.

In [None]:
print("\n" + "="*80)
print("TRAINING EXPERIMENT SUMMARY AND CONCLUSIONS")
print("="*80)

# Main experiment results
print("\nüìä MAIN EXPERIMENT RESULTS:")
print("-" * 40)

for model_name, results in all_experiments.items():
    config = results['config']
    print(f"\n{model_name}:")
    print(f"  Best validation accuracy: {results['best_val_acc']:.2f}%")
    print(f"  Final validation accuracy: {results['final_val_acc']:.2f}%")
    print(f"  Minimum validation loss: {results['min_val_loss']:.4f}")
    print(f"  Training epochs: {len(results['history']['train_loss'])}")
    print(f"  Optimizer: {config['optimizer']}")
    print(f"  Learning rate: {config['learning_rate']}")
    print(f"  Scheduler: {config['scheduler']}")

# Key findings
print("\nüîç KEY FINDINGS:")
print("-" * 20)

# Best model identification
best_model = max(all_experiments.keys(), 
                key=lambda x: all_experiments[x]['best_val_acc'])
best_acc = all_experiments[best_model]['best_val_acc']
worst_model = min(all_experiments.keys(), 
                 key=lambda x: all_experiments[x]['best_val_acc'])
worst_acc = all_experiments[worst_model]['best_val_acc']

print(f"\n1. Model Performance Ranking:")
sorted_models = sorted(all_experiments.items(), 
                      key=lambda x: x[1]['best_val_acc'], 
                      reverse=True)
for i, (name, results) in enumerate(sorted_models, 1):
    print(f"   {i}. {name}: {results['best_val_acc']:.2f}%")

print(f"\n2. Performance Range: {worst_acc:.2f}% - {best_acc:.2f}% ({best_acc - worst_acc:.2f}% spread)")

# Architecture comparison
simple_models = [k for k in all_experiments.keys() if 'SimpleCNN' in k]
resnet_models = [k for k in all_experiments.keys() if 'ResNet' in k]

if simple_models and resnet_models:
    best_simple = max(simple_models, key=lambda x: all_experiments[x]['best_val_acc'])
    best_resnet = max(resnet_models, key=lambda x: all_experiments[x]['best_val_acc'])
    
    simple_acc = all_experiments[best_simple]['best_val_acc']
    resnet_acc = all_experiments[best_resnet]['best_val_acc']
    
    print(f"\n3. Architecture Comparison:")
    print(f"   Best SimpleCNN: {simple_acc:.2f}%")
    print(f"   Best ResNet: {resnet_acc:.2f}%")
    if resnet_acc > simple_acc:
        print(f"   ‚úÖ ResNet shows {resnet_acc - simple_acc:.2f}% improvement")
    else:
        print(f"   ‚ö†Ô∏è SimpleCNN performs {simple_acc - resnet_acc:.2f}% better")

# Hyperparameter tuning impact
if 'SimpleCNN_Tuned' in all_experiments and 'SimpleCNN' in all_experiments:
    baseline_acc = all_experiments['SimpleCNN']['best_val_acc']
    tuned_acc = all_experiments['SimpleCNN_Tuned']['best_val_acc']
    
    print(f"\n4. Hyperparameter Tuning Impact:")
    print(f"   Baseline: {baseline_acc:.2f}%")
    print(f"   Tuned: {tuned_acc:.2f}%")
    if tuned_acc > baseline_acc:
        print(f"   ‚úÖ Tuning improved performance by {tuned_acc - baseline_acc:.2f}%")
    else:
        print(f"   ‚ö†Ô∏è Tuning decreased performance by {baseline_acc - tuned_acc:.2f}%")

# Learning rate findings
if 'lr_experiments' in locals() and lr_experiments:
    best_lr_exp = max(lr_experiments.values(), key=lambda x: x['best_val_acc'])
    best_lr_name = max(lr_experiments.keys(), key=lambda x: lr_experiments[x]['best_val_acc'])
    best_lr_value = float(best_lr_name.split('_')[1])
    
    print(f"\n5. Learning Rate Analysis:")
    print(f"   Optimal LR: {best_lr_value:.0e}")
    print(f"   Performance range: {min(lr_experiments.values(), key=lambda x: x['best_val_acc'])['best_val_acc']:.2f}% - {best_lr_exp['best_val_acc']:.2f}%")

# Training efficiency
print(f"\n6. Training Efficiency:")
for name, results in all_experiments.items():
    epochs_trained = len(results['history']['train_loss'])
    config_epochs = results['config']['epochs']
    completion_rate = epochs_trained / config_epochs * 100
    print(f"   {name}: {epochs_trained}/{config_epochs} epochs ({completion_rate:.0f}%)")

# Recommendations
print("\nüéØ TRAINING RECOMMENDATIONS:")
print("-" * 30)

print("\n1. Model Architecture:")
if resnet_acc > simple_acc:
    print("   ‚úÖ Use ResNet architecture for better performance")
    print("   ‚Ä¢ ResNet's residual connections help with gradient flow")
    print("   ‚Ä¢ Better feature extraction capabilities")
else:
    print("   ‚úÖ SimpleCNN is sufficient for this task")
    print("   ‚Ä¢ Faster training and inference")
    print("   ‚Ä¢ Lower computational requirements")

print("\n2. Hyperparameter Settings:")
print(f"   ‚Ä¢ Best performing config: {best_model}")
best_config = all_experiments[best_model]['config']
print(f"   ‚Ä¢ Learning rate: {best_config['learning_rate']}")
print(f"   ‚Ä¢ Optimizer: {best_config['optimizer']}")
print(f"   ‚Ä¢ Scheduler: {best_config['scheduler']}")
print(f"   ‚Ä¢ Weight decay: {best_config['weight_decay']}")

if 'lr_experiments' in locals() and lr_experiments:
    print(f"   ‚Ä¢ Optimal learning rate from experiments: {best_lr_value:.0e}")

print("\n3. Training Strategy:")
avg_epochs = np.mean([len(r['history']['train_loss']) for r in all_experiments.values()])
print(f"   ‚Ä¢ Training duration: ~{avg_epochs:.0f} epochs typically sufficient")
print("   ‚Ä¢ Use early stopping to prevent overfitting")
print("   ‚Ä¢ Monitor validation accuracy as primary metric")
print("   ‚Ä¢ Consider cosine annealing scheduler for better convergence")

print("\n4. Further Improvements:")
print("   ‚Ä¢ Try different data augmentation strategies")
print("   ‚Ä¢ Experiment with larger image sizes (64x64 or 128x128)")
print("   ‚Ä¢ Consider ensemble methods")
print("   ‚Ä¢ Implement label smoothing for regularization")
print("   ‚Ä¢ Try different optimizers (AdamW, RMSprop)")

print("\n" + "="*80)
print(f"üèÜ BEST MODEL: {best_model} with {best_acc:.2f}% validation accuracy")
print("="*80)