# Main Training Notebook

This notebook provides a comprehensive training system with auto-resume capabilities, enhanced checkpointing,
and real-time monitoring. It replaces the original train.py script with an interactive notebook interface.

## Setup and Imports

Import all necessary libraries and initialize the training environment.

In [None]:
import os
import sys
import math
import json
import warnings
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
from dataclasses import asdict
from datetime import datetime

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets

# MLflow for experiment tracking
import mlflow
import mlflow.pytorch

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Add src to path for imports
if 'src' not in sys.path:
    sys.path.append('src')

# Import project modules
from src.data.dataset import DataModule
from src.models.model import EvidenceModel
from src.utils import (
    evaluate, flatten_dict, get_optimizer, get_scheduler, 
    prepare_thresholds, set_seed
)
from src.utils.ema import EMA
from src.utils.training import compute_loss

print("‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")

## Load Configuration and Checkpoint Managers

Load configuration from the configuration notebook and initialize checkpoint managers.

In [None]:
# Run the configuration and checkpoint notebooks to get managers
%run 01_Configuration_Management.ipynb
%run 02_Enhanced_Checkpoint_System.ipynb

print("‚úÖ Configuration and checkpoint managers loaded!")

## Training Configuration Selection

Select or create a configuration for training.

In [None]:
# Interactive configuration selection
def create_config_selector():
    """Create an interactive configuration selector."""
    
    available_configs = config_manager.list_configs()
    
    if not available_configs:
        print("No configurations found. Please run the Configuration Management notebook first.")
        return None
    
    config_dropdown = widgets.Dropdown(
        options=available_configs,
        value=available_configs[0] if available_configs else None,
        description='Configuration:'
    )
    
    experiment_name = widgets.Text(
        value='training_experiment',
        description='Experiment:'
    )
    
    auto_resume = widgets.Checkbox(
        value=True,
        description='Auto Resume'
    )
    
    load_button = widgets.Button(
        description='Load Configuration',
        button_style='success'
    )
    
    output = widgets.Output()
    
    def on_load_clicked(b):
        with output:
            output.clear_output()
            
            try:
                # Load configuration
                config = config_manager.load_config(config_dropdown.value)
                
                # Update experiment name
                config.mlflow.experiment_name = experiment_name.value
                config.training.auto_resume = auto_resume.value
                
                # Store in global variable for use in training
                global current_config, current_experiment_name
                current_config = config
                current_experiment_name = experiment_name.value
                
                print(f"‚úÖ Configuration loaded: {config_dropdown.value}")
                print(f"   Experiment: {experiment_name.value}")
                print(f"   Model: {config.model.encoder.type}")
                print(f"   Auto-resume: {auto_resume.value}")
                
                # Display configuration summary
                summary_df = config_manager.get_config_summary(config)
                display(summary_df.head(15))
                
            except Exception as e:
                print(f"‚ùå Error loading configuration: {e}")
    
    load_button.on_click(on_load_clicked)
    
    layout = widgets.VBox([
        widgets.HTML("<h3>Training Configuration</h3>"),
        config_dropdown,
        experiment_name,
        auto_resume,
        load_button,
        output
    ])
    
    return layout

# Display configuration selector
config_selector = create_config_selector()
if config_selector:
    display(config_selector)
else:
    # Fallback: create default configuration
    current_config = ExperimentConfig()
    current_experiment_name = "default_training"
    print("Using default configuration")

## Enhanced Training Loop

The main training loop with auto-resume, enhanced checkpointing, and real-time monitoring.

