## Packages


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from typing import Tuple
import sys
from pathlib import Path
from datetime import datetime
import os
import pyro

# Add parent directory to path to import Models
# This works for notebooks in the Experiments folder
project_root = Path.cwd().parent if Path.cwd().name == 'Experiments' else Path.cwd()
sys.path.insert(0, str(project_root))

# Setup results directory
results_dir = project_root / "results" / "ood_parameter_comparison"
results_dir.mkdir(parents=True, exist_ok=True)
plots_dir = results_dir / "plots"
plots_dir.mkdir(exist_ok=True)
stats_dir = results_dir / "statistics"
stats_dir.mkdir(exist_ok=True)

print(f"Results will be saved to: {results_dir}")

# Import from Models folder
from Models.MC_Dropout import (
    MCDropoutRegressor,
    train_model,
    mc_dropout_predict,
    gaussian_nll,
    beta_nll,
    plot_toy_data,
    plot_uncertainties,
    normalize_x,
    normalize_x_data
)

from Models.Deep_Ensemble import (
    train_ensemble_deep,
    ensemble_predict_deep
)

from utils.device import get_device
from utils.plotting import (
    plot_toy_data, 
    plot_uncertainties_ood,
    plot_uncertainties_ood_normalized,
    plot_uncertainties_entropy_ood,
    plot_uncertainties_entropy_ood_normalized,
    plot_entropy_lines_ood
)
import utils.results_save as results_save_module
from utils.results_save import save_plot, save_statistics, save_model_outputs

# Import OOD helper functions
from utils.ood_experiments import (
    generate_data_with_ood,
    compute_and_save_statistics_ood,
    compute_and_save_statistics_entropy_ood
)

# Import entropy uncertainty functions
from utils.entropy_uncertainty import entropy_uncertainty_analytical

# Import metrics functions
from utils.metrics import (
    compute_predictive_aggregation,
    compute_gaussian_nll,
    compute_crps_gaussian,
    compute_true_noise_variance,
    compute_uncertainty_disentanglement
)

# Set the module-level directories for results_save
results_save_module.plots_dir = plots_dir
results_save_module.stats_dir = stats_dir


## Device Setup


In [None]:
device = get_device()


## Generate Toy Datasets


In [None]:
# Reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

# ----- Data generation for linear function with homo/heteroscedastic noise -----
def generate_toy_regression(n_train=1000, train_range=(0.0, 10.0), train_ranges=None,
                           ood_ranges=None, grid_points=1000, noise_type='heteroscedastic', type = "linear"):
    """
    Generate toy regression data with support for multiple training ranges and OOD regions.
    
    Args:
        n_train: Number of training samples
        train_range: Single training range tuple (min, max) - for backward compatibility
        train_ranges: List of training range tuples [(min1, max1), (min2, max2), ...]
                     If provided, overrides train_range. Samples are distributed proportionally.
        ood_ranges: List of OOD range tuples [(min1, max1), (min2, max2), ...]
                   If None, OOD is automatically everything NOT in training ranges
        grid_points: Number of grid points for evaluation
        noise_type: 'homoscedastic' or 'heteroscedastic'
        type: 'linear' or 'sin'
    
    Returns:
        (x_train, y_train, x_grid, y_grid_clean, ood_mask)
    """
    # Handle train_ranges: if provided, use it; otherwise use train_range as single range
    if train_ranges is None:
        train_ranges = [train_range]
    else:
        # train_ranges provided, ignore train_range
        pass
    
    # Sample training data proportionally from each training range
    # Calculate total width of all training ranges
    total_width = sum([r[1] - r[0] for r in train_ranges])
    
    # Sample from each range proportionally
    x_train_list = []
    samples_allocated = 0
    for idx, train_r in enumerate(train_ranges):
        low, high = train_r
        range_width = high - low
        # Number of samples proportional to range width
        if idx == len(train_ranges) - 1:
            # Last range gets remaining samples to ensure exact total
            n_samples = n_train - samples_allocated
        else:
            n_samples = int(n_train * range_width / total_width)
            samples_allocated += n_samples
        x_train_range = np.random.uniform(low, high, size=(n_samples, 1))
        x_train_list.append(x_train_range)
    
    x_train = np.vstack(x_train_list)
    # Shuffle to mix samples from different ranges
    indices = np.random.permutation(len(x_train))
    x_train = x_train[indices]
    
    if type == "linear":
        # Linear function: f(x) = 0.7x + 0.5
        f_clean = lambda x: 0.7 * x + 0.5
        y_clean_train = f_clean(x_train)
    elif type == "sin":
        f_clean = lambda x:  x * np.sin(x) + x
        y_clean_train = f_clean(x_train)
    else:
        raise ValueError("type must be 'linear', 'sin'")

    # Define noise variance σ²(x)
    if noise_type == 'homoscedastic':
        # Homoscedastic: σ(x) = 0.8
        sigma = 2
        sigma_train = np.full_like(x_train, sigma)
    elif noise_type == 'heteroscedastic':
        # Heteroscedastic: 
        sigma_train = np.abs(2.5 * np.sin(0.5*x_train +5))
    else:
        raise ValueError("noise_type must be 'homoscedastic' or 'heteroscedastic'")
    
    # Generate noise: ε | x ~ N(0, σ²(x))
    epsilon = np.random.normal(0.0, sigma_train, size=(n_train, 1))
    y_train = y_clean_train + epsilon

    # Determine grid extent: from min of all training/OOD ranges to max
    all_ranges = train_ranges + (ood_ranges if ood_ranges else [])
    grid_start = min([r[0] for r in all_ranges])
    grid_end = max([r[1] for r in all_ranges])
    
    # Dense evaluation grid spanning all training and OOD regions
    x_grid = np.linspace(grid_start, grid_end, grid_points).reshape(-1, 1)
    y_grid_clean = f_clean(x_grid)
    
    # Create OOD mask: True for points NOT in any training range
    # Everything outside training ranges is OOD (including gaps and explicit OOD ranges)
    ood_mask = np.ones(len(x_grid), dtype=bool)  # Start with all True (OOD)
    
    # Mark training ranges as ID (False in ood_mask)
    for train_r in train_ranges:
        train_start, train_end = train_r
        train_mask = (x_grid[:, 0] >= train_start) & (x_grid[:, 0] <= train_end)
        ood_mask[train_mask] = False  # Training regions are ID, not OOD
    
    # If explicit ood_ranges provided, ensure they are marked as OOD
    # (they might already be OOD if they're gaps, but this ensures they're marked)
    if ood_ranges is not None:
        for ood_range in ood_ranges:
            ood_start, ood_end = ood_range
            ood_mask |= (x_grid[:, 0] >= ood_start) & (x_grid[:, 0] <= ood_end)

    return (x_train.astype(np.float32), y_train.astype(np.float32),
            x_grid.astype(np.float32), y_grid_clean.astype(np.float32), ood_mask)


### Set Parameters


In [None]:
# Common parameters
n_train = 1000
train_range = (-5, 10)
ood_ranges = [(10,15)]  # List of (min, max) tuples for OOD regions
grid_points = 1000
seed = 42
noise_type = 'heteroscedastic'
func_type = 'sin'  # or 'linear'
function_name = "Sinusoidal" if func_type == 'sin' else "Linear"

# Parameters to vary
mc_samples_values = [10, 20, 50, 100, 200]  # MC Dropout forward passes
dropout_p_values = [0.1, 0.2, 0.25]  # Can add more: [0.1, 0.2, 0.3]
K_values = [5, 10, 15, 20, 25, 30]  # Deep Ensemble number of nets
epochs_values = [100, 250, 500]  # Number of training epochs

# Fixed training parameters
beta = 0.5
lr = 1e-3
batch_size = 32

torch.manual_seed(seed)


## Helper Functions for Parameter Comparison


