In [9]:
import numpy as np
import subprocess
import tqdm
from tqdm import tqdm
import pandas as pd

import os
import os.path as osp

import glob

import h5py
import uproot

import torch
from torch import nn


from torch_geometric.data import Data
from torch_geometric.data import Dataset
from torch_geometric.data import DataLoader

import awkward as ak
import random

#singularity shell --bind /afs/cern.ch/user/p/pkakhand/public/CL/  /afs/cern.ch/user/p/pkakhand/geometricdl.sif

#singularity shell --bind /eos/project/c/contrast/public/solar/  /afs/cern.ch/user/p/pkakhand/geometricdl.sif
#source /cvmfs/sft.cern.ch/lcg/views/LCG_103cuda/x86_64-centos9-gcc11-opt/setup.sh

def find_highest_branch(path, base_name):
    with uproot.open(path) as f:
        # Find keys that exactly match the base_name (not containing other variations)
        branches = [k for k in f.keys() if k.startswith(base_name + ';')]
        
        # Sort and select the highest-numbered branch
        sorted_branches = sorted(branches, key=lambda x: int(x.split(';')[-1]))
        return sorted_branches[-1] if sorted_branches else None

def remove_duplicates(A,B):    
    all_masks = []
    for event_idx, event in enumerate(A):
        flat_A = np.array(ak.flatten(A[event_idx]))
        flat_B = np.array(ak.flatten(B[event_idx]))
        
        # Initialize a mask to keep track of which values to keep
        mask = np.zeros_like(flat_A, dtype=bool)

        # Iterate over the unique elements in A
        for elem in np.unique(flat_A):
            # Get the indices where the element occurs in A
            indices = np.where(flat_A == elem)[0]

            # If there's more than one occurrence, keep the one with the max B value
            if len(indices) > 1:
                max_index = indices[np.argmax(flat_B[indices])]
                mask[max_index] = True
            else:
                # If there's only one occurrence, keep it
                mask[indices[0]] = True

        unflattened_mask = ak.unflatten(mask, ak.num(A[event_idx]))
        all_masks.append(unflattened_mask)
        
    return ak.Array(all_masks)

