In [5]:
# Standard library imports
import json
import csv
import os
import sys
from pprint import pprint
import time
import gc 

# Third party imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch_geometric
from torch.nn.utils import clip_grad_norm_
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import roc_auc_score
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

# Local application imports
sys.path.append("/pbs/home/e/erodrigu/TesisPhDEzequielRodriguez/Code")
from UHECRs_gnn import(SD433UMDatasetHomogeneous,
                       GNNWithAttentionDiscriminator3Heads,
                       GNNWithAttentionDiscriminator3HeadsDualInput,
                       MaskNodes,
                       MaskMdCounters,
                       SilentPrunner,
                       MaskRandomNodes,
)

from my_utils.my_basic_utils import (
    create_bins,
    filter_dataframe,
)

# set PATHS
code_PATH = os.path.abspath(os.path.join(".."))
project_PATH = os.path.abspath(os.path.join(code_PATH, ".."))
data_PATH = os.path.join(project_PATH, "data")

### Version check

In [6]:
print(f"Torch version: {torch.__version__}")
print(f"Torch CUDA version: {torch.version.cuda}")
print(f"Cuda available: {torch.cuda.is_available()}")
print(f"Torch geometric version: {torch_geometric.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device is: {device}")

Torch version: 1.11.0
Torch CUDA version: 10.2
Cuda available: True
Torch geometric version: 2.4.0
Device is: cuda


In [7]:
!nvidia-smi

Thu Jan  2 00:39:21 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.256.02   Driver Version: 470.256.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:04:00.0 Off |                    0 |
| N/A   23C    P8    27W / 149W |      3MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

### Dataset Index Loading

In [8]:
folder_path = "/sps/pauger/users/erodriguez/PhotonDiscrimination/"
dir_path = "/sps/pauger/users/erodriguez/PhotonDiscrimination/JSONfiles/"
index = pd.DataFrame()

# indexes
primaries = ["Proton", "Photon"]
energy_bins = ["16.5_17.0", "17.0_17.5"]
atms = ["01", "03", "08", "09"]
indexes = [
    f"index_hadron_rec_{x}_{y}_{z}.csv"
    for x in primaries
    for y in energy_bins
    for z in atms
]

# create the index by appending
for index_name in indexes:
    proton_index = pd.read_csv(folder_path + index_name, on_bad_lines="skip")
    photon_rec_index = pd.read_csv(
        folder_path + index_name.replace("hadron", "photon"), on_bad_lines="skip"
    )
    index_ = pd.merge(
        proton_index,
        photon_rec_index,
        on=["filename", "atm_model", "shower_id", "use_id"],
        how="left",
    )
    index = pd.concat([index, index_], ignore_index=True)

index = index.drop_duplicates()
index = index.drop_duplicates(subset=["filename"])
# we won't train using iron
index["mass_group"] = index["filename"].str.split(pat="_", expand=True)[0]
index = index[index["mass_group"] != "Iron"]
print(f"Events before quality cuts: {len(index)}")

Events before quality cuts: 398227


In [9]:
index.describe()

  diff_b_a = subtract(b, a)


