# GMM Model Evaluation Tutorial

## Part 1: Introduction and Setup

In this first part, we'll prepare our environment by importing essential libraries and configuring paths for smooth project navigation. This setup ensures reproducibility and consistency when evaluating Gaussian Mixture Model (GMM) implementations.

Let's start by importing all the necessary utilities and defining our project structure:

In [None]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Add project root to path if needed
project_root = Path('.')
if project_root not in sys.path:
    sys.path.insert(0, str(project_root))
output_dir = project_root / 'output'
experiment_base_dir = output_dir / 'final_experiments'

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

import logging
logging.basicConfig(level=logging.WARNING)

## Step 2: Import Evaluation and Visualization Utilities

Next, we'll import specific functions required for evaluating and visualizing our Gaussian Mixture Models. These utilities are organized into three main categories:

- **Data Handling** (`io`): Loading models and data loaders.
- **Evaluation** (`eval_utils`): Evaluating model performance, running baseline comparisons like K-means, and computing metrics.
- **Visualization** (`visualization`): Tools for plotting, formatting figures, and creating animations for insightful comparisons.

In [None]:
from scripts.evaluation.tutorial.src.io import (
    load_model_from_experiment,
    create_data_loader
)
from scripts.evaluation.tutorial.src.eval_utils import (
    evaluate, 
    evaluate_dataset,
    run_kmeans,
    compute_metrics,
    get_flow_prediction
)
from scripts.evaluation.tutorial.src.visualization import (
    visualize_gmm_data,
    set_plotting_style,
    save_figure,
    create_comparison_grid,
    create_comparison_figure,
    format_axis_with_grid,
    format_legend,
    save_animation,
    VisualizationPipeline
)

# Part 2: Dataset Comparison

In this section, we will visualize and compare different GMM datasets to better understand how variations in dataset characteristics (such as cluster separability and complexity) affect model performance.

We'll start by initializing our visualization pipeline, which manages data handling and plotting in a convenient way.

In [3]:
# Initialize the visualization pipeline
import tempfile
temp_dir = Path(tempfile.mkdtemp())  # Create temporary directory for pipeline
pipeline = VisualizationPipeline(
    experiment_dir=experiment_base_dir,  # Use the already defined experiment directory
    output_dir=temp_dir,
    device=device,
    verbose=False  # Enable verbose output for progress tracking
)

## Basic Dataset Comparison

We'll generate and visualize examples from three types of synthetic Gaussian Mixture Model (GMM) datasets, each containing 1000 samples per instance:

- **Simple Dataset**: High separability with SNR-dB values sampled from a truncated normal distribution (mean=14.0, std=1.5, min=12.0, max=17.0). Number of clusters sampled from truncated Poisson (mean=2.5, min=2, max=4).

- **Standard Dataset**: Moderate complexity with SNR-dB values from a truncated normal distribution (mean=9.0, std=1.5, min=7.0, max=12.0). Number of clusters sampled from truncated Poisson (mean=5.0, min=3, max=7).

- **Complex Dataset**: High complexity with SNR-dB values from a truncated normal distribution (mean=5.0, std=1.5, min=3.0, max=7.0). Number of clusters sampled from truncated Poisson (mean=8.0, min=5, max=15).

We'll visualize three samples from each dataset type to illustrate how cluster separation varies across these scenarios.

In [None]:
# Generate 3 unique samples per dataset
results = []
dataset_types = ['simple', 'standard', 'complex']
dataset_labels = ['Simple', 'Standard', 'Complex']

# For each dataset type, generate 3 unique samples
for dataset_type, label in zip(dataset_types, dataset_labels):
    dataset_results = pipeline._process_dataset_input(
        datasets=dataset_type,
        models=None,
        parameter_values=None,
        show=['points', 'true_centers'],
        num_samples=3
    )
    
    # Add proper titles to the results
    for i, result in enumerate(dataset_results):
        result['metadata']['title'] = f"{label} Sample {i+1}"
        results.append(result)

# Create the grid
titles = [result['metadata']['title'] for result in results]
fig, axes = create_comparison_grid(
    results=results,
    layout='3x3',
    show_predictions=False,
    show_kmeans=False,
    titles=titles,
    figsize=(10, 10),
    verbose=False,
    size_scale=0.6
)

## SNR Dataset Comparison