class CCV1(Dataset):
    r'''
        input: layer clusters

    '''

    url = '/dummy/'

    def __init__(self, root, transform=None, max_events=1e8, inp = 'train'):
        super(CCV1, self).__init__(root, transform)
        self.step_size = 500
        self.inp = inp
        self.max_events = max_events
        self.fill_data(max_events)

    def fill_data(self,max_events):
        counter = 0
        arrLens0 = []
        arrLens1 = []

        print("### Loading data")
        for fi,path in enumerate(tqdm(self.raw_paths)):


            if self.inp == 'train':
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
            elif self.inp == 'val':
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
            else:
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
            
            crosstree =  uproot.open(path)[cluster_path]
            crosscounter = 0
            for array in uproot.iterate(f"{path}:{sim_path}", ["vertices_x", "vertices_y", "vertices_z", 
            "vertices_energy", "vertices_multiplicity", "vertices_time", "vertices_indexes", "barycenter_x", "barycenter_y", "barycenter_z"], step_size=self.step_size):
            
                tmp_stsCP_vertices_x = array['vertices_x']
                tmp_stsCP_vertices_y = array['vertices_y']
                tmp_stsCP_vertices_z = array['vertices_z']
                tmp_stsCP_vertices_energy = array['vertices_energy']
                tmp_stsCP_vertices_time = array['vertices_time']
                tmp_stsCP_vertices_indexes = array['vertices_indexes']
                tmp_stsCP_barycenter_x = array['barycenter_x']
                tmp_stsCP_barycenter_y = array['barycenter_y']
                tmp_stsCP_barycenter_z = array['barycenter_z']


                tmp_stsCP_vertices_multiplicity = array['vertices_multiplicity']
                
                # weighted energies (A LC appears in its caloparticle assignment array as the energy it contributes not full energy)
                #tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy * tmp_stsCP_vertices_multiplicity
                
                self.step_size = min(self.step_size,len(tmp_stsCP_vertices_x))


                # Code block for reading from other tree
                tmp_all_vertices_layer_id = crosstree['cluster_layer_id'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                #tmp_all_vertices_radius = crosstree['cluster_radius'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_noh = crosstree['cluster_number_of_hits'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_eta = crosstree['position_eta'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_phi = crosstree['position_phi'].array(entry_start=crosscounter*self.step_size,entry_stop=(crosscounter+1)*self.step_size)
                crosscounter += 1

                layer_id_list = []
                radius_list = []
                noh_list = []
                eta_list = []
                phi_list = []
                for evt_row in range(len(tmp_all_vertices_noh)):
                    #print("Event no: %i"%evt_row)
                    #print("There are %i particles in this event"%len(tmp_stsCP_vertices_indexes[evt_row]))
                    layer_id_list_one_event = []
                    #radius_list_one_event = []
                    noh_list_one_event = []
                    eta_list_one_event = []
                    phi_list_one_event = []
                    for particle in range(len(tmp_stsCP_vertices_indexes[evt_row])):
                        #print("Particle no: %i"%particle)
                        #print("A")
                        #print(np.array(tmp_all_vertices_radius[evt_row]).shape)
                        #print("B")
                        #print(np.array(tmp_stsCP_vertices_indexes[evt_row][particle]).shape)
                        #print("C")
                        tmp_stsCP_vertices_layer_id_one_particle = tmp_all_vertices_layer_id[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        #tmp_stsCP_vertices_radius_one_particle = tmp_all_vertices_radius[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_noh_one_particle = tmp_all_vertices_noh[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_eta_one_particle = tmp_all_vertices_eta[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_phi_one_particle = tmp_all_vertices_phi[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        #print(tmp_stsCP_vertices_radius_one_particle)
                        layer_id_list_one_event.append(tmp_stsCP_vertices_layer_id_one_particle)
                        #radius_list_one_event.append(tmp_stsCP_vertices_radius_one_particle)
                        noh_list_one_event.append(tmp_stsCP_vertices_noh_one_particle)
                        eta_list_one_event.append(tmp_stsCP_vertices_eta_one_particle)
                        phi_list_one_event.append(tmp_stsCP_vertices_phi_one_particle)
                    layer_id_list.append(layer_id_list_one_event)
                    #radius_list.append(radius_list_one_event)
                    noh_list.append(noh_list_one_event)
                    eta_list.append(eta_list_one_event)
                    phi_list.append(phi_list_one_event)
                tmp_stsCP_vertices_layer_id = ak.Array(layer_id_list)                
                #tmp_stsCP_vertices_radius = ak.Array(radius_list)                
                tmp_stsCP_vertices_noh = ak.Array(noh_list)                
                tmp_stsCP_vertices_eta = ak.Array(eta_list)                
                tmp_stsCP_vertices_phi = ak.Array(phi_list)                
                
                # Apply filter noh > 1 for the LCs
                skim_mask_noh = tmp_stsCP_vertices_noh > 1.0
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask_noh]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask_noh]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask_noh]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask_noh]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask_noh]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask_noh]
                #tmp_stsCP_vertices_radius = tmp_stsCP_vertices_radius[skim_mask_energyPercent]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask_noh]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask_noh]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask_noh]
                #tmp_stsCP_vertices_indexes_unmasked = tmp_stsCP_vertices_indexes
                tmp_stsCP_vertices_indexes = tmp_stsCP_vertices_indexes[skim_mask_noh]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask_noh]
                
                # Remove duplicates by only allowing the caloparticle that contributed the most energy to a LC to actually contribute.
                energyPercent = 1/tmp_stsCP_vertices_multiplicity
                skim_mask_energyPercent = remove_duplicates(tmp_stsCP_vertices_indexes,energyPercent)
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask_energyPercent]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask_energyPercent]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask_energyPercent]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask_energyPercent]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask_energyPercent]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask_energyPercent]
                #tmp_stsCP_vertices_radius = tmp_stsCP_vertices_radius[skim_mask_energyPercent]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask_energyPercent]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask_energyPercent]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask_energyPercent]
                #tmp_stsCP_vertices_indexes_unmasked = tmp_stsCP_vertices_indexes
                tmp_stsCP_vertices_indexes_filt = tmp_stsCP_vertices_indexes[skim_mask_energyPercent]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask_energyPercent]
               
                
                
                #SHOULD BE LEN(E) >= 2 for MULTI particles
                skim_mask = []
                for e in tmp_stsCP_vertices_x:
                    if 2 <= len(e): #<------ only train on samples with > 1 particle
                        skim_mask.append(True)
                    else:
                        skim_mask.append(False)
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask]
                #tmp_stsCP_vertices_radius = tmp_stsCP_vertices_radius[skim_mask]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask]
                #tmp_stsCP_vertices_indexes_unmasked = tmp_stsCP_vertices_indexes_unmasked[skim_mask]
                tmp_stsCP_vertices_indexes = tmp_stsCP_vertices_indexes[skim_mask]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask]
                tmp_stsCP_vertices_indexes_filt = tmp_stsCP_vertices_indexes_filt[skim_mask]


                if counter == 0:
                    self.stsCP_vertices_x = tmp_stsCP_vertices_x
                    self.stsCP_vertices_y = tmp_stsCP_vertices_y
                    self.stsCP_vertices_z = tmp_stsCP_vertices_z
                    self.stsCP_vertices_energy = tmp_stsCP_vertices_energy
                    self.stsCP_vertices_time = tmp_stsCP_vertices_time
                    self.stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id
                    #self.stsCP_vertices_radius = tmp_stsCP_vertices_radius
                    self.stsCP_vertices_noh = tmp_stsCP_vertices_noh
                    self.stsCP_vertices_eta = tmp_stsCP_vertices_eta
                    self.stsCP_vertices_phi = tmp_stsCP_vertices_phi
                    self.stsCP_vertices_indexes = tmp_stsCP_vertices_indexes
                    #self.stsCP_vertices_indexes_unmasked = tmp_stsCP_vertices_indexes_unmasked
                    self.stsCP_barycenter_x = tmp_stsCP_barycenter_x
                    self.stsCP_barycenter_y = tmp_stsCP_barycenter_y
                    self.stsCP_barycenter_z = tmp_stsCP_barycenter_z
                    self.stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity
                    self.stsCP_vertices_indexes_filt = tmp_stsCP_vertices_indexes_filt
                else:
                    self.stsCP_vertices_x = ak.concatenate((self.stsCP_vertices_x,tmp_stsCP_vertices_x))
                    self.stsCP_vertices_y = ak.concatenate((self.stsCP_vertices_y,tmp_stsCP_vertices_y))
                    self.stsCP_vertices_z = ak.concatenate((self.stsCP_vertices_z,tmp_stsCP_vertices_z))
                    self.stsCP_vertices_energy = ak.concatenate((self.stsCP_vertices_energy,tmp_stsCP_vertices_energy))
                    self.stsCP_vertices_time = ak.concatenate((self.stsCP_vertices_time,tmp_stsCP_vertices_time))
                    self.stsCP_vertices_layer_id = ak.concatenate((self.stsCP_vertices_layer_id,tmp_stsCP_vertices_layer_id))
                    #self.stsCP_vertices_radius = ak.concatenate((self.stsCP_vertices_radius,tmp_stsCP_vertices_radius))
                    self.stsCP_vertices_noh = ak.concatenate((self.stsCP_vertices_noh,tmp_stsCP_vertices_noh))
                    self.stsCP_vertices_eta = ak.concatenate((self.stsCP_vertices_eta,tmp_stsCP_vertices_eta))
                    self.stsCP_vertices_phi = ak.concatenate((self.stsCP_vertices_phi,tmp_stsCP_vertices_phi))
                    self.stsCP_vertices_indexes = ak.concatenate((self.stsCP_vertices_indexes,tmp_stsCP_vertices_indexes))
                    #self.stsCP_vertices_indexes_unmasked =  ak.concatenate((self.stsCP_vertices_indexes_unmasked,tmp_stsCP_vertices_indexes_unmasked))
                    self.stsCP_barycenter_x = ak.concatenate((self.stsCP_barycenter_x,tmp_stsCP_barycenter_x))
                    self.stsCP_barycenter_y = ak.concatenate((self.stsCP_barycenter_y,tmp_stsCP_barycenter_y))
                    self.stsCP_barycenter_z = ak.concatenate((self.stsCP_barycenter_z,tmp_stsCP_barycenter_z))
                    self.stsCP_vertices_multiplicity = ak.concatenate((self.stsCP_vertices_multiplicity, tmp_stsCP_vertices_multiplicity))
                    self.stsCP_vertices_indexes_filt = ak.concatenate((self.stsCP_vertices_indexes_filt,tmp_stsCP_vertices_indexes_filt))
                #print(len(self.stsCP_vertices_x))
                counter += 1
                if len(self.stsCP_vertices_x) > max_events:
                    print(f"Reached {max_events}!")
                    break
            if len(self.stsCP_vertices_x) > max_events:
                break
     
            
            
    def download(self):
        raise RuntimeError(
            'Dataset not found. Please download it from {} and move all '
            '*.z files to {}'.format(self.url, self.raw_dir))

    def len(self):
        return len(self.stsCP_vertices_x)

    @property
    def raw_file_names(self):
        raw_files = sorted(glob.glob(osp.join(self.raw_dir, '*.root')))
        
        #raw_files = [osp.join(self.raw_dir, 'step3_NTUPLE.root')]

        return raw_files

    @property
    def processed_file_names(self):
        return []


    def get(self, idx):
        """
        Return a Data object that includes:
          - x:            node features, shape (N, D)
          - x_pe:         a list of edges [anchor, positive]
          - x_ne:         a list of edges [anchor, negative]
                         such that the negative does NOT share any CP with anchor.
        """
        # ----------------------------------------------------
        # 1) Flatten your node features as you already do
        # ----------------------------------------------------
        edge_index = torch.empty((2,0), dtype=torch.long)

        lc_x = self.stsCP_vertices_x[idx]
        #print(ak.to_numpy(lc_x[0]).shape)
        #print(ak.to_numpy(lc_x[1]).shape)
        flat_lc_x = np.expand_dims(np.array(ak.flatten(lc_x)),axis=1)
        lc_y = self.stsCP_vertices_y[idx]
        flat_lc_y = np.expand_dims(np.array(ak.flatten(lc_y)),axis=1)
        lc_z = self.stsCP_vertices_z[idx]
        flat_lc_z = np.expand_dims(np.array(ak.flatten(lc_z)),axis=1)
        lc_e = self.stsCP_vertices_energy[idx]
        flat_lc_e = np.expand_dims(np.array(ak.flatten(lc_e)),axis=1)     
        lc_t = self.stsCP_vertices_time[idx]
        flat_lc_t = np.expand_dims(np.array(ak.flatten(lc_t)),axis=1)  
        lc_layer_id = self.stsCP_vertices_layer_id[idx]
        flat_lc_layer_id = np.expand_dims(np.array(ak.flatten(lc_layer_id)),axis=1)  
        #lc_radius = self.stsCP_vertices_radius[idx]
        #flat_lc_radius = np.expand_dims(np.array(ak.flatten(lc_radius)),axis=1)  
        lc_noh = self.stsCP_vertices_noh[idx]
        flat_lc_noh = np.expand_dims(np.array(ak.flatten(lc_noh)),axis=1)  
        lc_eta = self.stsCP_vertices_eta[idx]
        flat_lc_eta = np.expand_dims(np.array(ak.flatten(lc_eta)),axis=1)  
        lc_phi = self.stsCP_vertices_phi[idx]
        flat_lc_phi = np.expand_dims(np.array(ak.flatten(lc_phi)),axis=1) 
        # ... continue as needed ...

        # Suppose after filtering, we have shape (N,1) => total_lc = N
        total_lc = flat_lc_x.shape[0]

        # Build a big feature array
        flat_lc_feats = np.concatenate(
            (flat_lc_x, flat_lc_y, flat_lc_z, flat_lc_e, flat_lc_layer_id, flat_lc_noh, flat_lc_eta, flat_lc_phi),
            axis=-1
        )
        # For example, shape => (N, D)

        # ----------------------------------------------------
        # 2) Figure out how each node maps to CP IDs, i.e. lc2cp
        # ----------------------------------------------------
        # The user code has something like 'lc_x[cp]' for the LCs in each caloparticle.
        # So let's build "cp2lc": for each cp index, which global LC indices does it contain?
        # Then we can invert that to get lc2cp.

        # "lc_x" is a list-of-lists: outer = CP index, inner = LCs for that CP
        # Actually, from your question it looks like:
        #   for cp in range(len(lc_x)):
        #       n_lc_cp = len(lc_x[cp])  # how many LCs belong to CP 'cp' in this event
        #
        # We'll walk these in order, building a global_index from 0..N-1.

        cp2lc = []
        lc2cp = [set() for _ in range(total_lc)]  # each entry will be a set of cp IDs
        global_idx = 0

        for cp_id in range(len(lc_x)):  # number of CP in this event
            n_lc_in_this_cp = len(lc_x[cp_id])
            # The node indices for this CP are range(global_idx, global_idx + n_lc_in_this_cp)
            # We'll store that
            cp_nodes = list(range(global_idx, global_idx + n_lc_in_this_cp))
            cp2lc.append(cp_nodes)
            # also mark each of those node indices as belonging to 'cp_id'
            for node_i in cp_nodes:
                lc2cp[node_i].add(cp_id)
            global_idx += n_lc_in_this_cp

        # Now cp2lc[cp] is the list of node indices that belong to caloparticle cp.
        # And lc2cp[node] is the set of CP IDs that node belongs to.

        # ----------------------------------------------------
        # 3) Build pos_edges and neg_edges with the new logic
        # ----------------------------------------------------
        pos_edges = []
        neg_edges = []

        # We'll just do a simple loop over each CP, then each node in that CP
        # for anchor. Then pick a random positive from the same CP, and a random
        # negative from outside ANY of anchor's CP IDs.
        for cp_id, node_list in enumerate(cp2lc):
            # node_list = e.g. [10, 11, 12] if those are the global indices
            for anchor_i in node_list:
                # 3a) Positive = pick from node_list (the same CP), excluding anchor if possible
                if len(node_list) > 1:
                    # pick a random node from node_list excluding anchor_i
                    candidates_pos = [n for n in node_list if n != anchor_i]
                    if len(candidates_pos) == 0:
                        pos_i = anchor_i
                    else:
                        pos_i = random.choice(candidates_pos)
                else:
                    # fallback = self
                    pos_i = anchor_i

                # 3b) Negative = pick from [0..total_lc) such that
                #   lc2cp[anchor_i] ∩ lc2cp[candidate_neg] = ∅
                # We'll do a small while loop or random sampling until we find one that works.
                # In your original snippet, you just keep picking until it doesn't belong to the same CP block.
                # But now the rule is "must not share ANY CP ID," not just this cp_id.

                anchor_cp_ids = lc2cp[anchor_i]
                max_tries = 20  # to avoid infinite loop, or up to N tries
                neg_i = anchor_i  # fallback
                for _ in range(max_tries):
                    candidate_neg = random.randint(0, total_lc - 1)
                    # check intersection
                    if lc2cp[candidate_neg].isdisjoint(anchor_cp_ids):
                        # means no shared CP => valid negative
                        neg_i = candidate_neg
                        break

                # store the edges
                pos_edges.append([anchor_i, pos_i])
                neg_edges.append([anchor_i, neg_i])

        # ----------------------------------------------------
        # 4) Wrap up in a Data object
        # ----------------------------------------------------
        x = torch.from_numpy(flat_lc_feats).float()  # shape (N, D)
        y = torch.zeros(x.size(0), dtype=torch.float) # if needed

        x_pos_edge = torch.from_numpy(np.array(pos_edges, dtype=np.int64))
        x_neg_edge = torch.from_numpy(np.array(neg_edges, dtype=np.int64))

        # Return a Data object with everything
        data = Data(
            x=x,                     # shape (N, D)
            edge_index=edge_index,   # empty here, if you don't have real GNN edges
            y=y,
            x_pe=x_pos_edge,         # shape (#pairs, 2)
            x_ne=x_neg_edge,         # shape (#pairs, 2)
        )

        return data