Unnamed: 0,atm_model,shower_id,use_id,energyMC,zenithMC,showerSize,showerSizeError,isT5,is6T5,Xmax,...,nearestid,nCandidates,bLDF,isSaturated,muonNumber,electromagneticEnergy,photon_energy,s_250,equivalent_energy,M1
count,398227.0,398227.0,398227.0,398227.0,398227.0,398227.0,398227.0,398227.0,398227.0,398227.0,...,398227.0,398227.0,398227.0,398227.0,398227.0,398227.0,398187.0,398187.0,397965.0,397965.0
mean,5.250465,624.671403,10.489314,1.238382e+17,0.677721,16.22201,1.188039,0.671835,0.671835,731.26586,...,0.0,7.788909,0.761453,0.116549,489250.2,1.139049e+17,1.642059e+17,29.00653,1.202942e+17,-inf
std,3.345009,360.883823,5.766413,7.912391e+16,0.273104,17.278987,0.866914,0.469546,0.469546,97.612257,...,0.0,5.378215,0.426196,0.320883,596083.7,7.390093e+16,8.526831e+18,32.499657,6.294315e+17,
min,1.0,0.0,1.0,3.16277e+16,0.005836,0.0,0.0,0.0,0.0,541.34,...,0.0,0.0,0.0,0.0,3914.03,2.37889e+16,0.0,0.0,0.0,-inf
25%,1.0,312.0,5.0,5.62889e+16,0.470053,3.482985,0.674117,0.0,0.0,674.485,...,0.0,4.0,1.0,0.0,68356.6,5.13339e+16,0.0,5.067435,3.01216e+16,
50%,8.0,625.0,10.0,9.96571e+16,0.698197,10.6728,1.19316,1.0,1.0,726.16,...,0.0,8.0,1.0,0.0,246276.0,9.1395e+16,7.3041e+16,17.9387,8.34971e+16,0.0920853
75%,9.0,937.0,15.0,1.78841e+17,0.903215,23.7102,1.74992,1.0,1.0,776.31,...,0.0,12.0,1.0,0.0,681854.0,1.64577e+17,1.755365e+17,42.08015,1.73069e+17,0.909041
max,9.0,1249.0,20.0,3.16147e+17,1.13436,486.789,25.0415,1.0,1.0,7344.8,...,0.0,23.0,1.0,1.0,3569170.0,3.12735e+17,5.2109e+21,245.14,2.35976e+20,2.08977


### Quality Cuts and Binning

In [10]:
index = index.sample(frac=1)
index.loc[index["filename"].str.contains("Photon"), "isPhoton"] = 1
index.loc[index["filename"].str.contains("Proton"), "isPhoton"] = 0

index["sin2zenith"] = np.sin(index["zenithMC"]) ** 2

# photon efficiency from fit from simulations
#index["est_efficiency"] = (
#    15.4074
#    + 17.4996 * (np.log10(index["energyMC"]) - 17)
#    - 12.7485 * index["sin2zenith"]
#    - 20.7650 * index["sin2zenith"] ** 2
#    - 13.1239 * (np.log10(index["energyMC"]) - 17) * index["sin2zenith"]
#)
#index["est_efficiency"] = expit(index["est_efficiency"])

feature_filters = {
    "zenithMC": {"filter_type": "range", "max_cut": np.deg2rad(45)},
    "photon_energy": {"filter_type": "range", "min_cut": 1},
    #"est_efficiency": {"filter_type": "range", "min_cut": 0.9},
    "isT5": {"filter_type": "value", "value_to_keep": 1}
}
index = filter_dataframe(index, feature_filters)

index, e_bin_centers, e_bin_edges, e_labels = create_bins(
    index,
    lower_val=10**16.5,
    upper_val=10**17.5,
    num=6,
    unbinned_col="energyMC",
    bin_column_name="e_bin",
    bin_width="equal_logarithmic",
)

index, z_bin_centers, z_bin_edges, z_labels = create_bins(
    index,
    lower_val=0,
    upper_val=np.sin(np.deg2rad(45)) ** 2,
    num=4,
    unbinned_col="sin2zenith",
    bin_column_name="z_bin",
    bin_width="equal",
)

index = index.loc[~index["e_bin"].isnull()]

# corrupted or problematic ADSTs
exclude_list = [
"Photon_17.0_17.5_011102_11",
"Photon_17.0_17.5_080595_20"
]
index = index[~index['filename'].isin(exclude_list)]

print(f"Events after quality cuts: {len(index)}")

Events after quality cuts: 209158


### Balanced Dataset Division

