In [None]:
import os
import sys
import re

sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())) + "\\src")

In [None]:
import argparse
from datetime import datetime
from data import *
from metrics import *
from model import *
import monotonicnetworks as lmn
import numpy as np
from spatial import *
from scipy import sparse
from scipy.stats import norm
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
from tensorboardX import SummaryWriter
import time
from torch_geometric.nn import knn_graph
from torch.optim import Adam
from torch.utils.data import DataLoader
import pandas as pd
import torch
import torch.nn as nn

In [None]:
class PEGQSAGE(nn.Module):
    """
    GraphSAGE with a positional encoder, quantile regression (as proposed by Si, 2020),
    and optional ybar input.
    """

    def __init__(
        self,
        num_features_in: int = 6,
        num_features_out: int = 1,
        gnn_hidden_dim: int = 32,
        gnn_emb_dim: int = 32,
        pe_hidden_dim: int = 128,
        pe_emb_dim: int = 64,
        final_emb_dim: int = 8,
        k: int = 5,
        p_dropout: float = 0.5,
        MAT: bool = False,
        KNN: bool = True,
    ) -> None:
        """
        Initialize the PEGQSAGE model.

        Parameters
        ----------
        num_features_in : int, optional
            Number of input features for the GCN, by default 6.
        num_features_out : int, optional
            Number of output features (e.g., for quantile regression), by default 1.
        gnn_hidden_dim : int, optional
            Dimension of the GCN hidden layer, by default 32.
        gnn_emb_dim : int, optional
            Dimension of the GCN embedding layer, by default 32.
        pe_hidden_dim : int, optional
            Dimension of the spatial encoder hidden layer, by default 128.
        pe_emb_dim : int, optional
            Dimension of the spatial encoder embedding layer, by default 64.
        final_emb_dim : int, optional
            Dimension of the final merged embedding, by default 8.
        k : int, optional
            Number of nearest neighbors for the KNN graph, by default 5.
        p_dropout : float, optional
            Dropout probability, by default 0.5.
        MAT : bool, optional
            If True, enable an auxiliary task for Moran's I, by default False.
        KNN : bool, optional
            If True, include `ybar` as an additional input to the monotonic subnet.
            By default True.
        """
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.gnn_hidden_dim = gnn_hidden_dim
        self.gnn_emb_dim = gnn_emb_dim
        self.pe_hidden_dim = pe_hidden_dim
        self.pe_emb_dim = pe_emb_dim
        self.final_emb_dim = final_emb_dim
        self.k = k
        self.p_dropout = p_dropout
        self.MAT = MAT
        self.KNN = KNN

        # GraphSAGE layers
        self.conv1 = SAGEConv(num_features_in, gnn_hidden_dim)
        self.conv2 = SAGEConv(gnn_hidden_dim, gnn_emb_dim)

        # Spatial encoder
        self.spenc = GridCellSpatialRelationEncoder(
            spa_embed_dim=pe_hidden_dim, ffn=True, min_radius=1e-6, max_radius=360
        )
        self.dec_pe = nn.Sequential(
            nn.Linear(pe_hidden_dim, pe_hidden_dim // 2),
            nn.Tanh(),
            nn.Linear(pe_hidden_dim // 2, pe_hidden_dim // 4),
            nn.Tanh(),
            nn.Linear(pe_hidden_dim // 4, pe_emb_dim),
        )

        # Merge GraphSAGE and positional embeddings
        self.dec = nn.Sequential(
            nn.Linear(pe_emb_dim + gnn_emb_dim, final_emb_dim * 4),
            nn.Tanh(),
            nn.Linear(final_emb_dim * 4, final_emb_dim * 2),
            nn.Tanh(),
            nn.Linear(final_emb_dim * 2, final_emb_dim),
        )

        # Monotonic constraints setup
        if KNN:
            in_dim = final_emb_dim + 2  # includes tau and ybar
            monotonic_constraints = [0] * final_emb_dim + [1, 0]
        else:
            in_dim = final_emb_dim + 1  # only includes tau
            monotonic_constraints = [0] * final_emb_dim + [1]

        net = nn.Sequential(
            lmn.LipschitzLinear(in_dim, 32, kind="one-inf"),
            lmn.GroupSort(2),
            lmn.LipschitzLinear(32, num_features_out, kind="inf"),
        )

        # Monotonic network for quantile regression
        self.monotonic_subnet = lmn.MonotonicWrapper(
            lipschitz_module=net,
            lipschitz_const=1.0,
            monotonic_constraints=monotonic_constraints,
        )

        # Optional auxiliary task for Moran's I
        if MAT:
            self.fc_morans = lmn.LipschitzLinear(
                final_emb_dim, num_features_out, kind="inf"
            )

    def forward(
        self,
        x: torch.Tensor,
        c: torch.Tensor,
        ei: torch.Tensor | None,
        ew: torch.Tensor | None,
        tau: torch.Tensor,
        ybar: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the PEGQSAGE model.

        Parameters
        ----------
        x : torch.Tensor
            Node features, shape [num_nodes, num_features_in].
        c : torch.Tensor
            Node coordinates, shape [num_nodes, coord_dim].
        ei : torch.Tensor | None
            Edge indices if precomputed. If None, KNN will be constructed on the fly.
        ew : torch.Tensor | None
            Edge weights if precomputed. If None, they will be computed on the fly.
        tau : torch.Tensor
            Quantile levels for each node, shape [num_nodes].
        ybar : torch.Tensor
            Optional additional feature (e.g., local average or uncertainty),
            shape [num_nodes].

        Returns
        -------
        torch.Tensor | tuple[torch.Tensor, torch.Tensor]
            If MAT is False, returns the quantile regression output. If MAT is True,
            returns (quantile output, Moran's I output).
        """
        x = x.float().to(self.device)
        c = c.float().to(self.device)
        tau = tau.float().to(self.device)
        ybar = ybar.float().to(self.device)

        # Handle graph edges
        if torch.is_tensor(ei) and torch.is_tensor(ew):
            edge_index = ei
            # Edge weights are not directly used in SAGEConv here
        else:
            edge_index = knn_graph(c, k=self.k).to(self.device)
            _ = makeEdgeWeight(c, edge_index).to(self.device)  # not used directly

        # GraphSAGE forward pass
        x_emb = F.relu(self.conv1(x, edge_index))
        x_emb = F.dropout(x_emb, self.p_dropout, training=self.training)
        x_emb = F.relu(self.conv2(x_emb, edge_index))
        x_emb = F.dropout(x_emb, self.p_dropout, training=self.training)

        # Positional encoder forward pass
        c_reshaped = c.reshape(
            1, c.shape[0], c.shape[1]
        )  # shape [1, num_nodes, coord_dim]
        c_emb = self.spenc(
            c_reshaped.detach().cpu().numpy()
        )  # shape [1, num_nodes, pe_hidden_dim]
        c_emb = c_emb.reshape(c_emb.shape[1], c_emb.shape[2])
        c_emb = self.dec_pe(c_emb).float().to(self.device)

        # Merge GraphSAGE and positional embeddings
        l_emb = torch.cat((c_emb, x_emb), dim=1)
        phi_emb = self.dec(l_emb).float()

        # Build monotonic input
        tau = tau.view(-1, 1)
        phi_til_emb = torch.cat((phi_emb, tau), dim=1)
        if self.KNN:
            ybar = ybar.view(-1, 1)
            phi_til_emb = torch.cat((phi_til_emb, ybar), dim=1)

        # Monotonic regression output (quantile)
        output = self.monotonic_subnet(phi_til_emb)

        # Auxiliary task (Moran's I) if enabled
        if self.MAT:
            morans_output = self.fc_morans(phi_emb)
            return output, morans_output
        return output

In [None]:
class LossWrapperQuantile(nn.Module):
    """
    A wrapper that computes quantile loss (pinball loss) for a single-task
    quantile regression GNN model.
    """

    def __init__(
        self,
        model: nn.Module,
        task_num: int = 1,
        uw: bool = False,
        lamb: float = 0.0,
        k: int = 5,
        batch_size: int = 2048,
    ) -> None:
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model
        self.task_num = task_num
        self.uw = uw
        self.lamb = lamb
        self.k = k
        self.batch_size = batch_size

        # For multi-task settings (not used in this wrapper)
        if self.task_num > 1:
            self.log_vars = nn.Parameter(torch.zeros(self.task_num))

    def forward(
        self,
        input_data: torch.Tensor,
        targets: torch.Tensor,
        coords: torch.Tensor,
        edge_index: torch.Tensor | None,
        edge_weight: torch.Tensor | None,
        morans_input: torch.Tensor | None,
        tau: torch.Tensor,
        ybar: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute the quantile regression loss (pinball loss) for a single-task scenario.

        Parameters
        ----------
        input_data : torch.Tensor
            Node features.
        targets : torch.Tensor
            Ground-truth values.
        coords : torch.Tensor
            Node coordinates.
        edge_index : torch.Tensor | None
            Precomputed edge indices (if any).
        edge_weight : torch.Tensor | None
            Precomputed edge weights (if any).
        morans_input : torch.Tensor | None
            Unused in single-task mode.
        tau : torch.Tensor
            Quantile levels for each sample.
        ybar : torch.Tensor
            An auxiliary feature (e.g., local average) for the model, if used.

        Returns
        -------
        torch.Tensor
            The mean pinball loss.
        """
        if self.task_num == 1:
            outputs = self.model(
                input_data, coords, edge_index, edge_weight, probit(tau), ybar
            )
            return self.pinball_loss(
                targets.float().view(-1), outputs.float().view(-1), tau.float().view(-1)
            )
        else:
            raise ValueError("PEGQNN can only be used with task_num=1.")

    def pinball_loss(
        self, y_true: torch.Tensor, y_pred: torch.Tensor, tau: torch.Tensor
    ) -> torch.Tensor:
        """
        Pinball loss for quantile regression.

        Parameters
        ----------
        y_true : torch.Tensor
            Ground-truth values.
        y_pred : torch.Tensor
            Model predictions.
        tau : torch.Tensor
            Quantile levels.

        Returns
        -------
        torch.Tensor
            The mean pinball loss.
        """
        if y_true.size() != tau.size():
            raise ValueError("The size of y_true and tau must match.")

        delta = y_true - y_pred
        loss = torch.where(delta > 0, tau * delta, (tau - 1.0) * delta)
        return loss.mean()

In [None]:
def train_single_seed(args, random_state, model_folder_name=None):
    """
    Train model with a single random seed and return comprehensive results
    """
    # Get args
    dataset = args.dataset
    model_name = args.model_name
    path = args.path
    train_size = args.train_size
    val_size = args.val_size
    test_size = 1 - (args.train_size + args.val_size)
    batched_training = args.batched_training
    batch_size = args.batch_size
    max_epochs = args.max_epochs
    patience_limit = args.patience_limit
    min_improvement = args.min_improvement
    train_crit = args.train_crit
    lr = args.lr
    gnn_hidden_dim = args.gnn_hidden_dim
    gnn_emb_dim = args.gnn_emb_dim
    pe_hidden_dim = args.pe_hidden_dim
    pe_emb_dim = args.pe_emb_dim
    final_emb_dim = args.final_emb_dim
    k = args.k
    p_dropout = args.p_dropout
    MAT = args.mat
    uw = args.uw
    lamb = args.lamb
    KNN = args.knn
    save_freq = args.save_freq
    print_progress = args.print_progress

    # Set random seed
    set_seed(random_state)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Access and process data
    if dataset == "california_housing":
        x, y, c = get_california_housing_data()

    # Split data
    n = x.shape[0]
    indices = np.arange(n)
    _, _, _, _, idx_train, idx_val_test = train_test_split(
        x, y, indices, test_size=(1 - train_size), random_state=random_state
    )
    idx_val, idx_test = train_test_split(
        idx_val_test,
        test_size=(1 - train_size - val_size) / (1 - train_size),
        random_state=random_state,
    )

    # Separate x, y and c objects
    train_x, val_x, test_x = x[idx_train], x[idx_val], x[idx_test]
    train_y, val_y, test_y = y[idx_train], y[idx_val], y[idx_test]
    train_c, val_c, test_c = c[idx_train], c[idx_val], c[idx_test]

    # Compute ybar for the training set
    train_ybar = torch.tensor(compute_ybar(train_c, train_y, k))

    # Compute ybar for the validation and test set using training data
    train_c_rad = torch.deg2rad(train_c)
    val_c_rad = torch.deg2rad(val_c)
    test_c_rad = torch.deg2rad(test_c)
    nbrs = NearestNeighbors(n_neighbors=k, algorithm="brute", metric="haversine").fit(
        train_c_rad
    )
    _, val_indices = nbrs.kneighbors(val_c_rad)
    _, test_indices = nbrs.kneighbors(test_c_rad)
    val_ybar = torch.tensor(
        np.array([train_y[val_indices[i]].mean() for i in range(len(val_c))])
    )
    test_ybar = torch.tensor(
        np.array([train_y[test_indices[i]].mean() for i in range(len(test_c))])
    )

    # Create MyDataset objects
    train_dataset, val_dataset, test_dataset = (
        MyDataset(train_x, train_y, train_c, train_ybar),
        MyDataset(val_x, val_y, val_c, val_ybar),
        MyDataset(test_x, test_y, test_c, test_ybar),
    )

    # Define train loader
    if batched_training == False:
        batch_size = len(idx_train)
        train_edge_index = knn_graph(train_c, k=k).to(device)
        train_edge_weight = makeEdgeWeight(train_c, train_edge_index).to(device)
        val_edge_index = knn_graph(val_c, k=k).to(device)
        val_edge_weight = makeEdgeWeight(val_c, val_edge_index).to(device)
        test_edge_index = knn_graph(test_c, k=k).to(device)
        test_edge_weight = makeEdgeWeight(test_c, test_edge_index).to(device)
        train_moran_weight_matrix = knn_to_adj(train_edge_index, batch_size)
        with torch.enable_grad():
            train_y_moran = lw_tensor_local_moran(
                train_y, sparse.csr_matrix(train_moran_weight_matrix)
            ).to(device)
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=False, drop_last=False
        )
    else:
        train_edge_index = False
        train_edge_weight = False
        val_edge_index = False
        val_edge_weight = False
        test_edge_index = False
        test_edge_weight = False
        train_y_moran = False
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
        )

    # Make model
    if model_name == "pegqcn-ybar":
        model = PEGQCN(
            num_features_in=train_x.shape[1],
            gnn_hidden_dim=gnn_hidden_dim,
            gnn_emb_dim=gnn_emb_dim,
            pe_hidden_dim=pe_hidden_dim,
            pe_emb_dim=pe_emb_dim,
            final_emb_dim=final_emb_dim,
            k=k,
            p_dropout=p_dropout,
            MAT=MAT,
            KNN=KNN,
        ).to(device)
    elif model_name == "pegqat-ybar":
        model = PEGQAT(
            num_features_in=train_x.shape[1],
            gnn_hidden_dim=gnn_hidden_dim,
            gnn_emb_dim=gnn_emb_dim,
            pe_hidden_dim=pe_hidden_dim,
            pe_emb_dim=pe_emb_dim,
            final_emb_dim=final_emb_dim,
            k=k,
            p_dropout=p_dropout,
            MAT=MAT,
            KNN=KNN,
        ).to(device)
    elif model_name == "pegqsage-ybar":
        model = PEGQSAGE(
            num_features_in=train_x.shape[1],
            gnn_hidden_dim=gnn_hidden_dim,
            gnn_emb_dim=gnn_emb_dim,
            pe_hidden_dim=pe_hidden_dim,
            pe_emb_dim=pe_emb_dim,
            final_emb_dim=final_emb_dim,
            k=k,
            p_dropout=p_dropout,
            MAT=MAT,
            KNN=KNN,
        ).to(device)
    model = model.float()

    # Number of tasks
    if MAT:
        task_num = 2
    else:
        task_num = 1

    # Optimizer and loss function
    loss_wrapper = LossWrapperQuantile(
        model,
        k=k,
        batch_size=batch_size,
        task_num=task_num,
        uw=uw,
        lamb=lamb,
    ).to(device)
    optimizer = Adam(loss_wrapper.parameters(), lr=lr)
    score1 = nn.MSELoss()
    score2 = nn.L1Loss()

    # Create model folder name if not provided
    if model_folder_name is None:
        test_ = "final" + "_" + model_name + "_" + dataset + "_k" + str(k)
        test_ = test_ + "_ghid" + str(gnn_hidden_dim)
        test_ = test_ + "_gemb" + str(gnn_emb_dim)
        test_ = test_ + "_phid" + str(pe_hidden_dim)
        test_ = test_ + "_pemb" + str(pe_emb_dim)
        if batched_training == True:
            test_ = test_ + "_bs" + str(batch_size)
        else:
            test_ = test_ + "_bsn"

        now = datetime.now()
        saved_file = "{}_{}{}-{}h{}m{}s".format(
            test_,
            now.strftime("%h"),
            now.strftime("%d"),
            now.strftime("%H"),
            now.strftime("%M"),
            now.strftime("%S"),
        )
    else:
        saved_file = model_folder_name
        # Create test_ variable for logging purposes
        test_ = "final" + "_" + model_name + "_" + dataset + "_k" + str(k)
        test_ = test_ + "_ghid" + str(gnn_hidden_dim)
        test_ = test_ + "_gemb" + str(gnn_emb_dim)
        test_ = test_ + "_phid" + str(pe_hidden_dim)
        test_ = test_ + "_pemb" + str(pe_emb_dim)
        if batched_training == True:
            test_ = test_ + "_bs" + str(batch_size)
        else:
            test_ = test_ + "_bsn"

    log_dir = path + "//trained//{}//log".format(saved_file)

    if not os.path.exists(path + "//trained//{}//data".format(saved_file)):
        os.makedirs(path + "//trained//{}//data".format(saved_file))
    if not os.path.exists(path + "//trained//{}//images".format(saved_file)):
        os.makedirs(path + "//trained//{}//images".format(saved_file))
    with open(path + "//trained//{}//train_notes.txt".format(saved_file), "w") as f:
        f.write("Experiment notes: PE-GQSAGE for California Housing dataset \n\n")
        f.write("MODEL_DATA: {}\n".format(test_))
        f.write("DATASET: {}\n".format(dataset))
        f.write("RANDOM_STATE: {}\n".format(random_state))
        f.write(
            "[TRAIN_SIZE, VAL_SIZE, TEST_SIZE]: [{}, {}, {}]\n".format(
                train_size, val_size, test_size
            )
        )
        f.write(
            "BATCH_SIZE: {}\nTRAIN_CRIT: {}\nLEARNING_RATE: {}\n".format(
                batch_size, train_crit, lr
            )
        )
        f.write(
            "MAX_EPOCHS: {}\nPATIENCE_LIMIT: {}\nMIN_IMPROVEMENT: {}\n".format(
                max_epochs, patience_limit, min_improvement
            )
        )
        f.write(
            "GNN_HIDDEN_DIM: {}\nGNN_EMB_DIM: {}\n".format(gnn_hidden_dim, gnn_emb_dim)
        )
        f.write("PE_HIDDEN_DIM: {}\nPE_EMB_DIM: {}\n".format(pe_hidden_dim, pe_emb_dim))
        f.write("FINAL_EMB_DIM: {}\n".format(final_emb_dim))
        f.write("K: {}\nP_DROPOUT: {}\n".format(k, p_dropout))
        f.write("KNN: {}\n".format(KNN))

    writer = SummaryWriter(log_dir)

    # Training loop
    start_time = time.time()
    it_counts = 0
    best_val_mse = float("inf")
    best_epoch = 0
    patience_counter = 0
    found = False
    final_epoch = 0

    for epoch in range(max_epochs):
        for batch in train_loader:
            model.train()
            it_counts += 1
            x = batch[0].to(device).float()
            y = batch[1].to(device).float()
            c = batch[2].to(device).float()

            ybar = batch[3].to(device).float()

            tau = torch.rand_like(y)

            optimizer.zero_grad()

            if MAT == True & uw == True:
                loss, log_vars = loss_wrapper(
                    x,
                    y,
                    c,
                    train_edge_index,
                    train_edge_weight,
                    train_y_moran,
                    tau,
                    ybar,
                )
            else:
                loss = loss_wrapper(
                    x,
                    y,
                    c,
                    train_edge_index,
                    train_edge_weight,
                    train_y_moran,
                    tau,
                    ybar,
                )
            loss.backward()
            optimizer.step()

            # Eval
            if it_counts % save_freq == 0:
                model.eval()
                with torch.no_grad():
                    tau_median = torch.full_like(
                        val_dataset.target.clone().detach().reshape(-1).to(device), 0.5
                    ).float()
                    tau_random = torch.rand_like(
                        val_dataset.target.clone().detach().reshape(-1).to(device)
                    ).float()
                    if MAT:
                        pred_val_median, _ = model(
                            val_dataset.features.clone().detach().to(device),
                            val_dataset.coords.clone().detach().to(device),
                            val_edge_index,
                            val_edge_weight,
                            probit(tau_median),
                            val_dataset.ybar.clone().detach().to(device),
                        )
                        pred_val_random, _ = model(
                            val_dataset.features.clone().detach().to(device),
                            val_dataset.coords.clone().detach().to(device),
                            val_edge_index,
                            val_edge_weight,
                            probit(tau_random),
                            val_dataset.ybar.clone().detach().to(device),
                        )
                    else:
                        pred_val_median = model(
                            val_dataset.features.clone().detach().to(device),
                            val_dataset.coords.clone().detach().to(device),
                            val_edge_index,
                            val_edge_weight,
                            probit(tau_median),
                            val_dataset.ybar.clone().detach().to(device),
                        )
                        pred_val_random = model(
                            val_dataset.features.clone().detach().to(device),
                            val_dataset.coords.clone().detach().to(device),
                            val_edge_index,
                            val_edge_weight,
                            probit(tau_random),
                            val_dataset.ybar.clone().detach().to(device),
                        )
                val_score1 = score1(
                    val_dataset.target.clone().detach().reshape(-1).to(device),
                    pred_val_median.reshape(-1),
                )
                val_score2 = score2(
                    val_dataset.target.clone().detach().reshape(-1).to(device),
                    pred_val_median.reshape(-1),
                )
                val_score3 = mpe(
                    val_dataset.target.clone().detach().reshape(-1).to(device),
                    pred_val_random.reshape(-1),
                    tau=tau_random,
                )

                # Check for improvement
                if best_val_mse > val_score1.item() * (1 + min_improvement):
                    best_val_mse = val_score1.item()
                    best_epoch = epoch
                    patience_counter = 0  # Reset patience
                else:
                    patience_counter += 1  # Increment patience

                # Early stopping check
                if patience_counter > patience_limit:
                    if print_progress:
                        print(
                            f"Stopping early at epoch {epoch}. Best validation MSE: {best_val_mse} at epoch {best_epoch}."
                        )
                    found = True
                    break

                if print_progress:
                    print(
                        "Epoch [%d/%d] - Loss: %f - Valid. (MSE): %f - Valid. (MAE): %f- Valid. (MPE): %f"
                        % (
                            epoch,
                            max_epochs,
                            loss.item(),
                            val_score1.item(),
                            val_score2.item(),
                            val_score3.item(),
                        )
                    )
                save_path = path + "//trained//{}//ckpts".format(saved_file)
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                torch.save(model.state_dict(), save_path + "//" + "model_state.pt")
                writer.add_scalar("Validation (MSE)", val_score1.item(), it_counts)
                writer.add_scalar("Validation (MAE)", val_score2.item(), it_counts)
                writer.add_scalar("Validation (MPE)", val_score3.item(), it_counts)
            writer.add_scalar("Training loss", loss.item(), it_counts)
            if MAT == True & uw == True:
                writer.add_scalar(
                    "Uncertainty weight: main task", log_vars[0], it_counts
                )
                writer.add_scalar(
                    "Uncertainty weight: Morans aux task", log_vars[1], it_counts
                )
            writer.flush()
        final_epoch = epoch
        if found:
            break

    end_time = time.time()
    training_time = end_time - start_time

    # Test eval
    model.eval()
    with torch.no_grad():
        tau_median = torch.full_like(
            test_dataset.target.clone().detach().reshape(-1).to(device), 0.5
        ).float()
        tau_random = torch.rand_like(
            test_dataset.target.clone().detach().reshape(-1).to(device)
        ).float()
        if MAT:
            pred_test_median, _ = model(
                test_dataset.features.clone().detach().to(device),
                test_dataset.coords.clone().detach().to(device),
                test_edge_index,
                test_edge_weight,
                probit(tau_median),
                test_dataset.ybar.clone().detach().to(device),
            )
            pred_test_random, _ = model(
                test_dataset.features.clone().detach().to(device),
                test_dataset.coords.clone().detach().to(device),
                test_edge_index,
                test_edge_weight,
                probit(tau_random),
                test_dataset.ybar.clone().detach().to(device),
            )
        else:
            pred_test_median = model(
                test_dataset.features.clone().detach().to(device),
                test_dataset.coords.clone().detach().to(device),
                test_edge_index,
                test_edge_weight,
                probit(tau_median),
                test_dataset.ybar.clone().detach().to(device),
            )
            pred_test_random = model(
                test_dataset.features.clone().detach().to(device),
                test_dataset.coords.clone().detach().to(device),
                test_edge_index,
                test_edge_weight,
                probit(tau_random),
                test_dataset.ybar.clone().detach().to(device),
            )
    test_mse = score1(
        test_dataset.target.clone().detach().reshape(-1).to(device),
        pred_test_median.reshape(-1),
    )
    test_mae = score2(
        test_dataset.target.clone().detach().reshape(-1).to(device),
        pred_test_median.reshape(-1),
    )
    test_mpe = mpe(
        test_dataset.target.clone().detach().reshape(-1).to(device),
        pred_test_random.reshape(-1),
        tau=tau_random,
    )

    # Calculate calibration metrics
    try:
        save_calib_path = path + "trained/" + model_folder_name
        test_y_np = test_y.numpy()
        taus = np.round(np.arange(0.01, 1, 0.01), 2)
        pred_test_quantile = pd.DataFrame(index=range(test_y_np.shape[0]), columns=taus)

        try:
            pred_test_quantile = pd.read_parquet(
                f"{save_calib_path}/pred_test_quantile.parquet"
            )
        except:
            with torch.no_grad():
                for tau in taus:
                    tau_i = torch.full_like(
                        test_dataset.target.clone().detach().reshape(-1).to(device), tau
                    ).float()
                    if MAT:
                        pred_test, _ = model(
                            test_dataset.features.clone().detach().to(device),
                            test_dataset.coords.clone().detach().to(device),
                            test_edge_index,
                            test_edge_weight,
                            probit(tau_i),
                            test_dataset.ybar.clone().detach().to(device),
                        )
                    else:
                        pred_test = model(
                            test_dataset.features.clone().detach().to(device),
                            test_dataset.coords.clone().detach().to(device),
                            test_edge_index,
                            test_edge_weight,
                            probit(tau_i),
                            test_dataset.ybar.clone().detach().to(device),
                        )
                    pred_test = pred_test.reshape(-1)
                    pred_test_np = pred_test.numpy()
                    pred_test_quantile[tau] = pred_test_np
            pred_test_quantile.to_parquet(
                f"{save_calib_path}/pred_test_quantile.parquet"
            )

        pred_test_quantile_np = pred_test_quantile.values

        comparison = (pred_test_quantile_np >= test_y_np[:, np.newaxis]).astype(int)

        pred_test_quantile_comparison = pd.DataFrame(comparison, columns=taus)

        comparison_mean = pred_test_quantile_comparison.mean().reset_index()

        calibration = np.sum((comparison_mean["index"] - comparison_mean[0]) ** 2)

        madecp = np.mean(np.abs(comparison_mean["index"] - comparison_mean[0]))

        taus = [0.025, 0.975]  # Only need lower and upper quantiles for 95% interval
        pred_test_quantile = pd.DataFrame(index=range(test_y_np.shape[0]), columns=taus)

        try:
            pred_test_quantile = pd.read_parquet(
                f"{save_calib_path}/pred_test_quantile_95.parquet"
            )
        except:
            with torch.no_grad():
                for tau in taus:
                    tau_i = torch.full_like(
                        test_dataset.target.clone().detach().reshape(-1).to(device), tau
                    ).float()
                    if MAT:
                        pred_test, _ = model(
                            test_dataset.features.clone().detach().to(device),
                            test_dataset.coords.clone().detach().to(device),
                            test_edge_index,
                            test_edge_weight,
                            probit(tau_i),
                            test_dataset.ybar.clone().detach().to(device),
                        )
                    else:
                        pred_test = model(
                            test_dataset.features.clone().detach().to(device),
                            test_dataset.coords.clone().detach().to(device),
                            test_edge_index,
                            test_edge_weight,
                            probit(tau_i),
                            test_dataset.ybar.clone().detach().to(device),
                        )
                    pred_test = pred_test.reshape(-1)
                    pred_test_np = pred_test.cpu().numpy()
                    pred_test_quantile[tau] = pred_test_np
            pred_test_quantile.to_parquet(
                f"{save_calib_path}/pred_test_quantile_95.parquet"
            )

        # Calculate 95% prediction interval coverage
        tau_lower, tau_upper = 0.025, 0.975
        q_lower = pred_test_quantile[0.025].values
        q_upper = pred_test_quantile[0.975].values
        coverage_95 = np.mean((test_y_np >= q_lower) & (test_y_np <= q_upper))

    except Exception as e:
        print(f"Calibration metrics calculation error: {e}")
        calibration = float("inf")
        madecp = float("inf")
        coverage_95 = 0.0
        variance = 0.0

    # Return comprehensive results
    results = {
        "random_state": random_state,
        "model_folder": saved_file,
        "final_epoch": final_epoch,
        "best_epoch": best_epoch,
        "training_time": training_time,
        "test_mse": test_mse.item(),
        "test_mae": test_mae.item(),
        "val_mse": best_val_mse,
        "mpe": test_mpe.item(),
        "calibration": calibration,
        "madecp": madecp,
        "coverage_95": coverage_95,
        "variance": 0.0,
    }

    return results

In [None]:
def train_multiple_seeds(args, seeds=None, num_seeds=10):
    """
    Train model with multiple random seeds and return aggregated results
    """
    if seeds is None:
        seeds = list(range(42, 42 + num_seeds))

    all_results = []
    base_model_name = None

    print(f"Starting training with {len(seeds)} different random seeds: {seeds}")
    print("=" * 80)

    for i, seed in enumerate(seeds):
        print(f"\n--- Training with seed {seed} ({i+1}/{len(seeds)}) ---")

        # Create a consistent model folder name for the first run
        if i == 0:
            # Generate base model name
            dataset = args.dataset
            model_name = args.model_name
            k = args.k
            gnn_hidden_dim = args.gnn_hidden_dim
            gnn_emb_dim = args.gnn_emb_dim
            pe_hidden_dim = args.pe_hidden_dim
            pe_emb_dim = args.pe_emb_dim
            final_emb_dim = args.pe_emb_dim
            MAT = args.mat
            uw = args.uw
            lamb = args.lamb
            KNN = args.knn
            batched_training = args.batched_training
            batch_size = args.batch_size

            test_ = "final" + "_" + model_name + "_" + dataset + "_k" + str(k)
            test_ = test_ + "_ghid" + str(gnn_hidden_dim)
            test_ = test_ + "_gemb" + str(gnn_emb_dim)
            test_ = test_ + "_phid" + str(pe_hidden_dim)
            test_ = test_ + "_pemb" + str(pe_emb_dim)
            if batched_training == True:
                test_ = test_ + "_bs" + str(batch_size)
            else:
                test_ = test_ + "_bsn"

            now = datetime.now()
            base_model_name = "{}_{}{}-{}h{}m{}s".format(
                test_,
                now.strftime("%h"),
                now.strftime("%d"),
                now.strftime("%H"),
                now.strftime("%M"),
                now.strftime("%S"),
            )

        # Train with current seed with timeout protection
        try:
            print(f"Starting training for seed {seed}...")
            results = train_single_seed(
                args, seed, model_folder_name=f"{base_model_name}_seed{seed}"
            )
            all_results.append(results)

            print(f"Seed {seed} completed:")
            print(f"  - Final epoch: {results['final_epoch']}")
            print(f"  - Best epoch: {results['best_epoch']}")
            print(f"  - Training time: {results['training_time']:.2f}s")
            print(f"  - Test MSE: {results['test_mse']:.6f}")
            print(f"  - Test MAE: {results['test_mae']:.6f}")
            print(f"  - MPE: {results['mpe']:.6f}")
            print(f"  - Coverage 95%: {results['coverage_95']:.4f}")

        except Exception as e:
            print(f"Error training seed {seed}: {e}")
            # Create a dummy result to maintain consistency
            dummy_result = {
                "random_state": seed,
                "model_folder": f"{base_model_name}_seed{seed}",
                "final_epoch": 0,
                "best_epoch": 0,
                "training_time": 0.0,
                "test_mse": float("inf"),
                "test_mae": float("inf"),
                "val_mse": float("inf"),
                "mpe": float("inf"),
                "calibration": float("inf"),
                "madecp": float("inf"),
                "coverage_95": 0.0,
                "variance": 0.0,
            }
            all_results.append(dummy_result)
            print(f"Added dummy result for failed seed {seed}")

    return all_results

In [None]:
def calculate_aggregated_results(all_results):
    """
    Calculate mean and standard deviation for all metrics across seeds
    """

    # Extract all metrics
    metrics = [
        "final_epoch",
        "best_epoch",
        "training_time",
        "test_mse",
        "test_mae",
        "val_mse",
        "mpe",
        "calibration",
        "madecp",
        "coverage_95",
        "variance",
    ]

    results_dict = {}
    for metric in metrics:
        values = [result[metric] for result in all_results]
        results_dict[metric] = {
            "mean": np.mean(values),
            "std": np.std(values),
            "min": np.min(values),
            "max": np.max(values),
        }

    return results_dict


def print_aggregated_results(all_results, results_dict):
    """
    Print comprehensive results summary
    """
    print("\n" + "=" * 100)
    print("COMPREHENSIVE RESULTS SUMMARY")
    print("=" * 100)

    print(f"\nNumber of seeds: {len(all_results)}")
    print(f"Seeds used: {[r['random_state'] for r in all_results]}")

    print(f"\nBase model folder: {all_results[0]['model_folder'].split('_seed')[0]}")

    print("\n" + "-" * 80)
    print("TRAINING METRICS")
    print("-" * 80)

    # Training metrics
    print(f"Final Epochs:")
    print(
        f"  Mean ± Std: {results_dict['final_epoch']['mean']:.1f} ± {results_dict['final_epoch']['std']:.1f}"
    )
    print(
        f"  Range: [{results_dict['final_epoch']['min']:.0f}, {results_dict['final_epoch']['max']:.0f}]"
    )

    print(f"\nBest Epochs:")
    print(
        f"  Mean ± Std: {results_dict['best_epoch']['mean']:.1f} ± {results_dict['best_epoch']['std']:.1f}"
    )
    print(
        f"  Range: [{results_dict['best_epoch']['min']:.0f}, {results_dict['best_epoch']['max']:.0f}]"
    )

    print(f"\nTraining Time (seconds):")
    print(
        f"  Mean ± Std: {results_dict['training_time']['mean']:.2f} ± {results_dict['training_time']['std']:.2f}"
    )
    print(
        f"  Range: [{results_dict['training_time']['min']:.2f}, {results_dict['training_time']['max']:.2f}]"
    )

    print("\n" + "-" * 80)
    print("PERFORMANCE METRICS")
    print("-" * 80)

    # Performance metrics
    print(f"Test MSE:")
    print(
        f"  Mean ± Std: {results_dict['test_mse']['mean']:.6f} ± {results_dict['test_mse']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['test_mse']['min']:.6f}, {results_dict['test_mse']['max']:.6f}]"
    )

    print(f"\nTest MAE:")
    print(
        f"  Mean ± Std: {results_dict['test_mae']['mean']:.6f} ± {results_dict['test_mae']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['test_mae']['min']:.6f}, {results_dict['test_mae']['max']:.6f}]"
    )

    print(f"\nValidation MSE:")
    print(
        f"  Mean ± Std: {results_dict['val_mse']['mean']:.6f} ± {results_dict['val_mse']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['val_mse']['min']:.6f}, {results_dict['val_mse']['max']:.6f}]"
    )

    print("\n" + "-" * 80)
    print("UNCERTAINTY QUANTIFICATION METRICS")
    print("-" * 80)

    # Uncertainty metrics
    print(f"MPE (Mean Prediction Error):")
    print(
        f"  Mean ± Std: {results_dict['mpe']['mean']:.6f} ± {results_dict['mpe']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['mpe']['min']:.6f}, {results_dict['mpe']['max']:.6f}]"
    )

    print(f"\n95% Prediction Interval Coverage:")
    print(
        f"  Mean ± Std: {results_dict['coverage_95']['mean']:.4f} ± {results_dict['coverage_95']['std']:.4f}"
    )
    print(
        f"  Range: [{results_dict['coverage_95']['min']:.4f}, {results_dict['coverage_95']['max']:.4f}]"
    )

    print(f"\nCalibration Score:")
    print(
        f"  Mean ± Std: {results_dict['calibration']['mean']:.6f} ± {results_dict['calibration']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['calibration']['min']:.6f}, {results_dict['calibration']['max']:.6f}]"
    )

    print(f"\nMADECP (Mean Absolute Deviation from Expected Coverage Probability):")
    print(
        f"  Mean ± Std: {results_dict['madecp']['mean']:.6f} ± {results_dict['madecp']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['madecp']['min']:.6f}, {results_dict['madecp']['max']:.6f}]"
    )

    print(f"\nPrediction Variance:")
    print(
        f"  Mean ± Std: {results_dict['variance']['mean']:.6f} ± {results_dict['variance']['std']:.6f}"
    )
    print(
        f"  Range: [{results_dict['variance']['min']:.6f}, {results_dict['variance']['max']:.6f}]"
    )

    print("\n" + "=" * 100)
    print("INDIVIDUAL SEED RESULTS")
    print("=" * 100)

    # Individual results table
    print(
        f"\n{'Seed':<6} {'Epochs':<8} {'Time(s)':<10} {'Test MSE':<12} {'Test MAE':<12} {'MPE':<12} {'Coverage':<10}"
    )
    print("-" * 80)
    for result in all_results:
        print(
            f"{result['random_state']:<6} {result['final_epoch']:<8} {result['training_time']:<10.2f} "
            f"{result['test_mse']:<12.6f} {result['test_mae']:<12.6f} {result['mpe']:<12.6f} "
            f"{result['coverage_95']:<10.4f}"
        )

    print("\n" + "=" * 100)

In [None]:
# Set up arguments for multi-seed training
args = argparse.Namespace(
    dataset="california_housing",
    model_name="pegqsage-ybar",
    path="../../",
    train_size=0.8,
    val_size=0.1,
    batched_training=True,
    batch_size=512,
    max_epochs=1000,
    patience_limit=50,
    min_improvement=0.01,
    train_crit="pinball",
    lr=1e-3,
    gnn_hidden_dim=32,
    gnn_emb_dim=32,
    pe_hidden_dim=128,
    pe_emb_dim=64,
    final_emb_dim=8,
    k=5,
    p_dropout=0.5,
    mat=False,
    uw=False,
    lamb=0.0,
    knn=True,
    save_freq=5,
    print_progress=True,
)

# Run multi-seed training
print("Starting Multi-Seed Training Experiment")
print("=" * 80)

start_time = time.time()

# Train with 10 different seeds
all_results = train_multiple_seeds(args, num_seeds=10)

end_time = time.time()
total_execution_time = end_time - start_time

# Calculate aggregated results
results_dict = calculate_aggregated_results(all_results)

# Print comprehensive results
print_aggregated_results(all_results, results_dict)

# Print total execution time
days, remainder = divmod(total_execution_time, 60 * 60 * 24)
hours, remainder = divmod(remainder, 60 * 60)
minutes, seconds = divmod(remainder, 60)
print(
    f"\nTOTAL EXECUTION TIME: {int(days)} days, {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds"
)

In [None]:
# Access the trained models (all seeds will be saved with the base model name)
base_model_folder = all_results[0]["model_folder"].split("_seed")[0]
print(f"Base model folder: {base_model_folder}")

# List all model folders for this experiment
models_lst = os.listdir(f"{args.path}trained/")
experiment_folders = [m for m in models_lst if base_model_folder in m]
print(f"Experiment folders: {experiment_folders}")

# Load the best performing model (lowest test MSE) for detailed analysis
best_result = min(all_results, key=lambda x: x["test_mse"])
print(
    f"Best performing model: Seed {best_result['random_state']} with Test MSE: {best_result['test_mse']:.6f}"
)

# Load the best model for analysis
model_folder = best_result["model_folder"]
model_path = f"{args.path}trained/{model_folder}/ckpts/model_state.pt"

# Set up data and model for analysis (using the same seed as best result)
set_seed(best_result["random_state"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Access and process data (same as training)
if args.dataset == "california_housing":
    x, y, c = get_california_housing_data()

# Split data (same as training)
n = x.shape[0]
indices = np.arange(n)
_, _, _, _, idx_train, idx_val_test = train_test_split(
    x,
    y,
    indices,
    test_size=(1 - args.train_size),
    random_state=best_result["random_state"],
)
idx_val, idx_test = train_test_split(
    idx_val_test,
    test_size=(1 - args.train_size - args.val_size) / (1 - args.train_size),
    random_state=best_result["random_state"],
)

# Separate x, y and c objects
train_x, val_x, test_x = x[idx_train], x[idx_val], x[idx_test]
train_y, val_y, test_y = y[idx_train], y[idx_val], y[idx_test]
train_c, val_c, test_c = c[idx_train], c[idx_val], c[idx_test]

# Make model
if args.model_name == "pegqcn-ybar":
    model = PEGQCN(
        num_features_in=train_x.shape[1],
        gnn_hidden_dim=args.gnn_hidden_dim,
        gnn_emb_dim=args.gnn_emb_dim,
        pe_hidden_dim=args.pe_hidden_dim,
        pe_emb_dim=args.pe_emb_dim,
        final_emb_dim=args.final_emb_dim,
        k=args.k,
        p_dropout=args.p_dropout,
        MAT=args.mat,
        KNN=args.knn,
    ).to(device)
elif args.model_name == "pegqat-ybar":
    model = PEGQAT(
        num_features_in=train_x.shape[1],
        gnn_hidden_dim=args.gnn_hidden_dim,
        gnn_emb_dim=args.gnn_emb_dim,
        pe_hidden_dim=args.pe_hidden_dim,
        pe_emb_dim=args.pe_emb_dim,
        final_emb_dim=args.final_emb_dim,
        k=args.k,
        p_dropout=args.p_dropout,
        MAT=args.mat,
        KNN=args.knn,
    ).to(device)
elif args.model_name == "pegqsage-ybar":
    model = PEGQSAGE(
        num_features_in=train_x.shape[1],
        gnn_hidden_dim=args.gnn_hidden_dim,
        gnn_emb_dim=args.gnn_emb_dim,
        pe_hidden_dim=args.pe_hidden_dim,
        pe_emb_dim=args.pe_emb_dim,
        final_emb_dim=args.final_emb_dim,
        k=args.k,
        p_dropout=args.p_dropout,
        MAT=args.mat,
        KNN=args.knn,
    ).to(device)

# Load the best model
model.load_state_dict(torch.load(model_path))
model = model.float()

# Model analysis
model.eval()
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")