In [None]:
import os
import sys
import re

sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())) + "\\src")

In [None]:
import argparse
from datetime import datetime
from data import *
from dataloader_smacnp import DatasetGP, DatasetGP_test
from math import sqrt
from metrics import *
from model_smacnp import SpatialNeuralProcess, Criterion
import numpy as np
import pandas as pd
from scipy.stats import norm
import time
import torch as torch
import torch.optim as optim
from torch.utils.data import DataLoader
from train_configs import train_runner, test_runner

In [None]:
def train_single_seed(args, random_state, model_folder_name=None, print_progress=True):
    """
    Train model with a single random seed and return comprehensive results
    """
    # Get args
    dataset = args.dataset
    model_name = args.model_name
    path = args.path
    train_size = args.train_size
    val_size = args.val_size
    test_size = 1 - (args.train_size + args.val_size)
    train_sub_sample = args.train_sub_sample
    max_epoches = args.max_epoches
    patience_limit = args.patience_limit
    n_tasks = args.n_tasks
    batch_size = args.batch_size
    y_size = args.y_size
    start_lr = args.start_lr
    num_hidden = args.num_hidden

    # Get x_size
    if dataset == "california_housing":
        x, y, c = get_california_housing_data(norm_x=False)
    x_size = x.shape[1]

    # Set random seed
    set_seed(random_state)

    # Set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Train model
    trainset = DatasetGP(
        n_tasks=n_tasks,
        batch_size=batch_size,
        dataset=dataset,
        train_size=train_size,
        val_size=val_size,
        random_state=random_state,
        train_sub_sample=train_sub_sample,
        eval="validation",
    )
    valset = DatasetGP_test(
        n_tasks=n_tasks,
        batch_size=batch_size,
        dataset=dataset,
        train_size=train_size,
        val_size=val_size,
        random_state=random_state,
        train_sub_sample=train_sub_sample,
        eval="validation",
    )
    model = SpatialNeuralProcess(
        x_size=x_size, y_size=y_size, num_hidden=num_hidden
    ).to(device)

    # Calculate number of parameters
    num_parameters = sum(p.numel() for p in model.parameters())

    criterion = Criterion()
    optimizer = optim.Adam(model.parameters(), lr=start_lr)
    model.train()

    # Create model folder name if not provided
    if model_folder_name is None:
        name = "final" + "_" + model_name + "_" + dataset
        now = datetime.now()
        saved_file = "{}_{}{}-{}h{}m{}s".format(
            name,
            now.strftime("%h"),
            now.strftime("%d"),
            now.strftime("%H"),
            now.strftime("%M"),
            now.strftime("%S"),
        )
    else:
        saved_file = model_folder_name

    save_dir = os.path.join(path, "trained", model_name)
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, saved_file)

    # Training loop
    start_time = time.time()
    best_val_loss = float("inf")
    best_epoch = 0
    patience_counter = 0
    final_epoch = 0

    for epoch in range(max_epoches):
        trainloader = DataLoader(trainset, shuffle=True)
        valloader = DataLoader(valset, shuffle=True)

        mean_y, var_y, target_id, target_y, context_id, loss, train_mse = train_runner(
            model, trainloader, criterion, optimizer
        )
        val_pred_y, val_var_y, val_target_id, val_target_y, val_loss, val_mse = (
            test_runner(model, valloader, criterion)
        )

        # Compute metrics
        train_mse = (torch.sum((target_y - mean_y) ** 2)) / len(target_y)
        train_rmse = sqrt(train_mse.item())
        train_mae = (torch.sum(torch.absolute(target_y - mean_y))) / len(target_y)
        train_r2 = 1 - (
            (torch.sum((target_y - mean_y) ** 2))
            / torch.sum((target_y - target_y.mean()) ** 2)
        )
        val_mse = (torch.sum((val_target_y - val_pred_y) ** 2)) / len(val_target_y)
        val_rmse = sqrt(val_mse.item())
        val_mae = (torch.sum(torch.absolute(val_target_y - val_pred_y))) / len(
            val_target_y
        )
        val_r2 = 1 - (
            (torch.sum((val_target_y - val_pred_y) ** 2))
            / torch.sum((val_target_y - val_target_y.mean()) ** 2)
        )

        # Check for improvement
        if best_val_loss > val_loss.item():
            best_val_loss = val_loss.item()
            best_epoch = epoch
            patience_counter = 0  # Reset patience
            torch.save(
                {"model": model.state_dict(), "optimizer": optimizer.state_dict()},
                save_path,
            )
            if print_progress:
                print(
                    "Train Epoch: {} Lr: {:.4f}, train_loss: {:.4f}, train_mae: {:.4f}, train_mse: {:.4f}, train_rmse: {:.4f}, train_r2: {:.4f}".format(
                        epoch + 1,
                        optimizer.state_dict()["param_groups"][0]["lr"],
                        loss.item(),
                        train_mae.item(),
                        train_mse.item(),
                        train_rmse,
                        train_r2.item(),
                    )
                )
                print(
                    "Val Epoch: {} Lr: {:.4f}, val_loss: {:.4f}, val_mae: {:.4f}, val_mse: {:.4f}, val_rmse: {:.4f}, val_r2: {:.4f}".format(
                        epoch + 1,
                        optimizer.state_dict()["param_groups"][0]["lr"],
                        val_loss.item(),
                        val_mae.item(),
                        val_mse.item(),
                        val_rmse,
                        val_r2.item(),
                    )
                )
        else:
            patience_counter += 1  # Increment patience

        # Early stopping check
        if patience_counter > patience_limit:
            if print_progress:
                print(
                    f"Stopping early after epoch {epoch}. Best validation loss: {best_val_loss} at epoch {best_epoch + 1}."
                )
            final_epoch = epoch
            break
        final_epoch = epoch

    end_time = time.time()
    training_time = end_time - start_time

    # Test evaluation
    # IMPORTANT: Reload the best model from checkpoint before evaluation
    # The model in memory is from the last epoch, not necessarily the best epoch
    state_dict = torch.load(save_path)
    model.load_state_dict(state_dict["model"])
    model.eval()
    testset = DatasetGP_test(
        n_tasks=1,
        batch_size=1,
        dataset=dataset,
        train_size=train_size,
        val_size=val_size,
        random_state=random_state,
        train_sub_sample=train_sub_sample,
        eval="test",
    )
    testloader = DataLoader(testset, batch_size=1, shuffle=False)

    test_pred_y, test_var_y, test_target_id, test_target_y, test_loss, test_mse = (
        test_runner(model, testloader, criterion)
    )
    test_target_y_np = test_target_y.cpu().detach().numpy()
    test_pred_y_np = test_pred_y.cpu().detach().numpy()
    test_var_y_np = test_var_y.cpu().detach().numpy()

    # Compute test metrics
    test_mse = (np.sum((test_target_y_np - test_pred_y_np) ** 2)) / len(
        test_target_y_np
    )
    test_rmse = np.sqrt(test_mse)
    test_mae = (np.sum(np.absolute(test_target_y_np - test_pred_y_np))) / len(
        test_target_y_np
    )
    test_r2 = 1 - (
        (np.sum((test_target_y_np - test_pred_y_np) ** 2))
        / np.sum((test_target_y_np - test_target_y_np.mean()) ** 2)
    )

    # Calculate correlations
    corr_t_p = np.corrcoef(test_target_y_np, test_pred_y_np)[0, 1]
    corr_t_v = np.corrcoef(test_target_y_np, test_var_y_np)[0, 1]
    corr_p_v = np.corrcoef(test_pred_y_np, test_var_y_np)[0, 1]

    # Calculate CCC
    corr = np.corrcoef(test_target_y_np, test_pred_y_np)
    C = (2 * corr[0, 1] * np.std(test_pred_y_np) * np.std(test_target_y_np)) / (
        np.var(test_target_y_np)
        + np.var(test_pred_y_np)
        + (test_target_y_np.mean() - test_pred_y_np.mean()) ** 2
    )

    # Calculate MPE
    try:
        mpe_value = mpe(
            torch.tensor(test_target_y_np),
            torch.tensor(test_pred_y_np),
            var_pred=torch.tensor(test_var_y_np),
        )
    except Exception as e:
        if print_progress:
            print(f"MPE calculation error: {e}")
        mpe_value = float("inf")

    # Calculate coverage and SMIS
    alpha = 0.05
    prediction_df = pd.DataFrame(
        {
            "test_y": test_target_y_np,
            "test_pred": test_pred_y_np,
            "test_var_y": test_var_y_np,
        }
    )

    prediction_df["linf"] = [
        norm.ppf(
            alpha / 2,
            loc=prediction_df.test_pred.iloc[i],
            scale=np.sqrt(prediction_df.test_var_y.iloc[i]),
        )
        for i in range(len(prediction_df))
    ]
    prediction_df["lsup"] = [
        norm.ppf(
            1 - alpha / 2,
            loc=prediction_df.test_pred.iloc[i],
            scale=np.sqrt(prediction_df.test_var_y.iloc[i]),
        )
        for i in range(len(prediction_df))
    ]

    try:
        cov = coverage(
            prediction_df.test_y.values,
            prediction_df.linf.values,
            prediction_df.lsup.values,
        )
        smis_value = smis(
            prediction_df.test_y.values,
            prediction_df.linf.values,
            prediction_df.lsup.values,
            alpha=alpha,
        )
    except Exception as e:
        if print_progress:
            print(f"Coverage/SMIS calculation error: {e}")
        cov = 0.0
        smis_value = float("inf")

    # Calculate calibration metrics
    try:
        taus = np.round(np.arange(0.01, 1, 0.01), 2)
        means = np.repeat(test_pred_y_np[:, np.newaxis], len(taus), axis=1)
        std_devs = np.repeat(np.sqrt(test_var_y_np)[:, np.newaxis], len(taus), axis=1)
        quantiles = norm.ppf(taus, loc=means, scale=std_devs)

        comparison = (test_target_y_np[:, np.newaxis] <= quantiles).astype(int)
        comparison_mean = pd.DataFrame(comparison, columns=taus).mean().reset_index()

        calibration = np.sum((comparison_mean["index"] - comparison_mean[0]) ** 2)
        madecp = np.mean(np.abs(comparison_mean["index"] - comparison_mean[0]))

        # Calculate 95% prediction interval coverage
        tau_lower, tau_upper = 0.025, 0.975
        q_lower = norm.ppf(tau_lower, loc=test_pred_y_np, scale=np.sqrt(test_var_y_np))
        q_upper = norm.ppf(tau_upper, loc=test_pred_y_np, scale=np.sqrt(test_var_y_np))
        coverage_95 = np.mean(
            (test_target_y_np >= q_lower) & (test_target_y_np <= q_upper)
        )

    except Exception as e:
        if print_progress:
            print(f"Calibration metrics calculation error: {e}")
        calibration = float("inf")
        madecp = float("inf")
        coverage_95 = 0.0

    # Return comprehensive results
    results = {
        "random_state": random_state,
        "model_folder": saved_file,
        "final_epoch": final_epoch,
        "best_epoch": best_epoch,
        "training_time": training_time,
        "num_parameters": num_parameters,
        "test_mse": test_mse,
        "test_mae": test_mae,
        "test_rmse": test_rmse,
        "test_r2": test_r2,
        "val_mse": best_val_loss,
        "mpe": mpe_value,
        "calibration": calibration,
        "madecp": madecp,
        "coverage_95": coverage_95,
        "coverage": cov,
        "smis": smis_value,
        "ccc": C,
        "corr_t_p": corr_t_p,
        "corr_t_v": corr_t_v,
        "corr_p_v": corr_p_v,
        "average_var": np.mean(test_var_y_np),
    }

    return results