In [None]:
def enhanced_train_loop(config: ExperimentConfig, experiment_name: str) -> Dict[str, float]:
    """Enhanced training loop with auto-resume and comprehensive checkpointing."""
    
    print(f"üöÄ Starting training: {experiment_name}")
    print(f"   Model: {config.model.encoder.type}")
    print(f"   Max epochs: {config.training.max_epochs}")
    print(f"   Batch size: {config.training.batch_size}")
    print(f"   Learning rate: {config.training.optimizer.learning_rate}")
    
    # Set seed for reproducibility
    set_seed(config.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"   Device: {device}")
    
    # Setup MLflow
    mlflow.set_tracking_uri(config.mlflow.tracking_uri)
    mlflow.set_experiment(config.mlflow.experiment_name)
    
    # Setup data
    print("\nüìä Setting up data...")
    data_module = DataModule(config.data, config.model)
    train_loader, val_loader, test_loader = data_module.dataloaders(
        batch_size=config.training.batch_size,
        val_batch_size=config.training.val_batch_size,
        test_batch_size=config.training.test_batch_size,
        num_workers=config.training.num_workers,
    )
    print(f"   Train samples: {len(train_loader.dataset)}")
    print(f"   Val samples: {len(val_loader.dataset)}")
    print(f"   Test samples: {len(test_loader.dataset)}")
    
    # Setup model
    print("\nü§ñ Setting up model...")
    model = EvidenceModel(config.model)
    model.to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    
    head_cfgs = model.head_configs
    head_thresholds = {
        head_name: prepare_thresholds(head_cfg)
        for head_name, head_cfg in head_cfgs.items()
        if head_cfg.get("type") == "multi_label"
    }
    
    # Setup optimizer and scheduler
    print("\n‚öôÔ∏è Setting up optimizer and scheduler...")
    optimizer = get_optimizer(model, config.training.optimizer)
    
    updates_per_epoch = math.ceil(
        len(train_loader) / config.training.gradient_accumulation_steps
    )
    total_steps = updates_per_epoch * config.training.max_epochs
    scheduler = get_scheduler(optimizer, config.training.scheduler, total_steps)
    
    print(f"   Optimizer: {config.training.optimizer.name}")
    print(f"   Scheduler: {config.training.scheduler.name}")
    print(f"   Total steps: {total_steps:,}")
    
    # Setup EMA if enabled
    ema = None
    ema_decay = config.training.get("ema_decay", 0.0)
    if ema_decay > 0:
        ema = EMA(model, decay=ema_decay)
        print(f"   EMA enabled with decay: {ema_decay}")
    
    # Setup mixed precision training
    torch.set_float32_matmul_precision("medium")
    amp_enabled = config.training.amp and device.type == "cuda"
    is_bf16_supported = bool(
        getattr(torch.cuda, "is_bf16_supported", lambda: False)()
    )
    use_bf16 = amp_enabled and config.training.bf16 and is_bf16_supported
    amp_dtype = torch.bfloat16 if use_bf16 else torch.float16
    scaler = GradScaler(enabled=amp_enabled and not use_bf16)
    
    if amp_enabled:
        print(f"   Mixed precision: {amp_dtype}")
    
    # Check for auto-resume
    training_state = None
    start_epoch = 1
    
    if config.training.auto_resume:
        print("\nüîÑ Checking for resumable checkpoint...")
        config_dict = asdict(config)
        should_resume, checkpoint_id = auto_resume_manager.should_resume_training(
            experiment_name, config_dict
        )
        
        if should_resume and checkpoint_id:
            training_state, _ = auto_resume_manager.resume_training(
                checkpoint_id, model, optimizer, scheduler, device
            )
            start_epoch = training_state.epoch + 1
    
    # Create new training state if not resuming
    if training_state is None:
        print("\nüÜï Starting fresh training...")
        training_state = auto_resume_manager.create_training_state()
    
    # Training configuration
    loss_weights = config.training.loss_weights
    focal_cfg = config.training.focal
    best_metric = training_state.best_metric
    best_epoch = training_state.best_epoch
    epochs_without_improve = training_state.epochs_without_improve
    
    # Start MLflow run
    config_dict = asdict(config)
    flattened = flatten_dict(config_dict)
    
    with mlflow.start_run(run_name=f"{experiment_name}_{config.model.encoder.type}"):
        mlflow.log_params(flattened)
        if config.mlflow.autolog:
            mlflow.pytorch.autolog(log_models=False)
        
        print(f"\nüéØ Training from epoch {start_epoch} to {config.training.max_epochs}")
        
        # Training loop with progress tracking
        epoch_progress = tqdm(
            range(start_epoch, config.training.max_epochs + 1),
            desc="Training",
            position=0,
            leave=True
        )
        
        for epoch in epoch_progress:
            model.train()
            running_loss = 0.0
            progress = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False, position=1)
            optimizer.zero_grad(set_to_none=True)
            
            for step, batch in enumerate(progress, start=1):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                token_type_ids = batch.get("token_type_ids")
                if token_type_ids is not None:
                    token_type_ids = token_type_ids.to(device)
                
                with autocast(
                    enabled=amp_enabled,
                    dtype=amp_dtype,
                ):
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                    )
                    head_outputs = outputs["head_outputs"]
                    loss = compute_loss(
                        head_outputs, batch, head_cfgs, loss_weights, focal_cfg, device
                    )
                    loss = loss / config.training.gradient_accumulation_steps
                
                if scaler.is_enabled():
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                
                if step % config.training.gradient_accumulation_steps == 0:
                    if scaler.is_enabled():
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), config.training.max_grad_norm
                        )
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), config.training.max_grad_norm
                        )
                        optimizer.step()
                    
                    # Step scheduler (except ReduceLROnPlateau)
                    if not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                        scheduler.step()
                    optimizer.zero_grad(set_to_none=True)
                    
                    # Update EMA
                    if ema is not None:
                        ema.update()
                
                running_loss += loss.item()
                avg_loss = running_loss / step
                
                # Get learning rate
                if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    current_lr = optimizer.param_groups[0]['lr']
                else:
                    current_lr = scheduler.get_last_lr()[0]
                
                progress.set_postfix({
                    "loss": f"{avg_loss:.4f}", 
                    "lr": f"{current_lr:.2e}"
                })
                
                if step % config.training.logging_interval == 0:
                    mlflow.log_metric(
                        "train_loss",
                        avg_loss,
                        step=(epoch - 1) * len(train_loader) + step,
                    )
            
            # Validation
            if ema is not None:
                ema.apply_shadow()
            
            val_loss, val_metrics = evaluate(
                model, val_loader, device, config, head_thresholds
            )
            
            if ema is not None:
                ema.restore()
            
            # Log metrics
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            for name, value in val_metrics.items():
                if name != "val_loss":
                    mlflow.log_metric(name, value, step=epoch)
            
            # Update training state
            training_state.epoch = epoch
            training_state.step = (epoch - 1) * len(train_loader) + len(train_loader)
            
            # Add to training history
            from src.utils.training import TrainingMetrics
            epoch_metrics = TrainingMetrics(
                epoch=epoch,
                train_loss=avg_loss,
                val_loss=val_loss,
                val_metrics=val_metrics,
                learning_rate=current_lr,
                timestamp=datetime.now().isoformat()
            )
            training_state.training_history.append(epoch_metrics)
            
            # Early stopping and checkpointing
            monitor_metric = val_metrics.get(
                config.training.early_stopping.monitor, float("-inf")
            )
            
            # Step ReduceLROnPlateau scheduler if used
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(monitor_metric)
            
            improved = monitor_metric > best_metric + config.training.early_stopping.min_delta
            
            if improved:
                best_metric = monitor_metric
                best_epoch = epoch
                epochs_without_improve = 0
                training_state.best_metric = best_metric
                training_state.best_epoch = best_epoch
                training_state.epochs_without_improve = epochs_without_improve
                
                # Save best checkpoint
                checkpoint_manager.save_checkpoint(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    training_state=training_state,
                    config=config_dict,
                    experiment_name=experiment_name,
                    notes=f"Best model at epoch {epoch}",
                    is_best=True
                )
                status = "‚úì Improved"
            else:
                epochs_without_improve += 1
                training_state.epochs_without_improve = epochs_without_improve
                status = f"No improvement ({epochs_without_improve}/{config.training.early_stopping.patience})"
            
            # Save periodic checkpoint
            if epoch % config.training.save_every_n_epochs == 0:
                checkpoint_manager.save_checkpoint(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    training_state=training_state,
                    config=config_dict,
                    experiment_name=experiment_name,
                    notes=f"Periodic checkpoint at epoch {epoch}",
                    is_best=False
                )
            
            # Update progress bar
            epoch_progress.set_postfix({
                "val_loss": f"{val_loss:.4f}",
                config.training.early_stopping.monitor: f"{monitor_metric:.4f}",
                "best": f"{best_metric:.4f}",
                "status": status
            })
            
            # Early stopping check
            if epochs_without_improve >= config.training.early_stopping.patience:
                tqdm.write(f"\nEarly stopping triggered at epoch {epoch}")
                tqdm.write(f"Best {config.training.early_stopping.monitor}: {best_metric:.4f} at epoch {best_epoch}")
                break
        
        epoch_progress.close()
        
        # Final evaluation on test set
        print("\nüß™ Final evaluation on test set...")
        
        # Load best model
        best_checkpoint_id = checkpoint_manager.find_best_checkpoint(experiment_name)
        if best_checkpoint_id:
            checkpoint_manager.load_checkpoint(
                best_checkpoint_id, model, optimizer, scheduler, device
            )
            print(f"Loaded best model from epoch {best_epoch}")
        
        # Use EMA for final evaluation if enabled
        if ema is not None:
            ema.apply_shadow()
        
        test_loss, test_metrics = evaluate(model, test_loader, device, config, head_thresholds)
        
        if ema is not None:
            ema.restore()
        
        # Log final metrics
        mlflow.log_metric("test_loss", test_loss)
        for name, value in test_metrics.items():
            if name != "val_loss":
                mlflow.log_metric(name.replace("val_", "test_"), value)
        
        mlflow.log_metric("best_metric", best_metric)
        mlflow.log_metric("best_epoch", best_epoch)
        
        results = {
            "best_metric": best_metric,
            "best_epoch": best_epoch,
            "test_metrics": test_metrics,
            "test_loss": test_loss,
        }
        
        print(f"\n‚úÖ Training completed!")
        print(f"   Best {config.training.early_stopping.monitor}: {best_metric:.4f} at epoch {best_epoch}")
        print(f"   Test loss: {test_loss:.4f}")
        
        return results

