# Training Results Analysis & Visualization

**Purpose:** Analyze and visualize results from `src/run_training.py`

This notebook loads saved training artifacts and provides comprehensive analysis:
- Training curves (loss, accuracy, learning rate)
- Model performance metrics
- Class-wise performance analysis
- Confusion matrix
- Misclassification analysis

**Input Files:**
- `checkpoints/resnet50_full/training_history.csv`
- `checkpoints/resnet50_full/best_model.pth`
- `checkpoints/resnet50_full/training_metadata.json`

## 1. Setup

In [None]:
import sys
import json
from pathlib import Path

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from sklearn.metrics import confusion_matrix, classification_report

# Add project root to path
project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("✓ Imports complete")

## 2. Configuration

In [None]:
# Checkpoint directory to analyze
CHECKPOINT_DIR = '../../checkpoints/resnet50_full'

# File paths
HISTORY_PATH = Path(CHECKPOINT_DIR) / 'training_history.csv'
METADATA_PATH = Path(CHECKPOINT_DIR) / 'training_metadata.json'
MODEL_PATH = Path(CHECKPOINT_DIR) / 'best_model.pth'

# Output directory for figures
OUTPUT_DIR = Path(CHECKPOINT_DIR) / 'analysis'
OUTPUT_DIR.mkdir(exist_ok=True)

print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print(f"Output directory: {OUTPUT_DIR}")

## 3. Load Training Artifacts

In [None]:
# Load training history
print("Loading training history...")
history_df = pd.read_csv(HISTORY_PATH)
print(f"  ✓ Loaded {len(history_df)} epochs")

# Load metadata
print("\nLoading training metadata...")
with open(METADATA_PATH, 'r') as f:
    metadata = json.load(f)
print(f"  ✓ Model: {metadata['model']}")
print(f"  ✓ Classes: {metadata['num_classes']}")
print(f"  ✓ Epochs trained: {metadata['training_config']['epochs']}")
print(f"  ✓ Batch size: {metadata['training_config']['batch_size']}")

# Display first few rows
print("\nTraining history preview:")
history_df.head()

## 4. Training Summary Statistics

In [None]:
# Find best epoch
best_epoch = history_df['val_loss'].idxmin()
best_val_loss = history_df.loc[best_epoch, 'val_loss']
best_val_acc = history_df.loc[best_epoch, 'val_acc']

# Final epoch results
final_train_loss = history_df['train_loss'].iloc[-1]
final_train_acc = history_df['train_acc'].iloc[-1]
final_val_loss = history_df['val_loss'].iloc[-1]
final_val_acc = history_df['val_acc'].iloc[-1]

print("=" * 70)
print("TRAINING SUMMARY")
print("=" * 70)

print(f"\nBest Results (Epoch {best_epoch + 1}):")
print(f"  Val Loss: {best_val_loss:.4f}")
print(f"  Val Accuracy: {best_val_acc:.2f}%")

print(f"\nFinal Results (Epoch {len(history_df)}):")
print(f"  Train Loss: {final_train_loss:.4f}")
print(f"  Train Accuracy: {final_train_acc:.2f}%")
print(f"  Val Loss: {final_val_loss:.4f}")
print(f"  Val Accuracy: {final_val_acc:.2f}%")

# Improvement metrics
train_acc_improvement = final_train_acc - history_df['train_acc'].iloc[0]
val_acc_improvement = final_val_acc - history_df['val_acc'].iloc[0]

print(f"\nImprovement from Epoch 1:")
print(f"  Train Accuracy: +{train_acc_improvement:.2f}%")
print(f"  Val Accuracy: +{val_acc_improvement:.2f}%")

# Overfitting check
train_val_gap = final_train_acc - final_val_acc
print(f"\nOverfitting Analysis:")
print(f"  Train-Val Gap: {train_val_gap:.2f}%")
if train_val_gap < 5:
    print("  Status: ✅ No significant overfitting")
elif train_val_gap < 10:
    print("  Status: ⚠️ Slight overfitting")
else:
    print("  Status: ❌ Significant overfitting - consider regularization")

## 5. Training Curves Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Loss curves
ax = axes[0, 0]
ax.plot(history_df.index + 1, history_df['train_loss'], 
        label='Train Loss', marker='o', linewidth=2, markersize=6)
ax.plot(history_df.index + 1, history_df['val_loss'], 
        label='Val Loss', marker='s', linewidth=2, markersize=6)
ax.axvline(best_epoch + 1, color='red', linestyle='--', alpha=0.5, label=f'Best Epoch ({best_epoch + 1})')
ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax.set_ylabel('Loss', fontsize=12, fontweight='bold')
ax.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# Plot 2: Accuracy curves
ax = axes[0, 1]
ax.plot(history_df.index + 1, history_df['train_acc'], 
        label='Train Acc', marker='o', linewidth=2, markersize=6)
ax.plot(history_df.index + 1, history_df['val_acc'], 
        label='Val Acc', marker='s', linewidth=2, markersize=6)