We'll now visualize Gaussian Mixture Model datasets with varying noise levels (SNR), each containing a variable sample size between 500 and 2000 samples per instance. Each dataset maintains moderate cluster complexity, with the number of clusters sampled from a truncated Poisson distribution (mean=5.0, min=3, max=7):

- **High-SNR Dataset**: Fixed high SNR-dB at 15, very clear cluster separation.
- **Medium-SNR Dataset**: Fixed medium SNR-dB at 10, moderate cluster visibility.
- **Low-SNR Dataset**: Fixed low SNR-dB at 5, significant cluster overlap.

We'll generate three identical examples (with different SNR) from each dataset type to highlight how different SNR levels impact data visibility and clustering difficulty.


In [None]:
# Generate 3 unique samples per dataset
results = []
dataset_types = ['high_snr_fixed', 'average_snr_fixed', 'low_snr_fixed']
dataset_labels = ['High SNR', 'Moderate SNR', 'Low SNR']

# For each dataset type, generate 3 unique samples
for dataset_type, label in zip(dataset_types, dataset_labels):
    dataset_results = pipeline._process_dataset_input(
        datasets=dataset_type,
        models=None,
        parameter_values=None,
        show=['points', 'true_centers'],
        num_samples=3
    )
    
    # Add proper titles to the results
    for i, result in enumerate(dataset_results):
        result['metadata']['title'] = f"{label} Sample {i+1}"
        results.append(result)

# Create the grid
titles = [result['metadata']['title'] for result in results]
fig2, axes = create_comparison_grid(
    results=results,
    layout='3x3',
    show_predictions=False,
    show_kmeans=False,
    titles=titles,
    figsize=(10, 10),
    verbose=False,
    size_scale=0.6
)

## Dataset Comparison with KMeans Analysis

We'll now visualize how KMeans clustering performs as a baseline method on the previously defined datasets:

- **Simple**
- **Standard**
- **Complex**

Each plot illustrates the dataset points, true cluster centers (green stars), and cluster centers identified by KMeans (orange diamonds). This helps us intuitively assess how well KMeans recovers cluster structures at different levels of dataset complexity.

In [None]:
fig4 = pipeline.scatter_plot(
    datasets=['simple', 'standard', 'complex'],
    show=['points', 'true_centers', 'kmeans'],
    layout='1x3',
    titles=['Simple + KMeans', 'Standard + KMeans', 'Complex + KMeans'],
    figsize=(10, 3.33),
    save_path=None,
    verbose=False,
    size_scale=0.6
)

## Part 4: Training a Model

Here's how to train a GMM model using the experiment runner. This example shows training a 16-layer model with the same configuration used for the baseline models in the tutorial.

In [7]:
# Direct training using ExperimentManager
import json
from datetime import datetime
from config import ExperimentConfig
from config.registry import (
    get_model_config,
    get_training_config,
    get_data_config,
    get_validation_config
)
from training.experiment import ExperimentManager