In [None]:
def run_single_mc_dropout_ood(generate_toy_regression_func, x_train, y_train, x_grid, y_grid_clean, ood_mask,
                              p, mc_samples, beta, epochs, lr, batch_size, seed, 
                              function_name, noise_type, func_type, date, save_results=True):
    """Run a single MC Dropout OOD experiment and return results"""
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    ds = TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train))
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)
    
    model = MCDropoutRegressor(p=p)
    train_model(model, loader, epochs=epochs, lr=lr, loss_type='beta_nll', beta=beta)
    
    # Make predictions with raw arrays for metrics computation
    result = mc_dropout_predict(model, x_grid, M=mc_samples, return_raw_arrays=True)
    mu_pred, ale_var, epi_var, tot_var, (mu_samples, sigma2_samples) = result
    
    # Save model outputs for later recomputation
    save_model_outputs(
        mu_samples=mu_samples,
        sigma2_samples=sigma2_samples,
        x_grid=x_grid,
        y_grid_clean=y_grid_clean,
        x_train_subset=x_train,
        y_train_subset=y_train,
        model_name='MC_Dropout',
        noise_type=noise_type,
        func_type=func_type,
        subfolder='ood_parameter_comparison',
        dropout_p=p,
        mc_samples=mc_samples,
        date=date,
        epochs=epochs
    )
    
    # Split uncertainties by region
    id_mask = ~ood_mask
    
    uncertainties_id = {
        'ale': ale_var[id_mask] if ale_var.ndim == 1 else ale_var[id_mask].flatten(),
        'epi': epi_var[id_mask] if epi_var.ndim == 1 else epi_var[id_mask].flatten(),
        'tot': tot_var[id_mask] if tot_var.ndim == 1 else tot_var[id_mask].flatten()
    }
    
    uncertainties_ood = {
        'ale': ale_var[ood_mask] if ale_var.ndim == 1 else ale_var[ood_mask].flatten(),
        'epi': epi_var[ood_mask] if epi_var.ndim == 1 else epi_var[ood_mask].flatten(),
        'tot': tot_var[ood_mask] if tot_var.ndim == 1 else tot_var[ood_mask].flatten()
    }
    
    uncertainties_combined = {
        'ale': ale_var.flatten() if ale_var.ndim > 1 else ale_var,
        'epi': epi_var.flatten() if epi_var.ndim > 1 else epi_var,
        'tot': tot_var.flatten() if tot_var.ndim > 1 else tot_var
    }
    
    # Compute MSE separately
    mu_pred_flat = mu_pred.squeeze() if mu_pred.ndim > 1 else mu_pred
    y_grid_clean_flat = y_grid_clean.squeeze() if y_grid_clean.ndim > 1 else y_grid_clean
    
    mse_id = np.mean((mu_pred_flat[id_mask] - y_grid_clean_flat[id_mask])**2)
    mse_ood = np.mean((mu_pred_flat[ood_mask] - y_grid_clean_flat[ood_mask])**2)
    mse_combined = np.mean((mu_pred_flat - y_grid_clean_flat)**2)
    
    # Compute predictive aggregation (μ*, σ*²)
    mu_star, sigma2_star = compute_predictive_aggregation(mu_samples, sigma2_samples)
    
    # Compute true noise variance for grid points
    true_noise_var = compute_true_noise_variance(x_grid, noise_type, func_type)
    
    # Compute NLL, CRPS, and disentanglement metrics for each region
    nll_id = compute_gaussian_nll(y_grid_clean_flat[id_mask], mu_star[id_mask], sigma2_star[id_mask])
    nll_ood = compute_gaussian_nll(y_grid_clean_flat[ood_mask], mu_star[ood_mask], sigma2_star[ood_mask])
    nll_combined = compute_gaussian_nll(y_grid_clean_flat, mu_star, sigma2_star)
    
    crps_id = compute_crps_gaussian(y_grid_clean_flat[id_mask], mu_star[id_mask], sigma2_star[id_mask])
    crps_ood = compute_crps_gaussian(y_grid_clean_flat[ood_mask], mu_star[ood_mask], sigma2_star[ood_mask])
    crps_combined = compute_crps_gaussian(y_grid_clean_flat, mu_star, sigma2_star)
    
    disentangle_id = compute_uncertainty_disentanglement(
        y_grid_clean_flat[id_mask], mu_star[id_mask],
        ale_var[id_mask] if ale_var.ndim == 1 else ale_var[id_mask].flatten(),
        epi_var[id_mask] if epi_var.ndim == 1 else epi_var[id_mask].flatten(),
        true_noise_var[id_mask]
    )
    disentangle_ood = compute_uncertainty_disentanglement(
        y_grid_clean_flat[ood_mask], mu_star[ood_mask],
        ale_var[ood_mask] if ale_var.ndim == 1 else ale_var[ood_mask].flatten(),
        epi_var[ood_mask] if epi_var.ndim == 1 else epi_var[ood_mask].flatten(),
        true_noise_var[ood_mask]
    )
    disentangle_combined = compute_uncertainty_disentanglement(
        y_grid_clean_flat, mu_star, ale_var.flatten() if ale_var.ndim > 1 else ale_var,
        epi_var.flatten() if epi_var.ndim > 1 else epi_var, true_noise_var
    )
    
    # Compute entropy-based uncertainties
    entropy_results = entropy_uncertainty_analytical(mu_samples, sigma2_samples)
    ale_entropy = entropy_results['aleatoric']
    epi_entropy = entropy_results['epistemic']
    tot_entropy = entropy_results['total']
    
    # Split entropy uncertainties by region
    uncertainties_entropy_id = {
        'ale': ale_entropy[id_mask] if ale_entropy.ndim == 1 else ale_entropy[id_mask].flatten(),
        'epi': epi_entropy[id_mask] if epi_entropy.ndim == 1 else epi_entropy[id_mask].flatten(),
        'tot': tot_entropy[id_mask] if tot_entropy.ndim == 1 else tot_entropy[id_mask].flatten()
    }
    
    uncertainties_entropy_ood = {
        'ale': ale_entropy[ood_mask] if ale_entropy.ndim == 1 else ale_entropy[ood_mask].flatten(),
        'epi': epi_entropy[ood_mask] if epi_entropy.ndim == 1 else epi_entropy[ood_mask].flatten(),
        'tot': tot_entropy[ood_mask] if tot_entropy.ndim == 1 else tot_entropy[ood_mask].flatten()
    }
    
    uncertainties_entropy_combined = {
        'ale': ale_entropy.flatten() if ale_entropy.ndim > 1 else ale_entropy,
        'epi': epi_entropy.flatten() if epi_entropy.ndim > 1 else epi_entropy,
        'tot': tot_entropy.flatten() if tot_entropy.ndim > 1 else tot_entropy
    }
    
    # Save statistics if requested
    if save_results:
        compute_and_save_statistics_ood(
            uncertainties_id, uncertainties_ood, uncertainties_combined,
            mse_id, mse_ood, mse_combined,
            function_name, noise_type, func_type, 'MC_Dropout',
            date=date, dropout_p=p, mc_samples=mc_samples,
            nll_id=nll_id, nll_ood=nll_ood, nll_combined=nll_combined,
            crps_id=crps_id, crps_ood=crps_ood, crps_combined=crps_combined,
            spearman_aleatoric_id=disentangle_id['spearman_aleatoric'],
            spearman_aleatoric_ood=disentangle_ood['spearman_aleatoric'],
            spearman_aleatoric_combined=disentangle_combined['spearman_aleatoric'],
            spearman_epistemic_id=disentangle_id['spearman_epistemic'],
            spearman_epistemic_ood=disentangle_ood['spearman_epistemic'],
            spearman_epistemic_combined=disentangle_combined['spearman_epistemic']
        )
        
        # Compute and save normalized entropy-based statistics
        compute_and_save_statistics_entropy_ood(
            uncertainties_entropy_id, uncertainties_entropy_ood, uncertainties_entropy_combined,
            mse_id, mse_ood, mse_combined,
            function_name, noise_type, func_type, 'MC_Dropout',
            date=date, dropout_p=p, mc_samples=mc_samples,
            nll_id=nll_id, nll_ood=nll_ood, nll_combined=nll_combined,
            crps_id=crps_id, crps_ood=crps_ood, crps_combined=crps_combined,
            spearman_aleatoric_id=disentangle_id['spearman_aleatoric'],
            spearman_aleatoric_ood=disentangle_ood['spearman_aleatoric'],
            spearman_aleatoric_combined=disentangle_combined['spearman_aleatoric'],
            spearman_epistemic_id=disentangle_id['spearman_epistemic'],
            spearman_epistemic_ood=disentangle_ood['spearman_epistemic'],
            spearman_epistemic_combined=disentangle_combined['spearman_epistemic']
        )
        
        # Plot std-based variance uncertainties
        plot_uncertainties_ood(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_var, epi_var, tot_var, ood_mask,
            title=f"MC Dropout (p={p}, M={mc_samples}, E={epochs}) - Variance (Std)",
            noise_type=noise_type, func_type=func_type
        )
        
        # Plot normalized variance-based uncertainties
        plot_uncertainties_ood_normalized(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_var, epi_var, tot_var, ood_mask,
            title=f"MC Dropout (p={p}, M={mc_samples}, E={epochs}) - Normalized Variance",
            noise_type=noise_type, func_type=func_type,
            scale_factor=1
        )
        
        # Plot entropy-based uncertainties (as std-equivalent bands)
        plot_uncertainties_entropy_ood(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_entropy, epi_entropy, tot_entropy, ood_mask,
            title=f"MC Dropout (p={p}, M={mc_samples}, E={epochs}) - Entropy (Std-Equivalent)",
            noise_type=noise_type, func_type=func_type
        )
        
        # Plot normalized entropy-based uncertainties
        plot_uncertainties_entropy_ood_normalized(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_entropy, epi_entropy, tot_entropy, ood_mask,
            title=f"MC Dropout (p={p}, M={mc_samples}, E={epochs}) - Normalized Entropy",
            noise_type=noise_type, func_type=func_type,
            scale_factor=1
        )
        
        # Plot entropy values directly as line plots (in nats)
        plot_entropy_lines_ood(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_entropy, epi_entropy, tot_entropy, ood_mask,
            title=f"MC Dropout (p={p}, M={mc_samples}, E={epochs}) - Entropy Lines",
            noise_type=noise_type, func_type=func_type
        )
    
    return {
        'uncertainties_id': uncertainties_id,
        'uncertainties_ood': uncertainties_ood,
        'uncertainties_combined': uncertainties_combined,
        'mse_id': mse_id,
        'mse_ood': mse_ood,
        'mse_combined': mse_combined,
        'nll_id': nll_id,
        'nll_ood': nll_ood,
        'nll_combined': nll_combined,
        'crps_id': crps_id,
        'crps_ood': crps_ood,
        'crps_combined': crps_combined,
        'spearman_aleatoric_id': disentangle_id['spearman_aleatoric'],
        'spearman_aleatoric_ood': disentangle_ood['spearman_aleatoric'],
        'spearman_aleatoric_combined': disentangle_combined['spearman_aleatoric'],
        'spearman_epistemic_id': disentangle_id['spearman_epistemic'],
        'spearman_epistemic_ood': disentangle_ood['spearman_epistemic'],
        'spearman_epistemic_combined': disentangle_combined['spearman_epistemic'],
        'mu_pred': mu_pred,
        'ale_var': ale_var,
        'epi_var': epi_var,
        'tot_var': tot_var,
        'ale_entropy': ale_entropy,
        'epi_entropy': epi_entropy,
        'tot_entropy': tot_entropy,
        'uncertainties_entropy_id': uncertainties_entropy_id,
        'uncertainties_entropy_ood': uncertainties_entropy_ood,
        'uncertainties_entropy_combined': uncertainties_entropy_combined
    }


def run_single_deep_ensemble_ood(generate_toy_regression_func, x_train, y_train, x_grid, y_grid_clean, ood_mask,
                                 K, beta, batch_size, epochs, seed,
                                 function_name, noise_type, func_type, date, save_results=True):
    """Run a single Deep Ensemble OOD experiment and return results"""
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    x_mean, x_std = normalize_x(x_train)
    x_train_norm = normalize_x_data(x_train, x_mean, x_std)
    x_grid_norm = normalize_x_data(x_grid, x_mean, x_std)
    
    ensemble = train_ensemble_deep(
        x_train_norm, y_train,
        batch_size=batch_size, K=K,
        loss_type='beta_nll', beta=beta, parallel=True, epochs=epochs
    )
    
    # Make predictions with raw arrays for metrics computation
    result = ensemble_predict_deep(ensemble, x_grid_norm, return_raw_arrays=True)
    mu_pred, ale_var, epi_var, tot_var, (mu_samples, sigma2_samples) = result
    
    # Save model outputs for later recomputation
    save_model_outputs(
        mu_samples=mu_samples,
        sigma2_samples=sigma2_samples,
        x_grid=x_grid,
        y_grid_clean=y_grid_clean,
        x_train_subset=x_train,
        y_train_subset=y_train,
        model_name='Deep_Ensemble',
        noise_type=noise_type,
        func_type=func_type,
        subfolder='ood_parameter_comparison',
        n_nets=K,
        date=date,
        epochs=epochs
    )
    
    # Split uncertainties by region
    id_mask = ~ood_mask
    
    uncertainties_id = {
        'ale': ale_var[id_mask] if ale_var.ndim == 1 else ale_var[id_mask].flatten(),
        'epi': epi_var[id_mask] if epi_var.ndim == 1 else epi_var[id_mask].flatten(),
        'tot': tot_var[id_mask] if tot_var.ndim == 1 else tot_var[id_mask].flatten()
    }
    
    uncertainties_ood = {
        'ale': ale_var[ood_mask] if ale_var.ndim == 1 else ale_var[ood_mask].flatten(),
        'epi': epi_var[ood_mask] if epi_var.ndim == 1 else epi_var[ood_mask].flatten(),
        'tot': tot_var[ood_mask] if tot_var.ndim == 1 else tot_var[ood_mask].flatten()
    }
    
    uncertainties_combined = {
        'ale': ale_var.flatten() if ale_var.ndim > 1 else ale_var,
        'epi': epi_var.flatten() if epi_var.ndim > 1 else epi_var,
        'tot': tot_var.flatten() if tot_var.ndim > 1 else tot_var
    }
    
    # Compute MSE separately
    mu_pred_flat = mu_pred.squeeze() if mu_pred.ndim > 1 else mu_pred
    y_grid_clean_flat = y_grid_clean.squeeze() if y_grid_clean.ndim > 1 else y_grid_clean
    
    mse_id = np.mean((mu_pred_flat[id_mask] - y_grid_clean_flat[id_mask])**2)
    mse_ood = np.mean((mu_pred_flat[ood_mask] - y_grid_clean_flat[ood_mask])**2)
    mse_combined = np.mean((mu_pred_flat - y_grid_clean_flat)**2)
    
    # Compute predictive aggregation (μ*, σ*²)
    mu_star, sigma2_star = compute_predictive_aggregation(mu_samples, sigma2_samples)
    
    # Compute true noise variance for grid points
    true_noise_var = compute_true_noise_variance(x_grid, noise_type, func_type)
    
    # Compute NLL, CRPS, and disentanglement metrics for each region
    nll_id = compute_gaussian_nll(y_grid_clean_flat[id_mask], mu_star[id_mask], sigma2_star[id_mask])
    nll_ood = compute_gaussian_nll(y_grid_clean_flat[ood_mask], mu_star[ood_mask], sigma2_star[ood_mask])
    nll_combined = compute_gaussian_nll(y_grid_clean_flat, mu_star, sigma2_star)
    
    crps_id = compute_crps_gaussian(y_grid_clean_flat[id_mask], mu_star[id_mask], sigma2_star[id_mask])
    crps_ood = compute_crps_gaussian(y_grid_clean_flat[ood_mask], mu_star[ood_mask], sigma2_star[ood_mask])
    crps_combined = compute_crps_gaussian(y_grid_clean_flat, mu_star, sigma2_star)
    
    disentangle_id = compute_uncertainty_disentanglement(
        y_grid_clean_flat[id_mask], mu_star[id_mask],
        ale_var[id_mask] if ale_var.ndim == 1 else ale_var[id_mask].flatten(),
        epi_var[id_mask] if epi_var.ndim == 1 else epi_var[id_mask].flatten(),
        true_noise_var[id_mask]
    )
    disentangle_ood = compute_uncertainty_disentanglement(
        y_grid_clean_flat[ood_mask], mu_star[ood_mask],
        ale_var[ood_mask] if ale_var.ndim == 1 else ale_var[ood_mask].flatten(),
        epi_var[ood_mask] if epi_var.ndim == 1 else epi_var[ood_mask].flatten(),
        true_noise_var[ood_mask]
    )
    disentangle_combined = compute_uncertainty_disentanglement(
        y_grid_clean_flat, mu_star, ale_var.flatten() if ale_var.ndim > 1 else ale_var,
        epi_var.flatten() if epi_var.ndim > 1 else epi_var, true_noise_var
    )
    
    # Compute entropy-based uncertainties
    entropy_results = entropy_uncertainty_analytical(mu_samples, sigma2_samples)
    ale_entropy = entropy_results['aleatoric']
    epi_entropy = entropy_results['epistemic']
    tot_entropy = entropy_results['total']
    
    # Split entropy uncertainties by region
    uncertainties_entropy_id = {
        'ale': ale_entropy[id_mask] if ale_entropy.ndim == 1 else ale_entropy[id_mask].flatten(),
        'epi': epi_entropy[id_mask] if epi_entropy.ndim == 1 else epi_entropy[id_mask].flatten(),
        'tot': tot_entropy[id_mask] if tot_entropy.ndim == 1 else tot_entropy[id_mask].flatten()
    }
    
    uncertainties_entropy_ood = {
        'ale': ale_entropy[ood_mask] if ale_entropy.ndim == 1 else ale_entropy[ood_mask].flatten(),
        'epi': epi_entropy[ood_mask] if epi_entropy.ndim == 1 else epi_entropy[ood_mask].flatten(),
        'tot': tot_entropy[ood_mask] if tot_entropy.ndim == 1 else tot_entropy[ood_mask].flatten()
    }
    
    uncertainties_entropy_combined = {
        'ale': ale_entropy.flatten() if ale_entropy.ndim > 1 else ale_entropy,
        'epi': epi_entropy.flatten() if epi_entropy.ndim > 1 else epi_entropy,
        'tot': tot_entropy.flatten() if tot_entropy.ndim > 1 else tot_entropy
    }
    
    # Save statistics if requested
    if save_results:
        compute_and_save_statistics_ood(
            uncertainties_id, uncertainties_ood, uncertainties_combined,
            mse_id, mse_ood, mse_combined,
            function_name, noise_type, func_type, 'Deep_Ensemble',
            date=date, n_nets=K,
            nll_id=nll_id, nll_ood=nll_ood, nll_combined=nll_combined,
            crps_id=crps_id, crps_ood=crps_ood, crps_combined=crps_combined,
            spearman_aleatoric_id=disentangle_id['spearman_aleatoric'],
            spearman_aleatoric_ood=disentangle_ood['spearman_aleatoric'],
            spearman_aleatoric_combined=disentangle_combined['spearman_aleatoric'],
            spearman_epistemic_id=disentangle_id['spearman_epistemic'],
            spearman_epistemic_ood=disentangle_ood['spearman_epistemic'],
            spearman_epistemic_combined=disentangle_combined['spearman_epistemic']
        )
        
        # Compute and save normalized entropy-based statistics
        compute_and_save_statistics_entropy_ood(
            uncertainties_entropy_id, uncertainties_entropy_ood, uncertainties_entropy_combined,
            mse_id, mse_ood, mse_combined,
            function_name, noise_type, func_type, 'Deep_Ensemble',
            date=date, n_nets=K,
            nll_id=nll_id, nll_ood=nll_ood, nll_combined=nll_combined,
            crps_id=crps_id, crps_ood=crps_ood, crps_combined=crps_combined,
            spearman_aleatoric_id=disentangle_id['spearman_aleatoric'],
            spearman_aleatoric_ood=disentangle_ood['spearman_aleatoric'],
            spearman_aleatoric_combined=disentangle_combined['spearman_aleatoric'],
            spearman_epistemic_id=disentangle_id['spearman_epistemic'],
            spearman_epistemic_ood=disentangle_ood['spearman_epistemic'],
            spearman_epistemic_combined=disentangle_combined['spearman_epistemic']
        )
        
        # Plot std-based variance uncertainties
        plot_uncertainties_ood(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_var, epi_var, tot_var, ood_mask,
            title=f"Deep Ensemble (K={K}, E={epochs}) - Variance (Std)",
            noise_type=noise_type, func_type=func_type
        )
        
        # Plot normalized variance-based uncertainties
        plot_uncertainties_ood_normalized(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_var, epi_var, tot_var, ood_mask,
            title=f"Deep Ensemble (K={K}, E={epochs}) - Normalized Variance",
            noise_type=noise_type, func_type=func_type,
            scale_factor=1
        )
        
        # Plot entropy-based uncertainties (as std-equivalent bands)
        plot_uncertainties_entropy_ood(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_entropy, epi_entropy, tot_entropy, ood_mask,
            title=f"Deep Ensemble (K={K}, E={epochs}) - Entropy (Std-Equivalent)",
            noise_type=noise_type, func_type=func_type
        )
        
        # Plot normalized entropy-based uncertainties
        plot_uncertainties_entropy_ood_normalized(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_entropy, epi_entropy, tot_entropy, ood_mask,
            title=f"Deep Ensemble (K={K}, E={epochs}) - Normalized Entropy",
            noise_type=noise_type, func_type=func_type,
            scale_factor=1
        )
        
        # Plot entropy values directly as line plots (in nats)
        plot_entropy_lines_ood(
            x_train, y_train, x_grid, y_grid_clean,
            mu_pred, ale_entropy, epi_entropy, tot_entropy, ood_mask,
            title=f"Deep Ensemble (K={K}, E={epochs}) - Entropy Lines",
            noise_type=noise_type, func_type=func_type
        )
    
    return {
        'uncertainties_id': uncertainties_id,
        'uncertainties_ood': uncertainties_ood,
        'uncertainties_combined': uncertainties_combined,
        'mse_id': mse_id,
        'mse_ood': mse_ood,
        'mse_combined': mse_combined,
        'nll_id': nll_id,
        'nll_ood': nll_ood,
        'nll_combined': nll_combined,
        'crps_id': crps_id,
        'crps_ood': crps_ood,
        'crps_combined': crps_combined,
        'spearman_aleatoric_id': disentangle_id['spearman_aleatoric'],
        'spearman_aleatoric_ood': disentangle_ood['spearman_aleatoric'],
        'spearman_aleatoric_combined': disentangle_combined['spearman_aleatoric'],
        'spearman_epistemic_id': disentangle_id['spearman_epistemic'],
        'spearman_epistemic_ood': disentangle_ood['spearman_epistemic'],
        'spearman_epistemic_combined': disentangle_combined['spearman_epistemic'],
        'mu_pred': mu_pred,
        'ale_var': ale_var,
        'epi_var': epi_var,
        'tot_var': tot_var,
        'ale_entropy': ale_entropy,
        'epi_entropy': epi_entropy,
        'tot_entropy': tot_entropy,
        'uncertainties_entropy_id': uncertainties_entropy_id,
        'uncertainties_entropy_ood': uncertainties_entropy_ood,
        'uncertainties_entropy_combined': uncertainties_entropy_combined
    }