ax.axvline(best_epoch + 1, color='red', linestyle='--', alpha=0.5, label=f'Best Epoch ({best_epoch + 1})')
ax.axhline(60, color='green', linestyle=':', alpha=0.5, label='Target (60%)')
ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
ax.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# Plot 3: Learning rate schedule
ax = axes[1, 0]
ax.plot(history_df.index + 1, history_df['learning_rates'], 
        marker='o', linewidth=2, markersize=6, color='green')
ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax.set_ylabel('Learning Rate', fontsize=12, fontweight='bold')
ax.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

# Plot 4: Train-Val Gap
ax = axes[1, 1]
train_val_gap = history_df['train_acc'] - history_df['val_acc']
ax.plot(history_df.index + 1, train_val_gap, 
        marker='o', linewidth=2, markersize=6, color='orange')
ax.axhline(0, color='black', linestyle='-', alpha=0.3)
ax.axhline(5, color='yellow', linestyle='--', alpha=0.5, label='Slight overfitting')
ax.axhline(10, color='red', linestyle='--', alpha=0.5, label='Significant overfitting')
ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax.set_ylabel('Train - Val Accuracy (%)', fontsize=12, fontweight='bold')
ax.set_title('Overfitting Analysis (Train-Val Gap)', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Figure saved to {OUTPUT_DIR / 'training_curves.png'}")

## 6. Epoch-by-Epoch Analysis

In [None]:
# Show detailed epoch statistics
history_display = history_df.copy()
history_display.index = history_display.index + 1
history_display.index.name = 'Epoch'

# Add improvement columns
history_display['val_acc_delta'] = history_display['val_acc'].diff()
history_display['val_loss_delta'] = history_display['val_loss'].diff()

# Format for display
pd.options.display.float_format = '{:.4f}'.format

print("Epoch-by-Epoch Statistics:")
print("=" * 100)
history_display[['train_loss', 'train_acc', 'val_loss', 'val_acc', 'val_acc_delta', 'learning_rates']]

## 7. Model Evaluation (Optional - Requires Data)

In [None]:
# This section can be expanded to:
# 1. Load the best model
# 2. Run inference on validation set
# 3. Generate confusion matrix
# 4. Analyze per-class performance
# 5. Show misclassified examples

print("Model evaluation section (to be implemented)")
print(f"Model checkpoint available at: {MODEL_PATH}")
print("\nTo add evaluation:")
print("1. Load validation data")
print("2. Load best_model.pth")
print("3. Run inference")
print("4. Generate confusion matrix and classification report")

## 8. Recommendations

In [None]:
print("=" * 70)
print("RECOMMENDATIONS")
print("=" * 70)

# Based on results
if best_val_acc >= 60:
    print("\n✅ Excellent results! Val accuracy >= 60%")
    print("\nNext steps:")
    print("  1. Fine-tune with unfrozen backbone:")
    print("     python src/run_training.py --no_freeze_backbone --lr 1e-5 --epochs 10")
    print("  2. Try ensemble methods")
    print("  3. Deploy to production")
elif best_val_acc >= 50:
    print("\n⚠️ Moderate results. Consider:")
    print("\nRecommended actions:")
    print("  1. Train longer:")
    print("     python src/run_training.py --epochs 25")
    print("  2. Fine-tune backbone:")
    print("     python src/run_training.py --no_freeze_backbone --lr 1e-5")
    print("  3. Try different model:")
    print("     python src/run_training.py --model vit")
else:
    print("\n❌ Results below expectations. Investigate:")
    print("\nPossible issues:")
    print("  1. Data quality - check images and labels")
    print("  2. Model architecture - try different models")
    print("  3. Hyperparameters - adjust learning rate, batch size")
    print("  4. Training duration - train for more epochs")

# Overfitting analysis
if train_val_gap > 10:
    print("\n⚠️ Significant overfitting detected!")
    print("\nTry:")
    print("  - Increase dropout rate")
    print("  - Add more data augmentation")
    print("  - Reduce model complexity")
    print("  - Add L2 regularization")

print("\n" + "=" * 70)

## 9. Export Summary Report

In [None]:
# Create summary report
report = {
    'training_info': {
        'model': metadata['model'],
        'num_classes': metadata['num_classes'],
        'total_epochs': len(history_df),
        'batch_size': metadata['training_config']['batch_size'],
        'learning_rate': metadata['training_config']['learning_rate'],
    },
    'best_results': {
        'epoch': int(best_epoch + 1),
        'val_loss': float(best_val_loss),
        'val_accuracy': float(best_val_acc),
    },
    'final_results': {
        'train_loss': float(final_train_loss),
        'train_accuracy': float(final_train_acc),
        'val_loss': float(final_val_loss),
        'val_accuracy': float(final_val_acc),
    },
    'metrics': {
        'train_acc_improvement': float(train_acc_improvement),
        'val_acc_improvement': float(val_acc_improvement),
        'train_val_gap': float(train_val_gap),
    }
}

# Save report
report_path = OUTPUT_DIR / 'training_summary.json'
with open(report_path, 'w') as f:
    json.dump(report, f, indent=2)

print(f"✓ Summary report saved to {report_path}")
print("\nReport contents:")
print(json.dumps(report, indent=2))