In [None]:
def train_multiple_seeds(args, seeds=None, num_seeds=10):
    """
    Train model with multiple random seeds and return aggregated results
    """
    if seeds is None:
        seeds = list(range(42, 42 + num_seeds))

    all_results = []
    base_model_name = None

    print(f"Starting training with {len(seeds)} different random seeds: {seeds}")
    print("=" * 80)

    for i, seed in enumerate(seeds):
        print(f"\n--- Training with seed {seed} ({i+1}/{len(seeds)}) ---")

        # Create a consistent model folder name for the first run
        if i == 0:
            # Generate base model name
            dataset = args.dataset
            model_name = args.model_name

            name = "final" + "_" + model_name + "_" + dataset

            now = datetime.now()
            base_model_name = "{}_{}{}-{}h{}m{}s".format(
                name,
                now.strftime("%h"),
                now.strftime("%d"),
                now.strftime("%H"),
                now.strftime("%M"),
                now.strftime("%S"),
            )

        # Train with current seed with timeout protection
        try:
            print(f"Starting training for seed {seed}...")
            results = train_single_seed(
                args, seed, model_folder_name=f"{base_model_name}_seed{seed}"
            )
            all_results.append(results)

            print(f"Seed {seed} completed:")
            print(f"  - Model folder: {results['model_folder']}")
            print(f"  - Final epoch: {results['final_epoch']}")
            print(f"  - Best epoch: {results['best_epoch']}")
            print(f"  - Training time: {results['training_time']:.2f}s")
            print(f"  - Test MSE: {results['test_mse']:.6f}")
            print(f"  - Test MAE: {results['test_mae']:.6f}")
            print(f"  - Test RMSE: {results['test_rmse']:.6f}")
            print(f"  - Test R2: {results['test_r2']:.6f}")
            print(f"  - MPE: {results['mpe']:.6f}")
            print(f"  - Coverage 95%: {results['coverage_95']:.4f}")
            print(f"  - SMIS: {results['smis']:.6f}")
            print(f"  - Calibration: {results['calibration']:.6f}")
            print(f"  - MADECP: {results['madecp']:.6f}")

        except Exception as e:
            print(f"Error training seed {seed}: {e}")
            # Create a dummy result to maintain consistency
            dummy_result = {
                "random_state": seed,
                "model_folder": f"{base_model_name}_seed{seed}",
                "final_epoch": 0,
                "best_epoch": 0,
                "training_time": 0.0,
                "num_parameters": 0,
                "test_mse": float("inf"),
                "test_mae": float("inf"),
                "test_rmse": float("inf"),
                "test_r2": float("-inf"),
                "val_mse": float("inf"),
                "mpe": float("inf"),
                "calibration": float("inf"),
                "madecp": float("inf"),
                "coverage_95": 0.0,
                "coverage": 0.0,
                "smis": float("inf"),
                "ccc": float("-inf"),
                "corr_t_p": float("-inf"),
                "corr_t_v": float("-inf"),
                "corr_p_v": float("-inf"),
                "average_var": 0.0,
            }
            all_results.append(dummy_result)
            print(f"Added dummy result for failed seed {seed}")

    return all_results