def create_and_train_model(
    experiment_name="tutorial_model",
    epochs=10,
    device="cuda:0",
    output_dir="./output/tutorial_experiments"
):
    """
    Create and train a GMM model using the same configuration as baseline models.
    
    Args:
        experiment_name: Name for the experiment
        epochs: Number of training epochs
        device: Device to train on
        output_dir: Where to save the model
    """
    
    # Get preset configurations
    model_config = get_model_config("medium")
    training_config = get_training_config("standard")
    data_config = get_data_config("diverse_clusters_snr")
    validation_config = get_validation_config("standard")
    
    # Build configuration dictionary
    config_dict = {
        "model": model_config.model_dump(),
        "training": training_config.model_dump(),
        "data": data_config.model_dump(),
        "validation": validation_config.model_dump(),
        "metadata": {
            "id": f"{experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
            "experiment_name": experiment_name,
            "description": f"Tutorial experiment with 16-layer model trained for {epochs} epochs"
        },
        "device": {
            "device": device,
            "use_mixed_precision": False
        },
        "paths": {
            "base_dir": output_dir,
            "log_dir": "logs",
            "checkpoint_dir": "checkpoints",
            "data_dir": "data"
        }
    }
    
    # Configure model for 16 layers (1 layer repeated 16 times)
    config_dict["model"]["transformer"]["num_layers"] = 1
    config_dict["model"]["transformer"]["repeat_factor"] = 16
    config_dict["model"]["transformer"]["layer_repeat_mode"] = "layerwise"
    
    # Configure flow predictor
    config_dict["model"]["transformer"]["use_flow_predictor"] = True
    config_dict["model"]["transformer"]["flow_predictor"] = {
        "type": "monotonic",
        "num_basis": 100,
        "min_value": 0.0,
        "max_value": 1.0,
        "per_layer": False,
        "min_snr": 3.0,
        "max_snr": 15.0,
        "distribution_mode": "direct"
    }
    
    # Configure encoder/decoder
    config_dict["model"]["encoder"]["use_orthogonal"] = True
    config_dict["model"]["decoder"]["use_orthogonal"] = True
    
    # Configure loss
    config_dict["training"]["loss"]["loss_type"] = {
        "type": "wasserstein",
        "algorithm": "exact",
        "backend": "pot",
        "use_true_weights": False
    }
    config_dict["training"]["loss"]["normalization"] = "snr_power"
    config_dict["training"]["loss"]["snr_power"] = 1.0
    
    # Set number of epochs
    config_dict["training"]["num_epochs"] = epochs
    
    # Create and validate configuration
    config = ExperimentConfig.model_validate(config_dict)
    
    print(f"Creating experiment: {experiment_name}")
    print(f"Model: 16 layers (1 layer repeated 16 times)")
    print(f"Training for {epochs} epochs on {device}")
    print(f"Output directory: {output_dir}/{config.metadata.id}")
    
    # Create experiment manager
    experiment = ExperimentManager(config)
    
    # Setup experiment (creates directories, initializes model, etc.)
    print("\nSetting up experiment...")
    experiment.setup()
    
    # Run training
    print("\nStarting training...")
    history = experiment.run()
    
    print(f"\nTraining completed!")
    if 'val_loss' in history and len(history['val_loss']) > 0:
        print(f"Final validation loss: {history['val_loss'][-1]:.4f}")
    else:
        print(f"Final training loss: {history['train_loss'][-1]:.4f}")
    print(f"Model saved to: {experiment.experiment_dir}")
    
    return experiment, history

# Example: Train a model for 10 epochs
# Uncomment to run:
# experiment, history = create_and_train_model(
#     experiment_name="my_tutorial_model",
#     epochs=10,
#     device=device  # Uses the device defined earlier in the notebook
# )

### Quick Training Example

Here's how to train a smaller model for testing (fewer epochs, smaller configuration):

In [10]:
# Quick test: Train a small model for 3 epochs
def train_test_model():
    """Train a small test model to verify everything works."""
    
    # Get configurations
    model_config = get_model_config("small")  # Smaller model for faster training
    training_config = get_training_config("quick")  # Quick training preset
    data_config = get_data_config("standard")  # Standard data
    validation_config = get_validation_config("minimal")  # Minimal validation
    
    # Create simple config
    config_dict = {
        "model": model_config.model_dump(),
        "training": training_config.model_dump(),
        "data": data_config.model_dump(),
        "validation": validation_config.model_dump(),
        "metadata": {
            "id": f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
            "experiment_name": "quick_test",
            "description": "Quick test model"
        },
        "device": {
            "device": device,
            "use_mixed_precision": False
        },
        "paths": {
            "base_dir": "./output/test",
            "log_dir": "logs",
            "checkpoint_dir": "checkpoints",
            "data_dir": "data"
        }
    }
    
    # Simple 4-layer model
    config_dict["model"]["transformer"]["num_layers"] = 4
    config_dict["training"]["num_epochs"] = 3
    
    # Disable validation by setting val_every > num_epochs
    config_dict["training"]["val_every"] = 1000  # Much larger than 3 epochs
    
    # Create experiment
    config = ExperimentConfig.model_validate(config_dict)
    experiment = ExperimentManager(config)
    
    print("Training small test model (4 layers, 3 epochs)...")
    print("Validation disabled for faster training demo")
    experiment.setup()
    
    # Run training and capture the history
    history = experiment.run()
    
    # Return both experiment and history
    return experiment, history

# Uncomment to run a quick test:
# experiment, history = train_test_model()

### Run Training

Let's actually train a small model as a demonstration. This will train a 4-layer model for 3 epochs, which should complete quickly.

In [None]:
# Actually run the training
print("Starting model training...")
print("This will train a small 4-layer model for 3 epochs as a demonstration.")
print("Validation is disabled for faster training.\n")

# Train the model and get both experiment and history
test_experiment, history = train_test_model()

