In [1]:
import numpy as np
import subprocess
import tqdm
from tqdm import tqdm
import pandas as pd

import os
import os.path as osp
import glob

import h5py
import uproot

import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.data import Dataset
from torch_geometric.data import DataLoader

import awkward as ak
import random

# Function to find the highest numbered branch matching base_name in an uproot file
def find_highest_branch(path, base_name):
    with uproot.open(path) as f:
        branches = [k for k in f.keys() if k.startswith(base_name + ';')]
        sorted_branches = sorted(branches, key=lambda x: int(x.split(';')[-1]))
        return sorted_branches[-1] if sorted_branches else None

# Function to remove duplicates: for each event, only keep the entry with the highest B value for each unique element in A.
def remove_duplicates(A, B):    
    all_masks = []
    for event_idx, event in enumerate(A):
        flat_A = np.array(ak.flatten(A[event_idx]))
        flat_B = np.array(ak.flatten(B[event_idx]))
        mask = np.zeros_like(flat_A, dtype=bool)
        for elem in np.unique(flat_A):
            indices = np.where(flat_A == elem)[0]
            if len(indices) > 1:
                max_index = indices[np.argmax(flat_B[indices])]
                mask[max_index] = True
            else:
                mask[indices[0]] = True
        unflattened_mask = ak.unflatten(mask, ak.num(A[event_idx]))
        all_masks.append(unflattened_mask)
    return ak.Array(all_masks)

