In [None]:
import os
import sys
import re

sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())) + "\\src")

In [None]:
import argparse
from datetime import datetime
from data import *
from metrics import *
from model import *
import numpy as np
from spatial import *
from scipy import sparse
from scipy.stats import norm
from sklearn.model_selection import train_test_split
from tensorboardX import SummaryWriter
import time
from torch_geometric.nn import knn_graph
from torch.optim import Adam
from torch.utils.data import DataLoader
import pandas as pd
import torch
import torch.nn as nn

In [None]:
x, y, coords = get_australia_data()
print(x.shape)
print(y.shape)
print(coords.shape)

In [None]:
def train_single_seed(x, y, coords, args, random_state, model_folder_name=None):
    """
    Train model with a single random seed and return comprehensive results
    """
    # Get args
    dataset = args.dataset
    model_name = args.model_name
    path = args.path
    train_size = args.train_size
    val_size = args.val_size
    test_size = 1 - (args.train_size + args.val_size)
    batched_training = args.batched_training
    batch_size = args.batch_size
    max_epochs = args.max_epochs
    patience_limit = args.patience_limit
    min_improvement = args.min_improvement
    train_crit = args.train_crit
    lr = args.lr
    gnn_hidden_dim = args.gnn_hidden_dim
    gnn_emb_dim = args.gnn_emb_dim
    pe_hidden_dim = args.pe_hidden_dim
    pe_emb_dim = args.pe_emb_dim
    k = args.k
    p_dropout = args.p_dropout
    MAT = args.mat
    uw = args.uw
    lamb = args.lamb
    save_freq = args.save_freq
    print_progress = args.print_progress

    # Set random seed
    set_seed(random_state)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Access and process data
    if dataset == "australia":
        x, y, c = get_australia_data(aux=(x, y, coords))

    # Split data
    n = x.shape[0]
    indices = np.arange(n)
    _, _, _, _, idx_train, idx_val_test = train_test_split(
        x, y, indices, test_size=(1 - train_size), random_state=random_state
    )
    idx_val, idx_test = train_test_split(
        idx_val_test,
        test_size=(1 - train_size - val_size) / (1 - train_size),
        random_state=random_state,
    )

    # Separate x, y and c objects
    train_x, val_x, test_x = x[idx_train], x[idx_val], x[idx_test]
    train_y, val_y, test_y = y[idx_train], y[idx_val], y[idx_test]
    train_c, val_c, test_c = c[idx_train], c[idx_val], c[idx_test]

    # Create MyDataset objects
    train_dataset, val_dataset, test_dataset = (
        MyDataset(train_x, train_y, train_c),
        MyDataset(val_x, val_y, val_c),
        MyDataset(test_x, test_y, test_c),
    )

    # Define train loader
    if batched_training == False:
        batch_size = len(idx_train)
        train_edge_index = knn_graph(train_c, k=k).to(device)
        train_edge_weight = makeEdgeWeight(train_c, train_edge_index).to(device)
        val_edge_index = knn_graph(val_c, k=k).to(device)
        val_edge_weight = makeEdgeWeight(val_c, val_edge_index).to(device)
        test_edge_index = knn_graph(test_c, k=k).to(device)
        test_edge_weight = makeEdgeWeight(test_c, test_edge_index).to(device)
        train_moran_weight_matrix = knn_to_adj(train_edge_index, batch_size)
        with torch.enable_grad():
            train_y_moran = lw_tensor_local_moran(
                train_y, sparse.csr_matrix(train_moran_weight_matrix)
            ).to(device)
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=False, drop_last=False
        )
    else:
        train_edge_index = False
        train_edge_weight = False
        val_edge_index = False
        val_edge_weight = False
        test_edge_index = False
        test_edge_weight = False
        train_y_moran = False
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
        )

    # Make model
    if model_name == "pegcn":
        model = PEGCN(
            num_features_in=train_x.shape[1],
            gnn_hidden_dim=gnn_hidden_dim,
            gnn_emb_dim=gnn_emb_dim,
            pe_hidden_dim=pe_hidden_dim,
            pe_emb_dim=pe_emb_dim,
            k=k,
            p_dropout=p_dropout,
            MAT=MAT,
        ).to(device)
    elif model_name == "pegat":
        model = PEGAT(
            num_features_in=train_x.shape[1],
            gnn_hidden_dim=gnn_hidden_dim,
            gnn_emb_dim=gnn_emb_dim,
            pe_hidden_dim=pe_hidden_dim,
            pe_emb_dim=pe_emb_dim,
            k=k,
            p_dropout=p_dropout,
            MAT=MAT,
        ).to(device)
    elif model_name == "pegsage":
        model = PEGSAGE(
            num_features_in=train_x.shape[1],
            gnn_hidden_dim=gnn_hidden_dim,
            gnn_emb_dim=gnn_emb_dim,
            pe_hidden_dim=pe_hidden_dim,
            pe_emb_dim=pe_emb_dim,
            k=k,
            p_dropout=p_dropout,
            MAT=MAT,
        ).to(device)
    model = model.float()

    # Number of tasks
    if MAT:
        task_num = 2
    else:
        task_num = 1

    # Optimizer and loss function
    loss_wrapper = LossWrapperPEGNN(
        model,
        loss=train_crit,
        k=k,
        batch_size=batch_size,
        task_num=task_num,
        uw=uw,
        lamb=lamb,
    ).to(device)
    optimizer = Adam(loss_wrapper.parameters(), lr=lr)
    score1 = nn.MSELoss()
    score2 = nn.L1Loss()

    # Create model folder name if not provided
    if model_folder_name is None:
        test_ = "final" + "_" + model_name + "_" + dataset + "_k" + str(k)
        test_ = test_ + "_ghid" + str(gnn_hidden_dim)
        test_ = test_ + "_gemb" + str(gnn_emb_dim)
        test_ = test_ + "_phid" + str(pe_hidden_dim)
        test_ = test_ + "_pemb" + str(pe_emb_dim)
        if MAT:
            if uw:
                test_ = test_ + "_mat-uw"
            else:
                test_ = test_ + "_mat-lam" + str(lamb)
        if batched_training == True:
            test_ = test_ + "_bs" + str(batch_size)
        else:
            test_ = test_ + "_bsn"

        now = datetime.now()
        saved_file = "{}_{}{}-{}h{}m{}s".format(
            test_,
            now.strftime("%h"),
            now.strftime("%d"),
            now.strftime("%H"),
            now.strftime("%M"),
            now.strftime("%S"),
        )
    else:
        saved_file = model_folder_name
        # Create test_ variable for logging purposes
        test_ = "final" + "_" + model_name + "_" + dataset + "_k" + str(k)
        test_ = test_ + "_ghid" + str(gnn_hidden_dim)
        test_ = test_ + "_gemb" + str(gnn_emb_dim)
        test_ = test_ + "_phid" + str(pe_hidden_dim)
        test_ = test_ + "_pemb" + str(pe_emb_dim)
        if MAT:
            if uw:
                test_ = test_ + "_mat-uw"
            else:
                test_ = test_ + "_mat-lam" + str(lamb)
        if batched_training == True:
            test_ = test_ + "_bs" + str(batch_size)
        else:
            test_ = test_ + "_bsn"

    log_dir = path + "//trained//{}//log".format(saved_file)

    if not os.path.exists(path + "//trained//{}//data".format(saved_file)):
        os.makedirs(path + "//trained//{}//data".format(saved_file))
    if not os.path.exists(path + "//trained//{}//images".format(saved_file)):
        os.makedirs(path + "//trained//{}//images".format(saved_file))
    with open(path + "//trained//{}//train_notes.txt".format(saved_file), "w") as f:
        f.write("Experiment notes: PE-GAT for Australia dataset \n\n")
        f.write("MODEL_DATA: {}\n".format(test_))
        f.write("DATASET: {}\n".format(dataset))
        f.write("RANDOM_STATE: {}\n".format(random_state))
        f.write(
            "[TRAIN_SIZE, VAL_SIZE, TEST_SIZE]: [{}, {}, {}]\n".format(
                train_size, val_size, test_size
            )
        )
        f.write(
            "BATCH_SIZE: {}\nTRAIN_CRIT: {}\nLEARNING_RATE: {}\n".format(
                batch_size, train_crit, lr
            )
        )
        f.write(
            "MAX_EPOCHS: {}\nPATIENCE_LIMIT: {}\nMIN_IMPROVEMENT: {}\n".format(
                max_epochs, patience_limit, min_improvement
            )
        )
        f.write(
            "GNN_HIDDEN_DIM: {}\nGNN_EMB_DIM: {}\n".format(gnn_hidden_dim, gnn_emb_dim)
        )
        f.write("PE_HIDDEN_DIM: {}\nPE_EMB_DIM: {}\n".format(pe_hidden_dim, pe_emb_dim))
        f.write("K: {}\nP_DROPOUT: {}\n".format(k, p_dropout))
        f.write("MAT: {}\nUW: {}\nLAMBDA: {}\n".format(MAT, uw, lamb))

    writer = SummaryWriter(log_dir)

    # Training loop
    start_time = time.time()
    it_counts = 0
    best_val_mse = float("inf")
    best_epoch = 0
    patience_counter = 0
    found = False
    final_epoch = 0

    for epoch in range(max_epochs):
        for batch in train_loader:
            model.train()
            it_counts += 1
            x = batch[0].to(device).float()
            y = batch[1].to(device).float()
            c = batch[2].to(device).float()

            optimizer.zero_grad()

            if MAT == True & uw == True:
                loss, log_vars = loss_wrapper(
                    x, y, c, train_edge_index, train_edge_weight, train_y_moran
                )
            else:
                loss = loss_wrapper(
                    x, y, c, train_edge_index, train_edge_weight, train_y_moran
                )
            loss.backward()
            optimizer.step()

            # Eval
            if it_counts % save_freq == 0:
                model.eval()
                with torch.no_grad():
                    if MAT:
                        pred_val, _ = model(
                            val_dataset.features.clone().detach().to(device),
                            val_dataset.coords.clone().detach().to(device),
                            val_edge_index,
                            val_edge_weight,
                        )
                    else:
                        pred_val = model(
                            val_dataset.features.clone().detach().to(device),
                            val_dataset.coords.clone().detach().to(device),
                            val_edge_index,
                            val_edge_weight,
                        )
                val_score1 = score1(
                    val_dataset.target.clone().detach().reshape(-1).to(device),
                    pred_val.reshape(-1),
                )
                val_score2 = score2(
                    val_dataset.target.clone().detach().reshape(-1).to(device),
                    pred_val.reshape(-1),
                )

                # Check for improvement
                if best_val_mse > val_score1.item() * (1 + min_improvement):
                    best_val_mse = val_score1.item()
                    best_epoch = epoch
                    patience_counter = 0  # Reset patience
                else:
                    patience_counter += 1  # Increment patience

                # Early stopping check
                if patience_counter > patience_limit:
                    if print_progress:
                        print(
                            f"Stopping early at epoch {epoch}. Best validation MSE: {best_val_mse} at epoch {best_epoch}."
                        )
                    found = True
                    break

                if print_progress:
                    print(
                        "Epoch [%d/%d] - Loss: %f - Valid. (MSE): %f - Valid. (MAE): %f"
                        % (
                            epoch,
                            max_epochs,
                            loss.item(),
                            val_score1.item(),
                            val_score2.item(),
                        )
                    )
                save_path = path + "//trained//{}//ckpts".format(saved_file)
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                torch.save(model.state_dict(), save_path + "//" + "model_state.pt")
                writer.add_scalar("Validation (MSE)", val_score1.item(), it_counts)
                writer.add_scalar("Validation (MAE)", val_score2.item(), it_counts)
            writer.add_scalar("Training loss", loss.item(), it_counts)
            writer.flush()
        final_epoch = epoch
        if found:
            break

    end_time = time.time()
    training_time = end_time - start_time

    # Test eval
    model.eval()
    with torch.no_grad():
        if MAT:
            pred_test, _ = model(
                test_dataset.features.clone().detach().to(device),
                test_dataset.coords.clone().detach().to(device),
                test_edge_index,
                test_edge_weight,
            )
        else:
            pred_test = model(
                test_dataset.features.clone().detach().to(device),
                test_dataset.coords.clone().detach().to(device),
                test_edge_index,
                test_edge_weight,
            )
    test_mse = score1(
        test_dataset.target.clone().detach().reshape(-1).to(device),
        pred_test.reshape(-1),
    )
    test_mae = score2(
        test_dataset.target.clone().detach().reshape(-1).to(device),
        pred_test.reshape(-1),
    )

    # Calculate additional metrics
    pred_val = pred_val.reshape(-1)
    pred_test = pred_test.reshape(-1)

    residuals = val_y - pred_val
    variance = torch.var(residuals).item()

    test_y_np = test_y.numpy()
    pred_test_np = pred_test.numpy()

    # Calculate MPE - ensure all inputs are PyTorch tensors
    try:
        mpe_value = mpe(test_y, pred_test, var_pred=variance)
    except Exception as e:
        print(f"MPE calculation error: {e}")
        mpe_value = float("inf")

    # Calculate calibration metrics
    try:
        taus = np.round(np.arange(0.01, 1, 0.01), 2)
        means = np.repeat(pred_test_np[:, np.newaxis], len(taus), axis=1)
        std_devs = np.sqrt(variance)
        quantiles = norm.ppf(taus, loc=means, scale=std_devs)

        comparison = (test_y_np[:, np.newaxis] <= quantiles).astype(int)
        comparison_mean = pd.DataFrame(comparison, columns=taus).mean().reset_index()

        calibration = np.sum((comparison_mean["index"] - comparison_mean[0]) ** 2)
        madecp = np.mean(np.abs(comparison_mean["index"] - comparison_mean[0]))

        # Calculate 95% prediction interval coverage
        tau_lower, tau_upper = 0.025, 0.975
        q_lower = norm.ppf(tau_lower, loc=pred_test_np, scale=std_devs)
        q_upper = norm.ppf(tau_upper, loc=pred_test_np, scale=std_devs)
        coverage_95 = np.mean((test_y_np >= q_lower) & (test_y_np <= q_upper))

    except Exception as e:
        print(f"Calibration metrics calculation error: {e}")
        calibration = float("inf")
        madecp = float("inf")
        coverage_95 = 0.0

    # Return comprehensive results
    results = {
        "random_state": random_state,
        "model_folder": saved_file,
        "final_epoch": final_epoch,
        "best_epoch": best_epoch,
        "training_time": training_time,
        "test_mse": test_mse.item(),
        "test_mae": test_mae.item(),
        "val_mse": best_val_mse,
        "mpe": mpe_value,
        "calibration": calibration,
        "madecp": madecp,
        "coverage_95": coverage_95,
        "variance": variance,
    }

    return results