In [10]:
ipath = "/vols/cms/mm1221/Data/mix/train/"
vpath = "/vols/cms/mm1221/Data/mix/val/"
data_train = CCV1(ipath, max_events=10, inp = 'train')
data_val = CCV1(vpath, max_events=10, inp='val')

### Loading data


  0%|                                                                                                                                                                                 | 0/3 [00:37<?, ?it/s]


Reached 10!
### Loading data


  0%|                                                                                                                                                                                 | 0/1 [00:35<?, ?it/s]

Reached 10!





In [14]:
import torch
import torch.nn.functional as F
import random
from torch_geometric.nn import knn_graph
from tqdm import tqdm

###############################
# Contrastive Loss With Edges
###############################
def contrastive_loss_edges(
    embeddings,
    x_pe,  # (Epos, 2) Positive edges: each row = [anchor, pos_node].
    x_ne,  # (Eneg, 2) Negative edges: each row = [anchor, neg_node].
    temperature=0.1
):
    """
    Contrastive loss using explicit positive/negative edges from data.x_pe and data.x_ne.

    For each node i in [0..N-1]:
      - Find all positive edges that start at i in x_pe. Pick one at random, or fallback to i if none.
      - Find all negative edges that start at i in x_ne. Pick one at random, or fallback to i if none.
      - Compute an NT-Xent style loss comparing sim(i, pos_i) vs sim(i, neg_i).

    Args:
        embeddings:  (N, D) Node embeddings (a torch.Tensor).
        x_pe:        (Epos, 2) Positive edges. Each row: [anchor, pos].
        x_ne:        (Eneg, 2) Negative edges. Each row: [anchor, neg].
        temperature: float, softmax temperature for InfoNCE/NT-Xent.

    Returns:
        A scalar tensor (mean contrastive loss).
    """
    device = embeddings.device
    N = embeddings.size(0)
    
    # 1) Build adjacency lists: pos_dict[i], neg_dict[i]
    pos_dict = [[] for _ in range(N)]
    for edge in x_pe:
        anchor = edge[0].item()
        pos_tgt = edge[1].item()
        pos_dict[anchor].append(pos_tgt)

    neg_dict = [[] for _ in range(N)]
    for edge in x_ne:
        anchor = edge[0].item()
        neg_tgt = edge[1].item()
        neg_dict[anchor].append(neg_tgt)
    
    # 2) For each node i, pick a single positive & negative
    pos_indices = []
    neg_indices = []
    for i in range(N):
        # Positive
        if len(pos_dict[i]) > 0:
            j = random.choice(pos_dict[i])
        else:
            j = i  # fallback to self-loop

        # Negative
        if len(neg_dict[i]) > 0:
            k = random.choice(neg_dict[i])
        else:
            k = i  # fallback to self-loop

        pos_indices.append(j)
        neg_indices.append(k)
    
    pos_indices = torch.tensor(pos_indices, dtype=torch.long, device=device)
    neg_indices = torch.tensor(neg_indices, dtype=torch.long, device=device)

    # 3) Cosine similarities
    norm_emb = F.normalize(embeddings, p=2, dim=1)  # shape (N, D)
    sim_matrix = norm_emb @ norm_emb.t()            # shape (N, N)
    idx = torch.arange(N, device=device)

    # 4) Gather positive & negative similarities
    pos_sims = sim_matrix[idx, pos_indices]  # shape (N,)
    neg_sims = sim_matrix[idx, neg_indices]  # shape (N,)

    # 5) NT-Xent
    numer = torch.exp(pos_sims / temperature)
    denom = numer + torch.exp(neg_sims / temperature)
    loss = -torch.log(numer / denom)

    return loss.mean()

