# Training Analysis

This notebook analyzes training logs and results to visualize model performance.

## Objectives
- Load training logs (JSON, CSV, or TensorBoard)
- Plot Dice score curves (training and validation)
- Plot loss curves (training and validation)
- Compare validation vs training performance


In [1]:
import sys
from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Optional, Tuple
import pandas as pd

sys.path.insert(0, str(Path().absolute().parent))

# For TensorBoard support (optional)
try:
    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
    TENSORBOARD_AVAILABLE = True
except ImportError:
    TENSORBOARD_AVAILABLE = False
    print("TensorBoard not available. Install with: pip install tensorboard")


ModuleNotFoundError: No module named 'numpy'

## 1. Load Training Logs

Load training history from JSON, CSV, or TensorBoard format.


In [None]:
def load_json_logs(json_path: Path) -> Dict:
    """Load training history from JSON file.
    
    Args:
        json_path: Path to JSON file containing training history.
        
    Returns:
        Dictionary containing training history.
    """
    with open(json_path, 'r') as f:
        history = json.load(f)
    return history


def load_csv_logs(csv_path: Path) -> Dict:
    """Load training history from CSV file.
    
    Args:
        csv_path: Path to CSV file containing training history.
        
    Returns:
        Dictionary containing training history.
    """
    df = pd.read_csv(csv_path)
    history = {}
    for col in df.columns:
        history[col] = df[col].tolist()
    return history


def load_tensorboard_logs(tb_path: Path, tags: Optional[List[str]] = None) -> Dict:
    """Load training history from TensorBoard event files.
    
    Args:
        tb_path: Path to TensorBoard log directory.
        tags: Optional list of tags to extract (default: all).
        
    Returns:
        Dictionary containing training history.
    """
    if not TENSORBOARD_AVAILABLE:
        raise ImportError("TensorBoard is not available. Install with: pip install tensorboard")
    
    ea = EventAccumulator(str(tb_path))
    ea.Reload()
    
    history = {}
    available_tags = ea.Tags()['scalars']
    
    if tags is None:
        tags = available_tags
    
    for tag in tags:
        if tag in available_tags:
            scalar_events = ea.Scalars(tag)
            history[tag] = [event.value for event in scalar_events]
    
    return history


# Configuration: specify path to your training logs
# Option 1: JSON file (default format from train.py)
log_path = Path("../logs/training_history_20240101_120000.json")

# Option 2: CSV file (uncomment to use)
# log_path = Path("../logs/training_history.csv")

# Option 3: TensorBoard directory (uncomment to use)
# log_path = Path("../logs/tensorboard")

# Load logs
print(f"Loading logs from: {log_path}")

if log_path.is_dir():
    # TensorBoard format
    if TENSORBOARD_AVAILABLE:
        history = load_tensorboard_logs(log_path)
        print("Loaded TensorBoard logs")
    else:
        raise ImportError("TensorBoard is not available")
elif log_path.suffix == '.json':
    # JSON format
    if log_path.exists():
        history = load_json_logs(log_path)
        print("Loaded JSON logs")
    else:
        print(f"ERROR: Log file not found at {log_path}")
        print("Please run training first or specify correct path.")
        history = None
elif log_path.suffix == '.csv':
    # CSV format
    if log_path.exists():
        history = load_csv_logs(log_path)
        print("Loaded CSV logs")
    else:
        print(f"ERROR: Log file not found at {log_path}")
        history = None
else:
    print(f"ERROR: Unsupported file format: {log_path.suffix}")
    history = None

# Display available metrics
if history is not None:
    print(f"\nAvailable metrics: {list(history.keys())}")
    print(f"Number of epochs: {len(history.get('train_loss', []))}")


## 2. Plot Dice Score Curves

Visualize Dice score evolution for training and validation sets.