In [None]:
def train_multiple_seeds(x, y, coords, args, seeds=None, num_seeds=10):
    """
    Train model with multiple random seeds and return aggregated results
    """
    if seeds is None:
        seeds = list(range(42, 42 + num_seeds))

    all_results = []
    base_model_name = None

    print(f"Starting training with {len(seeds)} different random seeds: {seeds}")
    print("=" * 80)

    for i, seed in enumerate(seeds):
        print(f"\n--- Training with seed {seed} ({i+1}/{len(seeds)}) ---")

        # Create a consistent model folder name for the first run
        if i == 0:
            # Generate base model name
            dataset = args.dataset
            model_name = args.model_name
            k = args.k
            gnn_hidden_dim = args.gnn_hidden_dim
            gnn_emb_dim = args.gnn_emb_dim
            pe_hidden_dim = args.pe_hidden_dim
            pe_emb_dim = args.pe_emb_dim
            MAT = args.mat
            uw = args.uw
            lamb = args.lamb
            batched_training = args.batched_training
            batch_size = args.batch_size

            test_ = "final" + "_" + model_name + "_" + dataset + "_k" + str(k)
            test_ = test_ + "_ghid" + str(gnn_hidden_dim)
            test_ = test_ + "_gemb" + str(gnn_emb_dim)
            test_ = test_ + "_phid" + str(pe_hidden_dim)
            test_ = test_ + "_pemb" + str(pe_emb_dim)
            if MAT:
                if uw:
                    test_ = test_ + "_mat-uw"
                else:
                    test_ = test_ + "_mat-lam" + str(lamb)
            if batched_training == True:
                test_ = test_ + "_bs" + str(batch_size)
            else:
                test_ = test_ + "_bsn"

            now = datetime.now()
            base_model_name = "{}_{}{}-{}h{}m{}s".format(
                test_,
                now.strftime("%h"),
                now.strftime("%d"),
                now.strftime("%H"),
                now.strftime("%M"),
                now.strftime("%S"),
            )

        # Train with current seed with timeout protection
        try:
            print(f"Starting training for seed {seed}...")
            results = train_single_seed(
                x,
                y,
                coords,
                args,
                seed,
                model_folder_name=f"{base_model_name}_seed{seed}",
            )
            all_results.append(results)

            print(f"Seed {seed} completed:")
            print(f"  - Final epoch: {results['final_epoch']}")
            print(f"  - Best epoch: {results['best_epoch']}")
            print(f"  - Training time: {results['training_time']:.2f}s")
            print(f"  - Test MSE: {results['test_mse']:.6f}")
            print(f"  - Test MAE: {results['test_mae']:.6f}")
            print(f"  - MPE: {results['mpe']:.6f}")
            print(f"  - Coverage 95%: {results['coverage_95']:.4f}")

        except Exception as e:
            print(f"Error training seed {seed}: {e}")
            # Create a dummy result to maintain consistency
            dummy_result = {
                "random_state": seed,
                "model_folder": f"{base_model_name}_seed{seed}",
                "final_epoch": 0,
                "best_epoch": 0,
                "training_time": 0.0,
                "test_mse": float("inf"),
                "test_mae": float("inf"),
                "val_mse": float("inf"),
                "mpe": float("inf"),
                "calibration": float("inf"),
                "madecp": float("inf"),
                "coverage_95": 0.0,
                "variance": 0.0,
            }
            all_results.append(dummy_result)
            print(f"Added dummy result for failed seed {seed}")

    return all_results