###############################
# Training & Testing Pipeline
###############################
def train_new(train_loader, model, optimizer, device, k_value, alpha):
    model.train()
    total_loss = torch.zeros(1, device=device)

    for data in tqdm(train_loader, desc="Training"):
        data = data.to(device)
        optimizer.zero_grad()
        
        # Convert x_pe, x_ne to tensors if they're lists
        x_pe = data.x_pe
        if not isinstance(x_pe, torch.Tensor):
            x_pe = torch.tensor(x_pe, dtype=torch.long, device=data.x.device)
        x_ne = data.x_ne
        if not isinstance(x_ne, torch.Tensor):
            x_ne = torch.tensor(x_ne, dtype=torch.long, device=data.x.device)

        # Build edges (if needed) and get embeddings
        #edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        out = model(data.x, data.x_batch)
        # Unwrap if model returns a tuple
        embeddings = out[0] if isinstance(out, (tuple, list)) else out

        # Partition batch by event
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)

        loss_event_total = torch.zeros(1, device=device)
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count

            # Slice embeddings for this event
            event_embeddings = embeddings[start_idx:end_idx]

            # Filter x_pe, x_ne for edges whose anchor is in [start_idx, end_idx)
            pe_mask = (x_pe[:,0] >= start_idx) & (x_pe[:,0] < end_idx)
            pe_event = x_pe[pe_mask].clone()
            pe_event[:,0] -= start_idx  # re-map anchor to local index
            # pos node might also need re-mapping if your code expects both columns in [0,count)
            pe_event[:,1] -= start_idx

            ne_mask = (x_ne[:,0] >= start_idx) & (x_ne[:,0] < end_idx)
            ne_event = x_ne[ne_mask].clone()
            ne_event[:,0] -= start_idx
            ne_event[:,1] -= start_idx
            
            print(pe_event)

            # Loss for this event
            loss_event = contrastive_loss_edges(
                event_embeddings,
                pe_event,
                ne_event,
                temperature=0.1
            )
            loss_event_total += loss_event

            start_idx = end_idx
        
        loss = loss_event_total / len(counts)
        loss.backward()
        optimizer.step()

        total_loss += loss

    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test_new(test_loader, model, device, k_value, alpha):
    model.eval()
    total_loss = torch.zeros(1, device=device)

    for data in tqdm(test_loader, desc="Validation"):
        data = data.to(device)
        
        # Convert x_pe, x_ne if needed
        x_pe = data.x_pe
        if not isinstance(x_pe, torch.Tensor):
            x_pe = torch.tensor(x_pe, dtype=torch.long, device=data.x.device)
        x_ne = data.x_ne
        if not isinstance(x_ne, torch.Tensor):
            x_ne = torch.tensor(x_ne, dtype=torch.long, device=data.x.device)

        #edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        out = model(data.x, data.x_batch)
        embeddings = out[0] if isinstance(out, (tuple, list)) else out

        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = torch.zeros(1, device=device)
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count

            event_embeddings = embeddings[start_idx:end_idx]

            pe_mask = (x_pe[:,0] >= start_idx) & (x_pe[:,0] < end_idx)
            pe_event = x_pe[pe_mask].clone()
            pe_event[:,0] -= start_idx
            pe_event[:,1] -= start_idx

            ne_mask = (x_ne[:,0] >= start_idx) & (x_ne[:,0] < end_idx)
            ne_event = x_ne[ne_mask].clone()
            ne_event[:,0] -= start_idx
            ne_event[:,1] -= start_idx

            loss_event = contrastive_loss_edges(event_embeddings, pe_event, ne_event, temperature=0.1)
            loss_event_total += loss_event

            start_idx = end_idx
        
        total_loss += loss_event_total / len(counts)

    return total_loss / len(test_loader.dataset)