## Generate Data (Once)


In [None]:
# Generate data once (same for all parameter variations)
np.random.seed(seed)
torch.manual_seed(seed)

x_train, y_train, x_grid, y_grid_clean, ood_mask = generate_data_with_ood(
    generate_toy_regression, n_train, train_range, ood_ranges,
    grid_points, noise_type, func_type, seed
)

print(f"Training range: {train_range}")
print(f"OOD ranges: {ood_ranges}")
print(f"Grid spans: [{x_grid[0, 0]:.2f}, {x_grid[-1, 0]:.2f}]")
print(f"ID points: {np.sum(~ood_mask)}, OOD points: {np.sum(ood_mask)}")
print(f"Function type: {function_name} ({func_type})")
print(f"Noise type: {noise_type}\n")


## MC Dropout - Vary mc_samples


In [None]:
# Generate date for this experiment batch
date = datetime.now().strftime('%Y%m%d')

# Store results for comparison
results_mc_dropout = {}

print(f"\n{'='*80}")
print(f"MC Dropout Parameter Comparison - Varying mc_samples and epochs")
print(f"{'='*80}\n")

for p in dropout_p_values:
    for mc_samples in mc_samples_values:
        for epochs_val in epochs_values:
            param_key = f"p{p}_M{mc_samples}_E{epochs_val}"
            print(f"\n{'='*60}")
            print(f"Testing: p={p}, mc_samples={mc_samples}, epochs={epochs_val}")
            print(f"{'='*60}")
            
            result = run_single_mc_dropout_ood(
                generate_toy_regression, x_train, y_train, x_grid, y_grid_clean, ood_mask,
                p=p, mc_samples=mc_samples, beta=beta, epochs=epochs_val, lr=lr, batch_size=batch_size,
                seed=seed, function_name=function_name, noise_type=noise_type, func_type=func_type,
                date=date, save_results=True
            )
            
            results_mc_dropout[param_key] = result
            
            # Print summary
            print(f"  ID - Avg Ale: {np.mean(result['uncertainties_id']['ale']):.6f}, "
                  f"Avg Epi: {np.mean(result['uncertainties_id']['epi']):.6f}, "
                  f"MSE: {result['mse_id']:.6f}")
            print(f"  OOD - Avg Ale: {np.mean(result['uncertainties_ood']['ale']):.6f}, "
                  f"Avg Epi: {np.mean(result['uncertainties_ood']['epi']):.6f}, "
                  f"MSE: {result['mse_ood']:.6f}")

print(f"\n{'='*80}")
print("MC Dropout experiments completed!")
print(f"{'='*80}\n")


## MC Dropout - Comparison Plots


In [None]:
# Extract data for plotting
mc_samples_list = []
epochs_list = []
p_list = []
avg_ale_id_list = []
avg_epi_id_list = []
avg_tot_id_list = []
avg_ale_ood_list = []
avg_epi_ood_list = []
avg_tot_ood_list = []
avg_ale_entropy_id_list = []
avg_epi_entropy_id_list = []
avg_tot_entropy_id_list = []
avg_ale_entropy_ood_list = []
avg_epi_entropy_ood_list = []
avg_tot_entropy_ood_list = []
mse_id_list = []
mse_ood_list = []
nll_id_list = []
nll_ood_list = []
crps_id_list = []
crps_ood_list = []
spearman_aleatoric_id_list = []
spearman_aleatoric_ood_list = []
spearman_epistemic_id_list = []
spearman_epistemic_ood_list = []

for param_key, result in results_mc_dropout.items():
    # Extract p, mc_samples and epochs from param_key (format: "p0.2_M20_E250")
    parts = param_key.split('_')
    p_val = float(parts[0][1:])  # Extract number after 'p'
    mc_samples_val = int(parts[1][1:])  # Extract number after 'M'
    epochs_val = int(parts[2][1:])  # Extract number after 'E'
    p_list.append(p_val)
    mc_samples_list.append(mc_samples_val)
    epochs_list.append(epochs_val)
    
    avg_ale_id_list.append(np.mean(result['uncertainties_id']['ale']))
    avg_epi_id_list.append(np.mean(result['uncertainties_id']['epi']))
    avg_tot_id_list.append(np.mean(result['uncertainties_id']['tot']))
    
    avg_ale_ood_list.append(np.mean(result['uncertainties_ood']['ale']))
    avg_epi_ood_list.append(np.mean(result['uncertainties_ood']['epi']))
    avg_tot_ood_list.append(np.mean(result['uncertainties_ood']['tot']))
    
    # Extract entropy-based uncertainties
    avg_ale_entropy_id_list.append(np.mean(result['uncertainties_entropy_id']['ale']))
    avg_epi_entropy_id_list.append(np.mean(result['uncertainties_entropy_id']['epi']))
    avg_tot_entropy_id_list.append(np.mean(result['uncertainties_entropy_id']['tot']))
    
    avg_ale_entropy_ood_list.append(np.mean(result['uncertainties_entropy_ood']['ale']))
    avg_epi_entropy_ood_list.append(np.mean(result['uncertainties_entropy_ood']['epi']))
    avg_tot_entropy_ood_list.append(np.mean(result['uncertainties_entropy_ood']['tot']))
    
    mse_id_list.append(result['mse_id'])
    mse_ood_list.append(result['mse_ood'])
    nll_id_list.append(result['nll_id'])
    nll_ood_list.append(result['nll_ood'])
    crps_id_list.append(result['crps_id'])
    crps_ood_list.append(result['crps_ood'])
    spearman_aleatoric_id_list.append(result['spearman_aleatoric_id'])
    spearman_aleatoric_ood_list.append(result['spearman_aleatoric_ood'])
    spearman_epistemic_id_list.append(result['spearman_epistemic_id'])
    spearman_epistemic_ood_list.append(result['spearman_epistemic_ood'])

# Sort by mc_samples first, then by epochs, then by p
sorted_indices = np.lexsort((p_list, epochs_list, mc_samples_list))
p_list = [p_list[i] for i in sorted_indices]
mc_samples_list = [mc_samples_list[i] for i in sorted_indices]
epochs_list = [epochs_list[i] for i in sorted_indices]
avg_ale_id_list = [avg_ale_id_list[i] for i in sorted_indices]
avg_epi_id_list = [avg_epi_id_list[i] for i in sorted_indices]
avg_tot_id_list = [avg_tot_id_list[i] for i in sorted_indices]
avg_ale_ood_list = [avg_ale_ood_list[i] for i in sorted_indices]
avg_epi_ood_list = [avg_epi_ood_list[i] for i in sorted_indices]
avg_tot_ood_list = [avg_tot_ood_list[i] for i in sorted_indices]
avg_ale_entropy_id_list = [avg_ale_entropy_id_list[i] for i in sorted_indices]
avg_epi_entropy_id_list = [avg_epi_entropy_id_list[i] for i in sorted_indices]
avg_tot_entropy_id_list = [avg_tot_entropy_id_list[i] for i in sorted_indices]
avg_ale_entropy_ood_list = [avg_ale_entropy_ood_list[i] for i in sorted_indices]
avg_epi_entropy_ood_list = [avg_epi_entropy_ood_list[i] for i in sorted_indices]
avg_tot_entropy_ood_list = [avg_tot_entropy_ood_list[i] for i in sorted_indices]
mse_id_list = [mse_id_list[i] for i in sorted_indices]
mse_ood_list = [mse_ood_list[i] for i in sorted_indices]
nll_id_list = [nll_id_list[i] for i in sorted_indices]
nll_ood_list = [nll_ood_list[i] for i in sorted_indices]
crps_id_list = [crps_id_list[i] for i in sorted_indices]
crps_ood_list = [crps_ood_list[i] for i in sorted_indices]
spearman_aleatoric_id_list = [spearman_aleatoric_id_list[i] for i in sorted_indices]
spearman_aleatoric_ood_list = [spearman_aleatoric_ood_list[i] for i in sorted_indices]
spearman_epistemic_id_list = [spearman_epistemic_id_list[i] for i in sorted_indices]
spearman_epistemic_ood_list = [spearman_epistemic_ood_list[i] for i in sorted_indices]