class CCV1(Dataset):
    r'''
    Dataset for layer clusters.
    For each event it builds:
      - x: node features, shape (N, D)
      - groups: top-3 contributing group IDs per node (N, 3)
      - fractions: corresponding energy fractions (N, 3)
      - x_pe: positive edge pairs, shape (N, 2)
      - x_ne: negative edge pairs, shape (N, 2)
    '''
    url = '/dummy/'

    def __init__(self, root, transform=None, max_events=1e8, inp='train'):
        super(CCV1, self).__init__(root, transform)
        self.step_size = 500
        self.inp = inp
        self.max_events = max_events
        
        # These will hold the precomputed groups and fractions arrays (one per event)
        self.precomputed_groups = []
        self.precomputed_fractions = []
        
        self.fill_data(max_events)
        self.precompute_pairings()

    def fill_data(self, max_events):
        counter = 0
        print("### Loading data")
        for fi, path in enumerate(tqdm(self.raw_file_names)):
            if self.inp == 'train':
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
                track_path = find_highest_branch(path, 'tracksters')
            elif self.inp == 'val':
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
                track_path = find_highest_branch(path, 'tracksters')
            else:
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
                track_path = find_highest_branch(path, 'tracksters')
            
            crosstree = uproot.open(path)[cluster_path]
            crosscounter = 0
            for array in uproot.iterate(f"{path}:{sim_path}", 
                                         ["vertices_x", "vertices_y", "vertices_z", 
                                          "vertices_energy", "vertices_multiplicity", "vertices_time", 
                                          "vertices_indexes", "barycenter_x", "barycenter_y", "barycenter_z"],
                                         step_size=self.step_size):
                tmp_stsCP_vertices_x = array['vertices_x']
                tmp_stsCP_vertices_y = array['vertices_y']
                tmp_stsCP_vertices_z = array['vertices_z']
                tmp_stsCP_vertices_energy = array['vertices_energy']
                tmp_stsCP_vertices_time = array['vertices_time']
                tmp_stsCP_vertices_indexes = array['vertices_indexes']
                tmp_stsCP_barycenter_x = array['barycenter_x']
                tmp_stsCP_barycenter_y = array['barycenter_y']
                tmp_stsCP_barycenter_z = array['barycenter_z']
                tmp_stsCP_vertices_multiplicity = array['vertices_multiplicity']
                
                self.step_size = min(self.step_size, len(tmp_stsCP_vertices_x))

                tmp_all_vertices_layer_id = crosstree['cluster_layer_id'].array(entry_start=crosscounter*self.step_size,
                                                                                 entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_noh = crosstree['cluster_number_of_hits'].array(entry_start=crosscounter*self.step_size,
                                                                                 entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_eta = crosstree['position_eta'].array(entry_start=crosscounter*self.step_size,
                                                                       entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_phi = crosstree['position_phi'].array(entry_start=crosscounter*self.step_size,
                                                                       entry_stop=(crosscounter+1)*self.step_size)
                crosscounter += 1

                layer_id_list = []
                noh_list = []
                eta_list = []
                phi_list = []
                for evt_row in range(len(tmp_all_vertices_noh)):
                    layer_id_list_one_event = []
                    noh_list_one_event = []
                    eta_list_one_event = []
                    phi_list_one_event = []
                    for particle in range(len(tmp_stsCP_vertices_indexes[evt_row])):
                        tmp_stsCP_vertices_layer_id_one_particle = tmp_all_vertices_layer_id[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_noh_one_particle = tmp_all_vertices_noh[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_eta_one_particle = tmp_all_vertices_eta[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_phi_one_particle = tmp_all_vertices_phi[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        layer_id_list_one_event.append(tmp_stsCP_vertices_layer_id_one_particle)
                        noh_list_one_event.append(tmp_stsCP_vertices_noh_one_particle)
                        eta_list_one_event.append(tmp_stsCP_vertices_eta_one_particle)
                        phi_list_one_event.append(tmp_stsCP_vertices_phi_one_particle)
                    layer_id_list.append(layer_id_list_one_event)
                    noh_list.append(noh_list_one_event)
                    eta_list.append(eta_list_one_event)
                    phi_list.append(phi_list_one_event)
                tmp_stsCP_vertices_layer_id = ak.Array(layer_id_list)
                tmp_stsCP_vertices_noh = ak.Array(noh_list)
                tmp_stsCP_vertices_eta = ak.Array(eta_list)
                tmp_stsCP_vertices_phi = ak.Array(phi_list)
                
                skim_mask_noh = tmp_stsCP_vertices_noh > 1.0
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask_noh]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask_noh]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask_noh]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask_noh]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask_noh]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask_noh]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask_noh]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask_noh]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask_noh]
                tmp_stsCP_vertices_indexes = tmp_stsCP_vertices_indexes[skim_mask_noh]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask_noh]
                
                skim_mask = []
                for e in tmp_stsCP_vertices_x:
                    skim_mask.append(len(e) >= 2)
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask]
                tmp_stsCP_vertices_indexes = tmp_stsCP_vertices_indexes[skim_mask]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask]

                if counter == 0:
                    self.stsCP_vertices_indexes_unfilt = tmp_stsCP_vertices_indexes
                    self.stsCP_vertices_multiplicity_unfilt = tmp_stsCP_vertices_multiplicity
                else:
                    self.stsCP_vertices_indexes_unfilt = ak.concatenate(
                        (self.stsCP_vertices_indexes_unfilt, tmp_stsCP_vertices_indexes))
                    self.stsCP_vertices_multiplicity_unfilt = ak.concatenate(
                        (self.stsCP_vertices_multiplicity_unfilt, tmp_stsCP_vertices_multiplicity))
                
                energyPercent = 1 / tmp_stsCP_vertices_multiplicity
                skim_mask_energyPercent = remove_duplicates(tmp_stsCP_vertices_indexes, energyPercent)
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask_energyPercent]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask_energyPercent]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask_energyPercent]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask_energyPercent]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask_energyPercent]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask_energyPercent]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask_energyPercent]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask_energyPercent]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask_energyPercent]
                tmp_stsCP_vertices_indexes_filt = tmp_stsCP_vertices_indexes[skim_mask_energyPercent]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask_energyPercent]
                
                if counter == 0:
                    self.stsCP_vertices_x = tmp_stsCP_vertices_x
                    self.stsCP_vertices_y = tmp_stsCP_vertices_y
                    self.stsCP_vertices_z = tmp_stsCP_vertices_z
                    self.stsCP_vertices_energy = tmp_stsCP_vertices_energy
                    self.stsCP_vertices_time = tmp_stsCP_vertices_time
                    self.stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id
                    self.stsCP_vertices_noh = tmp_stsCP_vertices_noh
                    self.stsCP_vertices_eta = tmp_stsCP_vertices_eta
                    self.stsCP_vertices_phi = tmp_stsCP_vertices_phi
                    self.stsCP_vertices_indexes = tmp_stsCP_vertices_indexes
                    self.stsCP_barycenter_x = tmp_stsCP_barycenter_x
                    self.stsCP_barycenter_y = tmp_stsCP_barycenter_y
                    self.stsCP_barycenter_z = tmp_stsCP_barycenter_z
                    self.stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity
                    self.stsCP_vertices_indexes_filt = tmp_stsCP_vertices_indexes_filt
                else:
                    self.stsCP_vertices_x = ak.concatenate((self.stsCP_vertices_x, tmp_stsCP_vertices_x))
                    self.stsCP_vertices_y = ak.concatenate((self.stsCP_vertices_y, tmp_stsCP_vertices_y))
                    self.stsCP_vertices_z = ak.concatenate((self.stsCP_vertices_z, tmp_stsCP_vertices_z))
                    self.stsCP_vertices_energy = ak.concatenate((self.stsCP_vertices_energy, tmp_stsCP_vertices_energy))
                    self.stsCP_vertices_time = ak.concatenate((self.stsCP_vertices_time, tmp_stsCP_vertices_time))
                    self.stsCP_vertices_layer_id = ak.concatenate((self.stsCP_vertices_layer_id, tmp_stsCP_vertices_layer_id))
                    self.stsCP_vertices_noh = ak.concatenate((self.stsCP_vertices_noh, tmp_stsCP_vertices_noh))
                    self.stsCP_vertices_eta = ak.concatenate((self.stsCP_vertices_eta, tmp_stsCP_vertices_eta))
                    self.stsCP_vertices_phi = ak.concatenate((self.stsCP_vertices_phi, tmp_stsCP_vertices_phi))
                    self.stsCP_vertices_indexes = ak.concatenate((self.stsCP_vertices_indexes, tmp_stsCP_vertices_indexes))
                    self.stsCP_barycenter_x = ak.concatenate((self.stsCP_barycenter_x, tmp_stsCP_barycenter_x))
                    self.stsCP_barycenter_y = ak.concatenate((self.stsCP_barycenter_y, tmp_stsCP_barycenter_y))
                    self.stsCP_barycenter_z = ak.concatenate((self.stsCP_barycenter_z, tmp_stsCP_barycenter_z))
                    self.stsCP_vertices_multiplicity = ak.concatenate((self.stsCP_vertices_multiplicity, tmp_stsCP_vertices_multiplicity))
                    self.stsCP_vertices_indexes_filt = ak.concatenate((self.stsCP_vertices_indexes_filt, tmp_stsCP_vertices_indexes_filt))
                
                counter += 1
                if len(self.stsCP_vertices_x) > max_events:
                    print(f"Reached {max_events}!")
                    break
            if len(self.stsCP_vertices_x) > max_events:
                break

    def precompute_pairings(self):
        """
        Precompute the groups and fractions arrays for each event and store them as attributes.
        (Positive and negative edges will be computed dynamically in get().)
        """
        n_events = len(self.stsCP_vertices_x)
        for idx in range(n_events):
            unfilt_cp_array = self.stsCP_vertices_indexes_unfilt[idx]
            unfilt_mult_array = self.stsCP_vertices_multiplicity_unfilt[idx]
            cluster_contributors = {}
            for cp_id in range(len(unfilt_cp_array)):
                for local_lc_id, cluster_id in enumerate(unfilt_cp_array[cp_id]):
                    frac = 1.0 / unfilt_mult_array[cp_id][local_lc_id]
                    cluster_contributors.setdefault(cluster_id, []).append((cp_id, frac))

            final_cp_array = self.stsCP_vertices_indexes_filt[idx]
            flattened_cluster_ids = ak.flatten(final_cp_array)
            total_lc = len(flattened_cluster_ids)
            groups_np = np.zeros((total_lc, 3), dtype=np.int64)
            fractions_np = np.zeros((total_lc, 3), dtype=np.float32)
            for global_lc, cluster_id in enumerate(flattened_cluster_ids):
                contribs = cluster_contributors.get(cluster_id, [])
                contribs_sorted = sorted(contribs, key=lambda x: x[1], reverse=True)
                if len(contribs_sorted) == 0:
                    top3 = [(0, 0.0)] * 3
                else:
                    top3 = contribs_sorted[:3]
                    while len(top3) < 3:
                        top3.append(top3[-1])
                for i in range(3):
                    groups_np[global_lc, i] = top3[i][0]
                    fractions_np[global_lc, i] = top3[i][1]
            self.precomputed_groups.append(groups_np)
            self.precomputed_fractions.append(fractions_np)

    def download(self):
        raise RuntimeError(
            'Dataset not found. Please download it from {} and move all '
            '*.z files to {}'.format(self.url, self.raw_dir))

    def len(self):
        return len(self.stsCP_vertices_x)

    @property
    def raw_file_names(self):
        raw_files = sorted(glob.glob(osp.join(self.raw_dir, '*.root')))
        return raw_files

    @property
    def processed_file_names(self):
        return []

    def get(self, idx):
        import numpy as np
        import torch
        from torch_geometric.data import Data
        import awkward as ak
        import random

        # 1) Flatten node features for the event
        lc_x = self.stsCP_vertices_x[idx]
        lc_y = self.stsCP_vertices_y[idx]
        lc_z = self.stsCP_vertices_z[idx]
        lc_e = self.stsCP_vertices_energy[idx]
        lc_layer_id = self.stsCP_vertices_layer_id[idx]
        lc_noh = self.stsCP_vertices_noh[idx]
        lc_eta = self.stsCP_vertices_eta[idx]
        lc_phi = self.stsCP_vertices_phi[idx]

        flat_lc_x = np.expand_dims(np.array(ak.flatten(lc_x)), axis=1)
        flat_lc_y = np.expand_dims(np.array(ak.flatten(lc_y)), axis=1)
        flat_lc_z = np.expand_dims(np.array(ak.flatten(lc_z)), axis=1)
        flat_lc_e = np.expand_dims(np.array(ak.flatten(lc_e)), axis=1)
        flat_lc_layer_id = np.expand_dims(np.array(ak.flatten(lc_layer_id)), axis=1)
        flat_lc_noh = np.expand_dims(np.array(ak.flatten(lc_noh)), axis=1)
        flat_lc_eta = np.expand_dims(np.array(ak.flatten(lc_eta)), axis=1)
        flat_lc_phi = np.expand_dims(np.array(ak.flatten(lc_phi)), axis=1)
        flat_lc_feats = np.concatenate(
            (flat_lc_x, flat_lc_y, flat_lc_z, flat_lc_e,
             flat_lc_layer_id, flat_lc_noh, flat_lc_eta, flat_lc_phi),
            axis=-1
        )
        total_lc = flat_lc_feats.shape[0]
        x = torch.from_numpy(flat_lc_feats).float()

        # 2) Retrieve precomputed groups and fractions
        groups_np = self.precomputed_groups[idx]
        fractions_np = self.precomputed_fractions[idx]

        # 3) Dynamically compute positive and negative pairs based on groups and fractions
        pos_pairs = np.empty(total_lc, dtype=np.int64)
        neg_pairs = np.empty(total_lc, dtype=np.int64)
        for i in range(total_lc):
            pos_candidate = None
            first_group = groups_np[i][0]
            first_frac = fractions_np[i][0]
            if first_frac == 1:
                candidates = [j for j in range(total_lc) if j != i and groups_np[j][0] == first_group]
                if candidates:
                    candidates_exact = [j for j in candidates if fractions_np[j][0] == 1]
                    if candidates_exact:
                        pos_candidate = random.choice(candidates_exact)
                    else:
                        pos_candidate = random.choice(candidates)
            else:
                candidates = [j for j in range(total_lc) if j != i and set(groups_np[j][:2]) == set(groups_np[i][:2])]
                if candidates:
                    mean2_i = np.mean(fractions_np[i][:2])
                    pos_candidate = min(candidates, key=lambda j: abs(np.mean(fractions_np[j][:2]) - mean2_i))
                else:
                    candidates = [j for j in range(total_lc) if j != i and groups_np[j][0] == first_group]
                    if candidates:
                        pos_candidate = random.choice(candidates)
            if pos_candidate is None:
                remaining = [j for j in range(total_lc) if j != i]
                pos_candidate = random.choice(remaining) if remaining else i
            pos_pairs[i] = pos_candidate

            orig_set = set(groups_np[i])
            neg_candidate = None
            candidates = [j for j in range(total_lc) if j != i and 
                          (groups_np[j][0] not in orig_set or fractions_np[j][0] == 0)]
            if candidates:
                neg_candidate = random.choice(candidates)
            else:
                candidates = [j for j in range(total_lc) if j != i and groups_np[j][0] != groups_np[i][0]]
                if candidates:
                    neg_candidate = random.choice(candidates)
                else:
                    neg_candidate = (i + 1) % total_lc
            neg_pairs[i] = neg_candidate

        pos_edge = torch.from_numpy(np.column_stack((np.arange(total_lc), pos_pairs))).long()
        neg_edge = torch.from_numpy(np.column_stack((np.arange(total_lc), neg_pairs))).long()

        data = Data(
            x=x,
            groups=torch.from_numpy(groups_np),
            fractions=torch.from_numpy(fractions_np),
            x_pe=pos_edge,
            x_ne=neg_edge
        )
        return data


