# Model Evaluation and Analysis

This notebook provides comprehensive model evaluation, analysis, and comparison capabilities.
It includes performance metrics, error analysis, model interpretation, and comparison tools.

## Setup and Imports

Import all necessary libraries for model evaluation and analysis.

In [None]:
import os
import sys
import json
import warnings
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional
from collections import defaultdict

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm.auto import tqdm
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets

# ML and evaluation libraries
import torch
import torch.nn.functional as F
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_auc_score, 
    precision_recall_curve, roc_curve, average_precision_score,
    multilabel_confusion_matrix, hamming_loss, jaccard_score
)
from sklearn.preprocessing import label_binarize
import mlflow
import mlflow.pytorch

# Suppress warnings
warnings.filterwarnings('ignore')

# Add src to path
if 'src' not in sys.path:
    sys.path.append('src')

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

print("‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Load Dependencies

Load configuration and model components.

In [None]:
# Load configuration and checkpoint managers
%run 01_Configuration_Management.ipynb
%run 02_Enhanced_Checkpoint_System.ipynb

# Import project modules
from src.models.model import EvidenceModel
from src.data.dataset import DataModule
from src.utils import evaluate, prepare_thresholds
from src.utils.metrics import compute_metrics

print("‚úÖ Dependencies loaded!")

## Model Loading and Configuration

Load trained models and configure evaluation settings.

In [None]:
def create_model_selector():
    """Create an interactive model selector for evaluation."""
    
    # Get available checkpoints
    checkpoints_df = checkpoint_manager.list_checkpoints()
    
    if checkpoints_df.empty:
        print("No trained models found. Please train a model first.")
        return None
    
    # Group by experiment
    experiments = checkpoints_df['experiment'].unique().tolist()
    
    experiment_dropdown = widgets.Dropdown(
        options=experiments,
        value=experiments[0] if experiments else None,
        description='Experiment:'
    )
    
    checkpoint_dropdown = widgets.Dropdown(
        options=[],
        description='Checkpoint:'
    )
    
    use_best = widgets.Checkbox(
        value=True,
        description='Use Best Checkpoint'
    )
    
    device_dropdown = widgets.Dropdown(
        options=['auto', 'cpu', 'cuda'],
        value='auto',
        description='Device:'
    )
    
    load_button = widgets.Button(
        description='Load Model',
        button_style='success'
    )
    
    output = widgets.Output()
    
    def update_checkpoints(change):
        """Update checkpoint dropdown based on selected experiment."""
        experiment = change['new']
        exp_checkpoints = checkpoints_df[checkpoints_df['experiment'] == experiment]
        checkpoint_options = [(f"{row['checkpoint_id']} (epoch {row['epoch']}, metric: {row['best_metric']:.4f})", 
                              row['checkpoint_id']) for _, row in exp_checkpoints.iterrows()]
        checkpoint_dropdown.options = checkpoint_options
        if checkpoint_options:
            checkpoint_dropdown.value = checkpoint_options[0][1]
    
    experiment_dropdown.observe(update_checkpoints, names='value')
    
    # Initialize checkpoints
    if experiments:
        update_checkpoints({'new': experiments[0]})
    
    def on_load_clicked(b):
        with output:
            output.clear_output()
            
            try:
                experiment = experiment_dropdown.value
                
                # Determine checkpoint to use
                if use_best.value:
                    checkpoint_id = checkpoint_manager.find_best_checkpoint(experiment)
                    print(f"üèÜ Using best checkpoint for {experiment}")
                else:
                    checkpoint_id = checkpoint_dropdown.value
                    print(f"üìã Using selected checkpoint: {checkpoint_id}")
                
                if not checkpoint_id:
                    print("‚ùå No checkpoint found")
                    return
                
                # Determine device
                if device_dropdown.value == 'auto':
                    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                else:
                    device = torch.device(device_dropdown.value)
                
                print(f"   Device: {device}")
                
                # Load checkpoint info
                checkpoint_info = checkpoint_manager.get_checkpoint_info(checkpoint_id)
                print(f"   Checkpoint: {checkpoint_id}")
                print(f"   Epoch: {checkpoint_info['metadata']['epoch']}")
                print(f"   Best metric: {checkpoint_info['metadata']['best_metric']:.4f}")
                
                # Store in global variables
                global current_checkpoint_id, current_device, current_experiment
                current_checkpoint_id = checkpoint_id
                current_device = device
                current_experiment = experiment
                
                print(f"\n‚úÖ Model configuration ready!")
                
            except Exception as e:
                print(f"‚ùå Error loading model: {e}")
                import traceback
                traceback.print_exc()
    
    load_button.on_click(on_load_clicked)
    
    layout = widgets.VBox([
        widgets.HTML("<h3>Model Selection for Evaluation</h3>"),
        experiment_dropdown,
        widgets.HBox([use_best, device_dropdown]),
        checkpoint_dropdown,
        load_button,
        output
    ])
    
    return layout

# Display model selector
model_selector = create_model_selector()
if model_selector:
    display(model_selector)
else:
    print("‚ö†Ô∏è  No trained models available for evaluation")

## Model Evaluation Pipeline

Comprehensive model evaluation with detailed metrics.

In [None]:
def evaluate_model_comprehensive(checkpoint_id: str, device: torch.device, 
                                split: str = 'test') -> Dict[str, Any]:
    """Perform comprehensive model evaluation."""
    
    print(f"üß™ Comprehensive Model Evaluation")
    print(f"   Checkpoint: {checkpoint_id}")
    print(f"   Split: {split}")
    print(f"   Device: {device}")
    print("=" * 50)
    
    try:
        # Load checkpoint info and configuration
        checkpoint_info = checkpoint_manager.get_checkpoint_info(checkpoint_id)
        
        # Load configuration from checkpoint
        config_path = checkpoint_manager.checkpoint_dir / checkpoint_id / "config.json"
        with open(config_path, 'r') as f:
            config_dict = json.load(f)
        
        # Reconstruct configuration object
        config = ExperimentConfig(**config_dict)
        
        print(f"üìä Setting up data...")
        
        # Setup data module
        data_module = DataModule(config.data, config.model)
        
        # Get data loaders
        train_loader, val_loader, test_loader = data_module.dataloaders(
            batch_size=config.training.val_batch_size,
            val_batch_size=config.training.val_batch_size,
            test_batch_size=config.training.test_batch_size,
            num_workers=config.training.num_workers
        )
        
        # Select appropriate data loader
        if split == 'train':
            data_loader = train_loader
        elif split == 'val':
            data_loader = val_loader
        else:
            data_loader = test_loader
        
        print(f"   {split.capitalize()} samples: {len(data_loader.dataset)}")
        
        print(f"\nü§ñ Loading model...")
        
        # Create and load model
        model = EvidenceModel(config.model)
        model.to(device)
        
        # Create dummy optimizer and scheduler for checkpoint loading
        from src.utils import get_optimizer, get_scheduler
        optimizer = get_optimizer(model, config.training.optimizer)
        scheduler = get_scheduler(optimizer, config.training.scheduler, 1000)
        
        # Load checkpoint
        training_state, _ = checkpoint_manager.load_checkpoint(
            checkpoint_id, model, optimizer, scheduler, device
        )
        
        print(f"   Model loaded from epoch {training_state.epoch}")
        print(f"   Best metric: {training_state.best_metric:.4f}")
        
        # Prepare thresholds
        head_configs = model.head_configs
        head_thresholds = {
            head_name: prepare_thresholds(head_cfg)
            for head_name, head_cfg in head_configs.items()
            if head_cfg.get("type") == "multi_label"
        }
        
        print(f"\nüîç Running evaluation...")
        
        # Run evaluation
        model.eval()
        all_predictions = []
        all_labels = []
        all_logits = []
        
        with torch.no_grad():
            for batch in tqdm(data_loader, desc="Evaluating"):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                token_type_ids = batch.get("token_type_ids")
                if token_type_ids is not None:
                    token_type_ids = token_type_ids.to(device)
                
                # Forward pass
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids
                )
                
                # Extract predictions and labels
                head_outputs = outputs["head_outputs"]
                
                for head_name, head_output in head_outputs.items():
                    if head_name in batch:
                        logits = head_output.cpu()
                        labels = batch[head_name].cpu()
                        
                        all_logits.append(logits)
                        all_labels.append(labels)
                        
                        # Apply thresholds for predictions
                        if head_name in head_thresholds:
                            thresholds = head_thresholds[head_name]
                            probs = torch.sigmoid(logits)
                            predictions = (probs > thresholds).float()
                        else:
                            predictions = (torch.sigmoid(logits) > 0.5).float()
                        
                        all_predictions.append(predictions)
        
        # Concatenate all results
        all_predictions = torch.cat(all_predictions, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        all_logits = torch.cat(all_logits, dim=0)
        all_probs = torch.sigmoid(all_logits)
        
        print(f"   Predictions shape: {all_predictions.shape}")
        print(f"   Labels shape: {all_labels.shape}")
        
        # Compute comprehensive metrics
        print(f"\nüìà Computing metrics...")
        
        # Convert to numpy for sklearn metrics
        y_true = all_labels.numpy()
        y_pred = all_predictions.numpy()
        y_prob = all_probs.numpy()
        
        # Overall metrics
        hamming = hamming_loss(y_true, y_pred)
        jaccard = jaccard_score(y_true, y_pred, average='samples')
        
        # Per-class metrics
        n_classes = y_true.shape[1]
        class_names = [f"Class_{i}" for i in range(n_classes)]  # Replace with actual class names if available
        
        per_class_metrics = {}
        for i in range(n_classes):
            class_true = y_true[:, i]
            class_pred = y_pred[:, i]
            class_prob = y_prob[:, i]
            
            # Skip if no positive examples
            if class_true.sum() == 0:
                continue
            
            try:
                auc = roc_auc_score(class_true, class_prob)
                ap = average_precision_score(class_true, class_prob)
            except ValueError:
                auc = 0.0
                ap = 0.0
            
            per_class_metrics[class_names[i]] = {
                'auc': auc,
                'average_precision': ap,
                'support': int(class_true.sum())
            }
        
        # Compute standard metrics using the project's metric function
        standard_metrics = compute_metrics(y_pred, y_true)
        
        results = {
            'checkpoint_id': checkpoint_id,
            'split': split,
            'n_samples': len(y_true),
            'n_classes': n_classes,
            'hamming_loss': hamming,
            'jaccard_score': jaccard,
            'standard_metrics': standard_metrics,
            'per_class_metrics': per_class_metrics,
            'predictions': y_pred,
            'labels': y_true,
            'probabilities': y_prob,
            'class_names': class_names
        }
        
        print(f"\n‚úÖ Evaluation complete!")
        print(f"   Hamming Loss: {hamming:.4f}")
        print(f"   Jaccard Score: {jaccard:.4f}")
        
        return results
        
    except Exception as e:
        print(f"‚ùå Error during evaluation: {e}")
        import traceback
        traceback.print_exc()
        return None

print("‚úÖ Evaluation pipeline ready!")

## Model Evaluation Execution

Execute model evaluation with interactive controls.

In [None]:
def create_evaluation_interface():
    """Create interactive evaluation interface."""
    
    split_dropdown = widgets.Dropdown(
        options=['test', 'val', 'train'],
        value='test',
        description='Data Split:'
    )
    
    evaluate_button = widgets.Button(
        description='üß™ Evaluate Model',
        button_style='primary'
    )
    
    output = widgets.Output()
    
    def on_evaluate_clicked(b):
        with output:
            output.clear_output()
            
            # Check if model is loaded
            if 'current_checkpoint_id' not in globals():
                print("‚ùå No model loaded. Please load a model first.")
                return
            
            try:
                # Run evaluation
                results = evaluate_model_comprehensive(
                    current_checkpoint_id, 
                    current_device, 
                    split_dropdown.value
                )
                
                if results:
                    # Store results globally for visualization
                    global current_evaluation_results
                    current_evaluation_results = results
                    
                    print(f"\nüìä Evaluation Results Summary:")
                    print(f"   Dataset: {results['split']} ({results['n_samples']} samples)")
                    print(f"   Classes: {results['n_classes']}")
                    print(f"   Hamming Loss: {results['hamming_loss']:.4f}")
                    print(f"   Jaccard Score: {results['jaccard_score']:.4f}")
                    
                    # Display standard metrics
                    if 'standard_metrics' in results:
                        std_metrics = results['standard_metrics']
                        print(f"\nüìà Standard Metrics:")
                        for metric, value in std_metrics.items():
                            if isinstance(value, (int, float)):
                                print(f"   {metric}: {value:.4f}")
                    
                    print(f"\n‚úÖ Use the visualization cells below to analyze results in detail.")
                
            except Exception as e:
                print(f"‚ùå Evaluation failed: {e}")
                import traceback
                traceback.print_exc()
    
    evaluate_button.on_click(on_evaluate_clicked)
    
    layout = widgets.VBox([
        widgets.HTML("<h3>Model Evaluation</h3>"),
        split_dropdown,
        evaluate_button,
        output
    ])
    
    return layout

# Display evaluation interface
evaluation_interface = create_evaluation_interface()
display(evaluation_interface)

## Performance Visualization

Visualize model performance with comprehensive plots.

In [None]:
def plot_performance_metrics():
    """Plot comprehensive performance metrics."""
    
    if 'current_evaluation_results' not in globals():
        print("‚ùå No evaluation results available. Please run evaluation first.")
        return
    
    results = current_evaluation_results
    
    print(f"üìä Performance Visualization")
    print(f"   Model: {results['checkpoint_id']}")
    print(f"   Split: {results['split']}")
    print("=" * 50)
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle(f'Model Performance Analysis - {results["split"].title()} Set', fontsize=16)
    
    # 1. Per-class AUC scores
    if results['per_class_metrics']:
        class_names = list(results['per_class_metrics'].keys())
        auc_scores = [results['per_class_metrics'][cls]['auc'] for cls in class_names]
        
        axes[0, 0].bar(range(len(class_names)), auc_scores, color='skyblue')
        axes[0, 0].set_xlabel('Classes')
        axes[0, 0].set_ylabel('AUC Score')
        axes[0, 0].set_title('Per-Class AUC Scores')
        axes[0, 0].set_xticks(range(len(class_names)))
        axes[0, 0].set_xticklabels(class_names, rotation=45, ha='right')
        axes[0, 0].axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='Random')
        axes[0, 0].legend()
    else:
        axes[0, 0].text(0.5, 0.5, 'No per-class metrics\navailable', 
                        ha='center', va='center', transform=axes[0, 0].transAxes)
        axes[0, 0].set_title('Per-Class AUC Scores')
    
    # 2. Per-class Average Precision
    if results['per_class_metrics']:
        ap_scores = [results['per_class_metrics'][cls]['average_precision'] for cls in class_names]
        
        axes[0, 1].bar(range(len(class_names)), ap_scores, color='lightcoral')
        axes[0, 1].set_xlabel('Classes')
        axes[0, 1].set_ylabel('Average Precision')
        axes[0, 1].set_title('Per-Class Average Precision')
        axes[0, 1].set_xticks(range(len(class_names)))
        axes[0, 1].set_xticklabels(class_names, rotation=45, ha='right')
    else:
        axes[0, 1].text(0.5, 0.5, 'No per-class metrics\navailable', 
                        ha='center', va='center', transform=axes[0, 1].transAxes)
        axes[0, 1].set_title('Per-Class Average Precision')
    
    # 3. Class support (number of positive examples)
    if results['per_class_metrics']:
        support = [results['per_class_metrics'][cls]['support'] for cls in class_names]
        
        axes[1, 0].bar(range(len(class_names)), support, color='lightgreen')
        axes[1, 0].set_xlabel('Classes')
        axes[1, 0].set_ylabel('Number of Positive Examples')
        axes[1, 0].set_title('Class Support Distribution')
        axes[1, 0].set_xticks(range(len(class_names)))
        axes[1, 0].set_xticklabels(class_names, rotation=45, ha='right')
    else:
        axes[1, 0].text(0.5, 0.5, 'No per-class metrics\navailable', 
                        ha='center', va='center', transform=axes[1, 0].transAxes)
        axes[1, 0].set_title('Class Support Distribution')
    
    # 4. Overall metrics summary
    metrics_text = f"""Overall Performance Metrics:
    
Samples: {results['n_samples']:,}
Classes: {results['n_classes']}
    
Hamming Loss: {results['hamming_loss']:.4f}
Jaccard Score: {results['jaccard_score']:.4f}
"""
    
    if 'standard_metrics' in results:
        std_metrics = results['standard_metrics']
        for metric, value in std_metrics.items():
            if isinstance(value, (int, float)):
                metrics_text += f"{metric}: {value:.4f}\n"
    
    axes[1, 1].text(0.1, 0.9, metrics_text, transform=axes[1, 1].transAxes, 
                    fontsize=10, verticalalignment='top', fontfamily='monospace')
    axes[1, 1].set_xlim(0, 1)
    axes[1, 1].set_ylim(0, 1)
    axes[1, 1].axis('off')
    axes[1, 1].set_title('Performance Summary')
    
    plt.tight_layout()
    plt.show()
    
    # Additional detailed metrics table
    if results['per_class_metrics']:
        print(f"\nüìã Detailed Per-Class Metrics:")
        
        metrics_df = pd.DataFrame(results['per_class_metrics']).T
        metrics_df = metrics_df.round(4)
        metrics_df = metrics_df.sort_values('auc', ascending=False)
        
        display(metrics_df)