In [11]:
# Combine label and the two categorical variables for stratified sampling
index["categorical_balance"] = (
    index["isPhoton"].astype(str)
    + "_"
    + index["e_bin"].astype(str)
    + "_"
    + index["z_bin"].astype(str)
)

random_seed = 42
stratified_split = StratifiedShuffleSplit(
    n_splits=1, test_size=0.25, random_state=random_seed
)

for dev_index_, test_index_ in stratified_split.split(
    index, index["categorical_balance"]
):
    # Original Training set
    dev_index = index.iloc[dev_index_]

    # Testing set
    test_index = index.iloc[test_index_]

# Further split the original training set into train and validation sets
validation_size = 0.25  # Adjust as needed
split = StratifiedShuffleSplit(
    n_splits=1, test_size=validation_size, random_state=random_seed
)

for train_index_, validation_index_ in split.split(
    dev_index, dev_index["categorical_balance"]
):
    train_index = dev_index.iloc[train_index_]
    validation_index = dev_index.iloc[validation_index_]

# Print the size of each dataset
print("Train dataset size:", train_index.shape[0])
print("Validation dataset size:", validation_index.shape[0])
print("Test dataset size:", test_index.shape[0])

Train dataset size: 117651
Validation dataset size: 39217
Test dataset size: 52290


In [12]:
dir_path = "/sps/pauger/users/erodriguez/PhotonDiscrimination/JSONfiles/"
root_path = "/sps/pauger/users/erodriguez/PhotonDiscrimination/root/"

# Function to construct paths based on DataFrame columns
def construct_path(row, base_path):
    return f"{base_path}{row['mass_group']}/{row['filename']}.json"

# Set paths according to index
train_paths = train_index.apply(lambda row: construct_path(row, dir_path), axis=1).tolist()
val_paths = validation_index.apply(lambda row: construct_path(row, dir_path), axis=1).tolist()

### Generation of Normalization Dictionary

In [13]:
norm_dict_computed = True
augmentation_and_normalization_options={"mask_PMTs":True,
                                        "AoP_and_saturation":True,
                                        "log_normalize_traces":True,
                                        "log_normalize_signals":True,
                                        "mask_MD_mods":True}

if not norm_dict_computed:
    # process the dataset
    train_PyG_ds = SD433UMDatasetHomogeneous(
        file_paths=train_paths,
        root=root_path,
        augmentation_options=augmentation_and_normalization_options)
    # compute statistics required for standardization
    normalization_dict = train_PyG_ds.compute_normalization_params(features=["x",
                                                                             "y",
                                                                             "z",
                                                                             "deltaTimeHottest",
                                                                             "WCD_signal"])
    # set values for min-max normalization
    normalization_dict["pmt_number"] = {"min": 1,
                                        "max": 3,
                                        "method": "min_max_scaling"}
    normalization_dict["effective_area"] = {"min": 0,
                                            "max": 3 * 10.46,
                                            "method": "min_max_scaling"}
    normalization_dict["rho_mu"] = {'min': -2.0,
                                    'method': 'min_max_scaling',
                                    'max': (64*3)/(3*10*np.cos(np.deg2rad(45)))}
    # print the dict to overwrite the code below
    pprint(normalization_dict)