In [2]:
ipath = "/vols/cms/mm1221/Data/mix/train/"
vpath = "/vols/cms/mm1221/Data/mix/val/"
data_train = CCV1(ipath, max_events=10, inp = 'train')
data_val = CCV1(vpath, max_events=10, inp='val')

### Loading data


  0%|                                                     | 0/3 [00:35<?, ?it/s]

Reached 10!
Precomputing groups for event 0
Precomputing groups for event 1
Precomputing groups for event 2





Precomputing groups for event 3
Precomputing groups for event 4
Precomputing groups for event 5
Precomputing groups for event 6
Precomputing groups for event 7
Precomputing groups for event 8
Precomputing groups for event 9
Precomputing groups for event 10
Precomputing groups for event 11
Precomputing groups for event 12
Precomputing groups for event 13
Precomputing groups for event 14
Precomputing groups for event 15
Precomputing groups for event 16
Precomputing groups for event 17
Precomputing groups for event 18
Precomputing groups for event 19
Precomputing groups for event 20
Precomputing groups for event 21
Precomputing groups for event 22
Precomputing groups for event 23
Precomputing groups for event 24
Precomputing groups for event 25
Precomputing groups for event 26
Precomputing groups for event 27
Precomputing groups for event 28
Precomputing groups for event 29
Precomputing groups for event 30
Precomputing groups for event 31
Precomputing groups for event 32
Precomputing grou

