# Model Performance Evaluation
Detailed analysis of model performance metrics and cross-validation results.

In [None]:
import sys
sys.path.append('../..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import mlflow
from sklearn.metrics import confusion_matrix
from src.utils.metrics import MetricsCalculator

plt.style.use('seaborn')
sns.set_context('talk')

In [None]:
def load_mlflow_results(experiment_name):
    """Load results from MLflow tracking."""
    client = mlflow.tracking.MlflowClient()
    experiment = client.get_experiment_by_name(experiment_name)
    
    runs = client.search_runs(
        experiment_ids=[experiment.experiment_id],
        order_by=['metrics.f1_score DESC']
    )
    
    results = []
    for run in runs:
        results.append({
            'run_id': run.info.run_id,
            'model_type': run.data.params.get('model_type'),
            **run.data.metrics
        })
    
    return pd.DataFrame(results)

results_df = load_mlflow_results('eeg_classification')

In [None]:
def plot_metrics_comparison(results):
    """Plot comparison of key metrics across models."""
    metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc']
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.ravel()
    
    for idx, metric in enumerate(metrics):
        sns.boxplot(data=results, x='model_type', y=metric, ax=axes[idx])
        axes[idx].set_xticklabels(axes[idx].get_xticklabels(), rotation=45)
        axes[idx].set_title(f'{metric.replace("_", " ").title()}')
    
    plt.tight_layout()
    plt.show()

plot_metrics_comparison(results_df)

In [None]:
def plot_confusion_matrices(predictions, model_types):
    """Plot confusion matrices for different models."""
    n_models = len(model_types)
    fig, axes = plt.subplots(1, n_models, figsize=(5*n_models, 4))
    
    for ax, model in zip(axes, model_types):
        model_preds = predictions[predictions['model_type'] == model]
        cm = confusion_matrix(model_preds['true_label'], 
                            model_preds['predicted_label'])
        
        sns.heatmap(cm, annot=True, fmt='d', ax=ax, cmap='Blues')
        ax.set_title(f'{model}')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
    
    plt.tight_layout()
    plt.show()

In [None]:
def plot_learning_curves(cv_results):
    """Plot learning curves from cross-validation results."""
    plt.figure(figsize=(10, 6))
    
    for model_type in cv_results['model_type'].unique():
        model_results = cv_results[cv_results['model_type'] == model_type]
        
        plt.plot(model_results['train_size'], 
                model_results['train_score'], 
                label=f'{model_type} (train)')
        plt.plot(model_results['train_size'], 
                model_results['val_score'], 
                label=f'{model_type} (val)')
    
    plt.xlabel('Training Examples')
    plt.ylabel('Score')
    plt.title('Learning Curves')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
def analyze_prediction_errors(predictions):
    """Analyze characteristics of prediction errors."""
    errors = predictions[predictions['true_label'] != predictions['predicted_label']]
    
    # Error rate by participant
    participant_errors = errors.groupby('Participant').size()
    participant_total = predictions.groupby('Participant').size()
    error_rates = (participant_errors / participant_total).sort_values(ascending=False)
    
    plt.figure(figsize=(12, 6))
    error_rates.plot(kind='bar')
    plt.title('Error Rate by Participant')
    plt.xlabel('Participant')
    plt.ylabel('Error Rate')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # Analyze feature values for errors
    feature_cols = [col for col in predictions.columns 
                   if col not in ['Participant', 'true_label', 'predicted_label']]
    
    for feature in feature_cols:
        plt.figure(figsize=(10, 6))
        sns.boxplot(data=predictions, x='true_label', y=feature, 
                   hue='predicted_label')
        plt.title(f'{feature} Distribution for Correct/Incorrect Predictions')
        plt.show()

In [None]:
def analyze_threshold_sensitivity(predictions):
    """Analyze model performance across different probability thresholds."""
    thresholds = np.linspace(0.1, 0.9, 9)
    metrics_calculator = MetricsCalculator()
    
    results = []
    for threshold in thresholds:
        pred_labels = (predictions['probability'] >= threshold).astype(int)
        metrics = metrics_calculator.calculate_classification_metrics(
            predictions['true_label'],
            pred_labels
        )
        results.append({'threshold': threshold, **metrics})
    
    results_df = pd.DataFrame(results)
    
    plt.figure(figsize=(12, 6))
    for metric in ['accuracy', 'precision', 'recall', 'f1_score']:
        plt.plot(results_df['threshold'], 
                results_df[metric], 
                label=metric)
    
    plt.xlabel('Classification Threshold')
    plt.ylabel('Score')
    plt.title('Metric Scores vs Classification Threshold')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    return results_df

In [None]:
def analyze_confidence_distribution(predictions):
    """Analyze distribution of model confidence scores."""
    plt.figure(figsize=(12, 6))
    
    # Correct predictions
    correct_mask = predictions['true_label'] == predictions['predicted_label']
    
    sns.kdeplot(data=predictions[correct_mask], x='probability',
                label='Correct Predictions')
    sns.kdeplot(data=predictions[~correct_mask], x='probability',
                label='Incorrect Predictions')
    
    plt.xlabel('Model Confidence')
    plt.ylabel('Density')
    plt.title('Distribution of Model Confidence by Prediction Correctness')
    plt.legend()
    plt.show()
    
    # Calculate confidence statistics
    confidence_stats = predictions.groupby(
        predictions['true_label'] == predictions['predicted_label']
    )['probability'].describe()
    
    print("\nConfidence Statistics:")
    print(confidence_stats)

In [None]:
def analyze_temporal_patterns(predictions):
    """Analyze performance patterns over time/sequence."""
    predictions['window_idx'] = predictions.groupby('Participant').cumcount()
    
    # Accuracy over window sequence
    window_accuracy = predictions.groupby('window_idx').apply(
        lambda x: (x['true_label'] == x['predicted_label']).mean()
    )
    
    plt.figure(figsize=(12, 6))
    window_accuracy.plot()
    plt.xlabel('Window Sequence')
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy Over Window Sequence')
    plt.grid(True)
    plt.show()