# Plot the training history
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
if 'learning_rate' in history:
    plt.plot(history['learning_rate'])
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.grid(True)

plt.tight_layout()
plt.show()

print(f"\n✓ Model trained successfully!")
print(f"✓ Saved to: {test_experiment.experiment_dir}")
print(f"✓ Final training loss: {history['train_loss'][-1]:.4f}")

# Part 6: Model Training and Evaluation: High vs. Low SNR

We'll investigate how training the same model architecture on datasets with different Signal-to-Noise Ratios (SNR) impacts its performance. Specifically, we'll:

1. Train one model on a **High SNR** dataset and another on a **Low SNR** dataset.
2. Evaluate each trained model on both **High** and **Low SNR** datasets.

To ensure reproducibility, we'll generate dataset samples using fixed seeds as defined below.


In [12]:
# Fixed seeds for reproducible results
FIXED_SEEDS = {
    'high_snr_fixed': 90,   # → 15.0 dB
    'low_snr_fixed': 123    # → 5.0 dB
}

# Generate samples with fixed seeds
from scripts.evaluation.tutorial.src.io import create_data_samples

samples = {}
for dataset_name, seed in FIXED_SEEDS.items():
    inputs, targets = create_data_samples(
        dataset_name=dataset_name,
        num_samples=1,
        points_per_gmm=1000,
        device=device,
        base_seed=seed,
        loader_id=f"{dataset_name}_{seed}"
    )
    
    samples[dataset_name] = {
        'points': inputs[0],
        'centers': targets['centers'][0],
        'labels': targets['labels'][0],
        'snr_db': targets['snr_db'][0].item()
    }

In [None]:
# Compare simple and hard models on fixed samples
models_to_compare = ['simple_16_layers', 'hard_16_layers']
results = []

for model_name in models_to_compare:
    for dataset_name, label in zip(['high_snr_fixed', 'low_snr_fixed'], ['High SNR', 'Low SNR']):
        sample = samples[dataset_name]
        
        result = pipeline._process_single_data_input(
            data=sample,
            models=model_name,
            parameter_values=None,
            show=['points', 'predictions'],
        )
        
        result['metadata']['title'] = f"{model_name}\n{label} ({sample['snr_db']:.1f} dB)"
        results.append(result)

# Create comparison grid
for result in results:
    result['targets'] = {k: v for k, v in result['targets'].items() if k != 'centers'}
titles = [result['metadata']['title'] for result in results]
fig, axes = create_comparison_grid(
    results=results,
    layout='2x2',
    show_predictions=True,
    show_kmeans=False,
    titles=titles,
    figsize=(10, 10),
    size_scale=0.6,
    verbose=False
)

# Part 7: Flow Substitution Animation

In this section, we demonstrate the effect of manually substituting a fixed flow speed value into a model originally trained without an explicit flow predictor.

### Static Flow Comparison

First, we'll visualize this substitution to observe how introducing a fixed flow speed directly affects the model's behavior. Specifically, we'll:

- Take a baseline model trained without flow adjustment ("no flow").
- Manually set the flow speed to a fixed value (flow_speed = 0.2), without retraining or changing model parameters.

In [None]:
# Create static comparison of no_flow vs manual flow injection
static_fig = pipeline.create_static_flow_comparison(
    base_model='no_flow_16_layers',
    flow_value=0.2,
    comparison_type='no_flow_vs_manual',
    save_path=None  # Display inline instead of saving
)

## Basic Flow Animation (0 → 1)

Next, we'll visualize the impact of substituting a range of flow speeds (from 0 to 1) into a model that was trained without any explicit flow predictor.  

This animation smoothly transitions through different manually-set flow speeds, clearly illustrating how changing flow speed alone influences the model's clustering behavior.

In [None]:
# Create flow substitution animation
import os
from IPython.display import Image, display

# Save animation to a path
animation_path = temp_dir / "flow_substitution_animation.gif"
basic_anim = pipeline.create_flow_substitution_animation(
    base_model='no_flow_16_layers',
    flow_range=(0.0, 1.0),
    frames=100,  # Fewer frames for notebook display
    comparison_type='no_flow_vs_manual',
    save_path=animation_path,
    show_animation=False  # Don't show in separate window
)

# Display the GIF in the notebook
if animation_path.exists():
    display(Image(filename=str(animation_path)))
else:
    print("Animation not created")

## Extended Flow Animation (0 → 5)