print("‚úÖ Enhanced training loop defined!")

## Training Execution

Execute the training with the selected configuration.

In [None]:
# Training execution cell
def start_training():
    """Start the training process."""
    
    # Check if configuration is loaded
    if 'current_config' not in globals() or current_config is None:
        print("‚ùå No configuration loaded. Please run the configuration selection cell first.")
        return
    
    if 'current_experiment_name' not in globals() or current_experiment_name is None:
        print("‚ùå No experiment name set. Please run the configuration selection cell first.")
        return
    
    try:
        # Validate configuration
        issues = config_manager.validate_config(current_config)
        if issues:
            print("‚ö†Ô∏è  Configuration issues found:")
            for issue in issues:
                print(f"   - {issue}")
            
            response = input("Continue anyway? (y/n): ")
            if response.lower() != 'y':
                print("Training cancelled.")
                return
        
        # Start training
        print(f"\nüöÄ Starting training: {current_experiment_name}")
        print("=" * 60)
        
        results = enhanced_train_loop(current_config, current_experiment_name)
        
        print("\n" + "=" * 60)
        print("üéâ Training completed successfully!")
        
        # Display results
        print(f"\nüìä Final Results:")
        for key, value in results.items():
            if isinstance(value, dict):
                print(f"   {key}:")
                for k, v in value.items():
                    print(f"     {k}: {v:.4f}")
            else:
                print(f"   {key}: {value}")
        
        # Display checkpoint summary
        print("\nüìÅ Checkpoint Summary:")
        checkpoints_df = checkpoint_manager.list_checkpoints(current_experiment_name)
        if not checkpoints_df.empty:
            display(checkpoints_df)
        
        return results
        
    except KeyboardInterrupt:
        print("\n‚èπÔ∏è  Training interrupted by user")
        print("   Checkpoints have been saved and training can be resumed")
    except Exception as e:
        print(f"\n‚ùå Training failed with error: {e}")
        import traceback
        traceback.print_exc()