KeyboardInterrupt: 

In [27]:

print(data_train[1].groups[0])
print(data_train[1].fractions[0])

tensor([0, 0, 0])
tensor([1., 1., 1.])


In [28]:
print(data_train[1].groups[96])
print(data_train[1].fractions[96])

tensor([1, 1, 1])
tensor([1., 1., 1.])


In [36]:
print(data_train[3].x_ne)

tensor([[  0, 338],
        [  1, 203],
        [  2, 324],
        [  3, 211],
        [  4, 147],
        [  5, 104],
        [  6,  84],
        [  7, 325],
        [  8, 162],
        [  9, 322],
        [ 10, 297],
        [ 11, 190],
        [ 12, 178],
        [ 13, 195],
        [ 14, 341],
        [ 15, 156],
        [ 16, 134],
        [ 17, 103],
        [ 18, 220],
        [ 19, 218],
        [ 20, 172],
        [ 21, 111],
        [ 22, 193],
        [ 23, 118],
        [ 24, 313],
        [ 25, 223],
        [ 26,  97],
        [ 27, 204],
        [ 28, 146],
        [ 29, 237],
        [ 30, 217],
        [ 31, 265],
        [ 32, 284],
        [ 33, 361],
        [ 34, 186],
        [ 35, 116],
        [ 36, 342],
        [ 37, 298],
        [ 38, 347],
        [ 39, 297],
        [ 40, 244],
        [ 41, 291],
        [ 42, 311],
        [ 43, 112],
        [ 44, 102],
        [ 45, 317],
        [ 46, 337],
        [ 47, 335],
        [ 48, 273],
        [ 49, 279],


In [31]:
import torch
import torch.nn as nn
from torch_geometric.nn import DynamicEdgeConv

class Net(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=4, dropout=0.3, contrastive_dim=8, k=20):
        """
        Initializes the neural network with DynamicEdgeConv layers.

        Args:
            hidden_dim (int): Dimension of hidden layers.
            num_layers (int): Total number of DynamicEdgeConv layers.
            dropout (float): Dropout rate.
            contrastive_dim (int): Dimension of the contrastive output.
            k (int): Number of nearest neighbors to use in DynamicEdgeConv.
        """
        super(Net, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.contrastive_dim = contrastive_dim
        self.k = k

        # Input feature normalization (BatchNorm1d normalizes feature-wise across samples)
        # self.input_norm = nn.BatchNorm1d(8)

        # Input encoder
        self.lc_encode = nn.Sequential(
            nn.Linear(8, 32),
            nn.ELU(),
            nn.Linear(32, hidden_dim),
            nn.ELU()
        )

        # Define the network's convolutional layers using DynamicEdgeConv layers
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            # For odd-numbered layers (1-based index: i+1), use k//2
            # For even-numbered layers, use k
            if (i + 1) % 2 == 0:
                current_k = self.k
            else:
                current_k = self.k

            mlp = nn.Sequential(
                nn.Linear(2 * hidden_dim, hidden_dim),
                nn.ELU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(p=dropout)
            )
            conv = DynamicEdgeConv(mlp, k=current_k, aggr="max")
            self.convs.append(conv)

        # Output layer
        self.output = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(64, 32),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(32, contrastive_dim)
        )

    def forward(self, x, batch):
        """
        Forward pass of the network.

        Args:
            x (torch.Tensor): Input node features of shape (N, 8).
            batch (torch.Tensor): Batch vector that assigns each node to an example in the batch.

        Returns:
            torch.Tensor: Output features after processing.
            torch.Tensor: Batch vector.
        """
        # Normalize input features
        # x = self.input_norm(x)

        # Input encoding
        x_lc_enc = self.lc_encode(x)  # Shape: (N, hidden_dim)

        # Apply DynamicEdgeConv layers with residual connections
        feats = x_lc_enc
        for conv in self.convs:
            feats = conv(feats, batch) + feats

        # Final output
        out = self.output(feats)
        return out, batch

In [32]:
import torch
import torch.nn.functional as F
import random

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

def contrastive_loss_edges(embeddings, pos_edge, neg_edge, temperature=0.1, eps=1e-6):
    """
    Computes contrastive loss using precomputed positive and negative edges.
    
    For each node i, let pos = pos_edge[i, 1] and neg = neg_edge[i, 1].
    The loss is computed as:
    
         L(i) = - log ( exp(sim(i, pos)/temperature) / 
                        (exp(sim(i, pos)/temperature) + exp(sim(i, neg)/temperature)) )
              = - log(sigmoid((sim(i, pos)-sim(i, neg))/temperature))
    
    Args:
        embeddings: Tensor of shape (N, D) containing node embeddings.
        pos_edge:     Tensor of shape (N, 2) where each row is [node, positive_node].
        neg_edge:     Tensor of shape (N, 2) where each row is [node, negative_node].
        temperature:  Temperature scaling factor.
        eps:          Small constant for numerical stability.
        
    Returns:
        Scalar tensor representing the average loss.
    """
    # Normalize embeddings to unit length.
    norm_emb = F.normalize(embeddings, p=2, dim=1)
    
    # Compute cosine similarity for positive and negative pairs.
    pos_sim = (norm_emb[pos_edge[:, 0]] * norm_emb[pos_edge[:, 1]]).sum(dim=1)
    neg_sim = (norm_emb[neg_edge[:, 0]] * norm_emb[neg_edge[:, 1]]).sum(dim=1)

    # Compute the difference scaled by temperature.
    logits_diff = (pos_sim - neg_sim) / temperature
    # Compute loss as -log(sigmoid(...))
    loss = -torch.log(torch.sigmoid(logits_diff) + eps)
    return loss.mean()


def train_new(train_loader, model, optimizer, device, temperature=0.1):
    """
    Training loop that uses the new contrastive loss based on precomputed edges.
    
    For each batch (partitioned by event using data.x_batch), the loss is computed for each event
    separately based on the node embeddings and the stored positive (x_pe) and negative (x_ne) edges.
    
    Args:
        train_loader: PyTorch DataLoader yielding Data objects.
        model:        The network model.
        optimizer:    Optimizer.
        device:       Torch device.
        temperature:  Temperature scaling factor for the loss.
        
    Returns:
        Average loss per sample over the training set.
    """
    model.train()
    total_loss = 0.0
    n_events = 0
    for data in tqdm(train_loader, desc="Training"):
        data = data.to(device)
        optimizer.zero_grad()
        
        # Compute embeddings.
        embeddings, _ = model(data.x, data.x_batch)
        
        # Partition by event using data.x_batch.
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = 0.0
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            # Slice the positive and negative edge tensors for the event.
            event_pos_edge = data.x_pe[start_idx:end_idx]
            event_neg_edge = data.x_ne[start_idx:end_idx]

            loss_event = contrastive_loss_edges(event_embeddings, event_pos_edge, event_neg_edge,
                                                  temperature=temperature)
            loss_event_total += loss_event
            n_events += 1
            start_idx = end_idx
        
        loss = loss_event_total / (len(counts) if len(counts) > 0 else 1)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / n_events if n_events > 0 else 0.0


@torch.no_grad()
def test_new(test_loader, model, device, temperature=0.1):
    """
    Validation loop that uses the new contrastive loss based on precomputed edges.
    
    Args:
        test_loader: PyTorch DataLoader yielding Data objects.
        model:       The network model.
        device:      Torch device.
        temperature: Temperature scaling factor.
        
    Returns:
        Average loss per sample over the validation set.
    """
    model.eval()
    total_loss = 0.0
    n_events = 0
    for data in tqdm(test_loader, desc="Validation"):
        data = data.to(device)
        embeddings, _ = model(data.x, data.x_batch)
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = 0.0
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            event_pos_edge = data.x_pe[start_idx:end_idx]
            event_neg_edge = data.x_ne[start_idx:end_idx]
            
            loss_event = contrastive_loss_edges(event_embeddings, event_pos_edge, event_neg_edge,
                                                  temperature=temperature)
            loss_event_total += loss_event
            n_events += 1
            start_idx = end_idx
        total_loss += loss_event_total / (len(counts) if len(counts) > 0 else 1)
    return total_loss / n_events if n_events > 0 else 0.0



In [33]:
# Initialize model with passed hyperparameters
model = Net(
    hidden_dim=64,
    num_layers=3,
    dropout=0.3,
    contrastive_dim=64,
    k=16
)

BS = 1
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)



device = torch.device('cpu')
# Load DataLoader with current batch_size
train_loader = DataLoader(data_train, batch_size=1, shuffle=False, follow_batch=['x'])
val_loader = DataLoader(data_val, batch_size=1, shuffle=False, follow_batch=['x'])

# Train and evaluate the model for the specified number of epochs
best_val_loss = float('inf')

# Store train and validation losses for all epochs
train_losses = []
val_losses = []

output_dir = '/vols/cms/mm1221/hgcal/elec5New/LC/Fraction/resultstest/'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
patience =30    
epochs = 5
for epoch in range(epochs):
    # For epochs 1 to 150, gradually increase alpha from 0 to 1.
    # From epoch 151 onward, set alpha = 1 (fully hard negatives).
    alpha = 0

    print(f"Epoch {epoch+1}/{epochs} | Alpha: {alpha:.2f}")
    train_loss = train_new(train_loader, model, optimizer, device,temperature=0.1)
    val_loss = test_new(val_loader, model, device,temperature=0.1)

    train_losses.append(train_loss.item())
    val_losses.append(val_loss.item())
    scheduler.step()

    # Save best model if validation loss improves.
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
        torch.save(model.state_dict(), os.path.join(output_dir, 'best_model.pt'))
    else:
        no_improvement_epochs += 1

    # Save intermediate checkpoint.
    state_dicts = {'model': model.state_dict(),
                   'opt': optimizer.state_dict(),
                   'lr': scheduler.state_dict()}
    torch.save(state_dicts, os.path.join(output_dir, f'epoch-{epoch+1}.pt'))

    print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss.item():.8f}, Validation Loss: {val_loss.item():.8f}")
    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered. No improvement for {patience} epochs.")
        break

