In [8]:
import numpy as np
import subprocess
import tqdm
from tqdm import tqdm
import pandas as pd

import os
import os.path as osp

import glob

import h5py
import uproot

import torch
from torch import nn


from torch_geometric.data import Data
from torch_geometric.data import Dataset
from torch_geometric.data import DataLoader

import awkward as ak
import random

#singularity shell --bind /afs/cern.ch/user/p/pkakhand/public/CL/  /afs/cern.ch/user/p/pkakhand/geometricdl.sif

#singularity shell --bind /eos/project/c/contrast/public/solar/  /afs/cern.ch/user/p/pkakhand/geometricdl.sif
#source /cvmfs/sft.cern.ch/lcg/views/LCG_103cuda/x86_64-centos9-gcc11-opt/setup.sh

def find_highest_branch(path, base_name):
    with uproot.open(path) as f:
        # Find keys that exactly match the base_name (not containing other variations)
        branches = [k for k in f.keys() if k.startswith(base_name + ';')]
        
        # Sort and select the highest-numbered branch
        sorted_branches = sorted(branches, key=lambda x: int(x.split(';')[-1]))
        return sorted_branches[-1] if sorted_branches else None

def remove_duplicates(A, B, threshold=0.9):    
    """
    A: awkward array of cluster IDs (or similar)
    B: awkward array of the same shape indicating 'fractions'
    threshold: fraction cutoff above which to keep the cluster
    """
    all_masks = []
    for event_idx in range(len(A)):
        flat_A = np.array(ak.flatten(A[event_idx]))
        flat_B = np.array(ak.flatten(B[event_idx]))

        # Initialize a mask to keep track of which values to keep
        mask = np.zeros_like(flat_A, dtype=bool)

        # For each unique cluster ID, check if its maximum fraction > threshold
        for elem in np.unique(flat_A):
            indices = np.where(flat_A == elem)[0]
            max_b = np.max(flat_B[indices])
            
            # Only keep one occurrence if max_b > threshold; otherwise, keep none
            if max_b > threshold:
                max_index = indices[np.argmax(flat_B[indices])]
                mask[max_index] = True
        
        # Reshape mask to the original (unflattened) shape
        unflattened_mask = ak.unflatten(mask, ak.num(A[event_idx]))
        all_masks.append(unflattened_mask)
        
    return ak.Array(all_masks)