# Filter for first epoch value and first p value for mc_samples plots (to avoid multiple points per mc_samples)
first_epoch = epochs_values[0]
first_p = dropout_p_values[0]
mask_first_epoch = [e == first_epoch for e in epochs_list]
mask_first_p = [p == first_p for p in p_list]
mask_filtered = [mask_first_epoch[i] and mask_first_p[i] for i in range(len(mask_first_epoch))]
mc_samples_filtered = [mc_samples_list[i] for i in range(len(mc_samples_list)) if mask_filtered[i]]
avg_ale_id_filtered = [avg_ale_id_list[i] for i in range(len(avg_ale_id_list)) if mask_filtered[i]]
avg_epi_id_filtered = [avg_epi_id_list[i] for i in range(len(avg_epi_id_list)) if mask_filtered[i]]
avg_tot_id_filtered = [avg_tot_id_list[i] for i in range(len(avg_tot_id_list)) if mask_filtered[i]]
avg_ale_ood_filtered = [avg_ale_ood_list[i] for i in range(len(avg_ale_ood_list)) if mask_filtered[i]]
avg_epi_ood_filtered = [avg_epi_ood_list[i] for i in range(len(avg_epi_ood_list)) if mask_filtered[i]]
avg_tot_ood_filtered = [avg_tot_ood_list[i] for i in range(len(avg_tot_ood_list)) if mask_filtered[i]]
avg_ale_entropy_id_filtered = [avg_ale_entropy_id_list[i] for i in range(len(avg_ale_entropy_id_list)) if mask_filtered[i]]
avg_epi_entropy_id_filtered = [avg_epi_entropy_id_list[i] for i in range(len(avg_epi_entropy_id_list)) if mask_filtered[i]]
avg_tot_entropy_id_filtered = [avg_tot_entropy_id_list[i] for i in range(len(avg_tot_entropy_id_list)) if mask_filtered[i]]
avg_ale_entropy_ood_filtered = [avg_ale_entropy_ood_list[i] for i in range(len(avg_ale_entropy_ood_list)) if mask_filtered[i]]
avg_epi_entropy_ood_filtered = [avg_epi_entropy_ood_list[i] for i in range(len(avg_epi_entropy_ood_list)) if mask_filtered[i]]
avg_tot_entropy_ood_filtered = [avg_tot_entropy_ood_list[i] for i in range(len(avg_tot_entropy_ood_list)) if mask_filtered[i]]
mse_id_filtered = [mse_id_list[i] for i in range(len(mse_id_list)) if mask_filtered[i]]
mse_ood_filtered = [mse_ood_list[i] for i in range(len(mse_ood_list)) if mask_filtered[i]]
nll_id_filtered = [nll_id_list[i] for i in range(len(nll_id_list)) if mask_filtered[i]]
nll_ood_filtered = [nll_ood_list[i] for i in range(len(nll_ood_list)) if mask_filtered[i]]
crps_id_filtered = [crps_id_list[i] for i in range(len(crps_id_list)) if mask_filtered[i]]
crps_ood_filtered = [crps_ood_list[i] for i in range(len(crps_ood_list)) if mask_filtered[i]]
spearman_aleatoric_id_filtered = [spearman_aleatoric_id_list[i] for i in range(len(spearman_aleatoric_id_list)) if mask_filtered[i]]
spearman_aleatoric_ood_filtered = [spearman_aleatoric_ood_list[i] for i in range(len(spearman_aleatoric_ood_list)) if mask_filtered[i]]
spearman_epistemic_id_filtered = [spearman_epistemic_id_list[i] for i in range(len(spearman_epistemic_id_list)) if mask_filtered[i]]
spearman_epistemic_ood_filtered = [spearman_epistemic_ood_list[i] for i in range(len(spearman_epistemic_ood_list)) if mask_filtered[i]]
avg_epi_id_filtered_plot = [avg_epi_id_list[i] for i in range(len(avg_epi_id_list)) if mask_filtered[i]]
avg_epi_ood_filtered_plot = [avg_epi_ood_list[i] for i in range(len(avg_epi_ood_list)) if mask_filtered[i]]

