# Loading Training Data

In [28]:
import glob
import os.path as osp
import uproot
import awkward as ak
import torch
import numpy as np
import random
import tqdm
from torch_geometric.data import Data, Dataset

import numpy as np
import subprocess
import tqdm
from tqdm import tqdm
import pandas as pd

import os
import os.path as osp

import glob

import h5py
import uproot

import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.data import Dataset
from torch_geometric.data import DataLoader

import awkward as ak
import random
from torch_geometric.nn import knn_graph
import torch.nn.functional as F

In [29]:

def find_highest_branch(path, base_name):
    with uproot.open(path) as f:
        # Find keys that exactly match the base_name (not containing other variations)
        branches = [k for k in f.keys() if k.startswith(base_name + ';')]
        
        # Sort and select the highest-numbered branch
        sorted_branches = sorted(branches, key=lambda x: int(x.split(';')[-1]))
        return sorted_branches[-1] if sorted_branches else None

class CCV1(Dataset):
    r'''
    Loads trackster-level features and associations for positive/negative edge creation.
    '''

    url = '/dummy/'

    def __init__(self, root, transform=None, max_events=1e8, inp='train'):
        super(CCV1, self).__init__(root, transform)
        self.inp = inp
        self.max_events = max_events
        self.fill_data(max_events)

    def fill_data(self, max_events):
        counter = 0
        print("### Loading tracksters data")


        for path in tqdm(self.raw_paths):
            print(path)
            
            tracksters_path = find_highest_branch(path, 'tracksters')
            associations_path = find_highest_branch(path, 'associations')
            simtrack = find_highest_branch(path, 'simtrackstersCP')
            # Load tracksters features in chunks
            for array in uproot.iterate(
                f"{path}:{tracksters_path}",
                [
                    "time", "raw_energy",
                    "barycenter_x", "barycenter_y", "barycenter_z", 
                    "barycenter_eta", "barycenter_phi",
                    "EV1", "EV2", "EV3",
                    "eVector0_x", "eVector0_y", "eVector0_z",
                    "sigmaPCA1", "sigmaPCA2", "sigmaPCA3"
                ],
            ):

                tmp_time = array["time"]
                tmp_raw_energy = array["raw_energy"]
                tmp_bx = array["barycenter_x"]
                tmp_by = array["barycenter_y"]
                tmp_bz = array["barycenter_z"]
                tmp_beta = array["barycenter_eta"]
                tmp_bphi = array["barycenter_phi"]
                tmp_EV1 = array["EV1"]
                tmp_EV2 = array["EV2"]
                tmp_EV3 = array["EV3"]
                tmp_eV0x = array["eVector0_x"]
                tmp_eV0y = array["eVector0_y"]
                tmp_eV0z = array["eVector0_z"]
                tmp_sigma1 = array["sigmaPCA1"]
                tmp_sigma2 = array["sigmaPCA2"]
                tmp_sigma3 = array["sigmaPCA3"]
                
                
                vert_array = []
                for vert_chunk in uproot.iterate(
                    f"{path}:{simtrack}",
                    ["barycenter_x"],
                ):
                    vert_array = vert_chunk["barycenter_x"]
                    break  # Since we have a matching chunk, no need to continue
                

                # Now load the associations for the same events/chunk
                # 'tsCLUE3D_recoToSim_CP' gives association arrays like [[1,0],[0,1],...]
                # Make sure we read from the same events
                tmp_array = []
                for assoc_chunk in uproot.iterate(
                    f"{path}:{associations_path}",
                    ["tsCLUE3D_recoToSim_CP"],
                ):
                    tmp_array = assoc_chunk["tsCLUE3D_recoToSim_CP"]
                    break  # Since we have a matching chunk, no need to continue
                
                
                skim_mask = []
                for e in vert_array:
                    if 1 <=len(e) <= 5:
                        skim_mask.append(True)
                    elif len(e) == 0:
                        skim_mask.append(False)

                    else:
                        skim_mask.append(False)



                tmp_time = tmp_time[skim_mask]
                tmp_raw_energy = tmp_raw_energy[skim_mask]
                tmp_bx = tmp_bx[skim_mask]
                tmp_by = tmp_by[skim_mask]
                tmp_bz = tmp_bz[skim_mask]
                tmp_beta = tmp_beta[skim_mask]
                tmp_bphi = tmp_bphi[skim_mask]
                tmp_EV1 = tmp_EV1[skim_mask]
                tmp_EV2 = tmp_EV2[skim_mask]
                tmp_EV3 = tmp_EV3[skim_mask]
                tmp_eV0x = tmp_eV0x[skim_mask]
                tmp_eV0y = tmp_eV0y[skim_mask]
                tmp_eV0z = tmp_eV0z[skim_mask]
                tmp_sigma1 = tmp_sigma1[skim_mask]
                tmp_sigma2 = tmp_sigma2[skim_mask]
                tmp_sigma3 = tmp_sigma3[skim_mask]
                tmp_array = tmp_array[skim_mask]
                
                skim_mask = []
                for e in tmp_array:
                    if 2 <= len(e):
                        skim_mask.append(True)

                    elif len(e) == 0:
                        skim_mask.append(False)

                    else:
                        skim_mask.append(False)

                        
                tmp_time = tmp_time[skim_mask]
                tmp_raw_energy = tmp_raw_energy[skim_mask]
                tmp_bx = tmp_bx[skim_mask]
                tmp_by = tmp_by[skim_mask]
                tmp_bz = tmp_bz[skim_mask]
                tmp_beta = tmp_beta[skim_mask]
                tmp_bphi = tmp_bphi[skim_mask]
                tmp_EV1 = tmp_EV1[skim_mask]
                tmp_EV2 = tmp_EV2[skim_mask]
                tmp_EV3 = tmp_EV3[skim_mask]
                tmp_eV0x = tmp_eV0x[skim_mask]
                tmp_eV0y = tmp_eV0y[skim_mask]
                tmp_eV0z = tmp_eV0z[skim_mask]
                tmp_sigma1 = tmp_sigma1[skim_mask]
                tmp_sigma2 = tmp_sigma2[skim_mask]
                tmp_sigma3 = tmp_sigma3[skim_mask]
                tmp_array = tmp_array[skim_mask]

                
                # Concatenate or initialize storage
                if counter == 0:
                    self.time = tmp_time
                    self.raw_energy = tmp_raw_energy
                    self.bx = tmp_bx
                    self.by = tmp_by
                    self.bz = tmp_bz
                    self.beta = tmp_beta
                    self.bphi = tmp_bphi
                    self.EV1 = tmp_EV1
                    self.EV2 = tmp_EV2
                    self.EV3 = tmp_EV3
                    self.eV0x = tmp_eV0x
                    self.eV0y = tmp_eV0y
                    self.eV0z = tmp_eV0z
                    self.sigma1 = tmp_sigma1
                    self.sigma2 = tmp_sigma2
                    self.sigma3 = tmp_sigma3
                    self.assoc = tmp_array
                else:
                    self.time = ak.concatenate((self.time, tmp_time))
                    self.raw_energy = ak.concatenate((self.raw_energy, tmp_raw_energy))
                    self.bx = ak.concatenate((self.bx, tmp_bx))
                    self.by = ak.concatenate((self.by, tmp_by))
                    self.bz = ak.concatenate((self.bz, tmp_bz))
                    self.beta = ak.concatenate((self.beta, tmp_beta))
                    self.bphi = ak.concatenate((self.bphi, tmp_bphi))
                    self.EV1 = ak.concatenate((self.EV1, tmp_EV1))
                    self.EV2 = ak.concatenate((self.EV2, tmp_EV2))
                    self.EV3 = ak.concatenate((self.EV3, tmp_EV3))
                    self.eV0x = ak.concatenate((self.eV0x, tmp_eV0x))
                    self.eV0y = ak.concatenate((self.eV0y, tmp_eV0y))
                    self.eV0z = ak.concatenate((self.eV0z, tmp_eV0z))
                    self.sigma1 = ak.concatenate((self.sigma1, tmp_sigma1))
                    self.sigma2 = ak.concatenate((self.sigma2, tmp_sigma2))
                    self.sigma3 = ak.concatenate((self.sigma3, tmp_sigma3))
                    self.assoc = ak.concatenate((self.assoc, tmp_array))

                counter += len(tmp_bx)
                if counter >= max_events:
                    print(f"Reached {max_events} events!")
                    break
            if counter >= max_events:
                break

    def download(self):
        raise RuntimeError(
            f'Dataset not found. Please download it from {self.url} and move all '
            f'*.root files to {self.raw_dir}')

    def len(self):
        return len(self.time)

    @property
    def raw_file_names(self):
        raw_files = sorted(glob.glob(osp.join(self.raw_dir, '*.root')))
        return raw_files

    @property
    def processed_file_names(self):
        return []



    def get(self, idx):
        
        def reconstruct_array(grouped_indices):
            # Find the maximum index to determine the array length
            max_index = max(max(indices) for indices in grouped_indices.values())

            # Initialize an array with the correct size, filled with a placeholder (e.g., -1)
            reconstructed = [-1] * (max_index + 1)

            # Populate the array based on the dictionary
            for value, indices in grouped_indices.items():
                for idx in indices:
                    reconstructed[idx] = value

            return reconstructed
        # Extract per-event arrays
        event_time = self.time[idx]
        event_raw_energy = self.raw_energy[idx]
        event_bx = self.bx[idx]
        event_by = self.by[idx]
        event_bz = self.bz[idx]
        event_beta = self.beta[idx]
        event_bphi = self.bphi[idx]
        event_EV1 = self.EV1[idx]
        event_EV2 = self.EV2[idx]
        event_EV3 = self.EV3[idx]
        event_eV0x = self.eV0x[idx]
        event_eV0y = self.eV0y[idx]
        event_eV0z = self.eV0z[idx]
        event_sigma1 = self.sigma1[idx]
        event_sigma2 = self.sigma2[idx]
        event_sigma3 = self.sigma3[idx]
        event_assoc = self.assoc[idx]  # Each is now an array (e.g., [0, 1, 2]) indicating the pion group

        # Convert each to numpy arrays if needed
        event_time = np.array(event_time)
        event_raw_energy = np.array(event_raw_energy)
        event_bx = np.array(event_bx)
        event_by = np.array(event_by)
        event_bz = np.array(event_bz)
        event_beta = np.array(event_beta)
        event_bphi = np.array(event_bphi)
        event_EV1 = np.array(event_EV1)
        event_EV2 = np.array(event_EV2)
        event_EV3 = np.array(event_EV3)
        event_eV0x = np.array(event_eV0x)
        event_eV0y = np.array(event_eV0y)
        event_eV0z = np.array(event_eV0z)
        event_sigma1 = np.array(event_sigma1)
        event_sigma2 = np.array(event_sigma2)
        event_sigma3 = np.array(event_sigma3)
        event_assoc = np.array(event_assoc)  # e.g. [[0,1,2], [2,0,1], [1,0,2], ...]

        # Combine features
        flat_feats = np.column_stack((
            event_bx, event_by, event_bz, event_raw_energy,
            event_beta, event_bphi,
            event_EV1, event_EV2, event_EV3,
            event_eV0x, event_eV0y, event_eV0z,
            event_sigma1, event_sigma2, event_sigma3
        ))
        x = torch.from_numpy(flat_feats).float()
        assoc = event_assoc

        total_tracksters = len(event_time)

        # --------------------------------------------------------------------
        # Group tracksters by their association tuple. Two tracksters belong
        # to the same pion group if their association arrays (converted to tuples)
        # match.
        # --------------------------------------------------------------------
        # Group tracksters by their association tuple
         # Group tracksters by the first element of event_assoc
        assoc_groups = {}
        for i, assoc in enumerate(event_assoc):
            key = assoc[0]  # Only use the first element as the key
            if key not in assoc_groups:
                assoc_groups[key] = []
            assoc_groups[key].append(i)
        assoc_array = reconstruct_array(assoc_groups)
        pos_edges = []
        neg_edges = []
        # Ensure positive edges always connect to another trackster in the same group if possible
        for i in range(total_tracksters):
            key = event_assoc[i][0]  # Get first element as group identifier
            same_group = assoc_groups[key]

            # --- Positive edge ---
            if len(same_group) > 1:
                # Always select another trackster from the same group
                pos_target = random.choice([j for j in same_group if j != i])
            else:
                # No other trackster in the group, form a self-loop
                pos_target = i
            pos_edges.append([i, pos_target])

            # --- Negative edge ---
            neg_candidates = [j for j in range(total_tracksters) if event_assoc[j][0] != key]
            if neg_candidates:
                neg_target = random.choice(neg_candidates)
            else:
                neg_target = i
            neg_edges.append([i, neg_target])

        x_pos_edge = torch.tensor(pos_edges, dtype=torch.long)
        x_neg_edge = torch.tensor(neg_edges, dtype=torch.long)

        return Data(x=x, x_pe=x_pos_edge, x_ne=x_neg_edge, assoc = assoc_array)