class CCV1(Dataset):
    r'''
        input: layer clusters

    '''

    url = '/dummy/'

    def __init__(self, root, transform=None, max_events=1e8, inp = 'train'):
        super(CCV1, self).__init__(root, transform)
        self.step_size = 500
        self.inp = inp
        self.max_events = max_events
        self.fill_data(max_events)
    

    def fill_data(self,max_events):
        counter = 0
        arrLens0 = []
        arrLens1 = []

        print("### Loading data")
        for fi,path in enumerate(tqdm(self.raw_paths)):


            if self.inp == 'train':
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
                track_path = find_highest_branch(path, 'tracksters')
            elif self.inp == 'val':
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
                track_path = find_highest_branch(path, 'tracksters')
            else:
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
                track_path = find_highest_branch(path, 'tracksters')
            
            crosstree =  uproot.open(path)[cluster_path]
            crosscounter = 0
            
            # Create an iterator for the tracksters branch to load its 'vertices_indexes'
            tracksters_iter = uproot.iterate(
                f"{path}:{track_path}",
                ["vertices_indexes"],
                step_size=self.step_size
            )
            for array in uproot.iterate(f"{path}:{sim_path}", ["vertices_x", "vertices_y", "vertices_z", 
            "vertices_energy", "vertices_multiplicity", "vertices_time", "vertices_indexes", "barycenter_x", "barycenter_y", "barycenter_z"], step_size=self.step_size):
                
                tmp_tracksters_data = next(tracksters_iter)
                tmp_tracksters_vertices_indexes = tmp_tracksters_data["vertices_indexes"]
            
                tmp_stsCP_vertices_x = array['vertices_x']
                tmp_stsCP_vertices_y = array['vertices_y']
                tmp_stsCP_vertices_z = array['vertices_z']
                tmp_stsCP_vertices_energy = array['vertices_energy']
                tmp_stsCP_vertices_time = array['vertices_time']
                tmp_stsCP_vertices_indexes = array['vertices_indexes']
                tmp_stsCP_barycenter_x = array['barycenter_x']
                tmp_stsCP_barycenter_y = array['barycenter_y']
                tmp_stsCP_barycenter_z = array['barycenter_z']


                tmp_stsCP_vertices_multiplicity = array['vertices_multiplicity']
                
                # weighted energies (A LC appears in its caloparticle assignment array as the energy it contributes not full energy)
                #tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy * tmp_stsCP_vertices_multiplicity
                
                self.step_size = min(self.step_size,len(tmp_stsCP_vertices_x))


                # Code block for reading from other tree
                tmp_all_vertices_layer_id = crosstree['cluster_layer_id'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                #tmp_all_vertices_radius = crosstree['cluster_radius'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_noh = crosstree['cluster_number_of_hits'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_eta = crosstree['position_eta'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_phi = crosstree['position_phi'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                crosscounter += 1

                layer_id_list = []
                radius_list = []
                noh_list = []
                eta_list = []
                phi_list = []
                for evt_row in range(len(tmp_all_vertices_noh)):
                    #print("Event no: %i"%evt_row)
                    #print("There are %i particles in this event"%len(tmp_stsCP_vertices_indexes[evt_row]))
                    layer_id_list_one_event = []
                    #radius_list_one_event = []
                    noh_list_one_event = []
                    eta_list_one_event = []
                    phi_list_one_event = []
                    for particle in range(len(tmp_stsCP_vertices_indexes[evt_row])):
                        #print("Particle no: %i"%particle)
                        #print("A")
                        #print(np.array(tmp_all_vertices_radius[evt_row]).shape)
                        #print("B")
                        #print(np.array(tmp_stsCP_vertices_indexes[evt_row][particle]).shape)
                        #print("C")
                        tmp_stsCP_vertices_layer_id_one_particle = tmp_all_vertices_layer_id[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        #tmp_stsCP_vertices_radius_one_particle = tmp_all_vertices_radius[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_noh_one_particle = tmp_all_vertices_noh[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_eta_one_particle = tmp_all_vertices_eta[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_phi_one_particle = tmp_all_vertices_phi[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        #print(tmp_stsCP_vertices_radius_one_particle)
                        layer_id_list_one_event.append(tmp_stsCP_vertices_layer_id_one_particle)
                        #radius_list_one_event.append(tmp_stsCP_vertices_radius_one_particle)
                        noh_list_one_event.append(tmp_stsCP_vertices_noh_one_particle)
                        eta_list_one_event.append(tmp_stsCP_vertices_eta_one_particle)
                        phi_list_one_event.append(tmp_stsCP_vertices_phi_one_particle)
                    layer_id_list.append(layer_id_list_one_event)
                    #radius_list.append(radius_list_one_event)
                    noh_list.append(noh_list_one_event)
                    eta_list.append(eta_list_one_event)
                    phi_list.append(phi_list_one_event)
                tmp_stsCP_vertices_layer_id = ak.Array(layer_id_list)                
                #tmp_stsCP_vertices_radius = ak.Array(radius_list)                
                tmp_stsCP_vertices_noh = ak.Array(noh_list)                
                tmp_stsCP_vertices_eta = ak.Array(eta_list)                
                tmp_stsCP_vertices_phi = ak.Array(phi_list)                
                
                # NEW FILTERING: For each event, remove simtracksters entries whose index is not found
                # in any sub-array of the tracksters branch.
                mask_list = []
                for sim_evt, track_evt in zip(tmp_stsCP_vertices_indexes, tmp_tracksters_vertices_indexes):
                    # Flatten all tracksters indices into a single set
                    track_flat = ak.flatten(track_evt)
                    track_set = set(ak.to_list(track_flat))
                    # Convert the sim event to a nested Python list
                    sim_evt_list = ak.to_list(sim_evt)
                    # Build a nested mask preserving the structure:
                    # For each sub-array in the sim event, check each element for membership in track_set.
                    mask_evt = [[elem in track_set for elem in subarr] for subarr in sim_evt_list]
                    mask_list.append(mask_evt)
                mask_track = ak.Array(mask_list)

                # Apply the new mask to all simtracksters arrays:
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[mask_track]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[mask_track]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[mask_track]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[mask_track]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[mask_track]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[mask_track]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[mask_track]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[mask_track]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[mask_track]
                tmp_stsCP_vertices_indexes = tmp_stsCP_vertices_indexes[mask_track]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[mask_track]

                
                # Remove duplicates by only allowing the caloparticle that contributed the most energy to a LC to actually contribute.
                # Further filtering: remove events with fewer than 2 vertices.
                skim_mask = [len(e) >= 2 for e in tmp_stsCP_vertices_x]
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask]
                tmp_stsCP_vertices_indexes = tmp_stsCP_vertices_indexes[skim_mask]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask]

                if counter == 0:
                    self.stsCP_vertices_indexes_unfilt = tmp_stsCP_vertices_indexes
                    self.stsCP_vertices_multiplicity_unfilt = tmp_stsCP_vertices_multiplicity
                else:
                    self.stsCP_vertices_indexes_unfilt = ak.concatenate(
                        (self.stsCP_vertices_indexes_unfilt, tmp_stsCP_vertices_indexes))
                    self.stsCP_vertices_multiplicity_unfilt = ak.concatenate(
                        (self.stsCP_vertices_multiplicity_unfilt, tmp_stsCP_vertices_multiplicity))
                
                energyPercent = 1 / tmp_stsCP_vertices_multiplicity
                skim_mask_energyPercent = remove_duplicates(tmp_stsCP_vertices_indexes, energyPercent)
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask_energyPercent]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask_energyPercent]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask_energyPercent]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask_energyPercent]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask_energyPercent]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask_energyPercent]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask_energyPercent]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask_energyPercent]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask_energyPercent]
                tmp_stsCP_vertices_indexes_filt = tmp_stsCP_vertices_indexes[skim_mask_energyPercent]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask_energyPercent]


                if counter == 0:
                    self.stsCP_vertices_x = tmp_stsCP_vertices_x
                    self.stsCP_vertices_y = tmp_stsCP_vertices_y
                    self.stsCP_vertices_z = tmp_stsCP_vertices_z
                    self.stsCP_vertices_energy = tmp_stsCP_vertices_energy
                    self.stsCP_vertices_time = tmp_stsCP_vertices_time
                    self.stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id
                    self.stsCP_vertices_noh = tmp_stsCP_vertices_noh
                    self.stsCP_vertices_eta = tmp_stsCP_vertices_eta
                    self.stsCP_vertices_phi = tmp_stsCP_vertices_phi
                    self.stsCP_vertices_indexes = tmp_stsCP_vertices_indexes
                    self.stsCP_barycenter_x = tmp_stsCP_barycenter_x
                    self.stsCP_barycenter_y = tmp_stsCP_barycenter_y
                    self.stsCP_barycenter_z = tmp_stsCP_barycenter_z
                    self.stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity
                    self.stsCP_vertices_indexes_filt = tmp_stsCP_vertices_indexes_filt
                else:
                    self.stsCP_vertices_x = ak.concatenate((self.stsCP_vertices_x, tmp_stsCP_vertices_x))
                    self.stsCP_vertices_y = ak.concatenate((self.stsCP_vertices_y, tmp_stsCP_vertices_y))
                    self.stsCP_vertices_z = ak.concatenate((self.stsCP_vertices_z, tmp_stsCP_vertices_z))
                    self.stsCP_vertices_energy = ak.concatenate((self.stsCP_vertices_energy, tmp_stsCP_vertices_energy))
                    self.stsCP_vertices_time = ak.concatenate((self.stsCP_vertices_time, tmp_stsCP_vertices_time))
                    self.stsCP_vertices_layer_id = ak.concatenate((self.stsCP_vertices_layer_id, tmp_stsCP_vertices_layer_id))
                    self.stsCP_vertices_noh = ak.concatenate((self.stsCP_vertices_noh, tmp_stsCP_vertices_noh))
                    self.stsCP_vertices_eta = ak.concatenate((self.stsCP_vertices_eta, tmp_stsCP_vertices_eta))
                    self.stsCP_vertices_phi = ak.concatenate((self.stsCP_vertices_phi, tmp_stsCP_vertices_phi))
                    self.stsCP_vertices_indexes = ak.concatenate((self.stsCP_vertices_indexes, tmp_stsCP_vertices_indexes))
                    self.stsCP_barycenter_x = ak.concatenate((self.stsCP_barycenter_x, tmp_stsCP_barycenter_x))
                    self.stsCP_barycenter_y = ak.concatenate((self.stsCP_barycenter_y, tmp_stsCP_barycenter_y))
                    self.stsCP_barycenter_z = ak.concatenate((self.stsCP_barycenter_z, tmp_stsCP_barycenter_z))
                    self.stsCP_vertices_multiplicity = ak.concatenate((self.stsCP_vertices_multiplicity, tmp_stsCP_vertices_multiplicity))
                    self.stsCP_vertices_indexes_filt = ak.concatenate((self.stsCP_vertices_indexes_filt, tmp_stsCP_vertices_indexes_filt))
                
                
                counter += 1
                if len(self.stsCP_vertices_x) > max_events:
                    print(f"Reached {max_events}!")
                    break
            if len(self.stsCP_vertices_x) > max_events:
                break
    
 
            
            
    def download(self):
        raise RuntimeError(
            'Dataset not found. Please download it from {} and move all '
            '*.z files to {}'.format(self.url, self.raw_dir))

    def len(self):
        return len(self.stsCP_vertices_x)

    @property
    def raw_file_names(self):
        raw_files = sorted(glob.glob(osp.join(self.raw_dir, '*.root')))
        
        #raw_files = [osp.join(self.raw_dir, 'step3_NTUPLE.root')]

        return raw_files

    @property
    def processed_file_names(self):
        return []

    def get(self, idx):
        # 1) Flatten node features for the event
        lc_x = self.stsCP_vertices_x[idx]
        lc_y = self.stsCP_vertices_y[idx]
        lc_z = self.stsCP_vertices_z[idx]
        lc_e = self.stsCP_vertices_energy[idx]
        lc_layer_id = self.stsCP_vertices_layer_id[idx]
        lc_noh = self.stsCP_vertices_noh[idx]
        lc_eta = self.stsCP_vertices_eta[idx]
        lc_phi = self.stsCP_vertices_phi[idx]

        flat_lc_x = np.expand_dims(np.array(ak.flatten(lc_x)), axis=1)
        flat_lc_y = np.expand_dims(np.array(ak.flatten(lc_y)), axis=1)
        flat_lc_z = np.expand_dims(np.array(ak.flatten(lc_z)), axis=1)
        flat_lc_e = np.expand_dims(np.array(ak.flatten(lc_e)), axis=1)
        flat_lc_layer_id = np.expand_dims(np.array(ak.flatten(lc_layer_id)), axis=1)
        flat_lc_noh = np.expand_dims(np.array(ak.flatten(lc_noh)), axis=1)
        flat_lc_eta = np.expand_dims(np.array(ak.flatten(lc_eta)), axis=1)
        flat_lc_phi = np.expand_dims(np.array(ak.flatten(lc_phi)), axis=1)
        flat_lc_feats = np.concatenate(
            (flat_lc_x, flat_lc_y, flat_lc_z, flat_lc_e,
             flat_lc_layer_id, flat_lc_noh, flat_lc_eta, flat_lc_phi),
            axis=-1
        )
        total_lc = flat_lc_feats.shape[0]
        x = torch.from_numpy(flat_lc_feats).float()
        
        result = np.concatenate([np.full(len(subarr), i) for i, subarr in enumerate(lc_x)])
        result_list = result.tolist() 

        data = Data(
            x=x,
            assoc = result_list
        )
        return data