else:
    normalization_dict = {
                         # with silent
                         #'deltaTimeHottest': {'mean': -13.743452072143555,
                         #                     'method': 'standardization',
                         #                     'std': 446.0340270996094},
                         # without silent
                         'deltaTimeHottest': {'mean': -47.27531568592806,
                                              'method': 'standardization',
                                              'std': 603.6062483849028},
                         'effective_area': {'max': 31.380000000000003,
                                            'method': 'min_max_scaling',
                                            'min': 0},
                         'pmt_number': {'max': 3, 'method': 'min_max_scaling', 'min': 1},
                         #'rho_mu': {'min': -2.0,
                         #           'method': 'min_max_scaling',
                         #           'max': (64*3)/(3*10*np.cos(np.deg2rad(45)))},
                         'rho_mu':{'mean': 0.21510971141596952,
                                   'method': 'standardization',
                                   'std': 0.9057438827119321},
                         # with silent
                         #'x': {'mean': -2.081512212753296,
                         #      'method': 'standardization',
                         #      'std': 779.2147827148438},
                         #'y': {'mean': 0.3692401945590973,
                         #      'method': 'standardization',
                         #      'std': 779.95458984375},
                         #'z': {'mean': -0.17404018342494965,
                         #      'method': 'standardization',
                         #      'std': 9.372444152832031},
                         # without silent
                         'x': {'mean': -1.01887806305201,
                               'method': 'standardization',
                               'std': 364.5171135403138},
                         'y': {'mean': -0.06542848666299515,
                               'method': 'standardization',
                               'std': 356.8056580363697},
                         'z': {'mean': -0.16543275617686548,
                               'method': 'standardization',
                               'std': 6.410170534866934}
                        }

### Datasets and Loaders

In [14]:
test_run = False
include_silent = False
dual_input = True
loss_strategy = "equal"  # Options: "equal", "prioritize_sdmd", "alternate"
# Model ID and file paths
model_id = f"dual_input_{str(dual_input)}_silent_{str(include_silent)}_loss_strategy_{str(loss_strategy)}_post_UHECR"
model_filename = f'{model_id}.pth'
last_model_filename = f'last_{model_id}.pth'
csv_filename = f"learning_curves_{model_id}.csv"

if test_run:
    # subset for code testing
    train_paths = train_paths[:55000]
    val_paths = val_paths#[:1000]

In [None]:
# Divide the dataset paths into subsets
n_train_loaders = 5
n_val_loaders = 3
train_subset_size = len(train_paths) // n_train_loaders
train_subset_paths = [train_paths[i * train_subset_size: (i + 1) * train_subset_size] for i in range(n_train_loaders)]
val_subset_size = len(val_paths) // n_val_loaders
val_subset_paths = [val_paths[i * val_subset_size: (i + 1) * val_subset_size] for i in range(n_val_loaders)]

# Create datasets and loaders for each subset
# This is needed to speed up training
# One huge Dataset for trainig, even with the loader, seems to slow down the training process
train_loaders = []
val_loaders = []
n_time_bins = 60

for subset_paths in train_subset_paths:
    train_loader = DataLoader(SD433UMDatasetHomogeneous(file_paths=subset_paths,
                              root=root_path,
                              augmentation_options=augmentation_and_normalization_options,
                              normalization_dict=normalization_dict,
                              include_silent=include_silent, n_time_bins=n_time_bins),
                              batch_size=32, shuffle=True, num_workers=8)
    train_loaders.append(train_loader)

for subset_paths in val_subset_paths:
    val_loader = DataLoader(SD433UMDatasetHomogeneous(file_paths=subset_paths,
                              root=root_path,
                              augmentation_options=augmentation_and_normalization_options,
                              normalization_dict=normalization_dict,
                              include_silent=include_silent, n_time_bins=n_time_bins),
                              batch_size=32, num_workers=8,)
    val_loaders.append(val_loader)

## GAN compilation

In [3]:
# TA definition
TA_input_length = n_time_bins  # Length of the input trace
TA_initial_kernel_size = 7  # Initial kernel size for the first convolutional layer

# 120 bins
#TA_filters = [64, 32, 4]
#TA_kernel_sizes = [7, 7, 8]
#TA_strides = [3, 4, 1]

# 60 bins
TA_filters = [64, 32, 4]
TA_kernel_sizes = [7, 6, 5]
TA_strides = [3, 3, 1]

# arguments
sd_node_features = 5
sdmd_node_features = 7
md_node_features = 6
GCN_filters = [16, 32, 4]
dense_units = [16, 8]
num_heads = 3  # Number of attention heads to be used in GATv2Conv layers