Now, let's explore an extended range of flow speeds (from 0 to 5). We'll demonstrate two different approaches to achieve this increased flow effect:

1. **Direct Flow Adjustment**:  
   Keep the original model architecture fixed, substituting flow speeds ranging from 0 to 5.

2. **Layer Scaling Approach**:  
   Increase the number of layers by a factor of 5 and proportionally reduce the flow speed range to 0–1, effectively simulating the same cumulative flow.


In [None]:
# Extended range animation with regime comparison
extended_animation_path = temp_dir / "flow_comparison_0_to_5.gif"
custom_anim = pipeline.create_flow_substitution_animation(
    base_model='no_flow_16_layers',
    flow_range=(0.0, 5.0),
    frames=300,  # More frames for smoother animation
    comparison_type='regime_comparison',
    regime_settings={
        'left': {'repeat_factor': 16, 'flow_divisor': 1},    # 1x layers, full flow
        'right': {'repeat_factor': 80, 'flow_divisor': 5}    # 5x layers, flow÷5
    },
    save_path=extended_animation_path,
    show_animation=False
)

# Display the GIF
if extended_animation_path.exists():
    display(Image(filename=str(extended_animation_path)))
else:
    print("Animation not created")

## Uniform vs. Unit Flow Speed Comparison

We'll compare two different regimes for distributing flow speed across transformer layers:

- **Uniform regime**: Flow speed is uniformly distributed across all layers.
- **Unit regime**: Flow speed is allocated entirely to a single layer at a time (fractional approach).


In [None]:
# Direct vs Fractional flow mode comparison
direct_frac_path = temp_dir / "direct_vs_fractional.gif"
frac_anim = pipeline.create_flow_substitution_animation(
    base_model='no_flow_16_layers',
    flow_range=(0.0, 1.0),
    frames=300,
    comparison_type='direct_vs_fractional',
    save_path=direct_frac_path,
    show_animation=False
)
# Display the GIF
if direct_frac_path.exists():
    display(Image(filename=str(direct_frac_path)))
else:
    print("Animation not created")

## Simple vs. Hard Model Comparison

In this section, we directly compare two transformer models trained on datasets with different complexities:

- **Simple model**: Trained on clearly separable (high-SNR) data.
- **Hard model**: Trained on challenging, overlapping (low-SNR) data.

We substitute a manual flow speed parameter ranging from 0 to 1 into both models to illustrate how their learned parameters respond differently to changes in flow speed.


In [None]:
# Helper function for model comparisons
def create_model_comparison_animation(pipeline, left_model_name, right_model_name, 
                                    left_name, right_name, flow_range, frames, 
                                    save_path, **model_settings):
    """Create side-by-side model comparison animation."""
    
    # Load models
    left_model, _ = pipeline._load_model(left_model_name)
    right_model, _ = pipeline._load_model(right_model_name)
    
    # Apply settings
    if 'left' in model_settings:
        for attr, value in model_settings['left'].items():
            if attr != 'model' and hasattr(left_model.transformer, attr):
                setattr(left_model.transformer, attr, value)
    
    if 'right' in model_settings:
        for attr, value in model_settings['right'].items():
            if attr != 'model' and hasattr(right_model.transformer, attr):
                setattr(right_model.transformer, attr, value)
    
    # Generate flow values
    flow_values = np.linspace(flow_range[0], flow_range[1], frames)
    
    # Create frame data
    frame_data = []
    for flow_speed in flow_values:
        left_flow = flow_speed
        right_flow = flow_speed
        
        if 'left' in model_settings and 'flow_divisor' in model_settings['left']:
            left_flow = flow_speed / model_settings['left']['flow_divisor']
        if 'right' in model_settings and 'flow_divisor' in model_settings['right']:
            right_flow = flow_speed / model_settings['right']['flow_divisor']
        
        frame_data.append({
            'left_model': left_model,
            'right_model': right_model,
            'left_flow': left_flow,
            'right_flow': right_flow,
            'titles': [f"{left_name}: {flow_speed:.2f}", f"{right_name}: {flow_speed:.2f}"],
            'parameter_value': flow_speed
        })
    
    return pipeline._create_side_by_side_animation(
        frame_data, 
        save_path=save_path, 
        show_animation=False
    )