In [9]:
ipath = "/vols/cms/mm1221/Data/mix/train/"
vpath = "/vols/cms/mm1221/Data/mix/val/"
data_train = CCV1(ipath, max_events=10, inp = 'train')
data_val = CCV1(vpath, max_events=10, inp='val')

### Loading data


  0%|                                                                                                                                                             | 0/3 [00:57<?, ?it/s]


Reached 10!
### Loading data


  0%|                                                                                                                                                             | 0/1 [00:54<?, ?it/s]

Reached 10!





In [6]:
print(data_train[1].assoc)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 

In [7]:
print(data_train[1].x)

tensor([[-22.7473,  88.3267, 335.3770,  ...,  27.0000,   2.0132,   1.8229],
        [-22.7024,  87.9500, 334.3830,  ...,  22.0000,   2.0144,   1.8234],
        [-22.9461,  88.6947, 337.4400,  ...,  36.0000,   2.0149,   1.8240],
        ...,
        [-68.1731,  83.6942, 337.4400,  ...,   2.0000,   1.8576,   2.2543],
        [-51.0651,  87.9233, 334.3830,  ...,   2.0000,   1.9060,   2.0970],
        [-54.2263,  88.5393, 334.3830,  ...,   2.0000,   1.8860,   2.1203]])


In [14]:
print(len(data_train[1].assoc))