def plot_confusion_matrices():
    """Plot confusion matrices for each class."""
    
    if 'current_evaluation_results' not in globals():
        print("‚ùå No evaluation results available. Please run evaluation first.")
        return
    
    results = current_evaluation_results
    y_true = results['labels']
    y_pred = results['predictions']
    class_names = results['class_names']
    
    print(f"üîç Confusion Matrix Analysis")
    print("=" * 30)
    
    # Compute multilabel confusion matrices
    cm_multilabel = multilabel_confusion_matrix(y_true, y_pred)
    
    # Plot confusion matrices for each class
    n_classes = len(class_names)
    n_cols = min(4, n_classes)
    n_rows = (n_classes + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3*n_rows))
    if n_rows == 1 and n_cols == 1:
        axes = [axes]
    elif n_rows == 1 or n_cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    for i, (class_name, cm) in enumerate(zip(class_names, cm_multilabel)):
        if i < len(axes):
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                       xticklabels=['Negative', 'Positive'],
                       yticklabels=['Negative', 'Positive'],
                       ax=axes[i])
            axes[i].set_title(f'{class_name}')
            axes[i].set_xlabel('Predicted')
            axes[i].set_ylabel('Actual')
    
    # Hide unused subplots
    for i in range(n_classes, len(axes)):
        axes[i].set_visible(False)
    
    plt.suptitle('Per-Class Confusion Matrices', fontsize=16)
    plt.tight_layout()
    plt.show()