nn_args = {
        "TA_filters":TA_filters,
        "TA_kernel_sizes":TA_kernel_sizes,
        "TA_strides":TA_strides,
        "dense_units":dense_units,
        "GCN_filters":GCN_filters,
        "sd_node_features":sd_node_features,
        "md_node_features":md_node_features,
        "sdmd_node_features":sdmd_node_features,
        "num_heads":num_heads}

if dual_input:
    # Instantiate the GNNWithAttentionDiscriminator
    triheaded_model = GNNWithAttentionDiscriminator3HeadsDualInput(**nn_args)
else:
    # Instantiate the GNNWithAttentionDiscriminator
    triheaded_model = GNNWithAttentionDiscriminator3Heads(**nn_args)

print("Parameters: ", sum(p.numel() for p in triheaded_model.parameters()))
triheaded_model.to(device)

NameError: name 'n_time_bins' is not defined

In [13]:
# Check if the model exists to load it, otherwise initialize a new model
if os.path.exists(model_filename):
    # Load existing model and optimizer state
    triheaded_model = torch.load(model_filename)
    optimizer = torch.optim.Adam(triheaded_model.parameters(), lr=1e-4)  # Recreate optimizer
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    
    # Optional: load optimizer state and scheduler state if you have saved them
    checkpoint = torch.load(f'{model_id}_checkpoint.pth')
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_ROC_AUC = checkpoint['best_val_ROC_AUC']
    print(f"Resuming training from epoch {start_epoch} with best validation ROC-AUC {best_val_ROC_AUC:.4f}.")