# Create training button
train_button = widgets.Button(
    description='üöÄ Start Training',
    button_style='success',
    layout=widgets.Layout(width='200px', height='40px')
)

def on_train_clicked(b):
    start_training()

train_button.on_click(on_train_clicked)

print("Click the button below to start training:")
display(train_button)

## Training Monitoring and Visualization

Real-time monitoring and visualization of training progress.

In [None]:
def plot_training_history(experiment_name: str):
    """Plot training history from checkpoints."""
    
    # Get latest checkpoint for the experiment
    latest_checkpoint = checkpoint_manager.find_latest_checkpoint(experiment_name)
    
    if not latest_checkpoint:
        print(f"No checkpoints found for experiment: {experiment_name}")
        return
    
    try:
        # Load training history
        checkpoint_info = checkpoint_manager.get_checkpoint_info(latest_checkpoint)
        training_history = checkpoint_info['training_state']['training_history']
        
        if not training_history:
            print("No training history found")
            return
        
        # Convert to DataFrame
        df = pd.DataFrame(training_history)
        
        # Create subplots
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle(f'Training History: {experiment_name}', fontsize=16)
        
        # Loss plot
        axes[0, 0].plot(df['epoch'], df['train_loss'], label='Train Loss', color='blue')
        axes[0, 0].plot(df['epoch'], df['val_loss'], label='Val Loss', color='red')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # Learning rate plot
        axes[0, 1].plot(df['epoch'], df['learning_rate'], color='green')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Learning Rate')
        axes[0, 1].set_title('Learning Rate Schedule')
        axes[0, 1].set_yscale('log')
        axes[0, 1].grid(True)
        
        # Extract validation metrics
        val_metrics_keys = []
        if training_history:
            val_metrics_keys = list(training_history[0]['val_metrics'].keys())
        
        # Plot main validation metric
        if val_metrics_keys:
            main_metric = val_metrics_keys[0]  # Use first metric
            metric_values = [epoch['val_metrics'][main_metric] for epoch in training_history]
            axes[1, 0].plot(df['epoch'], metric_values, color='purple')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel(main_metric)
            axes[1, 0].set_title(f'Validation {main_metric}')
            axes[1, 0].grid(True)
        
        # Plot multiple metrics if available
        if len(val_metrics_keys) > 1:
            for i, metric in enumerate(val_metrics_keys[:4]):  # Plot up to 4 metrics
                metric_values = [epoch['val_metrics'][metric] for epoch in training_history]
                axes[1, 1].plot(df['epoch'], metric_values, label=metric)
            
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Metric Value')
            axes[1, 1].set_title('Validation Metrics')
            axes[1, 1].legend()
            axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.show()
        
        # Display summary statistics
        print(f"\nüìä Training Summary for {experiment_name}:")
        print(f"   Total epochs: {len(training_history)}")
        print(f"   Final train loss: {training_history[-1]['train_loss']:.4f}")
        print(f"   Final val loss: {training_history[-1]['val_loss']:.4f}")
        
        if val_metrics_keys:
            print(f"   Final validation metrics:")
            for metric, value in training_history[-1]['val_metrics'].items():
                print(f"     {metric}: {value:.4f}")
        
    except Exception as e:
        print(f"Error plotting training history: {e}")

