In [1]:
import numpy as np
import random
import glob
import os.path as osp
import uproot
import awkward as ak
import torch
from torch_geometric.data import Data, Dataset, DataLoader

import os
import glob
import random
import subprocess

import numpy as np
import pandas as pd
import h5py
import uproot
import awkward as ak

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import knn_graph

import tqdm
from tqdm import tqdm

import os
import os.path as osp  # This defines 'osp'
import glob

# Function to find the highest numbered branch matching base_name in an uproot file
def find_highest_branch(path, base_name):
    with uproot.open(path) as f:
        branches = [k for k in f.keys() if k.startswith(base_name + ';')]
        sorted_branches = sorted(branches, key=lambda x: int(x.split(';')[-1]))
        return sorted_branches[-1] if sorted_branches else None

# Function to remove duplicates: for each event, only keep the entry with the highest B value for each unique element in A.
def remove_duplicates(A, B):    
    all_masks = []
    for event_idx, event in enumerate(A):
        flat_A = np.array(ak.flatten(A[event_idx]))
        flat_B = np.array(ak.flatten(B[event_idx]))
        mask = np.zeros_like(flat_A, dtype=bool)
        for elem in np.unique(flat_A):
            indices = np.where(flat_A == elem)[0]
            if len(indices) > 1:
                max_index = indices[np.argmax(flat_B[indices])]
                mask[max_index] = True
            else:
                mask[indices[0]] = True
        unflattened_mask = ak.unflatten(mask, ak.num(A[event_idx]))
        all_masks.append(unflattened_mask)
    return ak.Array(all_masks)