In [None]:
def plot_dice_curves(history: Dict, save_path: Optional[Path] = None):
    """Plot Dice score curves for training and validation.
    
    Args:
        history: Training history dictionary.
        save_path: Optional path to save the figure.
    """
    epochs = range(1, len(history.get('train_dice', [])) + 1)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Training Dice
    if 'train_dice' in history and len(history['train_dice']) > 0:
        axes[0].plot(epochs, history['train_dice'], 'b-', label='Training Dice', linewidth=2)
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Dice Score')
        axes[0].set_title('Training Dice Score')
        axes[0].grid(True, alpha=0.3)
        axes[0].legend()
        axes[0].set_ylim([0, 1])
        
        # Add final value annotation
        final_dice = history['train_dice'][-1]
        axes[0].annotate(f'Final: {final_dice:.4f}', 
                         xy=(len(epochs), final_dice),
                         xytext=(len(epochs) * 0.7, final_dice + 0.1),
                         arrowprops=dict(arrowstyle='->', color='blue'))
    else:
        axes[0].text(0.5, 0.5, 'No training Dice data', 
                     ha='center', va='center', transform=axes[0].transAxes)
        axes[0].set_title('Training Dice Score')
    
    # Validation Dice
    if 'val_dice' in history and len(history['val_dice']) > 0:
        axes[1].plot(epochs, history['val_dice'], 'r-', label='Validation Dice', linewidth=2)
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Dice Score')
        axes[1].set_title('Validation Dice Score')
        axes[1].grid(True, alpha=0.3)
        axes[1].legend()
        axes[1].set_ylim([0, 1])
        
        # Add final value annotation
        final_dice = history['val_dice'][-1]
        best_dice = max(history['val_dice'])
        best_epoch = history['val_dice'].index(best_dice) + 1
        axes[1].annotate(f'Final: {final_dice:.4f}', 
                         xy=(len(epochs), final_dice),
                         xytext=(len(epochs) * 0.7, final_dice + 0.1),
                         arrowprops=dict(arrowstyle='->', color='red'))
        axes[1].annotate(f'Best: {best_dice:.4f} (Epoch {best_epoch})', 
                         xy=(best_epoch, best_dice),
                         xytext=(best_epoch, best_dice + 0.15),
                         arrowprops=dict(arrowstyle='->', color='green'))
    else:
        axes[1].text(0.5, 0.5, 'No validation Dice data', 
                     ha='center', va='center', transform=axes[1].transAxes)
        axes[1].set_title('Validation Dice Score')
    
    plt.tight_layout()
    
    if save_path:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")
    else:
        plt.show()
    
    plt.close()

if history is not None:
    plot_dice_curves(history)
else:
    print("Cannot plot Dice curves: history not loaded.")


In [None]:
def plot_loss_curves(history: Dict, save_path: Optional[Path] = None):
    """Plot loss curves for training and validation.
    
    Args:
        history: Training history dictionary.
        save_path: Optional path to save the figure.
    """
    epochs = range(1, len(history.get('train_loss', [])) + 1)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Training Loss
    if 'train_loss' in history and len(history['train_loss']) > 0:
        axes[0].plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2)
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training Loss')
        axes[0].grid(True, alpha=0.3)
        axes[0].legend()
        
        # Add final value annotation
        final_loss = history['train_loss'][-1]
        axes[0].annotate(f'Final: {final_loss:.4f}', 
                         xy=(len(epochs), final_loss),
                         xytext=(len(epochs) * 0.7, final_loss * 1.1),
                         arrowprops=dict(arrowstyle='->', color='blue'))
    else:
        axes[0].text(0.5, 0.5, 'No training loss data', 
                     ha='center', va='center', transform=axes[0].transAxes)
        axes[0].set_title('Training Loss')
    
    # Validation Loss
    if 'val_loss' in history and len(history['val_loss']) > 0:
        axes[1].plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Loss')
        axes[1].set_title('Validation Loss')
        axes[1].grid(True, alpha=0.3)
        axes[1].legend()
        
        # Add final value annotation
        final_loss = history['val_loss'][-1]
        best_loss = min(history['val_loss'])
        best_epoch = history['val_loss'].index(best_loss) + 1
        axes[1].annotate(f'Final: {final_loss:.4f}', 
                         xy=(len(epochs), final_loss),
                         xytext=(len(epochs) * 0.7, final_loss * 1.1),
                         arrowprops=dict(arrowstyle='->', color='red'))
        axes[1].annotate(f'Best: {best_loss:.4f} (Epoch {best_epoch})', 
                         xy=(best_epoch, best_loss),
                         xytext=(best_epoch, best_loss * 1.2),
                         arrowprops=dict(arrowstyle='->', color='green'))
    else:
        axes[1].text(0.5, 0.5, 'No validation loss data', 
                     ha='center', va='center', transform=axes[1].transAxes)
        axes[1].set_title('Validation Loss')
    
    plt.tight_layout()
    
    if save_path:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")
    else:
        plt.show()
    
    plt.close()