265


In [12]:
print(len(data_train[1].x))

265


In [4]:
import torch
import torch.nn as nn
from torch_geometric.nn import DynamicEdgeConv

class Net(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=4, dropout=0.3, contrastive_dim=8, k=20):
        """
        Initializes the neural network with DynamicEdgeConv layers.

        Args:
            hidden_dim (int): Dimension of hidden layers.
            num_layers (int): Total number of DynamicEdgeConv layers.
            dropout (float): Dropout rate.
            contrastive_dim (int): Dimension of the contrastive output.
            k (int): Number of nearest neighbors to use in DynamicEdgeConv.
        """
        super(Net, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.contrastive_dim = contrastive_dim
        self.k = k

        # Input feature normalization (BatchNorm1d normalizes feature-wise across samples)
        # self.input_norm = nn.BatchNorm1d(8)

        # Input encoder
        self.lc_encode = nn.Sequential(
            nn.Linear(8, 32),
            nn.ELU(),
            nn.Linear(32, hidden_dim),
            nn.ELU()
        )

        # Define the network's convolutional layers using DynamicEdgeConv layers
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            # For odd-numbered layers (1-based index: i+1), use k//2
            # For even-numbered layers, use k
            if (i + 1) % 2 == 0:
                current_k = self.k
            else:
                current_k = self.k 

            mlp = nn.Sequential(
                nn.Linear(2 * hidden_dim, hidden_dim),
                nn.ELU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(p=dropout)
            )
            conv = DynamicEdgeConv(mlp, k=current_k, aggr="max")
            self.convs.append(conv)

        # Output layer
        self.output = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(64, 32),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(32, contrastive_dim)
        )

    def forward(self, x, batch):
        """
        Forward pass of the network.

        Args:
            x (torch.Tensor): Input node features of shape (N, 8).
            batch (torch.Tensor): Batch vector that assigns each node to an example in the batch.

        Returns:
            torch.Tensor: Output features after processing.
            torch.Tensor: Batch vector.
        """
        # Normalize input features
        # x = self.input_norm(x)

        # Input encoding
        x_lc_enc = self.lc_encode(x)  # Shape: (N, hidden_dim)

        # Apply DynamicEdgeConv layers with residual connections
        feats = x_lc_enc
        for conv in self.convs:
            feats = conv(feats, batch) + feats

        # Final output
        out = self.output(feats)
        return out, batch