In [15]:
import torch
import torch.nn as nn
from torch_geometric.nn import DynamicEdgeConv

class Net(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=4, dropout=0.3, contrastive_dim=8, k=20):
        """
        Initializes the neural network with DynamicEdgeConv layers.

        Args:
            hidden_dim (int): Dimension of hidden layers.
            num_layers (int): Total number of DynamicEdgeConv layers.
            dropout (float): Dropout rate.
            contrastive_dim (int): Dimension of the contrastive output.
            k (int): Number of nearest neighbors to use in DynamicEdgeConv.
        """
        super(Net, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.contrastive_dim = contrastive_dim
        self.k = k

        # Input feature normalization (BatchNorm1d normalizes feature-wise across samples)
        # self.input_norm = nn.BatchNorm1d(8)

        # Input encoder
        self.lc_encode = nn.Sequential(
            nn.Linear(8, 32),
            nn.ELU(),
            nn.Linear(32, hidden_dim),
            nn.ELU()
        )

        # Define the network's convolutional layers using DynamicEdgeConv layers
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            # For odd-numbered layers (1-based index: i+1), use k//2
            # For even-numbered layers, use k
            if (i + 1) % 2 == 0:
                current_k = self.k
            else:
                current_k = self.k // 4

            mlp = nn.Sequential(
                nn.Linear(2 * hidden_dim, hidden_dim),
                nn.ELU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(p=dropout)
            )
            conv = DynamicEdgeConv(mlp, k=current_k, aggr="max")
            self.convs.append(conv)

        # Output layer
        self.output = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(64, 32),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(32, contrastive_dim)
        )

    def forward(self, x, batch):
        """
        Forward pass of the network.

        Args:
            x (torch.Tensor): Input node features of shape (N, 8).
            batch (torch.Tensor): Batch vector that assigns each node to an example in the batch.

        Returns:
            torch.Tensor: Output features after processing.
            torch.Tensor: Batch vector.
        """
        # Normalize input features
        # x = self.input_norm(x)

        # Input encoding
        x_lc_enc = self.lc_encode(x)  # Shape: (N, hidden_dim)

        # Apply DynamicEdgeConv layers with residual connections
        feats = x_lc_enc
        for conv in self.convs:
            feats = conv(feats, batch) + feats

        # Final output
        out = self.output(feats)
        return out, batch