# Simple vs Hard comparison
simple_hard_path = temp_dir / "simple_vs_hard_direct.gif"
simple_hard_anim = create_model_comparison_animation(
    pipeline=pipeline,
    left_model_name='simple_16_layers',
    right_model_name='hard_16_layers',
    left_name="Simple",
    right_name="Hard",
    flow_range=(0.0, 1.0),
    frames=100,
    save_path=simple_hard_path,
    left={'flow_distribution_mode': 'direct'},
    right={'flow_distribution_mode': 'direct'}
)

# Display the GIF
if simple_hard_path.exists():
    display(Image(filename=str(simple_hard_path)))
else:
    print("Animation not created")

## Part 8: Model Performance Comparison

Comprehensive evaluation of different model architectures and configurations across various datasets and SNR levels.

In [20]:
import pandas as pd
# Configuration
datasets = ["high_snr_fixed", "average_snr_fixed", "low_snr_fixed"]
batch_size = 32
total_samples = 4096

# Create cache directory matching original script structure
snr_cache_dir = project_root / 'scripts/evaluation/tutorial/output/snr_performance'
snr_cache_dir.mkdir(parents=True, exist_ok=True)

# Define models to evaluate
snr_model_configs = {
    "baseline_16_layers": {"name": "16 layers", "path": "baseline_16_layers"},
    "baseline_32_layers": {"name": "32 layers", "path": "baseline_32_layers"},
    "baseline_64_layers": {"name": "64 layers", "path": "baseline_64_layers"},
    "no_flow_16_layers": {"name": "No flow (16 layers)", "path": "no_flow_16_layers"},
}

def get_snr_cache_path(dataset_name, model_key=None):
    """Get cache file path for SNR evaluation results."""
    if model_key:
        return snr_cache_dir / f"{dataset_name}_{model_key}_results.pt"
    else:
        return snr_cache_dir / f"{dataset_name}_kmeans_results.pt"

def evaluate_model_on_dataset(model_key, model_config, dataset_name, device):
    """Evaluate a single model on a dataset, with caching."""
    cache_path = get_snr_cache_path(dataset_name, model_key)
    
    # Try to load cached results
    if cache_path.exists():
        return torch.load(cache_path, weights_only=False)
    
    print(f"  Evaluating {model_config['name']} on {dataset_name}...")
    model_path = experiment_base_dir / model_config["path"]
    model, _ = load_model_from_experiment(model_path, load_best=False, device=device)
    
    # Create data loader
    data_loader = create_data_loader(
        dataset_name=dataset_name,
        batch_size=batch_size,
        total_samples=total_samples,
        device=device
    )
    
    # Evaluate dataset
    eval_results_list = evaluate_dataset(
        model, 
        data_loader,
        kmeans_on_inputs=False,
        kmeans_on_predictions=False,
        metrics=['log_wasserstein'],
        device=device
    )
    
    # Aggregate results
    all_wass = []
    all_log_wass = []
    all_snr_db = []
    
    for batch_results in eval_results_list:
        if 'metrics' in batch_results and 'log_wasserstein' in batch_results['metrics']:
            batch_log_wass = batch_results['metrics']['log_wasserstein'].cpu().numpy()
            all_log_wass.extend(batch_log_wass)
            batch_wass = np.exp(batch_log_wass)
            all_wass.extend(batch_wass)
        if 'snr_values' in batch_results and batch_results['snr_values'] is not None:
            batch_snr = batch_results['snr_values'].cpu().numpy()
            all_snr_db.extend(batch_snr)
    
    results = {
        'wasserstein': np.array(all_wass),
        'log_wasserstein': np.array(all_log_wass),
        'snr_db': np.array(all_snr_db) if all_snr_db else None,
        'avg_wasserstein': np.mean(all_wass),
        'std_wasserstein': np.std(all_wass),
        'avg_log_wasserstein': np.mean(all_log_wass),
        'std_log_wasserstein': np.std(all_log_wass)
    }
    
    # Cache results
    torch.save(results, cache_path)
    return results