def create_monitoring_dashboard():
    """Create an interactive monitoring dashboard."""
    
    # Get available experiments
    checkpoints_df = checkpoint_manager.list_checkpoints()
    experiments = checkpoints_df['experiment'].unique().tolist() if not checkpoints_df.empty else []
    
    if not experiments:
        print("No experiments found. Start training first.")
        return
    
    experiment_dropdown = widgets.Dropdown(
        options=experiments,
        value=experiments[0],
        description='Experiment:'
    )
    
    plot_button = widgets.Button(
        description='üìä Plot History',
        button_style='info'
    )
    
    checkpoints_button = widgets.Button(
        description='üìÅ Show Checkpoints',
        button_style='info'
    )
    
    output = widgets.Output()
    
    def on_plot_clicked(b):
        with output:
            output.clear_output()
            plot_training_history(experiment_dropdown.value)
    
    def on_checkpoints_clicked(b):
        with output:
            output.clear_output()
            exp_checkpoints = checkpoint_manager.list_checkpoints(experiment_dropdown.value)
            if not exp_checkpoints.empty:
                print(f"Checkpoints for {experiment_dropdown.value}:")
                display(exp_checkpoints)
            else:
                print(f"No checkpoints found for {experiment_dropdown.value}")
    
    plot_button.on_click(on_plot_clicked)
    checkpoints_button.on_click(on_checkpoints_clicked)
    
    dashboard = widgets.VBox([
        widgets.HTML("<h3>Training Monitoring Dashboard</h3>"),
        experiment_dropdown,
        widgets.HBox([plot_button, checkpoints_button]),
        output
    ])
    
    return dashboard