class CCV1(Dataset):
    r'''
    Dataset for layer clusters.
    For each event it builds:
      - x: node features, shape (N, D)
      - groups: top-3 contributing group IDs per node (N, 3)
      - fractions: corresponding energy fractions (N, 3)
      - x_pe: positive edge pairs, shape (N, 2)
      - x_ne: negative edge pairs, shape (N, 2)
    '''
    url = '/dummy/'

    def __init__(self, root, transform=None, max_events=1e8, inp='train'):
        super(CCV1, self).__init__(root, transform)
        self.step_size = 500
        self.inp = inp
        self.max_events = max_events

        # These will hold the precomputed arrays (one per event)
        self.precomputed_groups = []
        self.precomputed_fractions = []
        
        self.fill_data(max_events)
        self.precompute_pairings()

    def fill_data(self, max_events):
        counter = 0
        print("### Loading data")
        for fi, path in enumerate(sorted(glob.glob(osp.join(self.raw_dir, '*.root')))):
            if self.inp in ['train', 'val']:
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
                track_path = find_highest_branch(path, 'tracksters')
            else:
                cluster_path = find_highest_branch(path, 'clusters')
                sim_path = find_highest_branch(path, 'simtrackstersCP')
                track_path = find_highest_branch(path, 'tracksters')
            
            crosstree = uproot.open(path)[cluster_path]
            crosscounter = 0
            
            # Create an iterator for the tracksters branch to load its 'vertices_indexes'
            tracksters_iter = uproot.iterate(
                f"{path}:{track_path}",
                ["vertices_indexes"],
                step_size=self.step_size
            )
            
            for array in uproot.iterate(f"{path}:{sim_path}", 
                                         ["vertices_x", "vertices_y", "vertices_z", 
                                          "vertices_energy", "vertices_multiplicity", "vertices_time", 
                                          "vertices_indexes", "barycenter_x", "barycenter_y", "barycenter_z"],
                                         step_size=self.step_size):
                # Get the tracksters branch data for this iteration
                tmp_tracksters_data = next(tracksters_iter)
                tmp_tracksters_vertices_indexes = tmp_tracksters_data["vertices_indexes"]
                
                # Load the simtrackstersCP branch arrays
                tmp_stsCP_vertices_x = array['vertices_x']
                tmp_stsCP_vertices_y = array['vertices_y']
                tmp_stsCP_vertices_z = array['vertices_z']
                tmp_stsCP_vertices_energy = array['vertices_energy']
                tmp_stsCP_vertices_time = array['vertices_time']
                tmp_stsCP_vertices_indexes = array['vertices_indexes']
                tmp_stsCP_vertices_multiplicity = array['vertices_multiplicity']
                tmp_stsCP_barycenter_x = array['barycenter_x']
                tmp_stsCP_barycenter_y = array['barycenter_y']
                tmp_stsCP_barycenter_z = array['barycenter_z']
                
                self.step_size = min(self.step_size, len(tmp_stsCP_vertices_x))

                tmp_all_vertices_layer_id = crosstree['cluster_layer_id'].array(
                    entry_start=crosscounter*self.step_size,
                    entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_noh = crosstree['cluster_number_of_hits'].array(
                    entry_start=crosscounter*self.step_size,
                    entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_eta = crosstree['position_eta'].array(
                    entry_start=crosscounter*self.step_size,
                    entry_stop=(crosscounter+1)*self.step_size)
                tmp_all_vertices_phi = crosstree['position_phi'].array(
                    entry_start=crosscounter*self.step_size,
                    entry_stop=(crosscounter+1)*self.step_size)
                crosscounter += 1

                layer_id_list = []
                noh_list = []
                eta_list = []
                phi_list = []
                for evt_row in range(len(tmp_all_vertices_noh)):
                    layer_id_list_one_event = []
                    noh_list_one_event = []
                    eta_list_one_event = []
                    phi_list_one_event = []
                    for particle in range(len(tmp_stsCP_vertices_indexes[evt_row])):
                        tmp_stsCP_vertices_layer_id_one_particle = tmp_all_vertices_layer_id[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_noh_one_particle = tmp_all_vertices_noh[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_eta_one_particle = tmp_all_vertices_eta[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        tmp_stsCP_vertices_phi_one_particle = tmp_all_vertices_phi[evt_row][tmp_stsCP_vertices_indexes[evt_row][particle]]
                        layer_id_list_one_event.append(tmp_stsCP_vertices_layer_id_one_particle)
                        noh_list_one_event.append(tmp_stsCP_vertices_noh_one_particle)
                        eta_list_one_event.append(tmp_stsCP_vertices_eta_one_particle)
                        phi_list_one_event.append(tmp_stsCP_vertices_phi_one_particle)
                    layer_id_list.append(layer_id_list_one_event)
                    noh_list.append(noh_list_one_event)
                    eta_list.append(eta_list_one_event)
                    phi_list.append(phi_list_one_event)
                tmp_stsCP_vertices_layer_id = ak.Array(layer_id_list)
                tmp_stsCP_vertices_noh = ak.Array(noh_list)
                tmp_stsCP_vertices_eta = ak.Array(eta_list)
                tmp_stsCP_vertices_phi = ak.Array(phi_list)
                
                # NEW FILTERING: Remove simtracksters entries whose index is not in any tracksters sub-array.
                mask_list = []
                for sim_evt, track_evt in zip(tmp_stsCP_vertices_indexes, tmp_tracksters_vertices_indexes):
                    track_flat = ak.flatten(track_evt)
                    track_set = set(ak.to_list(track_flat))
                    sim_evt_list = ak.to_list(sim_evt)
                    mask_evt = [[elem in track_set for elem in subarr] for subarr in sim_evt_list]
                    mask_list.append(mask_evt)
                mask_track = ak.Array(mask_list)

                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[mask_track]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[mask_track]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[mask_track]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[mask_track]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[mask_track]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[mask_track]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[mask_track]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[mask_track]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[mask_track]
                tmp_stsCP_vertices_indexes = tmp_stsCP_vertices_indexes[mask_track]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[mask_track]

                # Further filtering: remove events with fewer than 2 vertices.
                skim_mask = [len(e) >= 2 for e in tmp_stsCP_vertices_x]
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask]
                tmp_stsCP_vertices_indexes = tmp_stsCP_vertices_indexes[skim_mask]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask]

                if counter == 0:
                    self.stsCP_vertices_indexes_unfilt = tmp_stsCP_vertices_indexes
                    self.stsCP_vertices_multiplicity_unfilt = tmp_stsCP_vertices_multiplicity
                else:
                    self.stsCP_vertices_indexes_unfilt = ak.concatenate(
                        (self.stsCP_vertices_indexes_unfilt, tmp_stsCP_vertices_indexes))
                    self.stsCP_vertices_multiplicity_unfilt = ak.concatenate(
                        (self.stsCP_vertices_multiplicity_unfilt, tmp_stsCP_vertices_multiplicity))
                
                energyPercent = 1 / tmp_stsCP_vertices_multiplicity
                skim_mask_energyPercent = remove_duplicates(tmp_stsCP_vertices_indexes, energyPercent)
                tmp_stsCP_vertices_x = tmp_stsCP_vertices_x[skim_mask_energyPercent]
                tmp_stsCP_vertices_y = tmp_stsCP_vertices_y[skim_mask_energyPercent]
                tmp_stsCP_vertices_z = tmp_stsCP_vertices_z[skim_mask_energyPercent]
                tmp_stsCP_vertices_energy = tmp_stsCP_vertices_energy[skim_mask_energyPercent]
                tmp_stsCP_vertices_time = tmp_stsCP_vertices_time[skim_mask_energyPercent]
                tmp_stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id[skim_mask_energyPercent]
                tmp_stsCP_vertices_noh = tmp_stsCP_vertices_noh[skim_mask_energyPercent]
                tmp_stsCP_vertices_eta = tmp_stsCP_vertices_eta[skim_mask_energyPercent]
                tmp_stsCP_vertices_phi = tmp_stsCP_vertices_phi[skim_mask_energyPercent]
                tmp_stsCP_vertices_indexes_filt = tmp_stsCP_vertices_indexes[skim_mask_energyPercent]
                tmp_stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity[skim_mask_energyPercent]
                
                if counter == 0:
                    self.stsCP_vertices_x = tmp_stsCP_vertices_x
                    self.stsCP_vertices_y = tmp_stsCP_vertices_y
                    self.stsCP_vertices_z = tmp_stsCP_vertices_z
                    self.stsCP_vertices_energy = tmp_stsCP_vertices_energy
                    self.stsCP_vertices_time = tmp_stsCP_vertices_time
                    self.stsCP_vertices_layer_id = tmp_stsCP_vertices_layer_id
                    self.stsCP_vertices_noh = tmp_stsCP_vertices_noh
                    self.stsCP_vertices_eta = tmp_stsCP_vertices_eta
                    self.stsCP_vertices_phi = tmp_stsCP_vertices_phi
                    self.stsCP_vertices_indexes = tmp_stsCP_vertices_indexes
                    self.stsCP_barycenter_x = tmp_stsCP_barycenter_x
                    self.stsCP_barycenter_y = tmp_stsCP_barycenter_y
                    self.stsCP_barycenter_z = tmp_stsCP_barycenter_z
                    self.stsCP_vertices_multiplicity = tmp_stsCP_vertices_multiplicity
                    self.stsCP_vertices_indexes_filt = tmp_stsCP_vertices_indexes_filt
                else:
                    self.stsCP_vertices_x = ak.concatenate((self.stsCP_vertices_x, tmp_stsCP_vertices_x))
                    self.stsCP_vertices_y = ak.concatenate((self.stsCP_vertices_y, tmp_stsCP_vertices_y))
                    self.stsCP_vertices_z = ak.concatenate((self.stsCP_vertices_z, tmp_stsCP_vertices_z))
                    self.stsCP_vertices_energy = ak.concatenate((self.stsCP_vertices_energy, tmp_stsCP_vertices_energy))
                    self.stsCP_vertices_time = ak.concatenate((self.stsCP_vertices_time, tmp_stsCP_vertices_time))
                    self.stsCP_vertices_layer_id = ak.concatenate((self.stsCP_vertices_layer_id, tmp_stsCP_vertices_layer_id))
                    self.stsCP_vertices_noh = ak.concatenate((self.stsCP_vertices_noh, tmp_stsCP_vertices_noh))
                    self.stsCP_vertices_eta = ak.concatenate((self.stsCP_vertices_eta, tmp_stsCP_vertices_eta))
                    self.stsCP_vertices_phi = ak.concatenate((self.stsCP_vertices_phi, tmp_stsCP_vertices_phi))
                    self.stsCP_vertices_indexes = ak.concatenate((self.stsCP_vertices_indexes, tmp_stsCP_vertices_indexes))
                    self.stsCP_barycenter_x = ak.concatenate((self.stsCP_barycenter_x, tmp_stsCP_barycenter_x))
                    self.stsCP_barycenter_y = ak.concatenate((self.stsCP_barycenter_y, tmp_stsCP_barycenter_y))
                    self.stsCP_barycenter_z = ak.concatenate((self.stsCP_barycenter_z, tmp_stsCP_barycenter_z))
                    self.stsCP_vertices_multiplicity = ak.concatenate((self.stsCP_vertices_multiplicity, tmp_stsCP_vertices_multiplicity))
                    self.stsCP_vertices_indexes_filt = ak.concatenate((self.stsCP_vertices_indexes_filt, tmp_stsCP_vertices_indexes_filt))
                
                counter += 1
                if len(self.stsCP_vertices_x) > max_events:
                    print(f"Reached {max_events}!")
                    break
            if len(self.stsCP_vertices_x) > max_events:
                break

    def precompute_pairings(self):
        """
        Precompute the groups, fractions, and also the positive/negative edge pairs for each event.
        """
        n_events = len(self.stsCP_vertices_x)
        # (Clear previous lists if necessary)
        self.precomputed_groups = []
        self.precomputed_fractions = []
        self.precomputed_pos_edges = []
        self.precomputed_neg_edges = []
        for idx in range(n_events):
            unfilt_cp_array = self.stsCP_vertices_indexes_unfilt[idx]
            unfilt_mult_array = self.stsCP_vertices_multiplicity_unfilt[idx]
            cluster_contributors = {}
            for cp_id in range(len(unfilt_cp_array)):
                for local_lc_id, cluster_id in enumerate(unfilt_cp_array[cp_id]):
                    frac = 1.0 / unfilt_mult_array[cp_id][local_lc_id]
                    cluster_contributors.setdefault(cluster_id, []).append((cp_id, frac))
            final_cp_array = self.stsCP_vertices_indexes_filt[idx]
            flattened_cluster_ids = ak.flatten(final_cp_array)
            total_lc = len(flattened_cluster_ids)
            groups_np = np.zeros((total_lc, 3), dtype=np.int64)
            fractions_np = np.zeros((total_lc, 3), dtype=np.float32)
            for global_lc, cluster_id in enumerate(flattened_cluster_ids):
                contribs = cluster_contributors.get(cluster_id, [])
                contribs_sorted = sorted(contribs, key=lambda x: x[1], reverse=True)
                if len(contribs_sorted) == 0:
                    top3 = [(0, 0.0)] * 3
                else:
                    top3 = contribs_sorted[:3]
                    while len(top3) < 3:
                        top3.append(top3[-1])
                for i in range(3):
                    groups_np[global_lc, i] = top3[i][0]
                    fractions_np[global_lc, i] = top3[i][1]
            self.precomputed_groups.append(groups_np)
            self.precomputed_fractions.append(fractions_np)
            
            

    def download(self):
        raise RuntimeError(
            'Dataset not found. Please download it from {} and move all '
            '*.z files to {}'.format(self.url, self.raw_dir))

    def len(self):
        return len(self.stsCP_vertices_x)

    @property
    def raw_file_names(self):
        return sorted(glob.glob(osp.join(self.raw_dir, '*.root')))

    @property
    def processed_file_names(self):
        return []

    def get(self, idx):
        # 1) Flatten node features for the event
        lc_x = self.stsCP_vertices_x[idx]
        lc_y = self.stsCP_vertices_y[idx]
        lc_z = self.stsCP_vertices_z[idx]
        lc_e = self.stsCP_vertices_energy[idx]
        lc_layer_id = self.stsCP_vertices_layer_id[idx]
        lc_noh = self.stsCP_vertices_noh[idx]
        lc_eta = self.stsCP_vertices_eta[idx]
        lc_phi = self.stsCP_vertices_phi[idx]

        flat_lc_x = np.expand_dims(np.array(ak.flatten(lc_x)), axis=1)
        flat_lc_y = np.expand_dims(np.array(ak.flatten(lc_y)), axis=1)
        flat_lc_z = np.expand_dims(np.array(ak.flatten(lc_z)), axis=1)
        flat_lc_e = np.expand_dims(np.array(ak.flatten(lc_e)), axis=1)
        flat_lc_layer_id = np.expand_dims(np.array(ak.flatten(lc_layer_id)), axis=1)
        flat_lc_noh = np.expand_dims(np.array(ak.flatten(lc_noh)), axis=1)
        flat_lc_eta = np.expand_dims(np.array(ak.flatten(lc_eta)), axis=1)
        flat_lc_phi = np.expand_dims(np.array(ak.flatten(lc_phi)), axis=1)
        flat_lc_feats = np.concatenate(
            (flat_lc_x, flat_lc_y, flat_lc_z, flat_lc_e,
             flat_lc_layer_id, flat_lc_noh, flat_lc_eta, flat_lc_phi),
            axis=-1
        )
        total_lc = flat_lc_feats.shape[0]
        x = torch.from_numpy(flat_lc_feats).float()

        # 2) Retrieve precomputed groups, fractions, and edges
        groups_np = self.precomputed_groups[idx]
        fractions_np = self.precomputed_fractions[idx]


        data = Data(
            x=x,
            groups=torch.from_numpy(groups_np),
            fractions=torch.from_numpy(fractions_np)

        )
        return data


In [2]:
ipath = "/vols/cms/mm1221/Data/mix/train/"
vpath = "/vols/cms/mm1221/Data/mix/val/"
data_train = CCV1(ipath, max_events=10, inp = 'train')
data_val = CCV1(vpath, max_events=10, inp='val')

### Loading data
Reached 10!
### Loading data
Reached 10!


In [3]:
j=0
print(data_train[0].groups[j])
print(data_train[0].fractions[j])

tensor([0, 3, 3])
tensor([0.9007, 0.0993, 0.0993])


In [4]:
j=4
print(data_train[0].groups[j])
print(data_train[0].fractions[j])

tensor([0, 3, 3])
tensor([0.9484, 0.0516, 0.0516])


In [17]:
import torch
import torch.nn as nn
from torch_geometric.nn import DynamicEdgeConv

class Net(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=4, dropout=0.3, 
                 contrastive_dim=8, k=20):
        """
        Initializes the neural network with DynamicEdgeConv layers and two heads:
        one for contrastive learning and one for predicting if a node is split.

        Args:
            hidden_dim (int): Dimension of hidden layers.
            num_layers (int): Total number of DynamicEdgeConv layers.
            dropout (float): Dropout rate.
            contrastive_dim (int): Dimension of the contrastive output.
            k (int): Number of nearest neighbors to use in DynamicEdgeConv.
        """
        super(Net, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.contrastive_dim = contrastive_dim
        self.k = k

        # Input encoder
        self.lc_encode = nn.Sequential(
            nn.Linear(8, 32),
            nn.ELU(),
            nn.Linear(32, hidden_dim),
            nn.ELU()
        )

        # Define the network's convolutional layers using DynamicEdgeConv layers
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            # In this example, the same k is used for every layer.
            current_k = self.k
            mlp = nn.Sequential(
                nn.Linear(2 * hidden_dim, hidden_dim),
                nn.ELU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(p=dropout)
            )
            conv = DynamicEdgeConv(mlp, k=current_k, aggr="max")
            self.convs.append(conv)

        # Contrastive output head: produces node-level embeddings.
        self.output = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(64, 32),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(32, contrastive_dim)
        )
        
        # Additional head for predicting whether a node is a split node.
        self.split_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(64, 32),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(32, 1)  # Produces a single logit per node.
        )

    def forward(self, x, batch):
        """
        Forward pass of the network.

        Args:
            x (torch.Tensor): Input node features of shape (N, 8).
            batch (torch.Tensor): Batch vector that assigns each node to an example.

        Returns:
            tuple: (contrastive embeddings, split logits, batch)
        """
        # Input encoding
        x_lc_enc = self.lc_encode(x)  # Shape: (N, hidden_dim)

        # Apply DynamicEdgeConv layers with residual connections.
        feats = x_lc_enc
        for conv in self.convs:
            feats = conv(feats, batch) + feats

        # Contrastive head output.
        out = self.output(feats)
        # Split prediction head output.
        split_logit = self.split_head(feats)
        return out, split_logit, batch


In [21]:
import torch
import torch.nn.functional as F
import random

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

import torch
import torch.nn.functional as F
from collections import defaultdict

import torch
import torch.nn.functional as F
from collections import defaultdict

def build_shared_energy_matrix_vectorized(groups, scores):
    """
    Vectorized approach:
      1) For each row i, collect the (group, fraction) *only once* per group.
      2) For each group_id g, gather all (row_i, fraction_i) pairs.
      3) Use a single big min(...) over all pairs in that group to update shared[i,j].
    """
    N, num_slots = groups.shape
    shared = torch.zeros(N, N)

    # Step 1: Build a dictionary: g_id -> list of (row_index, fraction)
    group_to_rows = defaultdict(list)

    for i in range(N):
        seen = set()
        for s in range(num_slots):
            g_id = groups[i, s].item()
            if g_id not in seen:
                seen.add(g_id)
                frac = scores[i, s].item()
                group_to_rows[g_id].append((i, frac))

    # Step 2: For each group, vectorize the pairwise min
    for g_id, row_frac_list in group_to_rows.items():
        # row_frac_list: [(i1, frac1), (i2, frac2), ...]
        if len(row_frac_list) < 2:
            continue  # Only 1 row => no pairwise contribution

        row_ix = torch.tensor([p[0] for p in row_frac_list], dtype=torch.long)
        frac_ix = torch.tensor([p[1] for p in row_frac_list], dtype=torch.float)
        # frac_ix has shape (m,)

        # This gives an (m, m) matrix of min(...) between each pair of fractions
        # – far more efficient than a Python loop over pairs.
        min_matrix = torch.min(frac_ix.unsqueeze(0), frac_ix.unsqueeze(1))

        # Now "scatter-add" these into the NxN 'shared' matrix
        # The indexing trick: row_ix.unsqueeze(0) is shape (1,m), row_ix.unsqueeze(1) is shape (m,1),
        # so shared[row_ix.unsqueeze(0), row_ix.unsqueeze(1)] is an (m, m) submatrix.
        shared[row_ix.unsqueeze(0), row_ix.unsqueeze(1)] += min_matrix

    return shared


def contrastive_loss_fractional(embeddings, groups, scores, temperature=0.1):
    """
    Same final logic, but uses the vectorized build_shared_energy_matrix_vectorized
    for speed.
    """
    device = embeddings.device
    N, D = embeddings.shape

    # 1) Cosine similarity (N x N)
    norm_emb = F.normalize(embeddings, p=2, dim=1)
    sim_matrix = norm_emb @ norm_emb.t()

    # 2) Build NxN "shared energy" matrix
    shared_energy = build_shared_energy_matrix_vectorized(groups, scores).to(device)
    # 3) Positive vs negative mask & weighting
    pos_mask = (shared_energy >= 0.5)
    neg_mask = ~pos_mask
    pos_weight = torch.zeros_like(shared_energy, device=device)
    neg_weight = torch.zeros_like(shared_energy, device=device)

    pos_weight[pos_mask] = 2.0 * (shared_energy[pos_mask] - 0.5)
    neg_weight[neg_mask] = 2.0 * (0.5 - shared_energy[neg_mask])
    pos_weight.fill_diagonal_(0)
    neg_weight.fill_diagonal_(0)

    # 4) Softmax terms
    exp_sim = torch.exp(sim_matrix / temperature)
    numerator = (pos_weight * exp_sim).sum(dim=1)  # shape (N,)
    denominator = ((pos_weight + neg_weight) * exp_sim).sum(dim=1)  # shape (N,)

    # 5) Filter anchors with no positives
    anchor_has_pos = (pos_weight.sum(dim=1) > 0)
    valid_numerator = numerator[anchor_has_pos]
    valid_denominator = denominator[anchor_has_pos]

    if valid_numerator.numel() == 0:
        return torch.tensor(0.0, device=device, requires_grad=True)

    # 6) Final loss
    loss_per_anchor = -torch.log(valid_numerator / (valid_denominator + 1e-8))
    return loss_per_anchor.mean()





import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from collections import defaultdict

import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

def train_new(train_loader, model, optimizer, device, temperature=0.1, alpha=1.0, beta=1.0, pos_weight=None):
    """
    Training loop that uses the contrastive loss and an additional loss for split node prediction.
    
    The total loss is computed as:
         loss = α * (contrastive loss) + β * (split loss)
         
    Additionally, the function prints the separate contrastive and split loss contributions per batch.
    
    Args:
        train_loader: DataLoader yielding Data objects.
        model: The network model.
        optimizer: Optimizer.
        device: Torch device.
        temperature: Temperature scaling for contrastive loss.
        alpha (float): Weighting factor for the contrastive loss.
        beta (float): Weighting factor for the split loss.
        pos_weight: Tensor or float for weighting positive examples in BCEWithLogitsLoss.
                    If None, defaults to 2.0.
        
    Returns:
        overall_loss, contrast_loss_avg, split_loss_avg: The average losses per node over the training set.
    """
    model.train()
    total_loss = 0.0
    total_contrast_loss = 0.0
    total_split_loss = 0.0
    n_samples = 0

    # Set a default positive weight if not provided.
    if pos_weight is None:
        pos_weight = torch.tensor(2.0, device=device)

    for data in tqdm(train_loader, desc="Training"):
        data = data.to(device)
        optimizer.zero_grad()
        
        # Compute both contrastive embeddings and split logits.
        embeddings, split_logits, _ = model(data.x, data.x_batch)
        
        # Partition by event using data.x_batch.
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        contrastive_loss_sum = 0.0
        split_loss_sum = 0.0
        start_idx = 0
        # Loop over each event in the batch.
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            event_split_logit = split_logits[start_idx:end_idx]  # shape: (count, 1)
            event_groups = data.groups[start_idx:end_idx]
            event_fractions = data.fractions[start_idx:end_idx]

            # 1) Contrastive loss for the event.
            loss_contrast = contrastive_loss_fractional(
                event_embeddings, event_groups, event_fractions, temperature=temperature
            )
            contrastive_loss_sum += loss_contrast

            # 2) Compute split label from event_fractions.
            # For each node, count how many fraction values are >= 0.1.
            below_threshold = (event_fractions >= 0.1).sum(dim=1)
            split_label = (below_threshold >= 2).float()  # shape: (count,)

            # 3) Compute BCEWithLogitsLoss for split classification.
            event_split_logit = event_split_logit.view(-1)  # shape: (count,)
            loss_split = F.binary_cross_entropy_with_logits(
                event_split_logit, split_label, pos_weight=pos_weight
            )
            split_loss_sum += loss_split

            n_samples += event_embeddings.size(0)
            start_idx = end_idx
        
        # Average losses across events in the current batch.
        num_events = len(counts)
        batch_contrast_loss = contrastive_loss_sum / num_events
        batch_split_loss = split_loss_sum / num_events
        
        total_batch_loss = alpha * batch_contrast_loss + beta * batch_split_loss
        total_batch_loss.backward()
        optimizer.step()

        total_loss += total_batch_loss.item() * embeddings.size(0)
        total_contrast_loss += batch_contrast_loss.item() * embeddings.size(0)
        total_split_loss += batch_split_loss.item() * embeddings.size(0)
        
        
    overall_loss = total_loss / n_samples
    contrast_loss_avg = total_contrast_loss / n_samples
    split_loss_avg = total_split_loss / n_samples
    return overall_loss, contrast_loss_avg, split_loss_avg


@torch.no_grad()
def test_new(test_loader, model, device, temperature=0.1, alpha=1.0, beta=1.0, pos_weight=None):
    """
    Validation loop that computes both contrastive loss and split loss.
    
    The total loss is computed as:
         loss = α * (contrastive loss) + β * (split loss)
    
    Args:
        test_loader: DataLoader yielding Data objects.
        model: The network model.
        device: Torch device.
        temperature: Temperature scaling for contrastive loss.
        alpha (float): Weighting factor for the contrastive loss.
        beta (float): Weighting factor for the split loss.
        pos_weight: Tensor or float for BCEWithLogitsLoss positive weighting.
                    If None, defaults to 2.0.
        
    Returns:
        overall_loss, contrast_loss_avg, split_loss_avg: The average losses per node over the validation set.
    """
    model.eval()
    total_loss = 0.0
    total_contrast_loss = 0.0
    total_split_loss = 0.0
    n_samples = 0

    if pos_weight is None:
        pos_weight = torch.tensor(2.0, device=device)

    for data in tqdm(test_loader, desc="Validation"):
        data = data.to(device)
        embeddings, split_logits, _ = model(data.x, data.x_batch)
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = 0.0
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            event_split_logit = split_logits[start_idx:end_idx]
            event_groups = data.groups[start_idx:end_idx]
            event_fractions = data.fractions[start_idx:end_idx]
            
            loss_contrast = contrastive_loss_fractional(
                event_embeddings, event_groups, event_fractions, temperature=temperature
            )
            
            below_threshold = (event_fractions >= 0.1).sum(dim=1)
            split_label = (below_threshold >= 2).float()  # shape: (count,)
            event_split_logit = event_split_logit.view(-1)
            loss_split = F.binary_cross_entropy_with_logits(
                event_split_logit, split_label, pos_weight=pos_weight
            )
            
            loss_event = alpha * loss_contrast + beta * loss_split
            loss_event_total += loss_event
            n_samples += event_embeddings.size(0)
            start_idx = end_idx
        
        total_loss += loss_event_total / (len(counts) if len(counts) > 0 else 1)
        
    overall_loss = total_loss / n_samples
    contrast_loss_avg = total_contrast_loss / n_samples  # (Not accumulated separately in this loop)
    split_loss_avg = total_split_loss / n_samples        # (Not accumulated separately in this loop)
    return overall_loss, contrast_loss_avg, split_loss_avg





In [None]:
import os
import torch
import torch.optim
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import numpy as np
import pandas as pd

# Initialize model with passed hyperparameters
model = Net(
    hidden_dim=64,
    num_layers=3,
    dropout=0.3,
    contrastive_dim=64,
    k=16
)

BS = 1
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

device = torch.device('cpu')

# Load DataLoader with current batch_size
train_loader = DataLoader(data_train, batch_size=1, shuffle=False, follow_batch=['x'])
val_loader = DataLoader(data_val, batch_size=1, shuffle=False, follow_batch=['x'])

# Output directory for saving models and results
output_dir = '/vols/cms/mm1221/hgcal/elec5New/LC/Fraction/resultstest/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Initialize arrays for storing losses per epoch.
train_overall_losses = []
train_contrast_losses = []
train_split_losses = []
val_overall_losses = []
val_contrast_losses = []
val_split_losses = []

best_val_loss = float('inf')
patience = 30
epochs = 5
no_improvement_epochs = 0

for epoch in range(epochs):
    # For epochs 1 to 150, gradually increase alpha from 0 to 1.
    # From epoch 151 onward, set alpha = 1 (fully hard negatives).
    # Here, as an example, we use a fixed alpha. (You can change it as needed.)
    alpha_val = 1.0  
    beta_val = 1.0

    print(f"Epoch {epoch+1}/{epochs} ")
    
    # The updated train_new now returns three losses: overall, contrastive, and split.
    train_overall, train_contrast, train_split = train_new(
        train_loader, model, optimizer, device, temperature=0.1, alpha=alpha_val, beta=beta_val
    )
    val_overall, val_contrast, val_split = test_new(
        val_loader, model, device, temperature=0.1, alpha=alpha_val, beta=beta_val
    )

    train_overall_losses.append(train_overall)
    train_contrast_losses.append(train_contrast)
    train_split_losses.append(train_split)
    val_overall_losses.append(val_overall)
    val_contrast_losses.append(val_contrast)
    val_split_losses.append(val_split)

    scheduler.step()

    # Save best model if validation loss improves.
    if val_overall < best_val_loss:
        best_val_loss = val_overall
        no_improvement_epochs = 0
        torch.save(model.state_dict(), os.path.join(output_dir, 'best_model.pt'))
    else:
        no_improvement_epochs += 1

    # Save intermediate checkpoint.
    state_dicts = {'model': model.state_dict(),
                   'opt': optimizer.state_dict(),
                   'lr': scheduler.state_dict()}
    torch.save(state_dicts, os.path.join(output_dir, f'epoch-{epoch+1}.pt'))

    print(f"Epoch {epoch+1}/{epochs} - "
          f"Train Overall: {train_overall:.8f}, Contrast: {train_contrast:.8f}, Split: {train_split:.8f} | "
          f"Val Overall: {val_overall:.8f}, Contrast: {val_contrast:.8f}, Split: {val_split:.8f}")

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered. No improvement for {patience} epochs.")
        break

# Save training history to CSV.
results_df = pd.DataFrame({
    'epoch': list(range(1, len(train_overall_losses) + 1)),
    'train_overall_loss': train_overall_losses,
    'train_contrast_loss': train_contrast_losses,
    'train_split_loss': train_split_losses,
    'val_overall_loss': val_overall_losses,
    'val_contrast_loss': val_contrast_losses,
    'val_split_loss': val_split_losses
})
results_csv_path = os.path.join(output_dir, 'continued_training_loss.csv')
results_df.to_csv(results_csv_path, index=False)
print(f"Saved loss curves to {results_csv_path}")

# Save final model.
torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pt'))
print("Training complete. Final model saved.")


Epoch 1/5 


Training:  95%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌       | 431/456 [02:58<00:09,  2.69it/s]

In [15]:
print(data_train[1].groups[389])
print(data_train[1].fractions[389])

tensor([9, 8, 1])
tensor([0.4331, 0.3189, 0.2480])


In [16]:
print(data_train[1].groups[388])
print(data_train[1].fractions[388])

tensor([9, 9, 9])
tensor([1., 1., 1.])


In [63]:
print(lc_ind[3][3])

ValueError: in ListOffsetArray64 attempting to get 3, index out of range

(https://github.com/scikit-hep/awkward-1.0/blob/1.10.3/src/libawkward/array/ListOffsetArray.cpp#L682)

In [65]:
lc_x = data['simtrackstersCP;2']['vertices_x'].array()

In [88]:
print(lc_x[0][0])

[-59.1, -24.8, -58.7, -40.5, -60.4, -48, ... -40.5, -41.2, -43.3, -45.4, -51.6, -56]