# Create visualization buttons
plot_metrics_button = widgets.Button(
    description='üìä Plot Performance Metrics',
    button_style='info'
)

plot_confusion_button = widgets.Button(
    description='üîç Plot Confusion Matrices',
    button_style='info'
)

def on_plot_metrics_clicked(b):
    plot_performance_metrics()

def on_plot_confusion_clicked(b):
    plot_confusion_matrices()

plot_metrics_button.on_click(on_plot_metrics_clicked)
plot_confusion_button.on_click(on_plot_confusion_clicked)

print("üìä Visualization Tools:")
display(widgets.HBox([plot_metrics_button, plot_confusion_button]))

## Model Comparison

Compare multiple models and experiments.

In [None]:
def create_model_comparison_tool():
    """Create tool for comparing multiple models."""
    
    # Get available experiments
    checkpoints_df = checkpoint_manager.list_checkpoints()
    
    if checkpoints_df.empty:
        print("No models available for comparison.")
        return
    
    experiments = checkpoints_df['experiment'].unique().tolist()
    
    experiment_selector = widgets.SelectMultiple(
        options=experiments,
        value=[experiments[0]] if experiments else [],
        description='Experiments:',
        rows=min(10, len(experiments))
    )
    
    split_dropdown = widgets.Dropdown(
        options=['test', 'val'],
        value='test',
        description='Data Split:'
    )
    
    compare_button = widgets.Button(
        description='üîÑ Compare Models',
        button_style='warning'
    )
    
    output = widgets.Output()
    
    def on_compare_clicked(b):
        with output:
            output.clear_output()
            
            if not experiment_selector.value:
                print("‚ùå Please select at least one experiment to compare.")
                return
            
            print(f"üîÑ Comparing {len(experiment_selector.value)} experiments...")
            print("=" * 50)
            
            comparison_results = []
            
            for experiment in experiment_selector.value:
                try:
                    print(f"\nüìä Evaluating {experiment}...")
                    
                    # Get best checkpoint for experiment
                    checkpoint_id = checkpoint_manager.find_best_checkpoint(experiment)
                    
                    if not checkpoint_id:
                        print(f"   ‚ùå No checkpoint found for {experiment}")
                        continue
                    
                    # Determine device
                    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                    
                    # Run evaluation
                    results = evaluate_model_comprehensive(
                        checkpoint_id, device, split_dropdown.value
                    )
                    
                    if results:
                        comparison_results.append({
                            'experiment': experiment,
                            'checkpoint_id': checkpoint_id,
                            'hamming_loss': results['hamming_loss'],
                            'jaccard_score': results['jaccard_score'],
                            'n_samples': results['n_samples'],
                            'standard_metrics': results.get('standard_metrics', {})
                        })
                        
                        print(f"   ‚úÖ {experiment}: Hamming={results['hamming_loss']:.4f}, Jaccard={results['jaccard_score']:.4f}")
                    
                except Exception as e:
                    print(f"   ‚ùå Error evaluating {experiment}: {e}")
            
            if comparison_results:
                print(f"\nüìä Comparison Results:")
                print("=" * 50)
                
                # Create comparison DataFrame
                comparison_data = []
                for result in comparison_results:
                    row = {
                        'Experiment': result['experiment'],
                        'Checkpoint': result['checkpoint_id'][:12] + '...',
                        'Hamming Loss': result['hamming_loss'],
                        'Jaccard Score': result['jaccard_score'],
                        'Samples': result['n_samples']
                    }
                    
                    # Add standard metrics
                    for metric, value in result['standard_metrics'].items():
                        if isinstance(value, (int, float)):
                            row[metric] = value
                    
                    comparison_data.append(row)
                
                comparison_df = pd.DataFrame(comparison_data)
                comparison_df = comparison_df.round(4)
                
                # Sort by Jaccard score (higher is better)
                comparison_df = comparison_df.sort_values('Jaccard Score', ascending=False)
                
                display(comparison_df)
                
                # Plot comparison
                if len(comparison_results) > 1:
                    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
                    
                    experiments = [r['experiment'] for r in comparison_results]
                    hamming_losses = [r['hamming_loss'] for r in comparison_results]
                    jaccard_scores = [r['jaccard_score'] for r in comparison_results]
                    
                    # Hamming Loss (lower is better)
                    axes[0].bar(experiments, hamming_losses, color='lightcoral')
                    axes[0].set_title('Hamming Loss (Lower is Better)')
                    axes[0].set_ylabel('Hamming Loss')
                    axes[0].tick_params(axis='x', rotation=45)
                    
                    # Jaccard Score (higher is better)
                    axes[1].bar(experiments, jaccard_scores, color='lightgreen')
                    axes[1].set_title('Jaccard Score (Higher is Better)')
                    axes[1].set_ylabel('Jaccard Score')
                    axes[1].tick_params(axis='x', rotation=45)
                    
                    plt.tight_layout()
                    plt.show()
                
                # Store comparison results globally
                global current_comparison_results
                current_comparison_results = comparison_results
                
            else:
                print("‚ùå No successful evaluations completed.")
    
    compare_button.on_click(on_compare_clicked)
    
    layout = widgets.VBox([
        widgets.HTML("<h3>Model Comparison</h3>"),
        experiment_selector,
        split_dropdown,
        compare_button,
        output
    ])
    
    return layout