# Display monitoring dashboard
print("Training Monitoring Dashboard:")
monitoring_dashboard = create_monitoring_dashboard()
if monitoring_dashboard:
    display(monitoring_dashboard)

## Training Utilities

Additional utilities for training management.

In [None]:
def resume_training_from_checkpoint(checkpoint_id: str):
    """Resume training from a specific checkpoint."""
    
    try:
        # Get checkpoint info
        checkpoint_info = checkpoint_manager.get_checkpoint_info(checkpoint_id)
        experiment_name = checkpoint_info['metadata']['experiment_name']
        
        print(f"üîÑ Resuming training from checkpoint: {checkpoint_id}")
        print(f"   Experiment: {experiment_name}")
        print(f"   Epoch: {checkpoint_info['metadata']['epoch']}")
        
        # Load configuration from checkpoint
        config_dict = checkpoint_info['training_state']  # This should contain the config
        
        # Convert to ExperimentConfig (simplified)
        # In practice, you might want to reconstruct this more carefully
        global current_config, current_experiment_name
        current_experiment_name = experiment_name
        
        # Start training
        results = enhanced_train_loop(current_config, current_experiment_name)
        
        print("‚úÖ Training resumed and completed successfully!")
        return results
        
    except Exception as e:
        print(f"‚ùå Error resuming training: {e}")
        return None

def compare_experiments(experiment_names: List[str]):
    """Compare multiple experiments."""
    
    comparison_data = []
    
    for exp_name in experiment_names:
        latest_checkpoint = checkpoint_manager.find_latest_checkpoint(exp_name)
        if latest_checkpoint:
            try:
                checkpoint_info = checkpoint_manager.get_checkpoint_info(latest_checkpoint)
                metadata = checkpoint_info['metadata']
                
                comparison_data.append({
                    'experiment': exp_name,
                    'model_type': metadata['model_type'],
                    'epochs': metadata['epoch'],
                    'best_metric': metadata['best_metric'],
                    'best_epoch': metadata['best_epoch'],
                    'created_at': metadata['created_at']
                })
            except Exception as e:
                print(f"Error loading experiment {exp_name}: {e}")
    
    if comparison_data:
        df = pd.DataFrame(comparison_data)
        df = df.sort_values('best_metric', ascending=False)
        
        print("üèÜ Experiment Comparison:")
        display(df)
        
        return df
    else:
        print("No valid experiments found for comparison")
        return pd.DataFrame()

def cleanup_experiments(keep_best_n: int = 3):
    """Clean up old experiments, keeping only the best N."""
    
    checkpoints_df = checkpoint_manager.list_checkpoints()
    
    if checkpoints_df.empty:
        print("No checkpoints to clean up")
        return
    
    experiments = checkpoints_df['experiment'].unique()
    
    for experiment in experiments:
        exp_checkpoints = checkpoints_df[checkpoints_df['experiment'] == experiment]
        exp_checkpoints = exp_checkpoints.sort_values('best_metric', ascending=False)
        
        if len(exp_checkpoints) > keep_best_n:
            to_delete = exp_checkpoints.iloc[keep_best_n:]
            
            print(f"\nüßπ Cleaning up experiment: {experiment}")
            for _, checkpoint in to_delete.iterrows():
                checkpoint_manager.delete_checkpoint(checkpoint['checkpoint_id'])
                print(f"   Deleted: {checkpoint['checkpoint_id']}")

print("\n‚úÖ Training notebook setup complete!")
print("\nTo start training:")
print("1. Select a configuration using the configuration selector above")
print("2. Click the 'Start Training' button")
print("3. Monitor progress using the monitoring dashboard")
print("\nTraining will automatically resume from checkpoints if interrupted.")