In [7]:
import torch
import torch.nn.functional as F

def supcon_loss(embeddings, group_ids, temperature=0.1):
    """
    Computes the Supervised Contrastive (SupCon) loss as described in:
    "Supervised Contrastive Learning" (Khosla et al.).
    
    For each anchor, all other examples in the batch that share the same group id
    (excluding the anchor itself) are considered positives.
    
    Args:
        embeddings: Tensor of shape (N, D) (raw outputs; they will be normalized inside).
        group_ids: 1D Tensor (length N) of group or class identifiers.
        temperature: Temperature scaling factor.
        
    Returns:
        Scalar loss (mean over anchors).
    """
    # Normalize embeddings so that cosine similarity becomes the dot product.
    norm_emb = F.normalize(embeddings, p=2, dim=1)  # shape: (N, D)
    
    # Compute cosine similarity matrix, scaled by temperature.
    sim_matrix = torch.div(torch.matmul(norm_emb, norm_emb.t()), temperature)  # shape: (N, N)
    
    # Create a mask to remove self-similarity (diagonal elements).
    batch_size = embeddings.size(0)
    self_mask = torch.eye(batch_size, dtype=torch.bool, device=embeddings.device)
    sim_matrix_masked = sim_matrix.masked_fill(self_mask, -float('inf'))
    
    # Create positive mask: positive if group_ids match and not self.
    positive_mask = (group_ids.unsqueeze(1) == group_ids.unsqueeze(0)) & ~self_mask
    
    # Compute the denominator: sum over all exponentiated similarities for j != i.
    exp_sim = torch.exp(sim_matrix_masked)
    denominator = exp_sim.sum(dim=1)  # shape: (N,)
    
    # Numerator: sum over the positives.
    pos_exp_sum = (exp_sim * positive_mask.float()).sum(dim=1)
    
    # Count positives per anchor to average the loss.
    pos_counts = positive_mask.sum(dim=1).float()
    
    # For anchors with no positives, we set the loss to 0.
    loss = torch.zeros_like(denominator)
    valid = pos_counts > 0
    loss[valid] = - (1 / pos_counts[valid]) * torch.log(pos_exp_sum[valid] / denominator[valid])
    
    return loss.mean()