In [None]:
def calculate_aggregated_results(all_results):
    """
    Calculate mean and standard deviation for all metrics across seeds
    """

    # Extract all metrics
    metrics = [
        "final_epoch",
        "best_epoch",
        "training_time",
        "test_mse",
        "test_mae",
        "val_mse",
        "mpe",
        "calibration",
        "madecp",
        "coverage_95",
        "variance",
    ]

    results_dict = {}
    for metric in metrics:
        values = [result[metric] for result in all_results]
        results_dict[metric] = {
            "mean": np.mean(values),
            "std": np.std(values),
            "min": np.min(values),
            "max": np.max(values),
        }

    return results_dict


def print_aggregated_results(all_results, results_dict):
    """
    Print comprehensive results summary
    """
    print("\n" + "=" * 100)
    print("COMPREHENSIVE RESULTS SUMMARY")
    print("=" * 100)

    print(f"\nNumber of seeds: {len(all_results)}")
    print(f"Seeds used: {[r['random_state'] for r in all_results]}")

    print(f"\nBase model folder: {all_results[0]['model_folder'].split('_seed')[0]}")

    print("\n" + "-" * 80)
    print("TRAINING METRICS")
    print("-" * 80)

    # Training metrics
    print(f"Final Epochs:")
    print(
        f"  Mean ± Std: {results_dict['final_epoch']['mean']:.1f} ± {results_dict['final_epoch']['std']:.1f}"
    )
    print(
        f"  Range: [{results_dict['final_epoch']['min']:.0f}, {results_dict['final_epoch']['max']:.0f}]"
    )

    print(f"\nBest Epochs:")
    print(
        f"  Mean ± Std: {results_dict['best_epoch']['mean']:.1f} ± {results_dict['best_epoch']['std']:.1f}"
    )
    print(
        f"  Range: [{results_dict['best_epoch']['min']:.0f}, {results_dict['best_epoch']['max']:.0f}]"
    )

    print(f"\nTraining Time (seconds):")
    print(
        f"  Mean ± Std: {results_dict['training_time']['mean']:.2f} ± {results_dict['training_time']['std']:.2f}"
    )
    print(
        f"  Range: [{results_dict['training_time']['min']:.2f}, {results_dict['training_time']['max']:.2f}]"
    )

    print("\n" + "-" * 80)
    print("PERFORMANCE METRICS")
    print("-" * 80)

    # Performance metrics
    print(f"Test MSE:")
    print(
        f"  Mean ± Std: {results_dict['test_mse']['mean']:.6f} ± {results_dict['test_mse']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['test_mse']['min']:.6f}, {results_dict['test_mse']['max']:.6f}]"
    )

    print(f"\nTest MAE:")
    print(
        f"  Mean ± Std: {results_dict['test_mae']['mean']:.6f} ± {results_dict['test_mae']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['test_mae']['min']:.6f}, {results_dict['test_mae']['max']:.6f}]"
    )

    print(f"\nValidation MSE:")
    print(
        f"  Mean ± Std: {results_dict['val_mse']['mean']:.6f} ± {results_dict['val_mse']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['val_mse']['min']:.6f}, {results_dict['val_mse']['max']:.6f}]"
    )

    print("\n" + "-" * 80)
    print("UNCERTAINTY QUANTIFICATION METRICS")
    print("-" * 80)

    # Uncertainty metrics
    print(f"MPE (Mean Prediction Error):")
    print(
        f"  Mean ± Std: {results_dict['mpe']['mean']:.6f} ± {results_dict['mpe']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['mpe']['min']:.6f}, {results_dict['mpe']['max']:.6f}]"
    )

    print(f"\n95% Prediction Interval Coverage:")
    print(
        f"  Mean ± Std: {results_dict['coverage_95']['mean']:.4f} ± {results_dict['coverage_95']['std']:.4f}"
    )
    print(
        f"  Range: [{results_dict['coverage_95']['min']:.4f}, {results_dict['coverage_95']['max']:.4f}]"
    )

    print(f"\nCalibration Score:")
    print(
        f"  Mean ± Std: {results_dict['calibration']['mean']:.6f} ± {results_dict['calibration']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['calibration']['min']:.6f}, {results_dict['calibration']['max']:.6f}]"
    )

    print(f"\nMADECP (Mean Absolute Deviation from Expected Coverage Probability):")
    print(
        f"  Mean ± Std: {results_dict['madecp']['mean']:.6f} ± {results_dict['madecp']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['madecp']['min']:.6f}, {results_dict['madecp']['max']:.6f}]"
    )

    print(f"\nPrediction Variance:")
    print(
        f"  Mean ± Std: {results_dict['variance']['mean']:.6f} ± {results_dict['variance']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['variance']['min']:.6f}, {results_dict['variance']['max']:.6f}]"
    )

    print("\n" + "=" * 100)
    print("INDIVIDUAL SEED RESULTS")
    print("=" * 100)

    # Individual results table
    print(
        f"\n{'Seed':<6} {'Epochs':<8} {'Time(s)':<10} {'Test MSE':<12} {'Test MAE':<12} {'MPE':<12} {'Coverage':<10}"
    )
    print("-" * 80)
    for result in all_results:
        print(
            f"{result['random_state']:<6} {result['final_epoch']:<8} {result['training_time']:<10.2f} "
            f"{result['test_mse']:<12.6f} {result['test_mae']:<12.6f} {result['mpe']:<12.6f} "
            f"{result['coverage_95']:<10.4f}"
        )

    print("\n" + "=" * 100)