# Create comparison plots
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Average Uncertainties - ID region
axes[0, 0].plot(mc_samples_filtered, avg_ale_id_filtered, 'o-', label='Aleatoric (ID)', color='green', linewidth=2, markersize=8)
axes[0, 0].plot(mc_samples_filtered, avg_epi_id_filtered, 's-', label='Epistemic (ID)', color='orange', linewidth=2, markersize=8)
axes[0, 0].plot(mc_samples_filtered, avg_tot_id_filtered, '^-', label='Total (ID)', color='blue', linewidth=2, markersize=8)
axes[0, 0].set_xlabel('MC Samples', fontsize=12)
axes[0, 0].set_ylabel('Average Uncertainty', fontsize=12)
axes[0, 0].set_title(f'MC Dropout: Average Uncertainties (ID) vs MC Samples\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_xticks(mc_samples_filtered)

# Plot 2: Average Entropy Uncertainties - ID region
axes[0, 1].plot(mc_samples_filtered, avg_ale_entropy_id_filtered, 'o-', label='Aleatoric Entropy (ID)', color='green', linewidth=2, markersize=8)
axes[0, 1].plot(mc_samples_filtered, avg_epi_entropy_id_filtered, 's-', label='Epistemic Entropy (ID)', color='orange', linewidth=2, markersize=8)
axes[0, 1].plot(mc_samples_filtered, avg_tot_entropy_id_filtered, '^-', label='Total Entropy (ID)', color='blue', linewidth=2, markersize=8)
axes[0, 1].set_xlabel('MC Samples', fontsize=12)
axes[0, 1].set_ylabel('Average Entropy Uncertainty', fontsize=12)
axes[0, 1].set_title(f'MC Dropout: Average Entropy Uncertainties (ID) vs MC Samples\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[0, 1].legend(fontsize=10)
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_xticks(mc_samples_filtered)

# Plot 3: MSE comparison
axes[1, 0].plot(mc_samples_filtered, mse_id_filtered, 'o-', label='MSE (ID)', color='blue', linewidth=2, markersize=8)
axes[1, 0].plot(mc_samples_filtered, mse_ood_filtered, 's-', label='MSE (OOD)', color='red', linewidth=2, markersize=8)
axes[1, 0].set_xlabel('MC Samples', fontsize=12)
axes[1, 0].set_ylabel('MSE', fontsize=12)
axes[1, 0].set_title(f'MC Dropout: MSE vs MC Samples\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[1, 0].legend(fontsize=10)
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_yscale('log')
axes[1, 0].set_xticks(mc_samples_filtered)

# Plot 4: ID vs OOD comparison (bar chart for one parameter)
x_pos = np.arange(len(mc_samples_filtered))
width = 0.35
axes[1, 1].bar(x_pos - width/2, avg_epi_id_filtered_plot, width, label='Epistemic (ID)', color='orange', alpha=0.7)
axes[1, 1].bar(x_pos + width/2, avg_epi_ood_filtered_plot, width, label='Epistemic (OOD)', color='red', alpha=0.7)
axes[1, 1].set_xlabel('MC Samples', fontsize=12)
axes[1, 1].set_ylabel('Average Epistemic Uncertainty', fontsize=12)
axes[1, 1].set_title(f'MC Dropout: Epistemic Uncertainty - ID vs OOD\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(mc_samples_filtered)
axes[1, 1].legend(fontsize=10)
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.suptitle(f'MC Dropout Parameter Comparison: Varying MC Samples (p={dropout_p_values[0]}, epochs={epochs_values[0]})', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()

# Save plot
save_plot(fig, f"MC_Dropout_mc_samples_comparison_{function_name}_{noise_type}", 
          subfolder=f"comparisons/{noise_type}/{func_type}")
plt.show()
plt.close(fig)

# Create plots showing epochs variation (for fixed mc_samples)
# Group by mc_samples and plot epochs variation
unique_mc_samples = sorted(set(mc_samples_list))
fig_epochs, axes_epochs = plt.subplots(2, 2, figsize=(16, 12))

for mc_val in unique_mc_samples:
    # Filter data for this mc_samples value and first p value
    mask = [ms == mc_val and p == first_p for ms, p in zip(mc_samples_list, p_list)]
    epochs_subset = [epochs_list[i] for i in range(len(epochs_list)) if mask[i]]
    ale_id_subset = [avg_ale_id_list[i] for i in range(len(avg_ale_id_list)) if mask[i]]
    epi_id_subset = [avg_epi_id_list[i] for i in range(len(avg_epi_id_list)) if mask[i]]
    tot_id_subset = [avg_tot_id_list[i] for i in range(len(avg_tot_id_list)) if mask[i]]
    ale_entropy_id_subset = [avg_ale_entropy_id_list[i] for i in range(len(avg_ale_entropy_id_list)) if mask[i]]
    epi_entropy_id_subset = [avg_epi_entropy_id_list[i] for i in range(len(avg_epi_entropy_id_list)) if mask[i]]
    tot_entropy_id_subset = [avg_tot_entropy_id_list[i] for i in range(len(avg_tot_entropy_id_list)) if mask[i]]
    ale_ood_subset = [avg_ale_ood_list[i] for i in range(len(avg_ale_ood_list)) if mask[i]]
    epi_ood_subset = [avg_epi_ood_list[i] for i in range(len(avg_epi_ood_list)) if mask[i]]
    tot_ood_subset = [avg_tot_ood_list[i] for i in range(len(avg_tot_ood_list)) if mask[i]]
    mse_id_subset = [mse_id_list[i] for i in range(len(mse_id_list)) if mask[i]]
    mse_ood_subset = [mse_ood_list[i] for i in range(len(mse_ood_list)) if mask[i]]
    
    # Sort by epochs
    sorted_epochs_idx = np.argsort(epochs_subset)
    epochs_subset = [epochs_subset[i] for i in sorted_epochs_idx]
    ale_id_subset = [ale_id_subset[i] for i in sorted_epochs_idx]
    epi_id_subset = [epi_id_subset[i] for i in sorted_epochs_idx]
    tot_id_subset = [tot_id_subset[i] for i in sorted_epochs_idx]
    ale_entropy_id_subset = [ale_entropy_id_subset[i] for i in sorted_epochs_idx]
    epi_entropy_id_subset = [epi_entropy_id_subset[i] for i in sorted_epochs_idx]
    tot_entropy_id_subset = [tot_entropy_id_subset[i] for i in sorted_epochs_idx]
    ale_ood_subset = [ale_ood_subset[i] for i in sorted_epochs_idx]
    epi_ood_subset = [epi_ood_subset[i] for i in sorted_epochs_idx]
    tot_ood_subset = [tot_ood_subset[i] for i in sorted_epochs_idx]
    mse_id_subset = [mse_id_subset[i] for i in sorted_epochs_idx]
    mse_ood_subset = [mse_ood_subset[i] for i in sorted_epochs_idx]
    
    # Plot ID uncertainties vs epochs
    axes_epochs[0, 0].plot(epochs_subset, ale_id_subset, 'o-', label=f'Aleatoric (M={mc_val})', linewidth=2, markersize=6)
    axes_epochs[0, 0].plot(epochs_subset, epi_id_subset, 's-', label=f'Epistemic (M={mc_val})', linewidth=2, markersize=6)
    
    # Plot ID entropy uncertainties vs epochs
    axes_epochs[0, 1].plot(epochs_subset, ale_entropy_id_subset, 'o-', label=f'Aleatoric Entropy (M={mc_val})', linewidth=2, markersize=6)
    axes_epochs[0, 1].plot(epochs_subset, epi_entropy_id_subset, 's-', label=f'Epistemic Entropy (M={mc_val})', linewidth=2, markersize=6)
    
    # Plot MSE vs epochs
    axes_epochs[1, 0].plot(epochs_subset, mse_id_subset, 'o-', label=f'MSE ID (M={mc_val})', linewidth=2, markersize=6)
    axes_epochs[1, 0].plot(epochs_subset, mse_ood_subset, 's-', label=f'MSE OOD (M={mc_val})', linewidth=2, markersize=6)

axes_epochs[0, 0].set_xlabel('Epochs', fontsize=12)
axes_epochs[0, 0].set_ylabel('Average Uncertainty', fontsize=12)
axes_epochs[0, 0].set_title(f'MC Dropout: Average Uncertainties (ID) vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_epochs[0, 0].legend(fontsize=9, ncol=2)
axes_epochs[0, 0].grid(True, alpha=0.3)

axes_epochs[0, 1].set_xlabel('Epochs', fontsize=12)
axes_epochs[0, 1].set_ylabel('Average Entropy Uncertainty', fontsize=12)
axes_epochs[0, 1].set_title(f'MC Dropout: Average Entropy Uncertainties (ID) vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_epochs[0, 1].legend(fontsize=9, ncol=2)
axes_epochs[0, 1].grid(True, alpha=0.3)

axes_epochs[1, 0].set_xlabel('Epochs', fontsize=12)
axes_epochs[1, 0].set_ylabel('MSE', fontsize=12)
axes_epochs[1, 0].set_title(f'MC Dropout: MSE vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_epochs[1, 0].legend(fontsize=9, ncol=2)
axes_epochs[1, 0].grid(True, alpha=0.3)
axes_epochs[1, 0].set_yscale('log')

# Plot comparison of epistemic uncertainty ID vs OOD for different epochs
for mc_val in unique_mc_samples:
    mask = [ms == mc_val for ms in mc_samples_list]
    epochs_subset = [epochs_list[i] for i in range(len(epochs_list)) if mask[i]]
    epi_id_subset = [avg_epi_id_list[i] for i in range(len(avg_epi_id_list)) if mask[i]]
    epi_ood_subset = [avg_epi_ood_list[i] for i in range(len(avg_epi_ood_list)) if mask[i]]
    sorted_epochs_idx = np.argsort(epochs_subset)
    epochs_subset = [epochs_subset[i] for i in sorted_epochs_idx]
    epi_id_subset = [epi_id_subset[i] for i in sorted_epochs_idx]
    epi_ood_subset = [epi_ood_subset[i] for i in sorted_epochs_idx]
    axes_epochs[1, 1].plot(epochs_subset, epi_id_subset, 'o-', label=f'Epistemic ID (M={mc_val})', linewidth=2, markersize=6)
    axes_epochs[1, 1].plot(epochs_subset, epi_ood_subset, 's--', label=f'Epistemic OOD (M={mc_val})', linewidth=2, markersize=6)

axes_epochs[1, 1].set_xlabel('Epochs', fontsize=12)
axes_epochs[1, 1].set_ylabel('Average Epistemic Uncertainty', fontsize=12)
axes_epochs[1, 1].set_title(f'MC Dropout: Epistemic Uncertainty - ID vs OOD\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_epochs[1, 1].legend(fontsize=9, ncol=2)
axes_epochs[1, 1].grid(True, alpha=0.3)

plt.suptitle(f'MC Dropout Parameter Comparison: Varying Epochs (p={dropout_p_values[0]})', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()

# Save plot
save_plot(fig_epochs, f"MC_Dropout_epochs_comparison_{function_name}_{noise_type}", 
          subfolder=f"comparisons/{noise_type}/{func_type}")
plt.show()
plt.close(fig_epochs)

# Create metrics evolution plots (NLL, CRPS, Spearman) vs mc_samples
fig_metrics, axes_metrics = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: NLL vs mc_samples
axes_metrics[0, 0].plot(mc_samples_filtered, nll_id_filtered, 'o-', label='NLL (ID)', color='blue', linewidth=2, markersize=8)
axes_metrics[0, 0].plot(mc_samples_filtered, nll_ood_filtered, 's-', label='NLL (OOD)', color='red', linewidth=2, markersize=8)
axes_metrics[0, 0].set_xlabel('MC Samples', fontsize=12)
axes_metrics[0, 0].set_ylabel('NLL', fontsize=12)
axes_metrics[0, 0].set_title(f'MC Dropout: NLL vs MC Samples\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics[0, 0].legend(fontsize=10)
axes_metrics[0, 0].grid(True, alpha=0.3)
axes_metrics[0, 0].set_xticks(mc_samples_filtered)

# Plot 2: CRPS vs mc_samples
axes_metrics[0, 1].plot(mc_samples_filtered, crps_id_filtered, 'o-', label='CRPS (ID)', color='blue', linewidth=2, markersize=8)
axes_metrics[0, 1].plot(mc_samples_filtered, crps_ood_filtered, 's-', label='CRPS (OOD)', color='red', linewidth=2, markersize=8)
axes_metrics[0, 1].set_xlabel('MC Samples', fontsize=12)
axes_metrics[0, 1].set_ylabel('CRPS', fontsize=12)
axes_metrics[0, 1].set_title(f'MC Dropout: CRPS vs MC Samples\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics[0, 1].legend(fontsize=10)
axes_metrics[0, 1].grid(True, alpha=0.3)
axes_metrics[0, 1].set_xticks(mc_samples_filtered)

# Plot 3: Spearman Aleatoric vs mc_samples
axes_metrics[1, 0].plot(mc_samples_filtered, spearman_aleatoric_id_filtered, 'o-', label='Spearman Aleatoric (ID)', color='green', linewidth=2, markersize=8)
axes_metrics[1, 0].plot(mc_samples_filtered, spearman_aleatoric_ood_filtered, 's-', label='Spearman Aleatoric (OOD)', color='darkgreen', linewidth=2, markersize=8)
axes_metrics[1, 0].set_xlabel('MC Samples', fontsize=12)
axes_metrics[1, 0].set_ylabel('Spearman Correlation', fontsize=12)
axes_metrics[1, 0].set_title(f'MC Dropout: Spearman Aleatoric vs MC Samples\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics[1, 0].legend(fontsize=10)
axes_metrics[1, 0].grid(True, alpha=0.3)
axes_metrics[1, 0].set_xticks(mc_samples_filtered)

# Plot 4: Spearman Epistemic vs mc_samples
axes_metrics[1, 1].plot(mc_samples_filtered, spearman_epistemic_id_filtered, 'o-', label='Spearman Epistemic (ID)', color='orange', linewidth=2, markersize=8)
axes_metrics[1, 1].plot(mc_samples_filtered, spearman_epistemic_ood_filtered, 's-', label='Spearman Epistemic (OOD)', color='red', linewidth=2, markersize=8)
axes_metrics[1, 1].set_xlabel('MC Samples', fontsize=12)
axes_metrics[1, 1].set_ylabel('Spearman Correlation', fontsize=12)
axes_metrics[1, 1].set_title(f'MC Dropout: Spearman Epistemic vs MC Samples\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics[1, 1].legend(fontsize=10)
axes_metrics[1, 1].grid(True, alpha=0.3)
axes_metrics[1, 1].set_xticks(mc_samples_filtered)

plt.suptitle(f'MC Dropout Parameter Comparison: Metrics vs MC Samples (p={dropout_p_values[0]}, epochs={epochs_values[0]})', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()

# Save plot
save_plot(fig_metrics, f"MC_Dropout_metrics_mc_samples_comparison_{function_name}_{noise_type}", 
          subfolder=f"comparisons/{noise_type}/{func_type}")
plt.show()
plt.close(fig_metrics)

# Create metrics evolution plots vs epochs (grouped by mc_samples)
fig_metrics_epochs, axes_metrics_epochs = plt.subplots(2, 2, figsize=(16, 12))

for mc_val in unique_mc_samples:
    mask = [ms == mc_val for ms in mc_samples_list]
    epochs_subset = [epochs_list[i] for i in range(len(epochs_list)) if mask[i]]
    nll_id_subset = [nll_id_list[i] for i in range(len(nll_id_list)) if mask[i]]
    nll_ood_subset = [nll_ood_list[i] for i in range(len(nll_ood_list)) if mask[i]]
    crps_id_subset = [crps_id_list[i] for i in range(len(crps_id_list)) if mask[i]]
    crps_ood_subset = [crps_ood_list[i] for i in range(len(crps_ood_list)) if mask[i]]
    spear_ale_id_subset = [spearman_aleatoric_id_list[i] for i in range(len(spearman_aleatoric_id_list)) if mask[i]]
    spear_ale_ood_subset = [spearman_aleatoric_ood_list[i] for i in range(len(spearman_aleatoric_ood_list)) if mask[i]]
    spear_epi_id_subset = [spearman_epistemic_id_list[i] for i in range(len(spearman_epistemic_id_list)) if mask[i]]
    spear_epi_ood_subset = [spearman_epistemic_ood_list[i] for i in range(len(spearman_epistemic_ood_list)) if mask[i]]
    
    # Sort by epochs
    sorted_epochs_idx = np.argsort(epochs_subset)
    epochs_subset = [epochs_subset[i] for i in sorted_epochs_idx]
    nll_id_subset = [nll_id_subset[i] for i in sorted_epochs_idx]
    nll_ood_subset = [nll_ood_subset[i] for i in sorted_epochs_idx]
    crps_id_subset = [crps_id_subset[i] for i in sorted_epochs_idx]
    crps_ood_subset = [crps_ood_subset[i] for i in sorted_epochs_idx]
    spear_ale_id_subset = [spear_ale_id_subset[i] for i in sorted_epochs_idx]
    spear_ale_ood_subset = [spear_ale_ood_subset[i] for i in sorted_epochs_idx]
    spear_epi_id_subset = [spear_epi_id_subset[i] for i in sorted_epochs_idx]
    spear_epi_ood_subset = [spear_epi_ood_subset[i] for i in sorted_epochs_idx]
    
    # Plot NLL vs epochs
    axes_metrics_epochs[0, 0].plot(epochs_subset, nll_id_subset, 'o-', label=f'NLL ID (M={mc_val})', linewidth=2, markersize=6)
    axes_metrics_epochs[0, 0].plot(epochs_subset, nll_ood_subset, 's-', label=f'NLL OOD (M={mc_val})', linewidth=2, markersize=6)
    
    # Plot CRPS vs epochs
    axes_metrics_epochs[0, 1].plot(epochs_subset, crps_id_subset, 'o-', label=f'CRPS ID (M={mc_val})', linewidth=2, markersize=6)
    axes_metrics_epochs[0, 1].plot(epochs_subset, crps_ood_subset, 's-', label=f'CRPS OOD (M={mc_val})', linewidth=2, markersize=6)
    
    # Plot Spearman Aleatoric vs epochs
    axes_metrics_epochs[1, 0].plot(epochs_subset, spear_ale_id_subset, 'o-', label=f'Spear Ale ID (M={mc_val})', linewidth=2, markersize=6)
    axes_metrics_epochs[1, 0].plot(epochs_subset, spear_ale_ood_subset, 's-', label=f'Spear Ale OOD (M={mc_val})', linewidth=2, markersize=6)
    
    # Plot Spearman Epistemic vs epochs
    axes_metrics_epochs[1, 1].plot(epochs_subset, spear_epi_id_subset, 'o-', label=f'Spear Epi ID (M={mc_val})', linewidth=2, markersize=6)
    axes_metrics_epochs[1, 1].plot(epochs_subset, spear_epi_ood_subset, 's-', label=f'Spear Epi OOD (M={mc_val})', linewidth=2, markersize=6)

axes_metrics_epochs[0, 0].set_xlabel('Epochs', fontsize=12)
axes_metrics_epochs[0, 0].set_ylabel('NLL', fontsize=12)
axes_metrics_epochs[0, 0].set_title(f'MC Dropout: NLL vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_epochs[0, 0].legend(fontsize=9, ncol=2)
axes_metrics_epochs[0, 0].grid(True, alpha=0.3)

axes_metrics_epochs[0, 1].set_xlabel('Epochs', fontsize=12)
axes_metrics_epochs[0, 1].set_ylabel('CRPS', fontsize=12)
axes_metrics_epochs[0, 1].set_title(f'MC Dropout: CRPS vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_epochs[0, 1].legend(fontsize=9, ncol=2)
axes_metrics_epochs[0, 1].grid(True, alpha=0.3)

axes_metrics_epochs[1, 0].set_xlabel('Epochs', fontsize=12)
axes_metrics_epochs[1, 0].set_ylabel('Spearman Correlation', fontsize=12)
axes_metrics_epochs[1, 0].set_title(f'MC Dropout: Spearman Aleatoric vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_epochs[1, 0].legend(fontsize=9, ncol=2)
axes_metrics_epochs[1, 0].grid(True, alpha=0.3)

axes_metrics_epochs[1, 1].set_xlabel('Epochs', fontsize=12)
axes_metrics_epochs[1, 1].set_ylabel('Spearman Correlation', fontsize=12)
axes_metrics_epochs[1, 1].set_title(f'MC Dropout: Spearman Epistemic vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_epochs[1, 1].legend(fontsize=9, ncol=2)
axes_metrics_epochs[1, 1].grid(True, alpha=0.3)

plt.suptitle(f'MC Dropout Parameter Comparison: Metrics vs Epochs (p={dropout_p_values[0]})', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()

# Save plot
save_plot(fig_metrics_epochs, f"MC_Dropout_metrics_epochs_comparison_{function_name}_{noise_type}", 
          subfolder=f"comparisons/{noise_type}/{func_type}")
plt.show()
plt.close(fig_metrics_epochs)

# Create summary table (for first epoch value and first p value to keep it manageable)
# Use already filtered data
comparison_df = pd.DataFrame({
    'MC_Samples': [mc_samples_list[i] for i in range(len(mc_samples_list)) if mask_filtered[i]],
    'Epochs': [epochs_list[i] for i in range(len(epochs_list)) if mask_filtered[i]],
    'Dropout_p': [p_list[i] for i in range(len(p_list)) if mask_filtered[i]],
    'Avg_Ale_ID': [avg_ale_id_list[i] for i in range(len(avg_ale_id_list)) if mask_filtered[i]],
    'Avg_Epi_ID': [avg_epi_id_list[i] for i in range(len(avg_epi_id_list)) if mask_filtered[i]],
    'Avg_Tot_ID': [avg_tot_id_list[i] for i in range(len(avg_tot_id_list)) if mask_filtered[i]],
    'Avg_Ale_OOD': [avg_ale_ood_list[i] for i in range(len(avg_ale_ood_list)) if mask_filtered[i]],
    'Avg_Epi_OOD': [avg_epi_ood_list[i] for i in range(len(avg_epi_ood_list)) if mask_filtered[i]],
    'Avg_Tot_OOD': [avg_tot_ood_list[i] for i in range(len(avg_tot_ood_list)) if mask_filtered[i]],
    'Avg_Ale_Entropy_ID': [avg_ale_entropy_id_list[i] for i in range(len(avg_ale_entropy_id_list)) if mask_filtered[i]],
    'Avg_Epi_Entropy_ID': [avg_epi_entropy_id_list[i] for i in range(len(avg_epi_entropy_id_list)) if mask_filtered[i]],
    'Avg_Tot_Entropy_ID': [avg_tot_entropy_id_list[i] for i in range(len(avg_tot_entropy_id_list)) if mask_filtered[i]],
    'Avg_Ale_Entropy_OOD': [avg_ale_entropy_ood_list[i] for i in range(len(avg_ale_entropy_ood_list)) if mask_filtered[i]],
    'Avg_Epi_Entropy_OOD': [avg_epi_entropy_ood_list[i] for i in range(len(avg_epi_entropy_ood_list)) if mask_filtered[i]],
    'Avg_Tot_Entropy_OOD': [avg_tot_entropy_ood_list[i] for i in range(len(avg_tot_entropy_ood_list)) if mask_filtered[i]],
    'MSE_ID': [mse_id_list[i] for i in range(len(mse_id_list)) if mask_filtered[i]],
    'MSE_OOD': [mse_ood_list[i] for i in range(len(mse_ood_list)) if mask_filtered[i]],
    'NLL_ID': [nll_id_list[i] for i in range(len(nll_id_list)) if mask_filtered[i]],
    'NLL_OOD': [nll_ood_list[i] for i in range(len(nll_ood_list)) if mask_filtered[i]],
    'CRPS_ID': [crps_id_list[i] for i in range(len(crps_id_list)) if mask_filtered[i]],
    'CRPS_OOD': [crps_ood_list[i] for i in range(len(crps_ood_list)) if mask_filtered[i]],
    'Spear_Ale_ID': [spearman_aleatoric_id_list[i] for i in range(len(spearman_aleatoric_id_list)) if mask_filtered[i]],
    'Spear_Ale_OOD': [spearman_aleatoric_ood_list[i] for i in range(len(spearman_aleatoric_ood_list)) if mask_filtered[i]],
    'Spear_Epi_ID': [spearman_epistemic_id_list[i] for i in range(len(spearman_epistemic_id_list)) if mask_filtered[i]],
    'Spear_Epi_OOD': [spearman_epistemic_ood_list[i] for i in range(len(spearman_epistemic_ood_list)) if mask_filtered[i]]
})

print("\nSummary Table:")
print(comparison_df.to_string(index=False))

# Save comparison table
save_statistics(comparison_df, f"MC_Dropout_mc_samples_comparison_{function_name}_{noise_type}",
                subfolder=f"comparisons/{noise_type}/{func_type}")


## Deep Ensemble - Vary K


In [None]:
# Store results for comparison
results_deep_ensemble = {}

print(f"\n{'='*80}")
print(f"Deep Ensemble Parameter Comparison - Varying K and epochs")
print(f"{'='*80}\n")

for K in K_values:
    for epochs_val in epochs_values:
        param_key = f"K{K}_E{epochs_val}"
        print(f"\n{'='*60}")
        print(f"Testing: K={K}, epochs={epochs_val}")
        print(f"{'='*60}")
        
        result = run_single_deep_ensemble_ood(
            generate_toy_regression, x_train, y_train, x_grid, y_grid_clean, ood_mask,
            K=K, beta=beta, batch_size=batch_size, epochs=epochs_val, seed=seed,
            function_name=function_name, noise_type=noise_type, func_type=func_type,
            date=date, save_results=True
        )
        
        results_deep_ensemble[param_key] = result
        
        # Print summary
        print(f"  ID - Avg Ale: {np.mean(result['uncertainties_id']['ale']):.6f}, "
              f"Avg Epi: {np.mean(result['uncertainties_id']['epi']):.6f}, "
              f"MSE: {result['mse_id']:.6f}")
        print(f"  OOD - Avg Ale: {np.mean(result['uncertainties_ood']['ale']):.6f}, "
              f"Avg Epi: {np.mean(result['uncertainties_ood']['epi']):.6f}, "
              f"MSE: {result['mse_ood']:.6f}")

print(f"\n{'='*80}")
print("Deep Ensemble experiments completed!")
print(f"{'='*80}\n")


## Deep Ensemble - Comparison Plots


In [None]:
# Extract data for plotting
K_list = []
epochs_list_de = []
avg_ale_id_list_de = []
avg_epi_id_list_de = []
avg_tot_id_list_de = []
avg_ale_ood_list_de = []
avg_epi_ood_list_de = []
avg_tot_ood_list_de = []
avg_ale_entropy_id_list_de = []
avg_epi_entropy_id_list_de = []
avg_tot_entropy_id_list_de = []
avg_ale_entropy_ood_list_de = []
avg_epi_entropy_ood_list_de = []
avg_tot_entropy_ood_list_de = []
mse_id_list_de = []
mse_ood_list_de = []
nll_id_list_de = []
nll_ood_list_de = []
crps_id_list_de = []
crps_ood_list_de = []
spearman_aleatoric_id_list_de = []
spearman_aleatoric_ood_list_de = []
spearman_epistemic_id_list_de = []
spearman_epistemic_ood_list_de = []

for param_key, result in results_deep_ensemble.items():
    # Extract K and epochs from param_key (format: "K5_E250")
    parts = param_key.split('_')
    K_val = int(parts[0][1:])  # Extract number after 'K'
    epochs_val = int(parts[1][1:])  # Extract number after 'E'
    K_list.append(K_val)
    epochs_list_de.append(epochs_val)
    
    avg_ale_id_list_de.append(np.mean(result['uncertainties_id']['ale']))
    avg_epi_id_list_de.append(np.mean(result['uncertainties_id']['epi']))
    avg_tot_id_list_de.append(np.mean(result['uncertainties_id']['tot']))
    
    avg_ale_ood_list_de.append(np.mean(result['uncertainties_ood']['ale']))
    avg_epi_ood_list_de.append(np.mean(result['uncertainties_ood']['epi']))
    avg_tot_ood_list_de.append(np.mean(result['uncertainties_ood']['tot']))
    
    # Extract entropy-based uncertainties
    avg_ale_entropy_id_list_de.append(np.mean(result['uncertainties_entropy_id']['ale']))
    avg_epi_entropy_id_list_de.append(np.mean(result['uncertainties_entropy_id']['epi']))
    avg_tot_entropy_id_list_de.append(np.mean(result['uncertainties_entropy_id']['tot']))
    
    avg_ale_entropy_ood_list_de.append(np.mean(result['uncertainties_entropy_ood']['ale']))
    avg_epi_entropy_ood_list_de.append(np.mean(result['uncertainties_entropy_ood']['epi']))
    avg_tot_entropy_ood_list_de.append(np.mean(result['uncertainties_entropy_ood']['tot']))
    
    mse_id_list_de.append(result['mse_id'])
    mse_ood_list_de.append(result['mse_ood'])
    nll_id_list_de.append(result['nll_id'])
    nll_ood_list_de.append(result['nll_ood'])
    crps_id_list_de.append(result['crps_id'])
    crps_ood_list_de.append(result['crps_ood'])
    spearman_aleatoric_id_list_de.append(result['spearman_aleatoric_id'])
    spearman_aleatoric_ood_list_de.append(result['spearman_aleatoric_ood'])
    spearman_epistemic_id_list_de.append(result['spearman_epistemic_id'])
    spearman_epistemic_ood_list_de.append(result['spearman_epistemic_ood'])

# Sort by K first, then by epochs
sorted_indices = np.lexsort((epochs_list_de, K_list))
K_list = [K_list[i] for i in sorted_indices]
epochs_list_de = [epochs_list_de[i] for i in sorted_indices]
avg_ale_id_list_de = [avg_ale_id_list_de[i] for i in sorted_indices]
avg_epi_id_list_de = [avg_epi_id_list_de[i] for i in sorted_indices]
avg_tot_id_list_de = [avg_tot_id_list_de[i] for i in sorted_indices]
avg_ale_ood_list_de = [avg_ale_ood_list_de[i] for i in sorted_indices]
avg_epi_ood_list_de = [avg_epi_ood_list_de[i] for i in sorted_indices]
avg_tot_ood_list_de = [avg_tot_ood_list_de[i] for i in sorted_indices]
avg_ale_entropy_id_list_de = [avg_ale_entropy_id_list_de[i] for i in sorted_indices]
avg_epi_entropy_id_list_de = [avg_epi_entropy_id_list_de[i] for i in sorted_indices]
avg_tot_entropy_id_list_de = [avg_tot_entropy_id_list_de[i] for i in sorted_indices]
avg_ale_entropy_ood_list_de = [avg_ale_entropy_ood_list_de[i] for i in sorted_indices]
avg_epi_entropy_ood_list_de = [avg_epi_entropy_ood_list_de[i] for i in sorted_indices]
avg_tot_entropy_ood_list_de = [avg_tot_entropy_ood_list_de[i] for i in sorted_indices]
mse_id_list_de = [mse_id_list_de[i] for i in sorted_indices]
mse_ood_list_de = [mse_ood_list_de[i] for i in sorted_indices]
nll_id_list_de = [nll_id_list_de[i] for i in sorted_indices]
nll_ood_list_de = [nll_ood_list_de[i] for i in sorted_indices]
crps_id_list_de = [crps_id_list_de[i] for i in sorted_indices]
crps_ood_list_de = [crps_ood_list_de[i] for i in sorted_indices]
spearman_aleatoric_id_list_de = [spearman_aleatoric_id_list_de[i] for i in sorted_indices]
spearman_aleatoric_ood_list_de = [spearman_aleatoric_ood_list_de[i] for i in sorted_indices]
spearman_epistemic_id_list_de = [spearman_epistemic_id_list_de[i] for i in sorted_indices]
spearman_epistemic_ood_list_de = [spearman_epistemic_ood_list_de[i] for i in sorted_indices]

# Filter for first epoch value for K plots (to avoid multiple points per K)
first_epoch = epochs_values[0]
mask_first_epoch_de = [e == first_epoch for e in epochs_list_de]
K_filtered = [K_list[i] for i in range(len(K_list)) if mask_first_epoch_de[i]]
avg_ale_id_filtered_de = [avg_ale_id_list_de[i] for i in range(len(avg_ale_id_list_de)) if mask_first_epoch_de[i]]
avg_epi_id_filtered_de = [avg_epi_id_list_de[i] for i in range(len(avg_epi_id_list_de)) if mask_first_epoch_de[i]]
avg_tot_id_filtered_de = [avg_tot_id_list_de[i] for i in range(len(avg_tot_id_list_de)) if mask_first_epoch_de[i]]
avg_ale_ood_filtered_de = [avg_ale_ood_list_de[i] for i in range(len(avg_ale_ood_list_de)) if mask_first_epoch_de[i]]
avg_epi_ood_filtered_de = [avg_epi_ood_list_de[i] for i in range(len(avg_epi_ood_list_de)) if mask_first_epoch_de[i]]
avg_tot_ood_filtered_de = [avg_tot_ood_list_de[i] for i in range(len(avg_tot_ood_list_de)) if mask_first_epoch_de[i]]
avg_ale_entropy_id_filtered_de = [avg_ale_entropy_id_list_de[i] for i in range(len(avg_ale_entropy_id_list_de)) if mask_first_epoch_de[i]]
avg_epi_entropy_id_filtered_de = [avg_epi_entropy_id_list_de[i] for i in range(len(avg_epi_entropy_id_list_de)) if mask_first_epoch_de[i]]
avg_tot_entropy_id_filtered_de = [avg_tot_entropy_id_list_de[i] for i in range(len(avg_tot_entropy_id_list_de)) if mask_first_epoch_de[i]]
avg_ale_entropy_ood_filtered_de = [avg_ale_entropy_ood_list_de[i] for i in range(len(avg_ale_entropy_ood_list_de)) if mask_first_epoch_de[i]]
avg_epi_entropy_ood_filtered_de = [avg_epi_entropy_ood_list_de[i] for i in range(len(avg_epi_entropy_ood_list_de)) if mask_first_epoch_de[i]]
avg_tot_entropy_ood_filtered_de = [avg_tot_entropy_ood_list_de[i] for i in range(len(avg_tot_entropy_ood_list_de)) if mask_first_epoch_de[i]]
mse_id_filtered_de = [mse_id_list_de[i] for i in range(len(mse_id_list_de)) if mask_first_epoch_de[i]]
mse_ood_filtered_de = [mse_ood_list_de[i] for i in range(len(mse_ood_list_de)) if mask_first_epoch_de[i]]
nll_id_filtered_de = [nll_id_list_de[i] for i in range(len(nll_id_list_de)) if mask_first_epoch_de[i]]
nll_ood_filtered_de = [nll_ood_list_de[i] for i in range(len(nll_ood_list_de)) if mask_first_epoch_de[i]]
crps_id_filtered_de = [crps_id_list_de[i] for i in range(len(crps_id_list_de)) if mask_first_epoch_de[i]]
crps_ood_filtered_de = [crps_ood_list_de[i] for i in range(len(crps_ood_list_de)) if mask_first_epoch_de[i]]
spearman_aleatoric_id_filtered_de = [spearman_aleatoric_id_list_de[i] for i in range(len(spearman_aleatoric_id_list_de)) if mask_first_epoch_de[i]]
spearman_aleatoric_ood_filtered_de = [spearman_aleatoric_ood_list_de[i] for i in range(len(spearman_aleatoric_ood_list_de)) if mask_first_epoch_de[i]]
spearman_epistemic_id_filtered_de = [spearman_epistemic_id_list_de[i] for i in range(len(spearman_epistemic_id_list_de)) if mask_first_epoch_de[i]]
spearman_epistemic_ood_filtered_de = [spearman_epistemic_ood_list_de[i] for i in range(len(spearman_epistemic_ood_list_de)) if mask_first_epoch_de[i]]
avg_epi_id_filtered_plot_de = [avg_epi_id_list_de[i] for i in range(len(avg_epi_id_list_de)) if mask_first_epoch_de[i]]
avg_epi_ood_filtered_plot_de = [avg_epi_ood_list_de[i] for i in range(len(avg_epi_ood_list_de)) if mask_first_epoch_de[i]]

# Create comparison plots
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Average Uncertainties - ID region
axes[0, 0].plot(K_filtered, avg_ale_id_filtered_de, 'o-', label='Aleatoric (ID)', color='green', linewidth=2, markersize=8)
axes[0, 0].plot(K_filtered, avg_epi_id_filtered_de, 's-', label='Epistemic (ID)', color='orange', linewidth=2, markersize=8)
axes[0, 0].plot(K_filtered, avg_tot_id_filtered_de, '^-', label='Total (ID)', color='blue', linewidth=2, markersize=8)
axes[0, 0].set_xlabel('Number of Nets (K)', fontsize=12)
axes[0, 0].set_ylabel('Average Uncertainty', fontsize=12)
axes[0, 0].set_title(f'Deep Ensemble: Average Uncertainties (ID) vs K\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_xticks(K_filtered)

# Plot 2: Average Entropy Uncertainties - ID region
axes[0, 1].plot(K_filtered, avg_ale_entropy_id_filtered_de, 'o-', label='Aleatoric Entropy (ID)', color='green', linewidth=2, markersize=8)
axes[0, 1].plot(K_filtered, avg_epi_entropy_id_filtered_de, 's-', label='Epistemic Entropy (ID)', color='orange', linewidth=2, markersize=8)
axes[0, 1].plot(K_filtered, avg_tot_entropy_id_filtered_de, '^-', label='Total Entropy (ID)', color='blue', linewidth=2, markersize=8)
axes[0, 1].set_xlabel('Number of Nets (K)', fontsize=12)
axes[0, 1].set_ylabel('Average Entropy Uncertainty', fontsize=12)
axes[0, 1].set_title(f'Deep Ensemble: Average Entropy Uncertainties (ID) vs K\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[0, 1].legend(fontsize=10)
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_xticks(K_filtered)

# Plot 3: MSE comparison
axes[1, 0].plot(K_filtered, mse_id_filtered_de, 'o-', label='MSE (ID)', color='blue', linewidth=2, markersize=8)
axes[1, 0].plot(K_filtered, mse_ood_filtered_de, 's-', label='MSE (OOD)', color='red', linewidth=2, markersize=8)
axes[1, 0].set_xlabel('Number of Nets (K)', fontsize=12)
axes[1, 0].set_ylabel('MSE', fontsize=12)
axes[1, 0].set_title(f'Deep Ensemble: MSE vs K\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[1, 0].legend(fontsize=10)
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_yscale('log')
axes[1, 0].set_xticks(K_filtered)

# Plot 4: ID vs OOD comparison (bar chart)
x_pos = np.arange(len(K_filtered))
width = 0.35
axes[1, 1].bar(x_pos - width/2, avg_epi_id_filtered_plot_de, width, label='Epistemic (ID)', color='orange', alpha=0.7)
axes[1, 1].bar(x_pos + width/2, avg_epi_ood_filtered_plot_de, width, label='Epistemic (OOD)', color='red', alpha=0.7)
axes[1, 1].set_xlabel('Number of Nets (K)', fontsize=12)
axes[1, 1].set_ylabel('Average Epistemic Uncertainty', fontsize=12)
axes[1, 1].set_title(f'Deep Ensemble: Epistemic Uncertainty - ID vs OOD\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(K_filtered)
axes[1, 1].legend(fontsize=10)
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.suptitle(f'Deep Ensemble Parameter Comparison: Varying K (epochs={epochs_values[0]})', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()

# Save plot
save_plot(fig, f"Deep_Ensemble_K_comparison_{function_name}_{noise_type}", 
          subfolder=f"comparisons/{noise_type}/{func_type}")
plt.show()
plt.close(fig)

# Create plots showing epochs variation (for fixed K)
# Group by K and plot epochs variation
unique_K = sorted(set(K_list))
fig_epochs_de, axes_epochs_de = plt.subplots(2, 2, figsize=(16, 12))

for K_val in unique_K:
    # Filter data for this K value
    mask = [k == K_val for k in K_list]
    epochs_subset = [epochs_list_de[i] for i in range(len(epochs_list_de)) if mask[i]]
    ale_id_subset = [avg_ale_id_list_de[i] for i in range(len(avg_ale_id_list_de)) if mask[i]]
    epi_id_subset = [avg_epi_id_list_de[i] for i in range(len(avg_epi_id_list_de)) if mask[i]]
    tot_id_subset = [avg_tot_id_list_de[i] for i in range(len(avg_tot_id_list_de)) if mask[i]]
    ale_entropy_id_subset = [avg_ale_entropy_id_list_de[i] for i in range(len(avg_ale_entropy_id_list_de)) if mask[i]]
    epi_entropy_id_subset = [avg_epi_entropy_id_list_de[i] for i in range(len(avg_epi_entropy_id_list_de)) if mask[i]]
    tot_entropy_id_subset = [avg_tot_entropy_id_list_de[i] for i in range(len(avg_tot_entropy_id_list_de)) if mask[i]]
    ale_ood_subset = [avg_ale_ood_list_de[i] for i in range(len(avg_ale_ood_list_de)) if mask[i]]
    epi_ood_subset = [avg_epi_ood_list_de[i] for i in range(len(avg_epi_ood_list_de)) if mask[i]]
    tot_ood_subset = [avg_tot_ood_list_de[i] for i in range(len(avg_tot_ood_list_de)) if mask[i]]
    mse_id_subset = [mse_id_list_de[i] for i in range(len(mse_id_list_de)) if mask[i]]
    mse_ood_subset = [mse_ood_list_de[i] for i in range(len(mse_ood_list_de)) if mask[i]]
    
    # Sort by epochs
    sorted_epochs_idx = np.argsort(epochs_subset)
    epochs_subset = [epochs_subset[i] for i in sorted_epochs_idx]
    ale_id_subset = [ale_id_subset[i] for i in sorted_epochs_idx]
    epi_id_subset = [epi_id_subset[i] for i in sorted_epochs_idx]
    tot_id_subset = [tot_id_subset[i] for i in sorted_epochs_idx]
    ale_entropy_id_subset = [ale_entropy_id_subset[i] for i in sorted_epochs_idx]
    epi_entropy_id_subset = [epi_entropy_id_subset[i] for i in sorted_epochs_idx]
    tot_entropy_id_subset = [tot_entropy_id_subset[i] for i in sorted_epochs_idx]
    ale_ood_subset = [ale_ood_subset[i] for i in sorted_epochs_idx]
    epi_ood_subset = [epi_ood_subset[i] for i in sorted_epochs_idx]
    tot_ood_subset = [tot_ood_subset[i] for i in sorted_epochs_idx]
    mse_id_subset = [mse_id_subset[i] for i in sorted_epochs_idx]
    mse_ood_subset = [mse_ood_subset[i] for i in sorted_epochs_idx]
    
    # Plot ID uncertainties vs epochs
    axes_epochs_de[0, 0].plot(epochs_subset, ale_id_subset, 'o-', label=f'Aleatoric (K={K_val})', linewidth=2, markersize=6)
    axes_epochs_de[0, 0].plot(epochs_subset, epi_id_subset, 's-', label=f'Epistemic (K={K_val})', linewidth=2, markersize=6)
    
    # Plot ID entropy uncertainties vs epochs
    axes_epochs_de[0, 1].plot(epochs_subset, ale_entropy_id_subset, 'o-', label=f'Aleatoric Entropy (K={K_val})', linewidth=2, markersize=6)
    axes_epochs_de[0, 1].plot(epochs_subset, epi_entropy_id_subset, 's-', label=f'Epistemic Entropy (K={K_val})', linewidth=2, markersize=6)
    
    # Plot MSE vs epochs
    axes_epochs_de[1, 0].plot(epochs_subset, mse_id_subset, 'o-', label=f'MSE ID (K={K_val})', linewidth=2, markersize=6)
    axes_epochs_de[1, 0].plot(epochs_subset, mse_ood_subset, 's-', label=f'MSE OOD (K={K_val})', linewidth=2, markersize=6)

axes_epochs_de[0, 0].set_xlabel('Epochs', fontsize=12)
axes_epochs_de[0, 0].set_ylabel('Average Uncertainty', fontsize=12)
axes_epochs_de[0, 0].set_title(f'Deep Ensemble: Average Uncertainties (ID) vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_epochs_de[0, 0].legend(fontsize=9, ncol=2)
axes_epochs_de[0, 0].grid(True, alpha=0.3)

axes_epochs_de[0, 1].set_xlabel('Epochs', fontsize=12)
axes_epochs_de[0, 1].set_ylabel('Average Entropy Uncertainty', fontsize=12)
axes_epochs_de[0, 1].set_title(f'Deep Ensemble: Average Entropy Uncertainties (ID) vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_epochs_de[0, 1].legend(fontsize=9, ncol=2)
axes_epochs_de[0, 1].grid(True, alpha=0.3)

axes_epochs_de[1, 0].set_xlabel('Epochs', fontsize=12)
axes_epochs_de[1, 0].set_ylabel('MSE', fontsize=12)
axes_epochs_de[1, 0].set_title(f'Deep Ensemble: MSE vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_epochs_de[1, 0].legend(fontsize=9, ncol=2)
axes_epochs_de[1, 0].grid(True, alpha=0.3)
axes_epochs_de[1, 0].set_yscale('log')

# Plot comparison of epistemic uncertainty ID vs OOD for different epochs
for K_val in unique_K:
    mask = [k == K_val for k in K_list]
    epochs_subset = [epochs_list_de[i] for i in range(len(epochs_list_de)) if mask[i]]
    epi_id_subset = [avg_epi_id_list_de[i] for i in range(len(avg_epi_id_list_de)) if mask[i]]
    epi_ood_subset = [avg_epi_ood_list_de[i] for i in range(len(avg_epi_ood_list_de)) if mask[i]]
    sorted_epochs_idx = np.argsort(epochs_subset)
    epochs_subset = [epochs_subset[i] for i in sorted_epochs_idx]
    epi_id_subset = [epi_id_subset[i] for i in sorted_epochs_idx]
    epi_ood_subset = [epi_ood_subset[i] for i in sorted_epochs_idx]
    axes_epochs_de[1, 1].plot(epochs_subset, epi_id_subset, 'o-', label=f'Epistemic ID (K={K_val})', linewidth=2, markersize=6)
    axes_epochs_de[1, 1].plot(epochs_subset, epi_ood_subset, 's--', label=f'Epistemic OOD (K={K_val})', linewidth=2, markersize=6)

axes_epochs_de[1, 1].set_xlabel('Epochs', fontsize=12)
axes_epochs_de[1, 1].set_ylabel('Average Epistemic Uncertainty', fontsize=12)
axes_epochs_de[1, 1].set_title(f'Deep Ensemble: Epistemic Uncertainty - ID vs OOD\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_epochs_de[1, 1].legend(fontsize=9, ncol=2)
axes_epochs_de[1, 1].grid(True, alpha=0.3)

plt.suptitle(f'Deep Ensemble Parameter Comparison: Varying Epochs', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()

# Save plot
save_plot(fig_epochs_de, f"Deep_Ensemble_epochs_comparison_{function_name}_{noise_type}", 
          subfolder=f"comparisons/{noise_type}/{func_type}")
plt.show()
plt.close(fig_epochs_de)

# Create metrics evolution plots (NLL, CRPS, Spearman) vs K
fig_metrics_de, axes_metrics_de = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: NLL vs K
axes_metrics_de[0, 0].plot(K_filtered, nll_id_filtered_de, 'o-', label='NLL (ID)', color='blue', linewidth=2, markersize=8)
axes_metrics_de[0, 0].plot(K_filtered, nll_ood_filtered_de, 's-', label='NLL (OOD)', color='red', linewidth=2, markersize=8)
axes_metrics_de[0, 0].set_xlabel('Number of Nets (K)', fontsize=12)
axes_metrics_de[0, 0].set_ylabel('NLL', fontsize=12)
axes_metrics_de[0, 0].set_title(f'Deep Ensemble: NLL vs K\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_de[0, 0].legend(fontsize=10)
axes_metrics_de[0, 0].grid(True, alpha=0.3)
axes_metrics_de[0, 0].set_xticks(K_filtered)

# Plot 2: CRPS vs K
axes_metrics_de[0, 1].plot(K_filtered, crps_id_filtered_de, 'o-', label='CRPS (ID)', color='blue', linewidth=2, markersize=8)
axes_metrics_de[0, 1].plot(K_filtered, crps_ood_filtered_de, 's-', label='CRPS (OOD)', color='red', linewidth=2, markersize=8)
axes_metrics_de[0, 1].set_xlabel('Number of Nets (K)', fontsize=12)
axes_metrics_de[0, 1].set_ylabel('CRPS', fontsize=12)
axes_metrics_de[0, 1].set_title(f'Deep Ensemble: CRPS vs K\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_de[0, 1].legend(fontsize=10)
axes_metrics_de[0, 1].grid(True, alpha=0.3)
axes_metrics_de[0, 1].set_xticks(K_filtered)

# Plot 3: Spearman Aleatoric vs K
axes_metrics_de[1, 0].plot(K_filtered, spearman_aleatoric_id_filtered_de, 'o-', label='Spearman Aleatoric (ID)', color='green', linewidth=2, markersize=8)
axes_metrics_de[1, 0].plot(K_filtered, spearman_aleatoric_ood_filtered_de, 's-', label='Spearman Aleatoric (OOD)', color='darkgreen', linewidth=2, markersize=8)
axes_metrics_de[1, 0].set_xlabel('Number of Nets (K)', fontsize=12)
axes_metrics_de[1, 0].set_ylabel('Spearman Correlation', fontsize=12)
axes_metrics_de[1, 0].set_title(f'Deep Ensemble: Spearman Aleatoric vs K\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_de[1, 0].legend(fontsize=10)
axes_metrics_de[1, 0].grid(True, alpha=0.3)
axes_metrics_de[1, 0].set_xticks(K_filtered)

# Plot 4: Spearman Epistemic vs K
axes_metrics_de[1, 1].plot(K_filtered, spearman_epistemic_id_filtered_de, 'o-', label='Spearman Epistemic (ID)', color='orange', linewidth=2, markersize=8)
axes_metrics_de[1, 1].plot(K_filtered, spearman_epistemic_ood_filtered_de, 's-', label='Spearman Epistemic (OOD)', color='red', linewidth=2, markersize=8)
axes_metrics_de[1, 1].set_xlabel('Number of Nets (K)', fontsize=12)
axes_metrics_de[1, 1].set_ylabel('Spearman Correlation', fontsize=12)
axes_metrics_de[1, 1].set_title(f'Deep Ensemble: Spearman Epistemic vs K\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_de[1, 1].legend(fontsize=10)
axes_metrics_de[1, 1].grid(True, alpha=0.3)
axes_metrics_de[1, 1].set_xticks(K_filtered)

plt.suptitle(f'Deep Ensemble Parameter Comparison: Metrics vs K (epochs={epochs_values[0]})', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()

# Save plot
save_plot(fig_metrics_de, f"Deep_Ensemble_metrics_K_comparison_{function_name}_{noise_type}", 
          subfolder=f"comparisons/{noise_type}/{func_type}")
plt.show()
plt.close(fig_metrics_de)

# Create metrics evolution plots vs epochs (grouped by K)
fig_metrics_epochs_de, axes_metrics_epochs_de = plt.subplots(2, 2, figsize=(16, 12))

for K_val in unique_K:
    mask = [k == K_val for k in K_list]
    epochs_subset = [epochs_list_de[i] for i in range(len(epochs_list_de)) if mask[i]]
    nll_id_subset = [nll_id_list_de[i] for i in range(len(nll_id_list_de)) if mask[i]]
    nll_ood_subset = [nll_ood_list_de[i] for i in range(len(nll_ood_list_de)) if mask[i]]
    crps_id_subset = [crps_id_list_de[i] for i in range(len(crps_id_list_de)) if mask[i]]
    crps_ood_subset = [crps_ood_list_de[i] for i in range(len(crps_ood_list_de)) if mask[i]]
    spear_ale_id_subset = [spearman_aleatoric_id_list_de[i] for i in range(len(spearman_aleatoric_id_list_de)) if mask[i]]
    spear_ale_ood_subset = [spearman_aleatoric_ood_list_de[i] for i in range(len(spearman_aleatoric_ood_list_de)) if mask[i]]
    spear_epi_id_subset = [spearman_epistemic_id_list_de[i] for i in range(len(spearman_epistemic_id_list_de)) if mask[i]]
    spear_epi_ood_subset = [spearman_epistemic_ood_list_de[i] for i in range(len(spearman_epistemic_ood_list_de)) if mask[i]]
    
    # Sort by epochs
    sorted_epochs_idx = np.argsort(epochs_subset)
    epochs_subset = [epochs_subset[i] for i in sorted_epochs_idx]
    nll_id_subset = [nll_id_subset[i] for i in sorted_epochs_idx]
    nll_ood_subset = [nll_ood_subset[i] for i in sorted_epochs_idx]
    crps_id_subset = [crps_id_subset[i] for i in sorted_epochs_idx]
    crps_ood_subset = [crps_ood_subset[i] for i in sorted_epochs_idx]
    spear_ale_id_subset = [spear_ale_id_subset[i] for i in sorted_epochs_idx]
    spear_ale_ood_subset = [spear_ale_ood_subset[i] for i in sorted_epochs_idx]
    spear_epi_id_subset = [spear_epi_id_subset[i] for i in sorted_epochs_idx]
    spear_epi_ood_subset = [spear_epi_ood_subset[i] for i in sorted_epochs_idx]
    
    # Plot NLL vs epochs
    axes_metrics_epochs_de[0, 0].plot(epochs_subset, nll_id_subset, 'o-', label=f'NLL ID (K={K_val})', linewidth=2, markersize=6)
    axes_metrics_epochs_de[0, 0].plot(epochs_subset, nll_ood_subset, 's-', label=f'NLL OOD (K={K_val})', linewidth=2, markersize=6)
    
    # Plot CRPS vs epochs
    axes_metrics_epochs_de[0, 1].plot(epochs_subset, crps_id_subset, 'o-', label=f'CRPS ID (K={K_val})', linewidth=2, markersize=6)
    axes_metrics_epochs_de[0, 1].plot(epochs_subset, crps_ood_subset, 's-', label=f'CRPS OOD (K={K_val})', linewidth=2, markersize=6)
    
    # Plot Spearman Aleatoric vs epochs
    axes_metrics_epochs_de[1, 0].plot(epochs_subset, spear_ale_id_subset, 'o-', label=f'Spear Ale ID (K={K_val})', linewidth=2, markersize=6)
    axes_metrics_epochs_de[1, 0].plot(epochs_subset, spear_ale_ood_subset, 's-', label=f'Spear Ale OOD (K={K_val})', linewidth=2, markersize=6)
    
    # Plot Spearman Epistemic vs epochs
    axes_metrics_epochs_de[1, 1].plot(epochs_subset, spear_epi_id_subset, 'o-', label=f'Spear Epi ID (K={K_val})', linewidth=2, markersize=6)
    axes_metrics_epochs_de[1, 1].plot(epochs_subset, spear_epi_ood_subset, 's-', label=f'Spear Epi OOD (K={K_val})', linewidth=2, markersize=6)

axes_metrics_epochs_de[0, 0].set_xlabel('Epochs', fontsize=12)
axes_metrics_epochs_de[0, 0].set_ylabel('NLL', fontsize=12)
axes_metrics_epochs_de[0, 0].set_title(f'Deep Ensemble: NLL vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_epochs_de[0, 0].legend(fontsize=9, ncol=2)
axes_metrics_epochs_de[0, 0].grid(True, alpha=0.3)

axes_metrics_epochs_de[0, 1].set_xlabel('Epochs', fontsize=12)
axes_metrics_epochs_de[0, 1].set_ylabel('CRPS', fontsize=12)
axes_metrics_epochs_de[0, 1].set_title(f'Deep Ensemble: CRPS vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_epochs_de[0, 1].legend(fontsize=9, ncol=2)
axes_metrics_epochs_de[0, 1].grid(True, alpha=0.3)

axes_metrics_epochs_de[1, 0].set_xlabel('Epochs', fontsize=12)
axes_metrics_epochs_de[1, 0].set_ylabel('Spearman Correlation', fontsize=12)
axes_metrics_epochs_de[1, 0].set_title(f'Deep Ensemble: Spearman Aleatoric vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_epochs_de[1, 0].legend(fontsize=9, ncol=2)
axes_metrics_epochs_de[1, 0].grid(True, alpha=0.3)

axes_metrics_epochs_de[1, 1].set_xlabel('Epochs', fontsize=12)
axes_metrics_epochs_de[1, 1].set_ylabel('Spearman Correlation', fontsize=12)
axes_metrics_epochs_de[1, 1].set_title(f'Deep Ensemble: Spearman Epistemic vs Epochs\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes_metrics_epochs_de[1, 1].legend(fontsize=9, ncol=2)
axes_metrics_epochs_de[1, 1].grid(True, alpha=0.3)

plt.suptitle(f'Deep Ensemble Parameter Comparison: Metrics vs Epochs', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()

# Save plot
save_plot(fig_metrics_epochs_de, f"Deep_Ensemble_metrics_epochs_comparison_{function_name}_{noise_type}", 
          subfolder=f"comparisons/{noise_type}/{func_type}")
plt.show()
plt.close(fig_metrics_epochs_de)

# Create summary table (for first epoch value to keep it manageable)
# Use already filtered data
comparison_df_de = pd.DataFrame({
    'K': K_filtered,
    'Epochs': [epochs_list_de[i] for i in range(len(epochs_list_de)) if mask_first_epoch_de[i]],
    'Avg_Ale_ID': avg_ale_id_filtered_de,
    'Avg_Epi_ID': avg_epi_id_filtered_de,
    'Avg_Tot_ID': avg_tot_id_filtered_de,
    'Avg_Ale_OOD': avg_ale_ood_filtered_de,
    'Avg_Epi_OOD': avg_epi_ood_filtered_de,
    'Avg_Tot_OOD': avg_tot_ood_filtered_de,
    'Avg_Ale_Entropy_ID': avg_ale_entropy_id_filtered_de,
    'Avg_Epi_Entropy_ID': avg_epi_entropy_id_filtered_de,
    'Avg_Tot_Entropy_ID': avg_tot_entropy_id_filtered_de,
    'Avg_Ale_Entropy_OOD': avg_ale_entropy_ood_filtered_de,
    'Avg_Epi_Entropy_OOD': avg_epi_entropy_ood_filtered_de,
    'Avg_Tot_Entropy_OOD': avg_tot_entropy_ood_filtered_de,
    'MSE_ID': mse_id_filtered_de,
    'MSE_OOD': mse_ood_filtered_de,
    'NLL_ID': nll_id_filtered_de,
    'NLL_OOD': nll_ood_filtered_de,
    'CRPS_ID': crps_id_filtered_de,
    'CRPS_OOD': crps_ood_filtered_de,
    'Spear_Ale_ID': spearman_aleatoric_id_filtered_de,
    'Spear_Ale_OOD': spearman_aleatoric_ood_filtered_de,
    'Spear_Epi_ID': spearman_epistemic_id_filtered_de,
    'Spear_Epi_OOD': spearman_epistemic_ood_filtered_de
})

print("\nSummary Table:")
print(comparison_df_de.to_string(index=False))

# Save comparison table
save_statistics(comparison_df_de, f"Deep_Ensemble_K_comparison_{function_name}_{noise_type}",
                subfolder=f"comparisons/{noise_type}/{func_type}")


## Overall Comparison Summary


In [None]:
# Create a combined summary
print(f"\n{'='*80}")
print("OVERALL COMPARISON SUMMARY")
print(f"{'='*80}\n")

print("MC Dropout - Best Parameters (lowest OOD MSE):")
best_mc_idx = np.argmin(mse_ood_list)
best_mc_samples = mc_samples_list[best_mc_idx]
print(f"  MC Samples: {best_mc_samples}")
print(f"  OOD MSE: {mse_ood_list[best_mc_idx]:.6f}")
print(f"  OOD Epistemic Uncertainty: {avg_epi_ood_list[best_mc_idx]:.6f}")

print("\nDeep Ensemble - Best Parameters (lowest OOD MSE):")
best_de_idx = np.argmin(mse_ood_list_de)
best_K = K_list[best_de_idx]
print(f"  K: {best_K}")
print(f"  OOD MSE: {mse_ood_list_de[best_de_idx]:.6f}")
print(f"  OOD Epistemic Uncertainty: {avg_epi_ood_list_de[best_de_idx]:.6f}")

# Create side-by-side comparison plot
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# MC Dropout: Epistemic uncertainty comparison
axes[0].plot(mc_samples_list, avg_epi_id_list, 'o-', label='Epistemic (ID)', color='blue', linewidth=2, markersize=8)
axes[0].plot(mc_samples_list, avg_epi_ood_list, 's-', label='Epistemic (OOD)', color='red', linewidth=2, markersize=8)
axes[0].set_xlabel('MC Samples', fontsize=12)
axes[0].set_ylabel('Average Epistemic Uncertainty', fontsize=12)
axes[0].set_title(f'MC Dropout: Epistemic Uncertainty\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)
axes[0].set_xticks(mc_samples_list)

# Deep Ensemble: Epistemic uncertainty comparison
axes[1].plot(K_list, avg_epi_id_list_de, 'o-', label='Epistemic (ID)', color='blue', linewidth=2, markersize=8)
axes[1].plot(K_list, avg_epi_ood_list_de, 's-', label='Epistemic (OOD)', color='red', linewidth=2, markersize=8)
axes[1].set_xlabel('Number of Nets (K)', fontsize=12)
axes[1].set_ylabel('Average Epistemic Uncertainty', fontsize=12)
axes[1].set_title(f'Deep Ensemble: Epistemic Uncertainty\n{function_name} Function ({noise_type.capitalize()})', fontsize=13, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)
axes[1].set_xticks(K_list)

plt.suptitle('Parameter Comparison: ID vs OOD Epistemic Uncertainty', 
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()

# Save plot
save_plot(fig, f"Overall_comparison_{function_name}_{noise_type}", 
          subfolder=f"comparisons/{noise_type}/{func_type}")
plt.show()
plt.close(fig)