else:
    # Initialize model, optimizer, and scheduler from scratch
    triheaded_model.train()
    optimizer = torch.optim.Adam(triheaded_model.parameters(), lr=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    start_epoch = 0
    best_val_ROC_AUC = 0.0
    print("Starting training from scratch.")
    
    # Initialize CSV file and write header if starting fresh
    with open(csv_filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['epoch', 'time',
                         'train_roc_auc_sd', 'train_roc_auc_sdmd', 'train_roc_auc_md',
                         'train_loss_sd', 'train_loss_sdmd', 'train_loss_md',
                         'val_roc_auc_sd', 'val_roc_auc_sdmd', 'val_roc_auc_md',
                         'val_loss_sd', 'val_loss_sdmd', 'val_loss_md'])

Starting training from scratch.


In [14]:
# Early stopping parameters
patience = 120
counter = 0
num_epochs = 120

# transformations for individual input
masking_transformation_sd =  MaskNodes(max_nodes2prune=2)
masking_transformation_md =  MaskMdCounters(rho_mu_column_idx=7, effective_area_column_idx=6, silent_value=1/((64*3)/(3*10*np.cos(np.deg2rad(45)))-(-2.0)))

# transformations for dual input
silent_prunner = SilentPrunner(silent_col_index = 4, silent_value=0.0)
random_masking = MaskRandomNodes(max_nodes2prune=2)

# loss
criterion = nn.BCEWithLogitsLoss()

In [15]:
# Weighting configuration based on the chosen strategy
def get_loss_weights(epoch, strategy):
    if strategy == "equal":
        # Equal weighting for all heads
        return {'sd': 1.0, 'sdmd': 1.0, 'md': 1.00}
    
    elif strategy == "prioritize_sdmd":
        # More importance to SDMD (middle head)
        return {'sd': 0.2, 'sdmd': 0.6, 'md': 0.2}
    
    elif strategy == "alternate":
        # Alternating focus with SDMD every other epoch
        if epoch % 4 == 0:  # Focus on SD head (epoch 0, 4, 8, ...)
            return {'sd': 1.0, 'sdmd': 0.0, 'md': 0.0}
        elif epoch % 4 == 1:  # Focus on SDMD head (epoch 1, 5, 9, ...)
            return {'sd': 0.0, 'sdmd': 1.0, 'md': 0.0}
        elif epoch % 4 == 2:  # Focus on MD head (epoch 2, 6, 10, ...)
            return {'sd': 0.0, 'sdmd': 0.0, 'md': 1.0}
        else:  # Focus again on SDMD head (epoch 3, 7, 11, ...)
            return {'sd': 0.0, 'sdmd': 1.0, 'md': 0.0}

In [None]:
###################
## Training loop ##
###################

# would be so nice to just have a .fit() >:(

for epoch in tqdm(range(start_epoch, num_epochs), desc="Epoch", position=0, leave=True):
    
    start_time = time.time()  # Record start time of epoch
    # Get the current loss weights based on the chosen strategy
    loss_weights = get_loss_weights(epoch, loss_strategy)
    
    # Initialize containers for training predictions and targets
    train_preds_sd = []
    train_preds_sdmd = []
    train_preds_md = []
    train_targets = []
    
    # Training over subsets
    for i, train_loader in enumerate(train_loaders):
        subset_preds_sd = []
        subset_preds_sdmd = []
        subset_preds_md = []
        subset_targets = []
        
        for batch_idx, data in enumerate(tqdm(train_loader, desc=f"Subset {i+1} Training", position=0, leave=True)):
            try:
                # Apply transformations
                data_sd = masking_transformation_sd.collate(data)
                if dual_input:
                    # dual input
                    #data_sd_no_silent = silent_prunner.collate(data_sd)
                    #data_sdmd = random_masking.collate(data_sd_no_silent)
                    data_sdmd = random_masking.collate(data_sd)
                    gpu_sd_graph = data_sd.to(device)
                    gpu_sdmd_graph = data_sdmd.to(device)
                    inputs, targets = [gpu_sd_graph, gpu_sdmd_graph], gpu_sd_graph.y
                else:
                    # single input
                    data_sdmd = masking_transformation_md.collate(data_sd)
                    gpu_graph = data_sdmd.to(device)
                    inputs, targets = gpu_graph, gpu_graph.y
    
                # Sanity checks on inputs and targets
                assert torch.all(torch.isfinite(data_sd.x)), "NaN detected in data_sd.x"
                assert torch.all(torch.isfinite(data_sdmd.x)), "NaN detected in data_sdmd.x"
                assert torch.all(torch.isfinite(data_sd.x_traces)), "NaN detected in data_sd.x_traces"
                assert torch.all(torch.isfinite(data_sdmd.x_traces)), "NaN detected in data_sd.x_traces"
                assert torch.all(torch.isfinite(targets)), "NaN detected in targets"

            except AssertionError as e:
                # Print the error message and skip this batch
                print(f"Skipping batch {batch_idx} in subset {i+1} due to assertion error: {e}")
                continue
            
            # Training step
            optimizer.zero_grad()
            output1, output2, output3 = triheaded_model(inputs)

            # Compute the weighted losses for each head based on the current strategy
            loss_sd = criterion(output1.flatten(), targets.flatten().float()) * loss_weights['sd']
            loss_sdmd = criterion(output2.flatten(), targets.flatten().float()) * loss_weights['sdmd']
            loss_md = criterion(output3.flatten(), targets.flatten().float()) * loss_weights['md']

            # Additional NaN checks
            if torch.isnan(loss_sd) or torch.isnan(loss_sdmd) or torch.isnan(loss_md):
                print(f"NaN detected in loss components: SD: {loss_sd}, SDMD: {loss_sdmd}, MD: {loss_md}")
                continue  # Skip this batch

            # Total loss
            total_loss = loss_sd + loss_sdmd + loss_md
            with torch.autograd.set_detect_anomaly(True):
                total_loss.backward()
            clip_grad_norm_(triheaded_model.parameters(), max_norm=1.0)
            optimizer.step()
            
            # Save predictions and targets
            subset_preds_sd.append(output1.detach().cpu())
            subset_preds_sdmd.append(output2.detach().cpu())
            subset_preds_md.append(output3.detach().cpu())
            subset_targets.append(targets.detach().cpu())
        
        # appending to higher level container
        train_preds_sd.append(torch.cat(subset_preds_sd))
        train_preds_sdmd.append(torch.cat(subset_preds_sdmd))
        train_preds_md.append(torch.cat(subset_preds_md))
        train_targets.append(torch.cat(subset_targets))
    
    # Convert predictions and targets to tensors
    train_preds_sd = torch.cat(train_preds_sd)
    train_preds_sdmd = torch.cat(train_preds_sdmd)
    train_preds_md = torch.cat(train_preds_md)
    train_targets = torch.cat(train_targets)
    
    # Calculate loss for train set
    train_loss_sd = criterion(train_preds_sd, train_targets.unsqueeze(1).float())
    train_loss_sdmd = criterion(train_preds_sdmd, train_targets.unsqueeze(1).float())
    train_loss_md = criterion(train_preds_md, train_targets.unsqueeze(1).float())
    train_ROC_AUC_sd = roc_auc_score(train_targets.numpy(), train_preds_sd.numpy())
    train_ROC_AUC_sdmd = roc_auc_score(train_targets.numpy(), train_preds_sdmd.numpy())
    train_ROC_AUC_md = roc_auc_score(train_targets.numpy(), train_preds_md.numpy())
    print(f"Epoch {epoch + 1}/{num_epochs}: Train ROC-AUC SD: {train_ROC_AUC_sd.item():.4f}, Train Loss SD: {train_loss_sd.item():.4f}")
    print(f"Epoch {epoch + 1}/{num_epochs}: Train ROC-AUC SDMD: {train_ROC_AUC_sdmd.item():.4f}, Train Loss SDMD: {train_loss_sdmd.item():.4f}")
    print(f"Epoch {epoch + 1}/{num_epochs}: Train ROC-AUC MD: {train_ROC_AUC_md.item():.4f}, Train Loss MD: {train_loss_md.item():.4f}")
    
    del train_preds_sd, train_preds_sdmd, train_targets
    
    # Validation loop
    val_preds_sd = []
    val_preds_sdmd = []
    val_preds_md = []
    val_targets = []
    
    for i, val_loader in enumerate(val_loaders):
        subset_preds_sd = []
        subset_preds_sdmd = []
        subset_preds_md = []
        subset_targets = []
        
        with torch.no_grad():
            for batch_idx, data in enumerate(val_loader):
                try:
                    data_sd = masking_transformation_sd.collate(data)
                    if dual_input:
                        # dual input
                        #data_sd_no_silent = silent_prunner.collate(data_sd)
                        #data_sdmd = random_masking.collate(data_sd_no_silent)
                        data_sdmd = random_masking.collate(data_sd)
                        gpu_sd_graph = data_sd.to(device)
                        gpu_sdmd_graph = data_sdmd.to(device)
                        inputs, targets = [gpu_sd_graph, gpu_sdmd_graph], gpu_sd_graph.y
                    else:
                        # single input
                        data_sdmd = masking_transformation_md.collate(data_sd)
                        gpu_graph = data_sdmd.to(device)
                        inputs, targets = gpu_graph, gpu_graph.y

                    # Sanity checks on inputs and targets
                    assert torch.all(torch.isfinite(data_sd.x)), "NaN detected in data_sd.x"
                    assert torch.all(torch.isfinite(data_sdmd.x)), "NaN detected in data_sdmd.x"
                    assert torch.all(torch.isfinite(data_sd.x_traces)), "NaN detected in data_sd.x_traces"
                    assert torch.all(torch.isfinite(data_sdmd.x_traces)), "NaN detected in data_sd.x_traces"
                    assert torch.all(torch.isfinite(targets)), "NaN detected in targets"
    
                    output1, output2, output3 = triheaded_model(inputs)
                    subset_preds_sd.append(output1.detach().cpu())
                    subset_preds_sdmd.append(output2.detach().cpu())
                    subset_preds_md.append(output3.detach().cpu())
                    subset_targets.append(targets.detach().cpu())
                except AssertionError as e:
                    # Print the error message and skip this batch
                    print(f"Skipping batch {batch_idx} in subset {i+1} due to assertion error: {e}")
                    continue
            
        val_preds_sd.append(torch.cat(subset_preds_sd))
        val_preds_sdmd.append(torch.cat(subset_preds_sdmd))
        val_preds_md.append(torch.cat(subset_preds_md))
        val_targets.append(torch.cat(subset_targets))

    
    # Concatenate validation predictions and targets
    val_preds_sd = torch.cat(val_preds_sd)
    val_preds_sdmd = torch.cat(val_preds_sdmd)
    val_preds_md = torch.cat(val_preds_md)
    val_targets = torch.cat(val_targets)
    
    # Calculate validation metrics
    val_loss_sd = criterion(val_preds_sd, val_targets.unsqueeze(1).float())
    val_loss_sdmd = criterion(val_preds_sdmd, val_targets.unsqueeze(1).float())
    val_loss_md = criterion(val_preds_md, val_targets.unsqueeze(1).float())
    val_ROC_AUC_sd = roc_auc_score(val_targets.numpy(), val_preds_sd.numpy())
    val_ROC_AUC_sdmd = roc_auc_score(val_targets.numpy(), val_preds_sdmd.numpy())
    val_ROC_AUC_md = roc_auc_score(val_targets.numpy(), val_preds_md.numpy())
    print(f"Epoch {epoch + 1}/{num_epochs}: Val ROC-AUC SD: {val_ROC_AUC_sd.item():.4f}, Val Loss SD: {val_loss_sd.item():.4f}")
    print(f"Epoch {epoch + 1}/{num_epochs}: Val ROC-AUC SDMD: {val_ROC_AUC_sdmd.item():.4f}, Val Loss SDMD: {val_loss_sdmd.item():.4f}")
    print(f"Epoch {epoch + 1}/{num_epochs}: Val ROC-AUC MD: {val_ROC_AUC_md.item():.4f}, Val Loss MD: {val_loss_md.item():.4f}")
    
    del val_preds_sd, val_preds_sdmd, val_targets
    
    # Append training and validation metrics to the CSV file
    with open(csv_filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([epoch + 1, time.time() - start_time,
                         train_ROC_AUC_sd, train_ROC_AUC_sdmd, train_ROC_AUC_md,
                         train_loss_sd.item(), train_loss_sdmd.item(), train_loss_md.item(),
                         val_ROC_AUC_sd, val_ROC_AUC_sdmd, val_ROC_AUC_md,
                         
                         val_loss_sd.item(), val_loss_sdmd.item(), val_loss_md.item()])
    
    # Step the scheduler
    scheduler.step(val_loss_sdmd.item())
    
    # Early stopping and saving the best model
    if val_ROC_AUC_sdmd > best_val_ROC_AUC:
        best_val_ROC_AUC = val_ROC_AUC_sdmd
        counter = 0
        torch.save(triheaded_model, model_filename)
        torch.save({
            'epoch': epoch,
            'model_state_dict': triheaded_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_ROC_AUC': best_val_ROC_AUC
        }, f'{model_id}_checkpoint.pth')
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping: No improvement in validation accuracy.")
            break
    
    gc.collect()
    torch.save(triheaded_model, last_model_filename)

print("Training complete")

Subset 1 Training: 100%|██████████| 736/736 [08:13<00:00,  1.49it/s]
Subset 2 Training: 100%|██████████| 736/736 [08:19<00:00,  1.47it/s]
Subset 3 Training: 100%|██████████| 736/736 [08:23<00:00,  1.46it/s]
Subset 4 Training:  49%|████▊     | 358/736 [04:02<02:55,  2.16it/s]