def evaluate_kmeans_baseline(dataset_name, device):
    """Evaluate K-means baseline on a dataset, with caching."""
    cache_path = get_snr_cache_path(dataset_name)
    
    # Try to load cached results
    if cache_path.exists():
        return torch.load(cache_path, weights_only=False)
    
    data_loader = create_data_loader(
        dataset_name=dataset_name,
        batch_size=batch_size,
        total_samples=total_samples,
        device=device
    )
    
    # Use any model just to get K-means results
    model_path = experiment_base_dir / "baseline_16_layers"
    model, _ = load_model_from_experiment(model_path, load_best=False, device=device)
    
    # Evaluate with K-means only
    eval_results_list = evaluate_dataset(
        model, 
        data_loader,
        kmeans_on_inputs=True,
        kmeans_on_predictions=False,
        metrics=['log_kmeans_wasserstein'],
        device=device
    )
    
    # Aggregate K-means results
    all_kmeans_wass = []
    all_log_kmeans_wass = []
    all_snr_db = []
    
    for batch_results in eval_results_list:
        if 'metrics' in batch_results and 'log_kmeans_wasserstein' in batch_results['metrics']:
            batch_log_kmeans_wass = batch_results['metrics']['log_kmeans_wasserstein'].cpu().numpy()
            all_log_kmeans_wass.extend(batch_log_kmeans_wass)
            batch_kmeans_wass = np.exp(batch_log_kmeans_wass)
            all_kmeans_wass.extend(batch_kmeans_wass)
        if 'snr_values' in batch_results and batch_results['snr_values'] is not None:
            batch_snr = batch_results['snr_values'].cpu().numpy()
            all_snr_db.extend(batch_snr)
    
    results = {
        'wasserstein': np.array(all_kmeans_wass),
        'log_wasserstein': np.array(all_log_kmeans_wass),
        'snr_db': np.array(all_snr_db) if all_snr_db else None,
        'avg_wasserstein': np.mean(all_kmeans_wass),
        'std_wasserstein': np.std(all_kmeans_wass),
        'avg_log_wasserstein': np.mean(all_log_kmeans_wass),
        'std_log_wasserstein': np.std(all_log_kmeans_wass)
    }
    
    # Cache results
    torch.save(results, cache_path)
    return results

In [None]:
# Create average performance plot for SNR analysis
def evaluate_models_for_plotting(model_configs, device):
    """Evaluate all models on diverse_snr_moderate dataset for plotting, with caching."""
    
    diverse_dataset = "diverse_snr_moderate"
    batch_size = 1
    total_samples = 4096
    
    # Store all results
    all_results = {}
    
    # Evaluate K-means baseline
    kmeans_cache_path = snr_cache_dir / f"{diverse_dataset}_kmeans_full_results.pt"
    
    if kmeans_cache_path.exists():
        print(f"Loading cached K-means results from: {kmeans_cache_path}")
        all_results['kmeans'] = torch.load(kmeans_cache_path, weights_only=False)
    else:
        print(f"Evaluating K-means baseline on {diverse_dataset}...")
        # Create data loader
        data_loader = create_data_loader(
            dataset_name=diverse_dataset,
            batch_size=batch_size,
            total_samples=total_samples,
            device=device
        )
        
        # Use any model just to get K-means results
        model_path = experiment_base_dir / "baseline_16_layers"
        model, _ = load_model_from_experiment(model_path, load_best=False, device=device)
        
        # Evaluate with K-means only
        eval_results_list = evaluate_dataset(
            model, 
            data_loader,
            kmeans_on_inputs=True,
            kmeans_on_predictions=False,
            metrics=['log_kmeans_wasserstein'],
            device=device
        )
        
        # Aggregate K-means results
        all_kmeans_wass = []
        all_log_kmeans_wass = []
        all_snr_db = []
        
        for batch_results in eval_results_list:
            if 'metrics' in batch_results and 'log_kmeans_wasserstein' in batch_results['metrics']:
                batch_log_kmeans_wass = batch_results['metrics']['log_kmeans_wasserstein'].cpu().numpy()
                all_log_kmeans_wass.extend(batch_log_kmeans_wass)
                batch_kmeans_wass = np.exp(batch_log_kmeans_wass)
                all_kmeans_wass.extend(batch_kmeans_wass)
            if 'snr_values' in batch_results and batch_results['snr_values'] is not None:
                batch_snr = batch_results['snr_values'].cpu().numpy()
                all_snr_db.extend(batch_snr)
        
        kmeans_results = {
            'wasserstein': np.array(all_kmeans_wass),
            'log_wasserstein': np.array(all_log_kmeans_wass),
            'snr_db': np.array(all_snr_db) if all_snr_db else None,
            'avg_wasserstein': np.mean(all_kmeans_wass),
            'avg_log_wasserstein': np.mean(all_log_kmeans_wass),
        }
        
        all_results['kmeans'] = kmeans_results
        # Save the full results with different name to avoid overwriting simple cache
        torch.save(kmeans_results, kmeans_cache_path)
        print(f"Saved K-means results to: {kmeans_cache_path}")
    
    # Evaluate each model
    for model_key, model_config in model_configs.items():
        cache_path = snr_cache_dir / f"{diverse_dataset}_{model_key}_full_results.pt"
        
        if cache_path.exists():
            all_results[model_key] = torch.load(cache_path, weights_only=False)
        else:
            print(f"Evaluating {model_config['name']} on {diverse_dataset}...")
            # Create data loader
            data_loader = create_data_loader(
                dataset_name=diverse_dataset,
                batch_size=batch_size,
                total_samples=total_samples,
                device=device
            )
            
            # Load model
            model_path = experiment_base_dir / model_config["path"]
            model, _ = load_model_from_experiment(model_path, load_best=False, device=device)
            
            # Evaluate dataset
            eval_results_list = evaluate_dataset(
                model, 
                data_loader,
                kmeans_on_inputs=False,
                kmeans_on_predictions=False,
                metrics=['log_wasserstein'],
                device=device
            )
            
            # Aggregate results
            all_wass = []
            all_log_wass = []
            all_snr_db = []
            
            for batch_results in eval_results_list:
                if 'metrics' in batch_results and 'log_wasserstein' in batch_results['metrics']:
                    batch_log_wass = batch_results['metrics']['log_wasserstein'].cpu().numpy()
                    all_log_wass.extend(batch_log_wass)
                    batch_wass = np.exp(batch_log_wass)
                    all_wass.extend(batch_wass)
                if 'snr_values' in batch_results and batch_results['snr_values'] is not None:
                    batch_snr = batch_results['snr_values'].cpu().numpy()
                    all_snr_db.extend(batch_snr)
            
            model_results = {
                'wasserstein': np.array(all_wass),
                'log_wasserstein': np.array(all_log_wass),
                'snr_db': np.array(all_snr_db) if all_snr_db else None,
                'avg_wasserstein': np.mean(all_wass),
                'avg_log_wasserstein': np.mean(all_log_wass),
            }
            
            all_results[model_key] = model_results
            # Save the full results with different name
            torch.save(model_results, cache_path)
            print(f"Saved results to: {cache_path}")
    
    return all_results