def contrastive_loss_curriculum(embeddings, group_ids, temperature = 0.1):
    """
    Curriculum loss that uses both positive and negative blending.
    
    Delegates to contrastive_loss_curriculum_both.
    
    Args:
        embeddings: Tensor of shape (N, D).
        pos_indices: 1D Tensor (length N).
        group_ids: 1D Tensor (length N).
        temperature: Temperature scaling factor.
        alpha: Blending parameter.
        
    Returns:
        Scalar loss.
    """
    return supcon_loss(embeddings, group_ids, temperature)



#################################
# Training and Testing Functions
#################################

def train_new(train_loader, model, optimizer, device):
    model.train()
    total_loss = torch.zeros(1, device=device)
    for data in tqdm(train_loader, desc="Training"):
        data = data.to(device)
        optimizer.zero_grad()
        
        # Convert data.assoc to tensor if needed.
        if isinstance(data.assoc, list):
            if isinstance(data.assoc[0], list):
                assoc_tensor = torch.cat([torch.tensor(a, dtype=torch.int64, device=data.x.device)
                                          for a in data.assoc])
            else:
                assoc_tensor = torch.tensor(data.assoc, device=data.x.device)
        else:
            assoc_tensor = data.assoc

        #edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        embeddings, _ = model(data.x, data.x_batch)
        
        # Partition batch by event.
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = torch.zeros(1, device=device)
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            event_group_ids = assoc_tensor[start_idx:end_idx]
            loss_event = contrastive_loss_curriculum(event_embeddings,
                                                     event_group_ids, temperature=0.1)
            loss_event_total += loss_event
            start_idx = end_idx
        
        loss = loss_event_total / len(counts)
        loss.backward()
        total_loss += loss
        optimizer.step()
    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test_new(test_loader, model, device):
    model.eval()
    total_loss = torch.zeros(1, device=device)
    for data in tqdm(test_loader, desc="Validation"):
        data = data.to(device)
        
        if isinstance(data.assoc, list):
            if isinstance(data.assoc[0], list):
                assoc_tensor = torch.cat([torch.tensor(a, dtype=torch.int64, device=data.x.device)
                                          for a in data.assoc])
            else:
                assoc_tensor = torch.tensor(data.assoc, device=data.x.device)
        else:
            assoc_tensor = data.assoc
        
        #edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        embeddings, _ = model(data.x, data.x_batch)
        
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = torch.zeros(1, device=device)
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            event_group_ids = assoc_tensor[start_idx:end_idx]
            loss_event = contrastive_loss_curriculum(event_embeddings,
                                                     event_group_ids, temperature=0.1)
            loss_event_total += loss_event
            start_idx = end_idx
        total_loss += loss_event_total / len(counts)
    return total_loss / len(test_loader.dataset)