In [16]:
import argparse
import os
import torch
import pandas as pd
import numpy as np
from torch_geometric.data import DataLoader
from torch_geometric.nn import knn_graph
from tqdm import tqdm



# Set device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Loading data...")

print("Instantiating model...")
# Instantiate model.
model = Net(
    hidden_dim=128,
    dropout=0.3,
    contrastive_dim=128,
    k=48
).to(device)

k_value = 48
BS = 1

# Setup optimizer and scheduler.
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.5)

# Create DataLoaders.
train_loader = DataLoader(data_train, batch_size=BS, shuffle=False, follow_batch=['x'])
val_loader = DataLoader(data_val, batch_size=BS, shuffle=False, follow_batch=['x'])

# Setup output directory.
output_dir = '/vols/cms/mm1221/hgcal/Mixed/LC/NegativeMining/runs/DEC/hd128nl3cd128k24_48/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# (Optionally, you could load a pretrained model here if needed.)

best_val_loss = float('inf')
train_losses = []
val_losses = []
patience = 300
no_improvement_epochs = 0

print("Starting full training with curriculum for hard negative mining...")

epochs = 300
for epoch in range(epochs):
    # For epochs 1 to 150, gradually increase alpha from 0 to 1.
    # From epoch 151 onward, set alpha = 1 (fully hard negatives).
    if epoch < 75:
        alpha = 0
        alpha2 = 0
    elif epoch < 150:
        alpha = (epoch - 75) / 75.0  # Linearly increase from 0 to 1
        alpha2 = 1.0
    else:
        alpha = 1
        alpha2 = 1


    print(f"Epoch {epoch+1}/{epochs} | Alpha: {alpha:.2f}")
    train_loss = train_new(train_loader, model, optimizer, device,k_value, alpha)
    val_loss = test_new(val_loader, model, device, k_value, alpha = alpha2)

    train_losses.append(train_loss.item())
    val_losses.append(val_loss.item())
    scheduler.step()

    # Save best model if validation loss improves.
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
        torch.save(model.state_dict(), os.path.join(output_dir, 'best_model.pt'))
    else:
        no_improvement_epochs += 1

    # Save intermediate checkpoint.
    state_dicts = {'model': model.state_dict(),
                   'opt': optimizer.state_dict(),
                   'lr': scheduler.state_dict()}
    torch.save(state_dicts, os.path.join(output_dir, f'epoch-{epoch+1}.pt'))

    print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss.item():.8f}, Validation Loss: {val_loss.item():.8f}")
    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered. No improvement for {patience} epochs.")
        break