In [3]:
# Load datasets
ipath = "/vols/cms/mm1221/Data/100k/5pi/train/"
vpath = "/vols/cms/mm1221/Data/100k/5pi/val/"

data_train = CCV1(ipath, max_events=80000, inp = 'train')
data_val = CCV1(vpath, max_events=10000, inp='val')

### Loading tracksters data


  0%|                                                                                                                                                                                 | 0/1 [00:00<?, ?it/s]

/vols/cms/mm1221/Data/100k/5pi/train/raw/train.root


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:42<00:00, 42.93s/it]


### Loading tracksters data


  0%|                                                                                                                                                                                 | 0/2 [00:00<?, ?it/s]

/vols/cms/mm1221/Data/100k/5pi/val/raw/2k5pi.root


 50%|████████████████████████████████████████████████████████████████████████████████████▌                                                                                    | 1/2 [00:01<00:01,  1.20s/it]

/vols/cms/mm1221/Data/100k/5pi/val/raw/8k5pi.root


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.89s/it]


In [23]:
# Also load explicitely, used for analysis and plots
data_path = '/vols/cms/mm1221/Data/100k/5pi/test/raw/test.root'
data_file = uproot.open(data_path)

Track_ind = data_file['tracksters;2']['vertices_indexes'].array()
GT_ind = data_file['simtrackstersCP;3']['vertices_indexes'].array()
GT_mult = data_file['simtrackstersCP;3']['vertices_multiplicity'].array()
GT_bc = data_file['simtrackstersCP;3']['barycenter_x'].array()
energies = data_file['clusters;4']['energy'].array()
LC_x = data_file['clusters;4']['position_x'].array()
LC_y = data_file['clusters;4']['position_y'].array()
LC_z = data_file['clusters;4']['position_z'].array()
LC_eta = data_file['clusters;4']['position_eta'].array()
MT_ind = data_file['trackstersMerged;2']['vertices_indexes'].array()

#1.3 Filter so get rid of events with 0 calo particles
skim_mask = []
for e in GT_bc:
    if 1 <= len(e) <=5 :
        skim_mask.append(True)
    else:
        skim_mask.append(False)

Track_ind = Track_ind[skim_mask]
GT_ind = GT_ind[skim_mask]
GT_mult = GT_mult[skim_mask]
energies = energies[skim_mask]
LC_x = LC_x[skim_mask]
LC_y = LC_y[skim_mask]
LC_z = LC_z[skim_mask]
LC_eta = LC_eta[skim_mask]
MT_ind = MT_ind[skim_mask]


# Initialise the Network

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class CustomStaticEdgeConv(nn.Module):
    def __init__(self, nn_module):
        super(CustomStaticEdgeConv, self).__init__()
        self.nn_module = nn_module

    def forward(self, x, edge_index):
        """
        Args:
            x (torch.Tensor): Node features of shape (N, F).
            edge_index (torch.Tensor): Predefined edges [2, E], where E is the number of edges.

        Returns:
            torch.Tensor: Node features after static edge aggregation.
        """
        row, col = edge_index  # Extract row (source) and col (target) nodes
        x_center = x[row]
        x_neighbor = x[col]

        # Compute edge features (relative)
        edge_features = torch.cat([x_center, x_neighbor - x_center], dim=-1)
        edge_features = self.nn_module(edge_features)

        # Aggregate features back to nodes
        num_nodes = x.size(0)
        node_features = torch.zeros(num_nodes, edge_features.size(-1), device=x.device)
        node_features.index_add_(0, row, edge_features)

        # Normalization (Divide by node degrees)
        counts = torch.bincount(row, minlength=num_nodes).clamp(min=1).view(-1, 1)
        node_features = node_features / counts

        return node_features


class CustomGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, heads=1, concat=True, dropout=0.6, alpha=0.4):
        """
        Initializes the Custom GAT Layer.

        Args:
            in_dim (int): Input feature dimension.
            out_dim (int): Output feature dimension per head.
            heads (int): Number of attention heads.
            concat (bool): Whether to concatenate the heads' output or average them.
            dropout (float): Dropout rate on attention coefficients.
            alpha (float): Negative slope for LeakyReLU.
        """
        super(CustomGATLayer, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.heads = heads
        self.concat = concat

        # Linear transformation for node features
        self.W = nn.Linear(in_dim, heads * out_dim, bias=False)

        # Attention mechanism: a vector for each head
        self.a_src = nn.Parameter(torch.zeros(heads, out_dim))
        self.a_tgt = nn.Parameter(torch.zeros(heads, out_dim))
        nn.init.xavier_uniform_(self.a_src.data, gain=1.414)
        nn.init.xavier_uniform_(self.a_tgt.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(alpha)
        self.dropout = nn.Dropout(dropout)

        # Optional batch normalization
        self.batch_norm = nn.BatchNorm1d(heads * out_dim) if concat else nn.BatchNorm1d(out_dim)

    def forward(self, x, edge_index):
        """
        Forward pass of the GAT layer.

        Args:
            x (torch.Tensor): Node features of shape (N, in_dim).
            edge_index (torch.Tensor): Edge indices of shape (2, E).

        Returns:
            torch.Tensor: Updated node features after attention-based aggregation.
        """
        src, tgt = edge_index  # Source and target node indices
        N = x.size(0)

        # Apply linear transformation and reshape for multi-head attention
        h = self.W(x)  # Shape: (N, heads * out_dim)
        h = h.view(N, self.heads, self.out_dim)  # Shape: (N, heads, out_dim)

        # Gather node features for each edge
        h_src = h[src]  # Shape: (E, heads, out_dim)
        h_tgt = h[tgt]  # Shape: (E, heads, out_dim)

        # Compute attention coefficients using separate vectors for source and target
        e_src = (h_src * self.a_src).sum(dim=-1)  # Shape: (E, heads)
        e_tgt = (h_tgt * self.a_tgt).sum(dim=-1)  # Shape: (E, heads)
        e = self.leakyrelu(e_src + e_tgt)  # Shape: (E, heads)

        # Compute softmax normalization for attention coefficients
        # To ensure numerical stability
        e = e - e.max(dim=0, keepdim=True)[0]
        alpha = torch.exp(e)  # Shape: (E, heads)

        # Sum of attention coefficients for each target node and head
        alpha_sum = torch.zeros(N, self.heads, device=x.device).scatter_add_(0, tgt.unsqueeze(-1).expand(-1, self.heads), alpha)

        # Avoid division by zero
        alpha_sum = alpha_sum + 1e-16

        # Normalize attention coefficients
        alpha = alpha / alpha_sum[tgt]  # Shape: (E, heads)
        alpha = self.dropout(alpha)

        # Weighted aggregation of source node features
        h_prime = h_src * alpha.unsqueeze(-1)  # Shape: (E, heads, out_dim)

        # Initialize output tensor and aggregate
        out = torch.zeros(N, self.heads, self.out_dim, device=x.device)
        out.scatter_add_(0, tgt.unsqueeze(-1).unsqueeze(-1).expand(-1, self.heads, self.out_dim), h_prime)  # Shape: (N, heads, out_dim)

        # Concatenate or average the heads
        if self.concat:
            out = out.view(N, self.heads * self.out_dim)  # Shape: (N, heads*out_dim)
        else:
            out = out.mean(dim=1)  # Shape: (N, out_dim)

        # Apply batch normalization
        out = self.batch_norm(out)

        return out

class Net(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=4, dropout=0.3, contrastive_dim=8, heads=4):
        """
        Initializes the neural network with alternating StaticEdgeConv and GAT layers.

        Args:
            hidden_dim (int): Dimension of hidden layers.
            num_layers (int): Total number of convolutional layers (both StaticEdgeConv and GAT).
            dropout (float): Dropout rate.
            contrastive_dim (int): Dimension of the contrastive output.
            heads (int): Number of attention heads in GAT layers.
        """
        super(Net, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.contrastive_dim = contrastive_dim
        self.heads = heads

        # Input encoder
        self.lc_encode = nn.Sequential(
            nn.Linear(15, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ELU()
        )

        # Define the network's convolutional layers, alternating between StaticEdgeConv and GAT
        self.convs = nn.ModuleList()
        for layer_idx in range(num_layers):
            if layer_idx % 2 == 0:
                # Even-indexed layers: StaticEdgeConv
                conv = CustomStaticEdgeConv(
                    nn.Sequential(
                        nn.Linear(2 * hidden_dim, hidden_dim),
                        nn.ELU(),
                        nn.BatchNorm1d(hidden_dim),
                        nn.Dropout(p=dropout)
                    )
                )
                self.convs.append(conv)
            else:
                # Odd-indexed layers: GAT
                gat = CustomGATLayer(
                    in_dim=hidden_dim,
                    out_dim=hidden_dim // heads if heads > 1 else hidden_dim,
                    heads=heads,
                    concat=True,
                    dropout=0.6,
                    alpha=0.4
                )
                self.convs.append(gat)

        # Output layer
        self.output = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(32, 16),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(16, contrastive_dim)
        )

    def forward(self, x, edge_index, batch):
        """
        Forward pass of the network.

        Args:
            x (torch.Tensor): Input node features of shape (N, 15).
            edge_index (torch.Tensor): Edge indices of shape (2, E).
            batch (torch.Tensor): Batch vector.

        Returns:
            torch.Tensor: Output features after processing.
            torch.Tensor: Batch vector.
        """
        # Input encoding
        x_lc_enc = self.lc_encode(x)  # Shape: (N, hidden_dim)

        # Apply convolutional layers with residual connections
        feats = x_lc_enc
        for idx, conv in enumerate(self.convs):
            feats = conv(feats, edge_index) + feats  # Residual connection

        # Final output
        out = self.output(feats)
        return out, batch


In [8]:
# Initialize model with passed hyperparameters
model = Net(
    hidden_dim=128,
    num_layers=4,
    dropout=0.3,
    contrastive_dim=64,
    heads=16
)

k = 16
BS = 64

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

# Define The Loss term and the Training + Val Setup

In [6]:
import torch
import torch.nn.functional as F
import random

def contrastive_loss_random(embeddings, pos_indices, group_ids, temperature=0.3):
    """
    Contrastive loss using a randomly selected negative.
    """
    loss_sum = 0.0
    count = 0
    group_ids = group_ids.long()
    
    for i in range(len(embeddings)):
        anchor = embeddings[i]
        positive = embeddings[pos_indices[i]]
        neg_mask = (group_ids != group_ids[i])
        if neg_mask.sum() == 0:
            continue
        negatives = embeddings[neg_mask]
        # Randomly sample one negative from the candidates.
        rand_idx = torch.randint(0, negatives.size(0), (1,)).item()
        neg_sample = negatives[rand_idx]
        
        pos_sim = F.cosine_similarity(anchor.unsqueeze(0), positive.unsqueeze(0))
        neg_sim = F.cosine_similarity(anchor.unsqueeze(0), neg_sample.unsqueeze(0))
        
        loss = -torch.log(
            torch.exp(pos_sim/temperature) / (torch.exp(pos_sim/temperature) + torch.exp(neg_sim/temperature))
        )
        loss_sum += loss
        count += 1
    return loss_sum / count if count > 0 else torch.tensor(0.0, device=embeddings.device)

def contrastive_loss_hard(embeddings, pos_indices, group_ids, temperature=0.3):
    """
    Contrastive loss using hard negative mining.
    """
    loss_sum = 0.0
    count = 0
    group_ids = group_ids.long()
    
    for i in range(len(embeddings)):
        anchor = embeddings[i]
        positive = embeddings[pos_indices[i]]
        neg_mask = (group_ids != group_ids[i])
        if neg_mask.sum() == 0:
            continue
        negatives = embeddings[neg_mask]
        # Hard negative: the candidate with maximum cosine similarity.
        cos_sim = F.cosine_similarity(anchor.unsqueeze(0), negatives)
        hard_neg_sim = cos_sim.max()
        
        pos_sim = F.cosine_similarity(anchor.unsqueeze(0), positive.unsqueeze(0))
        
        loss = -torch.log(
            torch.exp(pos_sim/temperature) / (torch.exp(pos_sim/temperature) + torch.exp(hard_neg_sim/temperature))
        )
        loss_sum += loss
        count += 1
    return loss_sum / count if count > 0 else torch.tensor(0.0, device=embeddings.device)

def contrastive_loss_curriculum(embeddings, pos_indices, group_ids, temperature=0.3, alpha=1.0):
    """
    Blends the random-negative loss and the hard-negative loss.
    When alpha=0, uses only random negatives (easy scenario);
    When alpha=1, uses only hard negatives.
    """
    loss_random = contrastive_loss_random(embeddings, pos_indices, group_ids, temperature)
    loss_hard = contrastive_loss_hard(embeddings, pos_indices, group_ids, temperature)
    return (1 - alpha) * loss_random + alpha * loss_hard


In [7]:

##############################################
# OLD CONTRASTIVE LOSS & TRAIN/TEST FUNCTIONS
##############################################

def contrastive_loss(start_all, end_all, temperature=0.1):
    # Normalize the start and end embeddings
    z_start = F.normalize(start_all, dim=1)
    z_end = F.normalize(end_all, dim=1)
    # Split into positive and negative halves
    half = int(len(z_start) / 2)
    positives = torch.exp(F.cosine_similarity(z_start[:half], z_end[:half], dim=1))
    negatives = torch.exp(F.cosine_similarity(z_start[half:], z_end[half:], dim=1))
    nominator = positives / temperature
    denominator = negatives
    loss = torch.exp(-nominator.sum() / denominator.sum())
    return loss

def train_old(train_loader, model, optimizer, device, k_value):
    model.train()
    total_loss = 0.0
    for data in tqdm.tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()

        edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        out = model(data.x, edge_index, data.x_batch)

        # Get event partitioning information.
        values, counts = np.unique(data.x_batch.detach().cpu().numpy(), return_counts=True)
        losses = []
        for e in range(len(counts)):
            lower_edge = 0 if e == 0 else np.sum(counts[:e])
            upper_edge = lower_edge + counts[e]

            start_pos = out[0][lower_edge:upper_edge][data.x_pe[lower_edge:upper_edge, 0]]
            end_pos   = out[0][lower_edge:upper_edge][data.x_pe[lower_edge:upper_edge, 1]]
            start_neg = out[0][lower_edge:upper_edge][data.x_ne[lower_edge:upper_edge, 0]]
            end_neg   = out[0][lower_edge:upper_edge][data.x_ne[lower_edge:upper_edge, 1]]

            start_all = torch.cat((start_pos, start_neg), 0)
            end_all   = torch.cat((end_pos, end_neg), 0)

            if len(losses) == 0:
                losses.append(contrastive_loss(start_all, end_all, 0.3))
            else:
                losses.append(losses[-1] + contrastive_loss(start_all, end_all, 0.3))

        loss = losses[-1]
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test_old(test_loader, model, device, k_value):
    model.eval()
    total_loss = 0.0
    for data in tqdm.tqdm(test_loader):
        data = data.to(device)
        edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        out = model(data.x, edge_index, data.x_batch)
        values, counts = np.unique(data.x_batch.detach().cpu().numpy(), return_counts=True)
        losses = []
        for e in range(len(counts)):
            lower_edge = 0 if e == 0 else np.sum(counts[:e])
            upper_edge = lower_edge + counts[e]

            start_pos = out[0][lower_edge:upper_edge][data.x_pe[lower_edge:upper_edge, 0]]
            end_pos   = out[0][lower_edge:upper_edge][data.x_pe[lower_edge:upper_edge, 1]]
            start_neg = out[0][lower_edge:upper_edge][data.x_ne[lower_edge:upper_edge, 0]]
            end_neg   = out[0][lower_edge:upper_edge][data.x_ne[lower_edge:upper_edge, 1]]

            start_all = torch.cat((start_pos, start_neg), 0)
            end_all   = torch.cat((end_pos, end_neg), 0)

            if len(losses) == 0:
                losses.append(contrastive_loss(start_all, end_all, 0.3))
            else:
                losses.append(losses[-1] + contrastive_loss(start_all, end_all, 0.3))
        loss = losses[-1]
        total_loss += loss.item()
    return total_loss / len(test_loader.dataset)

##############################################
# NEW (Curriculum) Loss & TRAIN/TEST FUNCTIONS
##############################################

def contrastive_loss_random(embeddings, pos_indices, group_ids, temperature=0.3):
    """
    Contrastive loss using a randomly selected negative.
    """
    loss_sum = 0.0
    count = 0
    group_ids = group_ids.long()
    
    for i in range(len(embeddings)):
        anchor = embeddings[i]
        positive = embeddings[pos_indices[i]]
        neg_mask = (group_ids != group_ids[i])
        if neg_mask.sum() == 0:
            continue
        negatives = embeddings[neg_mask]
        # Randomly sample one negative.
        rand_idx = torch.randint(0, negatives.size(0), (1,)).item()
        neg_sample = negatives[rand_idx]
        pos_sim = F.cosine_similarity(anchor.unsqueeze(0), positive.unsqueeze(0))
        neg_sim = F.cosine_similarity(anchor.unsqueeze(0), neg_sample.unsqueeze(0))
        loss = -torch.log(
            torch.exp(pos_sim/temperature) / (torch.exp(pos_sim/temperature) + torch.exp(neg_sim/temperature))
        )
        loss_sum += loss
        count += 1
    return loss_sum / count if count > 0 else torch.tensor(0.0, device=embeddings.device)

def contrastive_loss_hard(embeddings, pos_indices, group_ids, temperature=0.3):
    """
    Contrastive loss using hard negative mining (loop version).
    """
    loss_sum = 0.0
    count = 0
    group_ids = group_ids.long()
    
    for i in range(len(embeddings)):
        anchor = embeddings[i]
        positive = embeddings[pos_indices[i]]
        neg_mask = (group_ids != group_ids[i])
        if neg_mask.sum() == 0:
            continue
        negatives = embeddings[neg_mask]
        cos_sim = F.cosine_similarity(anchor.unsqueeze(0), negatives)
        hard_neg_sim = cos_sim.max()
        pos_sim = F.cosine_similarity(anchor.unsqueeze(0), positive.unsqueeze(0))
        loss = -torch.log(
            torch.exp(pos_sim/temperature) / (torch.exp(pos_sim/temperature) + torch.exp(hard_neg_sim/temperature))
        )
        loss_sum += loss
        count += 1
    return loss_sum / count if count > 0 else torch.tensor(0.0, device=embeddings.device)

def contrastive_loss_curriculum(embeddings, pos_indices, group_ids, temperature=0.3, alpha=1.0):
    """
    Blends the random-negative loss and the hard-negative loss.
    When alpha=0: only random negatives; when alpha=1: only hard negatives.
    """
    loss_random = contrastive_loss_random(embeddings, pos_indices, group_ids, temperature)
    loss_hard = contrastive_loss_hard(embeddings, pos_indices, group_ids, temperature)
    return (1 - alpha) * loss_random + alpha * loss_hard

def train_new(train_loader, model, optimizer, device, k_value, alpha):
    model.train()
    total_loss = 0.0
    for data in tqdm.tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        # Convert data.assoc to tensor if needed.
        if isinstance(data.assoc, list):
            if isinstance(data.assoc[0], list):
                assoc_tensor = torch.cat([torch.tensor(a, dtype=torch.int64, device=data.x.device)
                                          for a in data.assoc])
            else:
                assoc_tensor = torch.tensor(data.assoc, device=data.x.device)
        else:
            assoc_tensor = data.assoc

        edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        embeddings, _ = model(data.x, edge_index, data.x_batch)
        
        # Partition batch by event.
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = 0.0
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            event_group_ids = assoc_tensor[start_idx:end_idx]
            event_pos_indices = data.x_pe[start_idx:end_idx, 1].view(-1)
            loss_event = contrastive_loss_curriculum(event_embeddings, event_pos_indices, event_group_ids,
                                                     temperature=0.3, alpha=alpha)
            loss_event_total += loss_event
            start_idx = end_idx
        
        loss = loss_event_total / len(counts)
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test_new(test_loader, model, device, k_value):
    model.eval()
    total_loss = 0.0
    for data in tqdm.tqdm(test_loader):
        data = data.to(device)
        
        if isinstance(data.assoc, list):
            if isinstance(data.assoc[0], list):
                assoc_tensor = torch.cat([torch.tensor(a, dtype=torch.int64, device=data.x.device)
                                          for a in data.assoc])
            else:
                assoc_tensor = torch.tensor(data.assoc, device=data.x.device)
        else:
            assoc_tensor = data.assoc
        
        edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        embeddings, _ = model(data.x, edge_index, data.x_batch)
        
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = 0.0
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            event_group_ids = assoc_tensor[start_idx:end_idx]
            event_pos_indices = data.x_pe[start_idx:end_idx, 1].view(-1)
            # For testing, we can use the hard-negative loss.
            loss_event = contrastive_loss_hard(event_embeddings, event_pos_indices, event_group_ids, temperature=0.3)
            loss_event_total += loss_event
            start_idx = end_idx
        total_loss += loss_event_total / len(counts)
    return total_loss / len(test_loader.dataset)


# Train and Validate

In [8]:
device = torch.device('cpu')
# Load DataLoader with current batch_size
train_loader = DataLoader(data_train, batch_size=BS, shuffle=True, follow_batch=['x'])
val_loader = DataLoader(data_val, batch_size=BS, shuffle=False, follow_batch=['x'])

# Train and evaluate the model for the specified number of epochs
best_val_loss = float('inf')

# Store train and validation losses for all epochs
train_losses = []
val_losses = []

output_dir = '/vols/cms/mm1221/hgcal/pion5New/Track/resultsEndHard/SEGATxyz/'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
epochs = 20
for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')
    
    # For epochs 1-20 use the old loss method.
    if epoch < 2:
        train_loss = train_old(train_loader, model, optimizer, device, k)
        val_loss = test_old(val_loader, model, device, k)
        print(f"Old Method: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    # For epochs 21-30, gradually shift from random negatives (alpha=0) to hard negatives (alpha=1).
    elif epoch < 4:
        alpha = (epoch - 2) / 2.0   # linearly increase from 0 to 1
        train_loss = train_new(train_loader, model, optimizer, device, k, alpha)
        val_loss = test_new(val_loader, model, device, k)
        print(f"Transition Method (alpha={alpha:.2f}): Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    # For epochs 31-40, use the new method with hard negatives (alpha=1).
    else:
        alpha = 1.0
        train_loss = train_new(train_loader, model, optimizer, device, k, alpha)
        val_loss = test_new(val_loader, model, device, k)
        print(f"New Method (alpha={alpha}): Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    # Adjust the learning rate
    scheduler.step()

    # Save the best model if this epoch's validation loss is lower
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(output_dir, 'best_model.pt'))

    # Save intermediate state dictionaries
    state_dicts = {'model': model.state_dict(), 'opt': optimizer.state_dict(), 'lr': scheduler.state_dict()}
    torch.save(state_dicts, os.path.join(output_dir, f'epoch-{epoch}.pt'))

    print(f'Epoch {epoch+1}/{epochs} - Train Loss: {train_loss}, Validation Loss: {val_loss}')



# Save training and validation loss curves
loss_result_filename = (
    'result.csv'
)

# Dynamically adjust the epoch range to match the length of train_losses and val_losses
results_df = pd.DataFrame({
    'epoch': list(range(1, len(train_losses) + 1)),  # Adjusted to the actual length of losses
    'train_loss': train_losses,
    'val_loss': val_losses
})

# Save to a CSV file in the output directory
results_df.to_csv(os.path.join(output_dir, loss_result_filename), index=False)

print(f'Saved training and validation losses to {os.path.join(output_dir, loss_result_filename)}')



Epoch 1/20


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 194/194 [01:49<00:00,  1.77it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:06<00:00,  3.63it/s]


Old Method: Train Loss: 0.0083, Val Loss: 0.0020
Epoch 1/20 - Train Loss: 0.008330954226158543, Validation Loss: 0.002001957932555419
Epoch 2/20


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 194/194 [01:44<00:00,  1.85it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:06<00:00,  3.66it/s]


Old Method: Train Loss: 0.0019, Val Loss: 0.0009
Epoch 2/20 - Train Loss: 0.0019222859552106676, Validation Loss: 0.0009232569384646476
Epoch 3/20


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 194/194 [09:17<00:00,  2.88s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:12<00:00,  2.04it/s]


TypeError: unsupported format string passed to Tensor.__format__

# Testing

## Loading Testing Data

In [30]:
testpath = "/vols/cms/mm1221/Data/100k/5pi/test/"

data_test = CCV1(testpath, max_events=10000, inp='test')

### Loading tracksters data


  0%|                                                                                                                                                                                 | 0/1 [00:00<?, ?it/s]

/vols/cms/mm1221/Data/100k/5pi/test/raw/test.root


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.42s/it]


In [31]:
#checkpoint= torch.load('/vols/cms/mm1221/hgcal/pion5New/Track/StaticEdge/results/SEGAT/results_lr0.001_bs64_hd128_nl4_do0.3_k16_cd64/best_model.pt',  map_location=torch.device('cpu'))
checkpoint= torch.load('/vols/cms/mm1221/hgcal/pion5New/Track/NegativeMining/resultsRandHard/best_model.pt',  map_location=torch.device('cpu'))

model.load_state_dict(checkpoint)  
model.eval()  

Net(
  (lc_encode): Sequential(
    (0): Linear(in_features=15, out_features=128, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ELU(alpha=1.0)
  )
  (convs): ModuleList(
    (0): CustomStaticEdgeConv(
      (nn_module): Sequential(
        (0): Linear(in_features=256, out_features=128, bias=True)
        (1): ELU(alpha=1.0)
        (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Dropout(p=0.3, inplace=False)
      )
    )
    (1): CustomGATLayer(
      (W): Linear(in_features=128, out_features=128, bias=False)
      (leakyrelu): LeakyReLU(negative_slope=0.4)
      (dropout): Dropout(p=0.6, inplace=False)
      (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): CustomStaticEdgeConv(
      (nn_module): Sequential(
        (0): Linear(in_features=256, out_features=128, bias=True)
        (1): ELU(alpha=1.0)
        (2): B

In [32]:
all_predictions = []  


for i, data in enumerate(data_test):
    edge_index = knn_graph(data.x[:, :3], k=k)  
    predictions = model(data.x, edge_index, 1)
    all_predictions.append(predictions[0].detach().cpu().numpy())  
    
all_predictions = np.array(all_predictions)

  all_predictions = np.array(all_predictions)


In [33]:
import numpy as np
import hdbscan

def HDBSCANClustering(all_predictions, 
                      min_cluster_size=5, 
                      min_samples=None, 
                      metric='euclidean', 
                      alpha=1.0,
                      cluster_selection_method='eom',
                      prediction_data=False,
                      allow_single_cluster=True,
                      core_dist_n_jobs=1,
                      cluster_selection_epsilon=0.0):
    """
    Performs HDBSCAN clustering on a list of prediction arrays with more hyperparameter control.

    Parameters:
    - all_predictions: List of numpy arrays, each containing data points for an event.
    - min_cluster_size: Minimum size of clusters.
    - min_samples: Number of samples in a neighborhood for a point to be considered a core point.
                   If None, it defaults to min_cluster_size.
    - metric: Distance metric to use.
    - alpha: Controls the balance between single linkage and average linkage clustering.
    - cluster_selection_method: 'eom' (Excess of Mass) or 'leaf' for finer clusters.
    - prediction_data: If True, allows later predictions on new data.
    - allow_single_cluster: If True, allows a single large cluster when applicable.
    - core_dist_n_jobs: Number of parallel jobs (-1 uses all cores).
    - cluster_selection_epsilon: Threshold distance for cluster selection (default 0.0).

    Returns:
    - all_cluster_labels: NumPy array of cluster labels for all events.
    """
    all_cluster_labels = []             

    for i, pred in enumerate(all_predictions):
        print(f"Processing event {i+1}/{len(all_predictions)}...")
        
        if len(pred) < 2:
            # Assign all points to cluster 0 (since HDBSCAN uses -1 for noise)
            cluster_labels = np.zeros(len(pred), dtype=int) 
        else:
            # Initialize HDBSCAN with specified parameters
            clusterer = hdbscan.HDBSCAN(
                min_cluster_size=min_cluster_size,
                min_samples=min_samples if min_samples is not None else min_cluster_size,
                metric=metric,
                alpha=alpha,
                cluster_selection_method=cluster_selection_method,
                prediction_data=prediction_data,
                allow_single_cluster=allow_single_cluster,
                core_dist_n_jobs=core_dist_n_jobs,
                cluster_selection_epsilon=cluster_selection_epsilon
            )
            
            # Perform clustering
            cluster_labels = clusterer.fit_predict(pred)  
        
        all_cluster_labels.append(cluster_labels)
    
    # Convert the list of cluster labels to a NumPy array
    all_cluster_labels = np.array(all_cluster_labels)
    return all_cluster_labels

all_cluster_labels = HDBSCANClustering(
    all_predictions, 
    min_cluster_size=2,  # Ensures at least 3 points per cluster
    metric='euclidean',  # Change distance metric
    alpha=1.0,  # Increase single linkage weighting
    cluster_selection_method='leaf',  # Allow finer clusters
    prediction_data=False,  # Enable future prediction capability
    allow_single_cluster=True,  # Allow one large cluster
    core_dist_n_jobs=-1  # Use all available CPU cores
)


Processing event 1/9536...
Processing event 2/9536...
Processing event 3/9536...
Processing event 4/9536...
Processing event 5/9536...
Processing event 6/9536...
Processing event 7/9536...
Processing event 8/9536...
Processing event 9/9536...
Processing event 10/9536...
Processing event 11/9536...
Processing event 12/9536...
Processing event 13/9536...
Processing event 14/9536...
Processing event 15/9536...
Processing event 16/9536...
Processing event 17/9536...
Processing event 18/9536...
Processing event 19/9536...
Processing event 20/9536...
Processing event 21/9536...
Processing event 22/9536...
Processing event 23/9536...
Processing event 24/9536...
Processing event 25/9536...
Processing event 26/9536...
Processing event 27/9536...
Processing event 28/9536...
Processing event 29/9536...
Processing event 30/9536...
Processing event 31/9536...
Processing event 32/9536...
Processing event 33/9536...
Processing event 34/9536...
Processing event 35/9536...
Processing event 36/9536...
P

Processing event 391/9536...
Processing event 392/9536...
Processing event 393/9536...
Processing event 394/9536...
Processing event 395/9536...
Processing event 396/9536...
Processing event 397/9536...
Processing event 398/9536...
Processing event 399/9536...
Processing event 400/9536...
Processing event 401/9536...
Processing event 402/9536...
Processing event 403/9536...
Processing event 404/9536...
Processing event 405/9536...
Processing event 406/9536...
Processing event 407/9536...
Processing event 408/9536...
Processing event 409/9536...
Processing event 410/9536...
Processing event 411/9536...
Processing event 412/9536...
Processing event 413/9536...
Processing event 414/9536...
Processing event 415/9536...
Processing event 416/9536...
Processing event 417/9536...
Processing event 418/9536...
Processing event 419/9536...
Processing event 420/9536...
Processing event 421/9536...
Processing event 422/9536...
Processing event 423/9536...
Processing event 424/9536...
Processing eve

Processing event 763/9536...
Processing event 764/9536...
Processing event 765/9536...
Processing event 766/9536...
Processing event 767/9536...
Processing event 768/9536...
Processing event 769/9536...
Processing event 770/9536...
Processing event 771/9536...
Processing event 772/9536...
Processing event 773/9536...
Processing event 774/9536...
Processing event 775/9536...
Processing event 776/9536...
Processing event 777/9536...
Processing event 778/9536...
Processing event 779/9536...
Processing event 780/9536...
Processing event 781/9536...
Processing event 782/9536...
Processing event 783/9536...
Processing event 784/9536...
Processing event 785/9536...
Processing event 786/9536...
Processing event 787/9536...
Processing event 788/9536...
Processing event 789/9536...
Processing event 790/9536...
Processing event 791/9536...
Processing event 792/9536...
Processing event 793/9536...
Processing event 794/9536...
Processing event 795/9536...
Processing event 796/9536...
Processing eve

Processing event 1135/9536...
Processing event 1136/9536...
Processing event 1137/9536...
Processing event 1138/9536...
Processing event 1139/9536...
Processing event 1140/9536...
Processing event 1141/9536...
Processing event 1142/9536...
Processing event 1143/9536...
Processing event 1144/9536...
Processing event 1145/9536...
Processing event 1146/9536...
Processing event 1147/9536...
Processing event 1148/9536...
Processing event 1149/9536...
Processing event 1150/9536...
Processing event 1151/9536...
Processing event 1152/9536...
Processing event 1153/9536...
Processing event 1154/9536...
Processing event 1155/9536...
Processing event 1156/9536...
Processing event 1157/9536...
Processing event 1158/9536...
Processing event 1159/9536...
Processing event 1160/9536...
Processing event 1161/9536...
Processing event 1162/9536...
Processing event 1163/9536...
Processing event 1164/9536...
Processing event 1165/9536...
Processing event 1166/9536...
Processing event 1167/9536...
Processing

Processing event 1515/9536...
Processing event 1516/9536...
Processing event 1517/9536...
Processing event 1518/9536...
Processing event 1519/9536...
Processing event 1520/9536...
Processing event 1521/9536...
Processing event 1522/9536...
Processing event 1523/9536...
Processing event 1524/9536...
Processing event 1525/9536...
Processing event 1526/9536...
Processing event 1527/9536...
Processing event 1528/9536...
Processing event 1529/9536...
Processing event 1530/9536...
Processing event 1531/9536...
Processing event 1532/9536...
Processing event 1533/9536...
Processing event 1534/9536...
Processing event 1535/9536...
Processing event 1536/9536...
Processing event 1537/9536...
Processing event 1538/9536...
Processing event 1539/9536...
Processing event 1540/9536...
Processing event 1541/9536...
Processing event 1542/9536...
Processing event 1543/9536...
Processing event 1544/9536...
Processing event 1545/9536...
Processing event 1546/9536...
Processing event 1547/9536...
Processing

Processing event 1888/9536...
Processing event 1889/9536...
Processing event 1890/9536...
Processing event 1891/9536...
Processing event 1892/9536...
Processing event 1893/9536...
Processing event 1894/9536...
Processing event 1895/9536...
Processing event 1896/9536...
Processing event 1897/9536...
Processing event 1898/9536...
Processing event 1899/9536...
Processing event 1900/9536...
Processing event 1901/9536...
Processing event 1902/9536...
Processing event 1903/9536...
Processing event 1904/9536...
Processing event 1905/9536...
Processing event 1906/9536...
Processing event 1907/9536...
Processing event 1908/9536...
Processing event 1909/9536...
Processing event 1910/9536...
Processing event 1911/9536...
Processing event 1912/9536...
Processing event 1913/9536...
Processing event 1914/9536...
Processing event 1915/9536...
Processing event 1916/9536...
Processing event 1917/9536...
Processing event 1918/9536...
Processing event 1919/9536...
Processing event 1920/9536...
Processing

Processing event 2262/9536...
Processing event 2263/9536...
Processing event 2264/9536...
Processing event 2265/9536...
Processing event 2266/9536...
Processing event 2267/9536...
Processing event 2268/9536...
Processing event 2269/9536...
Processing event 2270/9536...
Processing event 2271/9536...
Processing event 2272/9536...
Processing event 2273/9536...
Processing event 2274/9536...
Processing event 2275/9536...
Processing event 2276/9536...
Processing event 2277/9536...
Processing event 2278/9536...
Processing event 2279/9536...
Processing event 2280/9536...
Processing event 2281/9536...
Processing event 2282/9536...
Processing event 2283/9536...
Processing event 2284/9536...
Processing event 2285/9536...
Processing event 2286/9536...
Processing event 2287/9536...
Processing event 2288/9536...
Processing event 2289/9536...
Processing event 2290/9536...
Processing event 2291/9536...
Processing event 2292/9536...
Processing event 2293/9536...
Processing event 2294/9536...
Processing

Processing event 2639/9536...
Processing event 2640/9536...
Processing event 2641/9536...
Processing event 2642/9536...
Processing event 2643/9536...
Processing event 2644/9536...
Processing event 2645/9536...
Processing event 2646/9536...
Processing event 2647/9536...
Processing event 2648/9536...
Processing event 2649/9536...
Processing event 2650/9536...
Processing event 2651/9536...
Processing event 2652/9536...
Processing event 2653/9536...
Processing event 2654/9536...
Processing event 2655/9536...
Processing event 2656/9536...
Processing event 2657/9536...
Processing event 2658/9536...
Processing event 2659/9536...
Processing event 2660/9536...
Processing event 2661/9536...
Processing event 2662/9536...
Processing event 2663/9536...
Processing event 2664/9536...
Processing event 2665/9536...
Processing event 2666/9536...
Processing event 2667/9536...
Processing event 2668/9536...
Processing event 2669/9536...
Processing event 2670/9536...
Processing event 2671/9536...
Processing

Processing event 3017/9536...
Processing event 3018/9536...
Processing event 3019/9536...
Processing event 3020/9536...
Processing event 3021/9536...
Processing event 3022/9536...
Processing event 3023/9536...
Processing event 3024/9536...
Processing event 3025/9536...
Processing event 3026/9536...
Processing event 3027/9536...
Processing event 3028/9536...
Processing event 3029/9536...
Processing event 3030/9536...
Processing event 3031/9536...
Processing event 3032/9536...
Processing event 3033/9536...
Processing event 3034/9536...
Processing event 3035/9536...
Processing event 3036/9536...
Processing event 3037/9536...
Processing event 3038/9536...
Processing event 3039/9536...
Processing event 3040/9536...
Processing event 3041/9536...
Processing event 3042/9536...
Processing event 3043/9536...
Processing event 3044/9536...
Processing event 3045/9536...
Processing event 3046/9536...
Processing event 3047/9536...
Processing event 3048/9536...
Processing event 3049/9536...
Processing

Processing event 3396/9536...
Processing event 3397/9536...
Processing event 3398/9536...
Processing event 3399/9536...
Processing event 3400/9536...
Processing event 3401/9536...
Processing event 3402/9536...
Processing event 3403/9536...
Processing event 3404/9536...
Processing event 3405/9536...
Processing event 3406/9536...
Processing event 3407/9536...
Processing event 3408/9536...
Processing event 3409/9536...
Processing event 3410/9536...
Processing event 3411/9536...
Processing event 3412/9536...
Processing event 3413/9536...
Processing event 3414/9536...
Processing event 3415/9536...
Processing event 3416/9536...
Processing event 3417/9536...
Processing event 3418/9536...
Processing event 3419/9536...
Processing event 3420/9536...
Processing event 3421/9536...
Processing event 3422/9536...
Processing event 3423/9536...
Processing event 3424/9536...
Processing event 3425/9536...
Processing event 3426/9536...
Processing event 3427/9536...
Processing event 3428/9536...
Processing

Processing event 3771/9536...
Processing event 3772/9536...
Processing event 3773/9536...
Processing event 3774/9536...
Processing event 3775/9536...
Processing event 3776/9536...
Processing event 3777/9536...
Processing event 3778/9536...
Processing event 3779/9536...
Processing event 3780/9536...
Processing event 3781/9536...
Processing event 3782/9536...
Processing event 3783/9536...
Processing event 3784/9536...
Processing event 3785/9536...
Processing event 3786/9536...
Processing event 3787/9536...
Processing event 3788/9536...
Processing event 3789/9536...
Processing event 3790/9536...
Processing event 3791/9536...
Processing event 3792/9536...
Processing event 3793/9536...
Processing event 3794/9536...
Processing event 3795/9536...
Processing event 3796/9536...
Processing event 3797/9536...
Processing event 3798/9536...
Processing event 3799/9536...
Processing event 3800/9536...
Processing event 3801/9536...
Processing event 3802/9536...
Processing event 3803/9536...
Processing

Processing event 4148/9536...
Processing event 4149/9536...
Processing event 4150/9536...
Processing event 4151/9536...
Processing event 4152/9536...
Processing event 4153/9536...
Processing event 4154/9536...
Processing event 4155/9536...
Processing event 4156/9536...
Processing event 4157/9536...
Processing event 4158/9536...
Processing event 4159/9536...
Processing event 4160/9536...
Processing event 4161/9536...
Processing event 4162/9536...
Processing event 4163/9536...
Processing event 4164/9536...
Processing event 4165/9536...
Processing event 4166/9536...
Processing event 4167/9536...
Processing event 4168/9536...
Processing event 4169/9536...
Processing event 4170/9536...
Processing event 4171/9536...
Processing event 4172/9536...
Processing event 4173/9536...
Processing event 4174/9536...
Processing event 4175/9536...
Processing event 4176/9536...
Processing event 4177/9536...
Processing event 4178/9536...
Processing event 4179/9536...
Processing event 4180/9536...
Processing

Processing event 4527/9536...
Processing event 4528/9536...
Processing event 4529/9536...
Processing event 4530/9536...
Processing event 4531/9536...
Processing event 4532/9536...
Processing event 4533/9536...
Processing event 4534/9536...
Processing event 4535/9536...
Processing event 4536/9536...
Processing event 4537/9536...
Processing event 4538/9536...
Processing event 4539/9536...
Processing event 4540/9536...
Processing event 4541/9536...
Processing event 4542/9536...
Processing event 4543/9536...
Processing event 4544/9536...
Processing event 4545/9536...
Processing event 4546/9536...
Processing event 4547/9536...
Processing event 4548/9536...
Processing event 4549/9536...
Processing event 4550/9536...
Processing event 4551/9536...
Processing event 4552/9536...
Processing event 4553/9536...
Processing event 4554/9536...
Processing event 4555/9536...
Processing event 4556/9536...
Processing event 4557/9536...
Processing event 4558/9536...
Processing event 4559/9536...
Processing

Processing event 4902/9536...
Processing event 4903/9536...
Processing event 4904/9536...
Processing event 4905/9536...
Processing event 4906/9536...
Processing event 4907/9536...
Processing event 4908/9536...
Processing event 4909/9536...
Processing event 4910/9536...
Processing event 4911/9536...
Processing event 4912/9536...
Processing event 4913/9536...
Processing event 4914/9536...
Processing event 4915/9536...
Processing event 4916/9536...
Processing event 4917/9536...
Processing event 4918/9536...
Processing event 4919/9536...
Processing event 4920/9536...
Processing event 4921/9536...
Processing event 4922/9536...
Processing event 4923/9536...
Processing event 4924/9536...
Processing event 4925/9536...
Processing event 4926/9536...
Processing event 4927/9536...
Processing event 4928/9536...
Processing event 4929/9536...
Processing event 4930/9536...
Processing event 4931/9536...
Processing event 4932/9536...
Processing event 4933/9536...
Processing event 4934/9536...
Processing

Processing event 5274/9536...
Processing event 5275/9536...
Processing event 5276/9536...
Processing event 5277/9536...
Processing event 5278/9536...
Processing event 5279/9536...
Processing event 5280/9536...
Processing event 5281/9536...
Processing event 5282/9536...
Processing event 5283/9536...
Processing event 5284/9536...
Processing event 5285/9536...
Processing event 5286/9536...
Processing event 5287/9536...
Processing event 5288/9536...
Processing event 5289/9536...
Processing event 5290/9536...
Processing event 5291/9536...
Processing event 5292/9536...
Processing event 5293/9536...
Processing event 5294/9536...
Processing event 5295/9536...
Processing event 5296/9536...
Processing event 5297/9536...
Processing event 5298/9536...
Processing event 5299/9536...
Processing event 5300/9536...
Processing event 5301/9536...
Processing event 5302/9536...
Processing event 5303/9536...
Processing event 5304/9536...
Processing event 5305/9536...
Processing event 5306/9536...
Processing

Processing event 5650/9536...
Processing event 5651/9536...
Processing event 5652/9536...
Processing event 5653/9536...
Processing event 5654/9536...
Processing event 5655/9536...
Processing event 5656/9536...
Processing event 5657/9536...
Processing event 5658/9536...
Processing event 5659/9536...
Processing event 5660/9536...
Processing event 5661/9536...
Processing event 5662/9536...
Processing event 5663/9536...
Processing event 5664/9536...
Processing event 5665/9536...
Processing event 5666/9536...
Processing event 5667/9536...
Processing event 5668/9536...
Processing event 5669/9536...
Processing event 5670/9536...
Processing event 5671/9536...
Processing event 5672/9536...
Processing event 5673/9536...
Processing event 5674/9536...
Processing event 5675/9536...
Processing event 5676/9536...
Processing event 5677/9536...
Processing event 5678/9536...
Processing event 5679/9536...
Processing event 5680/9536...
Processing event 5681/9536...
Processing event 5682/9536...
Processing

Processing event 6024/9536...
Processing event 6025/9536...
Processing event 6026/9536...
Processing event 6027/9536...
Processing event 6028/9536...
Processing event 6029/9536...
Processing event 6030/9536...
Processing event 6031/9536...
Processing event 6032/9536...
Processing event 6033/9536...
Processing event 6034/9536...
Processing event 6035/9536...
Processing event 6036/9536...
Processing event 6037/9536...
Processing event 6038/9536...
Processing event 6039/9536...
Processing event 6040/9536...
Processing event 6041/9536...
Processing event 6042/9536...
Processing event 6043/9536...
Processing event 6044/9536...
Processing event 6045/9536...
Processing event 6046/9536...
Processing event 6047/9536...
Processing event 6048/9536...
Processing event 6049/9536...
Processing event 6050/9536...
Processing event 6051/9536...
Processing event 6052/9536...
Processing event 6053/9536...
Processing event 6054/9536...
Processing event 6055/9536...
Processing event 6056/9536...
Processing

Processing event 6403/9536...
Processing event 6404/9536...
Processing event 6405/9536...
Processing event 6406/9536...
Processing event 6407/9536...
Processing event 6408/9536...
Processing event 6409/9536...
Processing event 6410/9536...
Processing event 6411/9536...
Processing event 6412/9536...
Processing event 6413/9536...
Processing event 6414/9536...
Processing event 6415/9536...
Processing event 6416/9536...
Processing event 6417/9536...
Processing event 6418/9536...
Processing event 6419/9536...
Processing event 6420/9536...
Processing event 6421/9536...
Processing event 6422/9536...
Processing event 6423/9536...
Processing event 6424/9536...
Processing event 6425/9536...
Processing event 6426/9536...
Processing event 6427/9536...
Processing event 6428/9536...
Processing event 6429/9536...
Processing event 6430/9536...
Processing event 6431/9536...
Processing event 6432/9536...
Processing event 6433/9536...
Processing event 6434/9536...
Processing event 6435/9536...
Processing

Processing event 6779/9536...
Processing event 6780/9536...
Processing event 6781/9536...
Processing event 6782/9536...
Processing event 6783/9536...
Processing event 6784/9536...
Processing event 6785/9536...
Processing event 6786/9536...
Processing event 6787/9536...
Processing event 6788/9536...
Processing event 6789/9536...
Processing event 6790/9536...
Processing event 6791/9536...
Processing event 6792/9536...
Processing event 6793/9536...
Processing event 6794/9536...
Processing event 6795/9536...
Processing event 6796/9536...
Processing event 6797/9536...
Processing event 6798/9536...
Processing event 6799/9536...
Processing event 6800/9536...
Processing event 6801/9536...
Processing event 6802/9536...
Processing event 6803/9536...
Processing event 6804/9536...
Processing event 6805/9536...
Processing event 6806/9536...
Processing event 6807/9536...
Processing event 6808/9536...
Processing event 6809/9536...
Processing event 6810/9536...
Processing event 6811/9536...
Processing

Processing event 7148/9536...
Processing event 7149/9536...
Processing event 7150/9536...
Processing event 7151/9536...
Processing event 7152/9536...
Processing event 7153/9536...
Processing event 7154/9536...
Processing event 7155/9536...
Processing event 7156/9536...
Processing event 7157/9536...
Processing event 7158/9536...
Processing event 7159/9536...
Processing event 7160/9536...
Processing event 7161/9536...
Processing event 7162/9536...
Processing event 7163/9536...
Processing event 7164/9536...
Processing event 7165/9536...
Processing event 7166/9536...
Processing event 7167/9536...
Processing event 7168/9536...
Processing event 7169/9536...
Processing event 7170/9536...
Processing event 7171/9536...
Processing event 7172/9536...
Processing event 7173/9536...
Processing event 7174/9536...
Processing event 7175/9536...
Processing event 7176/9536...
Processing event 7177/9536...
Processing event 7178/9536...
Processing event 7179/9536...
Processing event 7180/9536...
Processing

Processing event 7521/9536...
Processing event 7522/9536...
Processing event 7523/9536...
Processing event 7524/9536...
Processing event 7525/9536...
Processing event 7526/9536...
Processing event 7527/9536...
Processing event 7528/9536...
Processing event 7529/9536...
Processing event 7530/9536...
Processing event 7531/9536...
Processing event 7532/9536...
Processing event 7533/9536...
Processing event 7534/9536...
Processing event 7535/9536...
Processing event 7536/9536...
Processing event 7537/9536...
Processing event 7538/9536...
Processing event 7539/9536...
Processing event 7540/9536...
Processing event 7541/9536...
Processing event 7542/9536...
Processing event 7543/9536...
Processing event 7544/9536...
Processing event 7545/9536...
Processing event 7546/9536...
Processing event 7547/9536...
Processing event 7548/9536...
Processing event 7549/9536...
Processing event 7550/9536...
Processing event 7551/9536...
Processing event 7552/9536...
Processing event 7553/9536...
Processing

Processing event 7897/9536...
Processing event 7898/9536...
Processing event 7899/9536...
Processing event 7900/9536...
Processing event 7901/9536...
Processing event 7902/9536...
Processing event 7903/9536...
Processing event 7904/9536...
Processing event 7905/9536...
Processing event 7906/9536...
Processing event 7907/9536...
Processing event 7908/9536...
Processing event 7909/9536...
Processing event 7910/9536...
Processing event 7911/9536...
Processing event 7912/9536...
Processing event 7913/9536...
Processing event 7914/9536...
Processing event 7915/9536...
Processing event 7916/9536...
Processing event 7917/9536...
Processing event 7918/9536...
Processing event 7919/9536...
Processing event 7920/9536...
Processing event 7921/9536...
Processing event 7922/9536...
Processing event 7923/9536...
Processing event 7924/9536...
Processing event 7925/9536...
Processing event 7926/9536...
Processing event 7927/9536...
Processing event 7928/9536...
Processing event 7929/9536...
Processing

Processing event 8270/9536...
Processing event 8271/9536...
Processing event 8272/9536...
Processing event 8273/9536...
Processing event 8274/9536...
Processing event 8275/9536...
Processing event 8276/9536...
Processing event 8277/9536...
Processing event 8278/9536...
Processing event 8279/9536...
Processing event 8280/9536...
Processing event 8281/9536...
Processing event 8282/9536...
Processing event 8283/9536...
Processing event 8284/9536...
Processing event 8285/9536...
Processing event 8286/9536...
Processing event 8287/9536...
Processing event 8288/9536...
Processing event 8289/9536...
Processing event 8290/9536...
Processing event 8291/9536...
Processing event 8292/9536...
Processing event 8293/9536...
Processing event 8294/9536...
Processing event 8295/9536...
Processing event 8296/9536...
Processing event 8297/9536...
Processing event 8298/9536...
Processing event 8299/9536...
Processing event 8300/9536...
Processing event 8301/9536...
Processing event 8302/9536...
Processing

Processing event 8644/9536...
Processing event 8645/9536...
Processing event 8646/9536...
Processing event 8647/9536...
Processing event 8648/9536...
Processing event 8649/9536...
Processing event 8650/9536...
Processing event 8651/9536...
Processing event 8652/9536...
Processing event 8653/9536...
Processing event 8654/9536...
Processing event 8655/9536...
Processing event 8656/9536...
Processing event 8657/9536...
Processing event 8658/9536...
Processing event 8659/9536...
Processing event 8660/9536...
Processing event 8661/9536...
Processing event 8662/9536...
Processing event 8663/9536...
Processing event 8664/9536...
Processing event 8665/9536...
Processing event 8666/9536...
Processing event 8667/9536...
Processing event 8668/9536...
Processing event 8669/9536...
Processing event 8670/9536...
Processing event 8671/9536...
Processing event 8672/9536...
Processing event 8673/9536...
Processing event 8674/9536...
Processing event 8675/9536...
Processing event 8676/9536...
Processing

Processing event 9020/9536...
Processing event 9021/9536...
Processing event 9022/9536...
Processing event 9023/9536...
Processing event 9024/9536...
Processing event 9025/9536...
Processing event 9026/9536...
Processing event 9027/9536...
Processing event 9028/9536...
Processing event 9029/9536...
Processing event 9030/9536...
Processing event 9031/9536...
Processing event 9032/9536...
Processing event 9033/9536...
Processing event 9034/9536...
Processing event 9035/9536...
Processing event 9036/9536...
Processing event 9037/9536...
Processing event 9038/9536...
Processing event 9039/9536...
Processing event 9040/9536...
Processing event 9041/9536...
Processing event 9042/9536...
Processing event 9043/9536...
Processing event 9044/9536...
Processing event 9045/9536...
Processing event 9046/9536...
Processing event 9047/9536...
Processing event 9048/9536...
Processing event 9049/9536...
Processing event 9050/9536...
Processing event 9051/9536...
Processing event 9052/9536...
Processing

Processing event 9396/9536...
Processing event 9397/9536...
Processing event 9398/9536...
Processing event 9399/9536...
Processing event 9400/9536...
Processing event 9401/9536...
Processing event 9402/9536...
Processing event 9403/9536...
Processing event 9404/9536...
Processing event 9405/9536...
Processing event 9406/9536...
Processing event 9407/9536...
Processing event 9408/9536...
Processing event 9409/9536...
Processing event 9410/9536...
Processing event 9411/9536...
Processing event 9412/9536...
Processing event 9413/9536...
Processing event 9414/9536...
Processing event 9415/9536...
Processing event 9416/9536...
Processing event 9417/9536...
Processing event 9418/9536...
Processing event 9419/9536...
Processing event 9420/9536...
Processing event 9421/9536...
Processing event 9422/9536...
Processing event 9423/9536...
Processing event 9424/9536...
Processing event 9425/9536...
Processing event 9426/9536...
Processing event 9427/9536...
Processing event 9428/9536...
Processing

  all_cluster_labels = np.array(all_cluster_labels)


In [34]:
recon_ind = []

for event_idx, labels in enumerate(all_cluster_labels):
    event_clusters = {} 
    
    for cluster_idx, cluster_label in enumerate(labels):
        if cluster_label not in event_clusters:
            event_clusters[cluster_label] = []
        event_clusters[cluster_label].extend(Track_ind[event_idx][cluster_idx])
    
    recon_ind.append([event_clusters[label] for label in sorted(event_clusters.keys())])

ValueError: in ListOffsetArray64 attempting to get 1, index out of range

(https://github.com/scikit-hep/awkward-1.0/blob/1.10.3/src/libawkward/array/ListOffsetArray.cpp#L682)

In [None]:
def calculate_sim_to_reco_score(CaloParticle, energies_indices, ReconstructedTrackster, Multi):
    """
    Calculate the sim-to-reco score for a given CaloParticle and ReconstructedTrackster.
    
    Parameters:
    - CaloParticle: array of Layer Clusters in the CaloParticle.
    - Multiplicity: array of Multiplicity for layer clusters in CP
    - energies_indices: array of energies associated with all LC (indexed by LC).
    - ReconstructedTrackster: array of LC in the reconstructed Trackster.
    
    Returns:
    - sim_to_reco_score: the calculated sim-to-reco score.
    """
    numerator = 0.0
    denominator = 0.0

    energy_caloparticle_lc = energies_indices[CaloParticle] / Multi
    total_energy_caloparticle = sum(energy_caloparticle_lc)
    if total_energy_caloparticle == 0:
        return 1.0  # No energy in the CaloParticle implies perfect mismatch

    # Calculate total energy of the ReconstructedTrackster
    total_energy_trackster = sum(energies_indices[det_id] for det_id in ReconstructedTrackster)
    i = 0
    # Iterate over all DetIds in the CaloParticle
    for det_id in CaloParticle:
        energy_k = energies_indices[det_id]  # Energy for the current DetId in CaloParticle
        # Fraction of energy in the Trackster (fr_k^TST)
        fr_tst_k = 1 if det_id in ReconstructedTrackster else 0.0
        # Fraction of energy in the CaloParticle (fr_k^SC)
        fr_sc_k = 1 / Multi[i]

        # Update numerator using the min function
        numerator += min(
            (fr_tst_k - fr_sc_k) ** 2,  # First term in the min function
            fr_sc_k ** 2                # Second term in the min function
        ) * (energy_k ** 2)

        # Update denominator
        denominator += (fr_sc_k ** 2) * (energy_k ** 2)
        i+=1

    # Calculate score
    sim_to_reco_score = numerator / denominator if denominator != 0 else 1.0
    return sim_to_reco_score

def calculate_reco_to_sim_score(ReconstructedTrackster, energies_indices, CaloParticle, Multi):
    """
    Calculate the reco-to-sim score for a given ReconstructedTrackster and CaloParticle.

    Parameters:
    - ReconstructedTrackster: array of DetIds in the ReconstructedTrackster.
    - energies_indices: array of energies associated with all DetIds (indexed by DetId).
    - CaloParticle: array of DetIds in the CaloParticle.

    Returns:
    - reco_to_sim_score: the calculated reco-to-sim score.
    """
    numerator = 0.0
    denominator = 0.0

    # Calculate total energy of the ReconstructedTrackster
    total_energy_trackster = sum(energies_indices[det_id] for det_id in ReconstructedTrackster)
    if total_energy_trackster == 0:
        return 1.0  # No energy in the Trackster implies perfect mismatch

    energy_caloparticle_lc = energies_indices[CaloParticle] / Multi
    total_energy_caloparticle = sum(energy_caloparticle_lc)
    # Iterate over all DetIds in the ReconstructedTrackster
    for det_id in ReconstructedTrackster:
        energy_k = energies_indices[det_id]  # Energy for the current DetId in the Trackster
        
        # Fraction of energy in the Trackster (fr_k^TST)
        fr_tst_k = 1

        #fr_sc_k = 1 if det_id in CaloParticle else 0.0
        if det_id in CaloParticle:
            index = np.where(CaloParticle == det_id)[0][0]  # Find the index
            Multiplicity = Multi[index]
            fr_sc_k = 1
        else:
            fr_sc_k = 0
            
        # Update numerator using the min function
        numerator += min(
            (fr_tst_k - fr_sc_k) ** 2,  # First term in the min function
            fr_tst_k ** 2               # Second term in the min function
        ) * (energy_k ** 2)

        # Update denominator
        denominator += (fr_tst_k ** 2) * (energy_k ** 2)

    # Calculate score
    reco_to_sim_score = numerator / denominator if denominator != 0 else 1.0
    return reco_to_sim_score



In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm  # For progress bar
def calculate_all_event_scores(GT_ind, energies, recon_ind, LC_x, LC_y, LC_z, LC_eta, multi, num_events = 100):
    """
    Calculate sim-to-reco and reco-to-sim scores for all CaloParticle and ReconstructedTrackster combinations across all events.

    Parameters:
    - GT_ind: List of CaloParticle indices for all events.
    - energies: List of energy arrays for all events.
    - recon_ind: List of ReconstructedTrackster indices for all events.
    - LC_x, LC_y, LC_z, LC_eta: Lists of x, y, z positions and eta values for all DetIds across events.

    Returns:
    - DataFrame containing scores and additional features for each CaloParticle-Trackster combination across all events.
    """
    # Initialize an empty list to store results
    all_results = []

    # Loop over all events with a progress bar
    for event_index in range(num_events):
        caloparticles = GT_ind[event_index]  # Indices for all CaloParticles in the event
        tracksters = recon_ind[event_index]  # Indices for all ReconstructedTracksters in the event
        event_energies = energies[event_index]  # Energies for this event
        event_multi = multi[event_index]

        # Extract layer cluster positions and eta for this event
        event_x = np.array(LC_x[event_index])
        event_y = np.array(LC_y[event_index])
        event_z = np.array(LC_z[event_index])
        event_eta = np.array(LC_eta[event_index])

        # Compute barycenter for each CaloParticle
        cp_barycenters = []
        cp_avg_etas = []
        for caloparticle in caloparticles:
            # Compute barycenter (x, y, z)
            
            barycenter_x = np.mean([event_x[det_id] for det_id in caloparticle])
            barycenter_y = np.mean([event_y[det_id] for det_id in caloparticle])
            barycenter_z = np.mean([event_z[det_id] for det_id in caloparticle])
            cp_barycenters.append(np.array([barycenter_x, barycenter_y, barycenter_z]))
            
            # Compute average eta
            avg_eta = np.mean([event_eta[det_id] for det_id in caloparticle])
            cp_avg_etas.append(avg_eta)

        # Compute separation between two CaloParticles if at least two exist
        if len(cp_barycenters) >= 2:
            cp_separation = np.linalg.norm(cp_barycenters[0] - cp_barycenters[1])
        else:
            cp_separation = 0.0
            
        trackster_det_id_sets = [set(trackster) for trackster in tracksters]

        # Loop over all CaloParticles
        for calo_idx, caloparticle in enumerate(caloparticles):
            Calo_multi = event_multi[calo_idx]
            calo_det_ids = set(calo_id for calo_id in caloparticle)
            # Loop over all Tracksters
            for trackster_idx, trackster in enumerate(tracksters):
                # Calculate sim-to-reco score
                trackster_det_ids = trackster_det_id_sets[trackster_idx]
                shared_det_ids = calo_det_ids.intersection(trackster_det_ids)
                
                # Calculate shared_energy by summing energies of shared det_ids
                shared_energy = np.sum(event_energies[list(shared_det_ids)]) if shared_det_ids else 0.0
                
                
                sim_to_reco_score = calculate_sim_to_reco_score(caloparticle, event_energies, trackster, Calo_multi)
                # Calculate reco-to-sim score
                reco_to_sim_score = calculate_reco_to_sim_score(trackster, event_energies, caloparticle, Calo_multi)

                # Calculate total energy for CaloParticle and Trackster
                cp_energy_lc2 = event_energies[caloparticle] / Calo_multi
                cp_energy = np.sum(cp_energy_lc2)
                
                trackster_energy = np.sum([event_energies[det_id] for det_id in trackster])

                # Calculate energy difference ratio
                energy_diff_ratio = (trackster_energy / cp_energy if cp_energy != 0 else None)

                # Append results
                all_results.append({
                    "event_index": event_index,
                    "cp_id": calo_idx,
                    "trackster_id": trackster_idx,
                    "sim_to_reco_score": sim_to_reco_score,
                    "reco_to_sim_score": reco_to_sim_score,
                    "cp_energy": cp_energy,
                    "trackster_energy": trackster_energy,
                    "cp_avg_eta": cp_avg_etas[calo_idx],
                    "cp_separation": cp_separation,
                    "energy_ratio": energy_diff_ratio,
                    "shared_energy": shared_energy  # New column
                })

    # Convert results to a DataFrame
    df = pd.DataFrame(all_results)
    return df



In [None]:
df_CL = calculate_all_event_scores(GT_ind, energies, recon_ind, LC_x, LC_y, LC_z, LC_eta, GT_mult, num_events = 100)
df_TICL = calculate_all_event_scores(GT_ind, energies, MT_ind, LC_x, LC_y, LC_z, LC_eta, GT_mult, num_events = 100)

In [None]:
#5: Print metrics

def calculate_metrics(df, model_name):
    # ----- Efficiency Calculation -----
    # Step 1: Filter out rows where 'cp_id' is NaN
    cp_valid = df.dropna(subset=['cp_id']).copy()

    # Step 2: Group by 'event_index' and 'cp_id' to process each CaloParticle individually
    cp_grouped = cp_valid.groupby(['event_index', 'cp_id'])

    # Step 3: For each CaloParticle, check if any 'shared_energy' >= 50% of 'cp_energy'
    def is_cp_associated(group):
        cp_energy = group['cp_energy'].iloc[0]  # Assuming 'cp_energy' is consistent within the group
        threshold = 0.5 * cp_energy
        return (group['shared_energy'] >= threshold).any()

    # Apply the association function to each group
    cp_associated = cp_grouped.apply(is_cp_associated)

    # Step 4: Calculate the number of associated CaloParticles and total CaloParticles
    num_associated_cp = cp_associated.sum()
    total_cp = cp_associated.count()
    efficiency = num_associated_cp / total_cp if total_cp > 0 else 0

    # ----- Purity Calculation -----
    tst_valid = df.dropna(subset=['trackster_id']).copy()
    tst_grouped = tst_valid.groupby(['event_index', 'trackster_id'])
    tst_associated = tst_grouped['reco_to_sim_score'].min() < 0.2
    num_associated_tst = tst_associated.sum()
    total_tst = tst_associated.count()
    purity = num_associated_tst / total_tst if total_tst > 0 else 0

    # ----- Average Energy Ratio Calculation -----
    low_score_mask = df['sim_to_reco_score'] < 0.2
    low_score_events = df[low_score_mask]
    if not low_score_events.empty:
        avg_energy_ratio = (low_score_events['trackster_energy'] / low_score_events['cp_energy']).mean()
    else:
        avg_energy_ratio = 0

    # Print results for the model
    print(f"\nModel: {model_name}")
    print(f"Efficiency: {efficiency:.4f} ({num_associated_cp} associated CPs out of {total_cp} total CPs)")
    print(f"Purity: {purity:.4f} ({num_associated_tst} associated Tracksters out of {total_tst} total Tracksters)")
    print(f"Num tracksters ratio: {total_tst / total_cp if total_cp > 0 else 0:.4f}")
    print(f"Average Energy Ratio: {avg_energy_ratio:.4f}")

    return {
        'efficiency': efficiency,
        'purity': purity,
        'avg_energy_ratio': avg_energy_ratio,
    }

# Example usage
our_model_metrics = calculate_metrics(df_CL, "Our Model")
cern_model_metrics = calculate_metrics(df_TICL, "CERN Model")