In [None]:
# Set up arguments for multi-seed training
args = argparse.Namespace(
    dataset="australia",
    model_name="pegat",
    path="../../",
    train_size=0.8,
    val_size=0.1,
    batched_training=True,
    batch_size=2048,
    max_epochs=1000,
    patience_limit=50,
    min_improvement=0.01,
    train_crit="mse",
    lr=1e-3,
    gnn_hidden_dim=32,
    gnn_emb_dim=32,
    pe_hidden_dim=128,
    pe_emb_dim=64,
    k=5,
    p_dropout=0.5,
    mat=True,
    uw=False,
    lamb=0.25,
    save_freq=5,
    print_progress=True,
)

# Run multi-seed training
print("Starting Multi-Seed Training Experiment")
print("=" * 80)

start_time = time.time()

# Train with 10 different seeds
all_results = train_multiple_seeds(x, y, coords, args, num_seeds=10)

end_time = time.time()
total_execution_time = end_time - start_time

# Calculate aggregated results
results_dict = calculate_aggregated_results(all_results)

# Print comprehensive results
print_aggregated_results(all_results, results_dict)

# Print total execution time
days, remainder = divmod(total_execution_time, 60 * 60 * 24)
hours, remainder = divmod(remainder, 60 * 60)
minutes, seconds = divmod(remainder, 60)
print(
    f"\nTOTAL EXECUTION TIME: {int(days)} days, {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds"
)