def plot_snr_performance(model_configs):
    """Plot average performance vs SNR for diverse dataset."""
    
    # Get cached results for diverse dataset
    diverse_results = evaluate_models_for_plotting(model_configs, device)
    
    # Create plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Define SNR bins
    snr_min, snr_max = 3.0, 15.0
    bin_edges = np.linspace(snr_min, snr_max, 10 + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    
    # Colors for different models
    colors = {
        'kmeans': 'orange',
        'baseline_16_layers': 'blue',
        'baseline_32_layers': 'red',
        'baseline_64_layers': 'green',
        'no_flow_16_layers': 'purple'
    }
    
    # Plot each model
    for model_key in ['kmeans'] + list(model_configs.keys()):
        if model_key in diverse_results:
            log_wass = diverse_results[model_key]['log_wasserstein']
            snr_db = diverse_results[model_key]['snr_db']
            
            # Compute average in each bin
            avg_wass = []
            for i in range(len(bin_edges) - 1):
                mask = (snr_db >= bin_edges[i]) & (snr_db < bin_edges[i + 1])
                if np.any(mask):
                    avg_wass.append(np.exp(log_wass[mask].mean()))
                else:
                    avg_wass.append(np.nan)
            
            avg_wass = np.array(avg_wass)
            valid_mask = ~np.isnan(avg_wass)
            
            if model_key == 'kmeans':
                label = 'K-means'
            else:
                label = model_configs[model_key]['name']
            
            color = colors.get(model_key, 'black')
            ax.plot(bin_centers[valid_mask], avg_wass[valid_mask], 
                    color=color, marker='o', markersize=6, 
                    linewidth=2, label=label)
    
    ax.set_xlabel('SNR (dB)')
    ax.set_ylabel('Average Wasserstein Distance')
    ax.set_title('Model Performance vs SNR')
    ax.set_xlim(snr_min - 0.5, snr_max + 0.5)
    ax.set_yscale('log')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    return fig

# Create the plot
snr_plot = plot_snr_performance(snr_model_configs)
plt.tight_layout()
plt.show()