# Save training history.
import pandas as pd
results_df = pd.DataFrame({
    'epoch': list(range(1, len(train_losses) + 1)),
    'train_loss': train_losses,
    'val_loss': val_losses
})
results_df.to_csv(os.path.join(output_dir, 'continued_training_loss.csv'), index=False)
print(f"Saved loss curves to {os.path.join(output_dir, 'continued_training_loss.csv')}")

# Save final model.
torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pt'))
print("Training complete. Final model saved.")

Loading data...
Instantiating model...
Starting full training with curriculum for hard negative mining...
Epoch 1/300 | Alpha: 0.00


Training:   0%|                                                                                                                                                                     | 0/456 [00:00<?, ?it/s]

tensor([[  0,   2],
        [  1,   0],
        [  2,  13],
        [  3,   0],
        [  4,  14],
        [  5,  10],
        [  6,  14],
        [  7,   4],
        [  8,   3],
        [  9,  10],
        [ 10,  14],
        [ 11,   3],
        [ 12,   8],
        [ 13,   4],
        [ 14,  11],
        [ 15,  10],
        [ 16,  50],
        [ 17,  88],
        [ 18,  52],
        [ 19,  53],
        [ 20,  81],
        [ 21,  69],
        [ 22,  48],
        [ 23,  32],
        [ 24,  50],
        [ 25,  77],
        [ 26,  85],
        [ 27,  46],
        [ 28,  69],
        [ 29,  96],
        [ 30,  41],
        [ 31,  70],
        [ 32,  48],
        [ 33,  26],
        [ 34,  19],
        [ 35,  20],
        [ 36,  88],
        [ 37,  79],
        [ 38,  63],
        [ 39,  63],
        [ 40,  47],
        [ 41,  55],
        [ 42,  45],
        [ 43,  57],
        [ 44,  33],
        [ 45,  74],
        [ 46,  42],
        [ 47,  67],
        [ 48,  47],
        [ 49,  43],


Training:   0%|▋                                                                                                                                                            | 2/456 [00:01<03:53,  1.95it/s]

tensor([[  0,  60],
        [  1,  55],
        [  2,   3],
        ...,
        [654, 621],
        [655, 612],
        [656, 639]])


Training:   1%|█                                                                                                                                                            | 3/456 [00:01<03:08,  2.40it/s]

tensor([[  0,  27],
        [  1,  58],
        [  2,  22],
        [  3,  38],
        [  4,  17],
        [  5,  33],
        [  6,  18],
        [  7,  13],
        [  8,  25],
        [  9,  46],
        [ 10,  34],
        [ 11,  12],
        [ 12,  56],
        [ 13,   8],
        [ 14,  60],
        [ 15,  44],
        [ 16,  51],
        [ 17,  55],
        [ 18,  37],
        [ 19,  54],
        [ 20,  60],
        [ 21,  60],
        [ 22,  58],
        [ 23,  49],
        [ 24,  52],
        [ 25,  20],
        [ 26,  12],
        [ 27,   8],
        [ 28,   3],
        [ 29,  57],
        [ 30,   5],
        [ 31,  55],
        [ 32,   6],
        [ 33,   3],
        [ 34,  19],
        [ 35,   7],
        [ 36,  55],
        [ 37,  12],
        [ 38,  23],
        [ 39,  28],
        [ 40,  39],
        [ 41,  45],
        [ 42,  37],
        [ 43,   9],
        [ 44,  24],
        [ 45,  44],
        [ 46,  58],
        [ 47,  28],
        [ 48,  50],
        [ 49,  57],


Training:   1%|█▍                                                                                                                                                           | 4/456 [00:01<02:38,  2.85it/s]

tensor([[  0,  73],
        [  1,  37],
        [  2,  37],
        [  3,  36],
        [  4,   6],
        [  5,  40],
        [  6,  20],
        [  7,  60],
        [  8,  67],
        [  9,   6],
        [ 10,  22],
        [ 11,  32],
        [ 12,  74],
        [ 13,  30],
        [ 14,  30],
        [ 15,  26],
        [ 16,  46],
        [ 17,  51],
        [ 18,  31],
        [ 19,  27],
        [ 20,  48],
        [ 21,  17],
        [ 22,  64],
        [ 23,  59],
        [ 24,  11],
        [ 25,  11],
        [ 26,  16],
        [ 27,  14],
        [ 28,  65],
        [ 29,  24],
        [ 30,  69],
        [ 31,  68],
        [ 32,  49],
        [ 33,  75],
        [ 34,  52],
        [ 35,  68],
        [ 36,  72],
        [ 37,  30],
        [ 38,  39],
        [ 39,  17],
        [ 40,   4],
        [ 41,  32],
        [ 42,  48],
        [ 43,   8],
        [ 44,  57],
        [ 45,  25],
        [ 46,  12],
        [ 47,  58],
        [ 48,  32],
        [ 49,  79],


Training:   1%|██                                                                                                                                                           | 6/456 [00:02<02:42,  2.76it/s]

tensor([[  0,  29],
        [  1,  71],
        [  2,  72],
        ...,
        [733, 724],
        [734, 725],
        [735, 658]])
tensor([[  0,  10],
        [  1,  36],
        [  2,  34],
        [  3,  11],
        [  4,   5],
        [  5,  36],
        [  6,  36],
        [  7,  25],
        [  8,  29],
        [  9,   2],
        [ 10,   1],
        [ 11,  13],
        [ 12,  20],
        [ 13,  33],
        [ 14,  42],
        [ 15,   4],
        [ 16,  33],
        [ 17,  19],
        [ 18,  27],
        [ 19,  41],
        [ 20,  28],
        [ 21,   8],
        [ 22,   8],
        [ 23,  31],
        [ 24,  34],
        [ 25,  24],
        [ 26,  28],
        [ 27,   7],
        [ 28,  15],
        [ 29,   0],
        [ 30,  26],
        [ 31,   1],
        [ 32,  18],
        [ 33,  30],
        [ 34,  41],
        [ 35,  31],
        [ 36,   9],
        [ 37,   3],
        [ 38,  39],
        [ 39,   9],
        [ 40,  12],
        [ 41,  36],
        [ 42,  34],
      

Training:   2%|██▊                                                                                                                                                          | 8/456 [00:02<02:24,  3.10it/s]

tensor([[  0,  42],
        [  1,  10],
        [  2,  10],
        ...,
        [606, 585],
        [607, 570],
        [608, 599]])


Training:   2%|███                                                                                                                                                          | 9/456 [00:02<02:02,  3.66it/s]

tensor([[  0,  99],
        [  1,  55],
        [  2,  60],
        [  3,  64],
        [  4,  87],
        [  5,  69],
        [  6,  42],
        [  7,  33],
        [  8,   9],
        [  9,  62],
        [ 10,  37],
        [ 11,  79],
        [ 12, 108],
        [ 13, 129],
        [ 14,  58],
        [ 15,   5],
        [ 16, 119],
        [ 17,  73],
        [ 18,  54],
        [ 19,  59],
        [ 20, 124],
        [ 21,  79],
        [ 22,  78],
        [ 23,   9],
        [ 24,  20],
        [ 25,  18],
        [ 26, 114],
        [ 27,  92],
        [ 28,  75],
        [ 29, 100],
        [ 30,  38],
        [ 31,  55],
        [ 32,  89],
        [ 33,   8],
        [ 34,  61],
        [ 35,  26],
        [ 36,   1],
        [ 37, 100],
        [ 38, 119],
        [ 39,  98],
        [ 40, 131],
        [ 41,  10],
        [ 42, 117],
        [ 43, 131],
        [ 44,  20],
        [ 45, 112],
        [ 46,  83],
        [ 47,  83],
        [ 48,  73],
        [ 49, 131],


Training:   2%|███                                                                                                                                                          | 9/456 [00:03<02:41,  2.76it/s]

tensor([[  0,  52],
        [  1,  26],
        [  2,  46],
        [  3,   1],
        [  4,  57],
        [  5,   3],
        [  6,  16],
        [  7,   9],
        [  8,  20],
        [  9,   3],
        [ 10,   9],
        [ 11,  38],
        [ 12,  30],
        [ 13,  52],
        [ 14,  36],
        [ 15,  32],
        [ 16,  12],
        [ 17,  24],
        [ 18,  53],
        [ 19,  48],
        [ 20,  57],
        [ 21,   9],
        [ 22,  25],
        [ 23,   5],
        [ 24,   4],
        [ 25,  12],
        [ 26,  50],
        [ 27,  24],
        [ 28,  49],
        [ 29,   7],
        [ 30,   4],
        [ 31,  16],
        [ 32,  16],
        [ 33,  13],
        [ 34,  25],
        [ 35,   3],
        [ 36,  28],
        [ 37,  52],
        [ 38,  51],
        [ 39,  23],
        [ 40,   6],
        [ 41,  44],
        [ 42,  48],
        [ 43,  18],
        [ 44,  54],
        [ 45,  47],
        [ 46,  38],
        [ 47,  40],
        [ 48,  45],
        [ 49,   0],





KeyboardInterrupt: 

In [27]:
print(data_train[0].cp[1])

tensor([ 0,  3,  5,  8, -1, -1, -1])