if history is not None:
    plot_loss_curves(history)
else:
    print("Cannot plot loss curves: history not loaded.")


## 4. Compare Training vs Validation

Side-by-side comparison of training and validation performance.


In [None]:
def plot_training_vs_validation(history: Dict, save_path: Optional[Path] = None):
    """Plot training vs validation comparison for Dice and Loss.
    
    Args:
        history: Training history dictionary.
        save_path: Optional path to save the figure.
    """
    epochs = range(1, len(history.get('train_loss', [])) + 1)
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 10))
    
    # Dice comparison
    if 'train_dice' in history and 'val_dice' in history:
        if len(history['train_dice']) > 0 and len(history['val_dice']) > 0:
            axes[0].plot(epochs, history['train_dice'], 'b-', label='Training Dice', linewidth=2, alpha=0.7)
            axes[0].plot(epochs, history['val_dice'], 'r-', label='Validation Dice', linewidth=2, alpha=0.7)
            axes[0].set_xlabel('Epoch')
            axes[0].set_ylabel('Dice Score')
            axes[0].set_title('Training vs Validation - Dice Score')
            axes[0].grid(True, alpha=0.3)
            axes[0].legend()
            axes[0].set_ylim([0, 1])
            
            # Add statistics
            train_final = history['train_dice'][-1]
            val_final = history['val_dice'][-1]
            val_best = max(history['val_dice'])
            axes[0].text(0.02, 0.98, 
                        f'Train Final: {train_final:.4f}\nVal Final: {val_final:.4f}\nVal Best: {val_best:.4f}',
                        transform=axes[0].transAxes,
                        verticalalignment='top',
                        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        else:
            axes[0].text(0.5, 0.5, 'No Dice data available', 
                       ha='center', va='center', transform=axes[0].transAxes)
            axes[0].set_title('Training vs Validation - Dice Score')
    else:
        axes[0].text(0.5, 0.5, 'No Dice data available', 
                   ha='center', va='center', transform=axes[0].transAxes)
        axes[0].set_title('Training vs Validation - Dice Score')
    
    # Loss comparison
    if 'train_loss' in history and 'val_loss' in history:
        if len(history['train_loss']) > 0 and len(history['val_loss']) > 0:
            axes[1].plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2, alpha=0.7)
            axes[1].plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2, alpha=0.7)
            axes[1].set_xlabel('Epoch')
            axes[1].set_ylabel('Loss')
            axes[1].set_title('Training vs Validation - Loss')
            axes[1].grid(True, alpha=0.3)
            axes[1].legend()
            
            # Add statistics
            train_final = history['train_loss'][-1]
            val_final = history['val_loss'][-1]
            val_best = min(history['val_loss'])
            axes[1].text(0.02, 0.98, 
                        f'Train Final: {train_final:.4f}\nVal Final: {val_final:.4f}\nVal Best: {val_best:.4f}',
                        transform=axes[1].transAxes,
                        verticalalignment='top',
                        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        else:
            axes[1].text(0.5, 0.5, 'No loss data available', 
                       ha='center', va='center', transform=axes[1].transAxes)
            axes[1].set_title('Training vs Validation - Loss')
    else:
        axes[1].text(0.5, 0.5, 'No loss data available', 
                   ha='center', va='center', transform=axes[1].transAxes)
        axes[1].set_title('Training vs Validation - Loss')
    
    plt.tight_layout()
    
    if save_path:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")
    else:
        plt.show()
    
    plt.close()

if history is not None:
    plot_training_vs_validation(history)
else:
    print("Cannot plot comparison: history not loaded.")


## 5. Additional Metrics

Visualize additional metrics like IoU and learning rate.


In [None]:
def plot_additional_metrics(history: Dict, save_path: Optional[Path] = None):
    """Plot additional metrics: IoU and learning rate.
    
    Args:
        history: Training history dictionary.
        save_path: Optional path to save the figure.
    """
    epochs = range(1, len(history.get('train_loss', [])) + 1)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # IoU scores
    if 'train_iou' in history or 'val_iou' in history:
        if 'train_iou' in history and len(history['train_iou']) > 0:
            axes[0].plot(epochs, history['train_iou'], 'b-', label='Training IoU', linewidth=2, alpha=0.7)
        if 'val_iou' in history and len(history['val_iou']) > 0:
            axes[0].plot(epochs, history['val_iou'], 'r-', label='Validation IoU', linewidth=2, alpha=0.7)
        
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('IoU Score')
        axes[0].set_title('IoU Score Evolution')
        axes[0].grid(True, alpha=0.3)
        axes[0].legend()
        axes[0].set_ylim([0, 1])
    else:
        axes[0].text(0.5, 0.5, 'No IoU data available', 
                   ha='center', va='center', transform=axes[0].transAxes)
        axes[0].set_title('IoU Score Evolution')
    
    # Learning rate
    if 'learning_rate' in history and len(history['learning_rate']) > 0:
        axes[1].plot(epochs, history['learning_rate'], 'g-', label='Learning Rate', linewidth=2)
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Learning Rate')
        axes[1].set_title('Learning Rate Schedule')
        axes[1].grid(True, alpha=0.3)
        axes[1].legend()
        axes[1].set_yscale('log')
        
        # Add final value annotation
        final_lr = history['learning_rate'][-1]
        initial_lr = history['learning_rate'][0]
        axes[1].annotate(f'Initial: {initial_lr:.2e}\nFinal: {final_lr:.2e}', 
                        xy=(len(epochs), final_lr),
                        xytext=(len(epochs) * 0.7, final_lr * 2),
                        arrowprops=dict(arrowstyle='->', color='green'))
    else:
        axes[1].text(0.5, 0.5, 'No learning rate data available', 
                   ha='center', va='center', transform=axes[1].transAxes)
        axes[1].set_title('Learning Rate Schedule')
    
    plt.tight_layout()
    
    if save_path:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")
    else:
        plt.show()
    
    plt.close()

if history is not None:
    plot_additional_metrics(history)
else:
    print("Cannot plot additional metrics: history not loaded.")


## 6. Summary Statistics

Display key statistics from the training run.


In [None]:
def print_summary_statistics(history: Dict):
    """Print summary statistics from training history.
    
    Args:
        history: Training history dictionary.
    """
    print("=" * 60)
    print("TRAINING SUMMARY STATISTICS")
    print("=" * 60)
    
    num_epochs = len(history.get('train_loss', []))
    print(f"\nTotal epochs: {num_epochs}")
    
    # Training metrics
    if 'train_loss' in history and len(history['train_loss']) > 0:
        print(f"\nTraining Loss:")
        print(f"  Initial: {history['train_loss'][0]:.4f}")
        print(f"  Final: {history['train_loss'][-1]:.4f}")
        print(f"  Best: {min(history['train_loss']):.4f}")
        print(f"  Improvement: {((history['train_loss'][0] - history['train_loss'][-1]) / history['train_loss'][0] * 100):.2f}%")
    
    if 'train_dice' in history and len(history['train_dice']) > 0:
        print(f"\nTraining Dice:")
        print(f"  Initial: {history['train_dice'][0]:.4f}")
        print(f"  Final: {history['train_dice'][-1]:.4f}")
        print(f"  Best: {max(history['train_dice']):.4f}")
        print(f"  Improvement: {((history['train_dice'][-1] - history['train_dice'][0]) / (1 - history['train_dice'][0]) * 100):.2f}%")
    
    # Validation metrics
    if 'val_loss' in history and len(history['val_loss']) > 0:
        print(f"\nValidation Loss:")
        print(f"  Initial: {history['val_loss'][0]:.4f}")
        print(f"  Final: {history['val_loss'][-1]:.4f}")
        print(f"  Best: {min(history['val_loss']):.4f} (Epoch {history['val_loss'].index(min(history['val_loss'])) + 1})")
        print(f"  Improvement: {((history['val_loss'][0] - history['val_loss'][-1]) / history['val_loss'][0] * 100):.2f}%")
    
    if 'val_dice' in history and len(history['val_dice']) > 0:
        print(f"\nValidation Dice:")
        print(f"  Initial: {history['val_dice'][0]:.4f}")
        print(f"  Final: {history['val_dice'][-1]:.4f}")
        best_dice = max(history['val_dice'])
        best_epoch = history['val_dice'].index(best_dice) + 1
        print(f"  Best: {best_dice:.4f} (Epoch {best_epoch})")
        print(f"  Improvement: {((history['val_dice'][-1] - history['val_dice'][0]) / (1 - history['val_dice'][0]) * 100):.2f}%")
    
    if 'val_iou' in history and len(history['val_iou']) > 0:
        print(f"\nValidation IoU:")
        print(f"  Initial: {history['val_iou'][0]:.4f}")
        print(f"  Final: {history['val_iou'][-1]:.4f}")
        print(f"  Best: {max(history['val_iou']):.4f}")
    
    if 'learning_rate' in history and len(history['learning_rate']) > 0:
        print(f"\nLearning Rate:")
        print(f"  Initial: {history['learning_rate'][0]:.2e}")
        print(f"  Final: {history['learning_rate'][-1]:.2e}")
        print(f"  Reduction: {((history['learning_rate'][0] - history['learning_rate'][-1]) / history['learning_rate'][0] * 100):.2f}%")
    
    # Overfitting check
    if 'train_dice' in history and 'val_dice' in history:
        if len(history['train_dice']) > 0 and len(history['val_dice']) > 0:
            train_final = history['train_dice'][-1]
            val_final = history['val_dice'][-1]
            gap = train_final - val_final
            print(f"\nOverfitting Analysis:")
            print(f"  Train-Val Dice Gap: {gap:.4f}")
            if gap > 0.1:
                print(f"  [WARNING] Large gap detected - possible overfitting")
            elif gap < 0.05:
                print(f"  [OK] Small gap - good generalization")
            else:
                print(f"  [OK] Moderate gap - acceptable")
    
    print("=" * 60)

if history is not None:
    print_summary_statistics(history)
else:
    print("Cannot print statistics: history not loaded.")


## Summary

This notebook analyzed training logs by:

1. **Loading logs**: Supported JSON, CSV, and TensorBoard formats
2. **Dice curves**: Visualized Dice score evolution for training and validation
3. **Loss curves**: Visualized loss evolution for training and validation
4. **Training vs Validation**: Side-by-side comparison of performance
5. **Additional metrics**: IoU scores and learning rate schedule
6. **Summary statistics**: Key metrics and overfitting analysis

### Key Insights:
- Monitor Dice score to track segmentation quality
- Compare training vs validation to detect overfitting
- Learning rate schedule affects convergence
- Best model checkpoint is typically at peak validation Dice