# Display model comparison tool
print("üîÑ Model Comparison Tool:")
comparison_tool = create_model_comparison_tool()
if comparison_tool:
    display(comparison_tool)

## Summary and Export

Generate evaluation summary and export results.

In [None]:
def generate_evaluation_report():
    """Generate a comprehensive evaluation report."""
    
    if 'current_evaluation_results' not in globals():
        print("‚ùå No evaluation results available. Please run evaluation first.")
        return
    
    results = current_evaluation_results
    
    print(f"üìã Evaluation Report")
    print("=" * 50)
    
    print(f"\nü§ñ Model Information:")
    print(f"   Checkpoint ID: {results['checkpoint_id']}")
    print(f"   Evaluation Split: {results['split']}")
    print(f"   Number of Samples: {results['n_samples']:,}")
    print(f"   Number of Classes: {results['n_classes']}")
    
    print(f"\nüìä Overall Performance:")
    print(f"   Hamming Loss: {results['hamming_loss']:.4f}")
    print(f"   Jaccard Score: {results['jaccard_score']:.4f}")
    
    if 'standard_metrics' in results:
        print(f"\nüìà Standard Metrics:")
        for metric, value in results['standard_metrics'].items():
            if isinstance(value, (int, float)):
                print(f"   {metric}: {value:.4f}")
    
    if results['per_class_metrics']:
        print(f"\nüè∑Ô∏è  Per-Class Performance:")
        
        # Calculate summary statistics
        auc_scores = [metrics['auc'] for metrics in results['per_class_metrics'].values()]
        ap_scores = [metrics['average_precision'] for metrics in results['per_class_metrics'].values()]
        
        print(f"   Mean AUC: {np.mean(auc_scores):.4f} (¬±{np.std(auc_scores):.4f})")
        print(f"   Mean AP: {np.mean(ap_scores):.4f} (¬±{np.std(ap_scores):.4f})")
        
        # Best and worst performing classes
        class_performance = [(name, metrics['auc']) for name, metrics in results['per_class_metrics'].items()]
        class_performance.sort(key=lambda x: x[1], reverse=True)
        
        print(f"\n   Best performing classes (by AUC):")
        for name, auc in class_performance[:3]:
            print(f"     {name}: {auc:.4f}")
        
        print(f"\n   Worst performing classes (by AUC):")
        for name, auc in class_performance[-3:]:
            print(f"     {name}: {auc:.4f}")
    
    print(f"\n‚úÖ Evaluation report complete!")
    
    return results

# Create summary button
report_button = widgets.Button(
    description='üìã Generate Report',
    button_style='success'
)

def on_report_clicked(b):
    generate_evaluation_report()

report_button.on_click(on_report_clicked)

print("\nüìã Evaluation Summary:")
display(report_button)

print("\n‚úÖ Model Evaluation notebook complete!")
print("\nThis notebook provides:")
print("‚Ä¢ Interactive model loading and selection")
print("‚Ä¢ Comprehensive model evaluation with detailed metrics")
print("‚Ä¢ Performance visualization and analysis")
print("‚Ä¢ Model comparison across experiments")
print("‚Ä¢ Evaluation report generation")