# Save training history.
import pandas as pd
results_df = pd.DataFrame({
    'epoch': list(range(1, len(train_losses) + 1)),
    'train_loss': train_losses,
    'val_loss': val_losses
})
results_df.to_csv(os.path.join(output_dir, 'continued_training_loss.csv'), index=False)
print(f"Saved loss curves to {os.path.join(output_dir, 'continued_training_loss.csv')}")

# Save final model.
torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pt'))
print("Training complete. Final model saved.")

Epoch 1/5 | Alpha: 0.00


Training:  10%|███                             | 44/456 [00:29<04:33,  1.51it/s]


KeyboardInterrupt: 

In [141]:
print(data_train[32].groups)

tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 1, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 

In [137]:
print(data_train[32].x)

tensor([[ 9.9798e+01,  5.9121e+00,  3.4050e+02,  ...,  2.8000e+01,
          1.9395e+00,  5.9172e-02],
        [ 1.0117e+02,  6.8384e+00,  3.4355e+02,  ...,  2.8000e+01,
          1.9345e+00,  6.7492e-02],
        [ 1.0017e+02,  5.9966e+00,  3.4149e+02,  ...,  3.1000e+01,
          1.9387e+00,  5.9793e-02],
        ...,
        [ 8.8137e+01, -4.5364e+00,  3.5338e+02,  ...,  2.0000e+00,
          2.0957e+00, -5.1424e-02],
        [ 8.4958e+01, -4.5540e+00,  3.5676e+02,  ...,  2.0000e+00,
          2.1405e+00, -5.3552e-02],
        [ 7.7862e+01, -6.1858e+00,  3.2521e+02,  ...,  2.0000e+00,
          2.1337e+00, -7.9279e-02]])


In [63]:
print(lc_ind[3][3])

ValueError: in ListOffsetArray64 attempting to get 3, index out of range

(https://github.com/scikit-hep/awkward-1.0/blob/1.10.3/src/libawkward/array/ListOffsetArray.cpp#L682)

In [65]:
lc_x = data['simtrackstersCP;2']['vertices_x'].array()

In [88]:
print(lc_x[0][0])

[-59.1, -24.8, -58.7, -40.5, -60.4, -48, ... -40.5, -41.2, -43.3, -45.4, -51.6, -56]