In [8]:
# Initialize model with passed hyperparameters
model = Net(
    hidden_dim=128,
    num_layers=3,
    dropout=0.3,
    contrastive_dim=128,
    k=48
)

k_value = 48
BS = 1
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

device = torch.device('cpu')
# Load DataLoader with current batch_size
train_loader = DataLoader(data_train, batch_size=BS, shuffle=True, follow_batch=['x'])
val_loader = DataLoader(data_val, batch_size=BS, shuffle=False, follow_batch=['x'])

# Train and evaluate the model for the specified number of epochs
best_val_loss = float('inf')

# Store train and validation losses for all epochs
train_losses = []
val_losses = []

output_dir = '/vols/cms/mm1221/hgcal/Mixed/LC/Full/resultstest/'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
patience =30    
epochs = 5
for epoch in range(epochs):
    # For epochs 1 to 150, gradually increase alpha from 0 to 1.


    print(f"Epoch {epoch+1}/{epochs}")
    train_loss = train_new(train_loader, model, optimizer, device)
    val_loss = test_new(val_loader, model, device)

    train_losses.append(train_loss.item())
    val_losses.append(val_loss.item())
    scheduler.step()

    # Save best model if validation loss improves.
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
        torch.save(model.state_dict(), os.path.join(output_dir, 'best_model.pt'))
    else:
        no_improvement_epochs += 1

    # Save intermediate checkpoint.
    state_dicts = {'model': model.state_dict(),
                   'opt': optimizer.state_dict(),
                   'lr': scheduler.state_dict()}
    torch.save(state_dicts, os.path.join(output_dir, f'epoch-{epoch+1}.pt'))

    print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss.item():.8f}, Validation Loss: {val_loss.item():.8f}")
    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered. No improvement for {patience} epochs.")
        break

# Save training history.
import pandas as pd
results_df = pd.DataFrame({
    'epoch': list(range(1, len(train_losses) + 1)),
    'train_loss': train_losses,
    'val_loss': val_losses
})
results_df.to_csv(os.path.join(output_dir, 'continued_training_loss.csv'), index=False)
print(f"Saved loss curves to {os.path.join(output_dir, 'continued_training_loss.csv')}")

# Save final model.
torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pt'))
print("Training complete. Final model saved.")

Epoch 1/5


Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 456/456 [01:36<00:00,  4.73it/s]
Validation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 449/449 [00:28<00:00, 16.01it/s]


Epoch 1/5 - Train Loss: 0.01637866, Validation Loss: 0.03010262
Epoch 2/5


Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 456/456 [01:39<00:00,  4.57it/s]
Validation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 449/449 [00:27<00:00, 16.46it/s]


Epoch 2/5 - Train Loss: 0.01258302, Validation Loss: 0.03196397
Epoch 3/5


Training:  20%|███████████████████████████▋                                                                                                            | 93/456 [00:21<01:22,  4.41it/s]


KeyboardInterrupt: 