In [None]:
# Access the trained models (all seeds will be saved with the base model name)
base_model_folder = all_results[0]["model_folder"].split("_seed")[0]
print(f"Base model folder: {base_model_folder}")

# List all model folders for this experiment
models_lst = os.listdir(f"{args.path}trained/")
experiment_folders = [m for m in models_lst if base_model_folder in m]
print(f"Experiment folders: {experiment_folders}")

# Load the best performing model (lowest test MSE) for detailed analysis
best_result = min(all_results, key=lambda x: x["test_mse"])
print(
    f"Best performing model: Seed {best_result['random_state']} with Test MSE: {best_result['test_mse']:.6f}"
)

# Load the best model for analysis
model_folder = best_result["model_folder"]
model_path = f"{args.path}trained/{model_folder}/ckpts/model_state.pt"

# Set up data and model for analysis (using the same seed as best result)
set_seed(best_result["random_state"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Access and process data (same as training)
if args.dataset == "australia":
    x, y, c = get_australia_data(aux=(x, y, coords))

# Split data (same as training)
n = x.shape[0]
indices = np.arange(n)
_, _, _, _, idx_train, idx_val_test = train_test_split(
    x,
    y,
    indices,
    test_size=(1 - args.train_size),
    random_state=best_result["random_state"],
)
idx_val, idx_test = train_test_split(
    idx_val_test,
    test_size=(1 - args.train_size - args.val_size) / (1 - args.train_size),
    random_state=best_result["random_state"],
)

# Separate x, y and c objects
train_x, val_x, test_x = x[idx_train], x[idx_val], x[idx_test]
train_y, val_y, test_y = y[idx_train], y[idx_val], y[idx_test]
train_c, val_c, test_c = c[idx_train], c[idx_val], c[idx_test]

# Create MyDataset objects
train_dataset, val_dataset, test_dataset = (
    MyDataset(train_x, train_y, train_c),
    MyDataset(val_x, val_y, val_c),
    MyDataset(test_x, test_y, test_c),
)

# Define edge indices and weights
if args.batched_training == False:
    batch_size = len(idx_train)
    train_edge_index = knn_graph(train_c, k=args.k).to(device)
    train_edge_weight = makeEdgeWeight(train_c, train_edge_index).to(device)
    val_edge_index = knn_graph(val_c, k=args.k).to(device)
    val_edge_weight = makeEdgeWeight(val_c, val_edge_index).to(device)
    test_edge_index = knn_graph(test_c, k=args.k).to(device)
    test_edge_weight = makeEdgeWeight(test_c, test_edge_index).to(device)
    train_moran_weight_matrix = knn_to_adj(train_edge_index, batch_size)
    with torch.enable_grad():
        train_y_moran = lw_tensor_local_moran(
            train_y, sparse.csr_matrix(train_moran_weight_matrix)
        ).to(device)
else:
    train_edge_index = False
    train_edge_weight = False
    val_edge_index = False
    val_edge_weight = False
    test_edge_index = False
    test_edge_weight = False
    train_y_moran = False

# Make model
if args.model_name == "pegcn":
    model = PEGCN(
        num_features_in=train_x.shape[1],
        gnn_hidden_dim=args.gnn_hidden_dim,
        gnn_emb_dim=args.gnn_emb_dim,
        pe_hidden_dim=args.pe_hidden_dim,
        pe_emb_dim=args.pe_emb_dim,
        k=args.k,
        p_dropout=args.p_dropout,
        MAT=args.mat,
    ).to(device)
elif args.model_name == "pegat":
    model = PEGAT(
        num_features_in=train_x.shape[1],
        gnn_hidden_dim=args.gnn_hidden_dim,
        gnn_emb_dim=args.gnn_emb_dim,
        pe_hidden_dim=args.pe_hidden_dim,
        pe_emb_dim=args.pe_emb_dim,
        k=args.k,
        p_dropout=args.p_dropout,
        MAT=args.mat,
    ).to(device)
elif args.model_name == "pegsage":
    model = PEGSAGE(
        num_features_in=train_x.shape[1],
        gnn_hidden_dim=args.gnn_hidden_dim,
        gnn_emb_dim=args.gnn_emb_dim,
        pe_hidden_dim=args.pe_hidden_dim,
        pe_emb_dim=args.pe_emb_dim,
        k=args.k,
        p_dropout=args.p_dropout,
        MAT=args.mat,
    ).to(device)

# Load the best model
model.load_state_dict(torch.load(model_path))
model = model.float()

# Model analysis
model.eval()
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")