In [None]:
def calculate_aggregated_results(all_results):
    """
    Calculate mean and standard deviation for all metrics across seeds
    """

    # Extract all metrics
    metrics = [
        "final_epoch",
        "best_epoch",
        "training_time",
        "test_mse",
        "test_mae",
        "test_rmse",
        "test_r2",
        "val_mse",
        "mpe",
        "calibration",
        "madecp",
        "coverage_95",
        "coverage",
        "smis",
        "ccc",
        "corr_t_p",
        "corr_t_v",
        "corr_p_v",
        "average_var",
    ]

    results_dict = {}
    for metric in metrics:
        values = [result[metric] for result in all_results]
        # Filter out inf and -inf values for statistics
        valid_values = [v for v in values if not (np.isinf(v) or np.isnan(v))]
        if len(valid_values) > 0:
            results_dict[metric] = {
                "mean": np.mean(valid_values),
                "std": np.std(valid_values),
                "min": np.min(valid_values),
                "max": np.max(valid_values),
            }
        else:
            results_dict[metric] = {
                "mean": float("inf"),
                "std": float("inf"),
                "min": float("inf"),
                "max": float("inf"),
            }

    return results_dict

In [None]:
def print_aggregated_results(all_results, results_dict):
    """
    Print comprehensive results summary
    """
    print("\n" + "=" * 100)
    print("COMPREHENSIVE RESULTS SUMMARY")
    print("=" * 100)

    print(f"\nNumber of seeds: {len(all_results)}")
    print(f"Seeds used: {[r['random_state'] for r in all_results]}")

    print(f"\nBase model folder: {all_results[0]['model_folder'].split('_seed')[0]}")

    # Print number of parameters (same for all seeds, using seed 42 if available, otherwise first seed)
    seed_42_result = next((r for r in all_results if r["random_state"] == 42), None)
    if seed_42_result is None:
        seed_42_result = all_results[0]
    num_params = seed_42_result.get("num_parameters", 0)
    print(f"\nModel Parameters: {num_params:,}")

    print("\n" + "-" * 80)
    print("TRAINING METRICS")
    print("-" * 80)

    # Training metrics
    print(f"Final Epochs:")
    print(
        f"  Mean ± Std: {results_dict['final_epoch']['mean']:.1f} ± {results_dict['final_epoch']['std']:.1f}"
    )
    print(
        f"  Range: [{results_dict['final_epoch']['min']:.0f}, {results_dict['final_epoch']['max']:.0f}]"
    )

    print(f"\nBest Epochs:")
    print(
        f"  Mean ± Std: {results_dict['best_epoch']['mean']:.1f} ± {results_dict['best_epoch']['std']:.1f}"
    )
    print(
        f"  Range: [{results_dict['best_epoch']['min']:.0f}, {results_dict['best_epoch']['max']:.0f}]"
    )

    print(f"\nTraining Time (seconds):")
    print(
        f"  Mean ± Std: {results_dict['training_time']['mean']:.2f} ± {results_dict['training_time']['std']:.2f}"
    )
    print(
        f"  Range: [{results_dict['training_time']['min']:.2f}, {results_dict['training_time']['max']:.2f}]"
    )

    print("\n" + "-" * 80)
    print("PERFORMANCE METRICS")
    print("-" * 80)

    # Performance metrics
    print(f"Test MSE:")
    print(
        f"  Mean ± Std: {results_dict['test_mse']['mean']:.6f} ± {results_dict['test_mse']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['test_mse']['min']:.6f}, {results_dict['test_mse']['max']:.6f}]"
    )

    print(f"\nTest MAE:")
    print(
        f"  Mean ± Std: {results_dict['test_mae']['mean']:.6f} ± {results_dict['test_mae']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['test_mae']['min']:.6f}, {results_dict['test_mae']['max']:.6f}]"
    )

    print(f"\nTest RMSE:")
    print(
        f"  Mean ± Std: {results_dict['test_rmse']['mean']:.6f} ± {results_dict['test_rmse']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['test_rmse']['min']:.6f}, {results_dict['test_rmse']['max']:.6f}]"
    )

    print(f"\nTest R2:")
    print(
        f"  Mean ± Std: {results_dict['test_r2']['mean']:.6f} ± {results_dict['test_r2']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['test_r2']['min']:.6f}, {results_dict['test_r2']['max']:.6f}]"
    )

    print(f"\nValidation MSE:")
    print(
        f"  Mean ± Std: {results_dict['val_mse']['mean']:.6f} ± {results_dict['val_mse']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['val_mse']['min']:.6f}, {results_dict['val_mse']['max']:.6f}]"
    )

    print("\n" + "-" * 80)
    print("UNCERTAINTY QUANTIFICATION METRICS")
    print("-" * 80)

    # Uncertainty metrics
    print(f"MPE (Mean Prediction Error):")
    print(
        f"  Mean ± Std: {results_dict['mpe']['mean']:.6f} ± {results_dict['mpe']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['mpe']['min']:.6f}, {results_dict['mpe']['max']:.6f}]"
    )

    print(f"\n95% Prediction Interval Coverage:")
    print(
        f"  Mean ± Std: {results_dict['coverage_95']['mean']:.4f} ± {results_dict['coverage_95']['std']:.4f}"
    )
    print(
        f"  Range: [{results_dict['coverage_95']['min']:.4f}, {results_dict['coverage_95']['max']:.4f}]"
    )

    print(f"\nCoverage:")
    print(
        f"  Mean ± Std: {results_dict['coverage']['mean']:.4f} ± {results_dict['coverage']['std']:.4f}"
    )
    print(
        f"  Range: [{results_dict['coverage']['min']:.4f}, {results_dict['coverage']['max']:.4f}]"
    )

    print(f"\nSMIS:")
    print(
        f"  Mean ± Std: {results_dict['smis']['mean']:.6f} ± {results_dict['smis']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['smis']['min']:.6f}, {results_dict['smis']['max']:.6f}]"
    )

    print(f"\nCalibration Score:")
    print(
        f"  Mean ± Std: {results_dict['calibration']['mean']:.6f} ± {results_dict['calibration']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['calibration']['min']:.6f}, {results_dict['calibration']['max']:.6f}]"
    )

    print(f"\nMADECP (Mean Absolute Deviation from Expected Coverage Probability):")
    print(
        f"  Mean ± Std: {results_dict['madecp']['mean']:.6f} ± {results_dict['madecp']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['madecp']['min']:.6f}, {results_dict['madecp']['max']:.6f}]"
    )

    print(f"\nCCC (Concordance Correlation Coefficient):")
    print(
        f"  Mean ± Std: {results_dict['ccc']['mean']:.6f} ± {results_dict['ccc']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['ccc']['min']:.6f}, {results_dict['ccc']['max']:.6f}]"
    )

    print(f"\nAverage Variance:")
    print(
        f"  Mean ± Std: {results_dict['average_var']['mean']:.6f} ± {results_dict['average_var']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['average_var']['min']:.6f}, {results_dict['average_var']['max']:.6f}]"
    )

    print("\n" + "=" * 100)
    print("INDIVIDUAL SEED RESULTS")
    print("=" * 100)

    # Individual results table
    print(
        f"\n{'Seed':<6} {'Epochs':<8} {'Time(s)':<10} {'Test MSE':<12} {'Test MAE':<12} {'Test RMSE':<12} {'Test R2':<10} {'MPE':<12} {'Coverage':<10}"
    )
    print("-" * 100)
    for result in all_results:
        print(
            f"{result['random_state']:<6} {result['final_epoch']:<8} {result['training_time']:<10.2f} "
            f"{result['test_mse']:<12.6f} {result['test_mae']:<12.6f} {result['test_rmse']:<12.6f} "
            f"{result['test_r2']:<10.6f} {result['mpe']:<12.6f} {result['coverage_95']:<10.4f}"
        )

    print("\n" + "=" * 100)

In [None]:
# Set up arguments for multi-seed training
args = argparse.Namespace(
    dataset="california_housing",
    model_name="smacnp",
    path="../../",
    train_size=0.8,
    val_size=0.1,
    train_sub_sample=0.1,
    max_epoches=300,
    patience_limit=50,
    n_tasks=30,
    batch_size=1,
    y_size=1,
    start_lr=0.001,
    num_hidden=128,
)

# Run multi-seed training
print("Starting Multi-Seed Training Experiment")
print("=" * 80)

start_time = time.time()

# Train with 10 different seeds
all_results = train_multiple_seeds(args, num_seeds=10)

end_time = time.time()
total_execution_time = end_time - start_time

# Calculate aggregated results
results_dict = calculate_aggregated_results(all_results)

# Print comprehensive results
print_aggregated_results(all_results, results_dict)

# Print total execution time
days, remainder = divmod(total_execution_time, 60 * 60 * 24)
hours, remainder = divmod(remainder, 60 * 60)
minutes, seconds = divmod(remainder, 60)
print(
    f"\nTOTAL EXECUTION TIME: {int(days)} days, {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds"
)