In [1]:
import os
import glob
import random
import subprocess

import numpy as np
import pandas as pd
import h5py
import uproot
import awkward as ak

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import knn_graph

import tqdm
from tqdm import tqdm

import os
import os.path as osp  # This defines 'osp'
import glob



def find_highest_branch(path, base_name):
    with uproot.open(path) as f:
        # Find keys that exactly match the base_name (not containing other variations)
        branches = [k for k in f.keys() if k.startswith(base_name + ';')]
        
        # Sort and select the highest-numbered branch
        sorted_branches = sorted(branches, key=lambda x: int(x.split(';')[-1]))
        return sorted_branches[-1] if sorted_branches else None

class CCV1(Dataset):
    r'''
    Loads trackster-level features and associations for positive/negative edge creation.
    '''

    url = '/dummy/'

    def __init__(self, root, transform=None, max_events=1e8, inp='train'):
        super(CCV1, self).__init__(root, transform)
        self.inp = inp
        self.max_events = max_events
        self.fill_data(max_events)

    def fill_data(self, max_events):
        counter = 0
        print("### Loading tracksters data")


        for path in tqdm(self.raw_paths):
            print(path)
            
            tracksters_path = find_highest_branch(path, 'tracksters')
            associations_path = find_highest_branch(path, 'associations')
            simtrack = find_highest_branch(path, 'simtrackstersCP')
            # Load tracksters features in chunks
            for array in uproot.iterate(
                f"{path}:{tracksters_path}",
                [
                    "time", "raw_energy",
                    "barycenter_x", "barycenter_y", "barycenter_z", 
                    "barycenter_eta", "barycenter_phi",
                    "EV1", "EV2", "EV3",
                    "eVector0_x", "eVector0_y", "eVector0_z",
                    "sigmaPCA1", "sigmaPCA2", "sigmaPCA3", "raw_pt", "vertices_time"
                ],
            ):

                tmp_time = array["time"]
                tmp_raw_energy = array["raw_energy"]
                tmp_bx = array["barycenter_x"]
                tmp_by = array["barycenter_y"]
                tmp_bz = array["barycenter_z"]
                tmp_beta = array["barycenter_eta"]
                tmp_bphi = array["barycenter_phi"]
                tmp_EV1 = array["EV1"]
                tmp_EV2 = array["EV2"]
                tmp_EV3 = array["EV3"]
                tmp_eV0x = array["eVector0_x"]
                tmp_eV0y = array["eVector0_y"]
                tmp_eV0z = array["eVector0_z"]
                tmp_sigma1 = array["sigmaPCA1"]
                tmp_sigma2 = array["sigmaPCA2"]
                tmp_sigma3 = array["sigmaPCA3"]
                tmp_pt = array["raw_pt"]
                tmp_vt = array["vertices_time"]
                
                
                vert_array = []
                for vert_chunk in uproot.iterate(
                    f"{path}:{simtrack}",
                    ["barycenter_x"],
                ):
                    vert_array = vert_chunk["barycenter_x"]
                    break  # Since we have a matching chunk, no need to continue
                

                # Now load the associations for the same events/chunk
                # 'tsCLUE3D_recoToSim_CP' gives association arrays like [[1,0],[0,1],...]
                # Make sure we read from the same events
                tmp_array = []
                for assoc_chunk in uproot.iterate(
                    f"{path}:{associations_path}",
                    ["tsCLUE3D_recoToSim_CP"],
                ):
                    tmp_array = assoc_chunk["tsCLUE3D_recoToSim_CP"]
                    break  # Since we have a matching chunk, no need to continue
                
                
                skim_mask = []
                for e in vert_array:
                    if len(e) >= 1:
                        skim_mask.append(True)
                    elif len(e) == 0:
                        skim_mask.append(False)

                    else:
                        skim_mask.append(False)



                tmp_time = tmp_time[skim_mask]
                tmp_raw_energy = tmp_raw_energy[skim_mask]
                tmp_bx = tmp_bx[skim_mask]
                tmp_by = tmp_by[skim_mask]
                tmp_bz = tmp_bz[skim_mask]
                tmp_beta = tmp_beta[skim_mask]
                tmp_bphi = tmp_bphi[skim_mask]
                tmp_EV1 = tmp_EV1[skim_mask]
                tmp_EV2 = tmp_EV2[skim_mask]
                tmp_EV3 = tmp_EV3[skim_mask]
                tmp_eV0x = tmp_eV0x[skim_mask]
                tmp_eV0y = tmp_eV0y[skim_mask]
                tmp_eV0z = tmp_eV0z[skim_mask]
                tmp_sigma1 = tmp_sigma1[skim_mask]
                tmp_sigma2 = tmp_sigma2[skim_mask]
                tmp_sigma3 = tmp_sigma3[skim_mask]
                tmp_array = tmp_array[skim_mask]
                tmp_pt = tmp_pt[skim_mask]
                tmp_vt = tmp_vt[skim_mask]


                
                # Concatenate or initialize storage
                if counter == 0:
                    self.time = tmp_time
                    self.raw_energy = tmp_raw_energy
                    self.bx = tmp_bx
                    self.by = tmp_by
                    self.bz = tmp_bz
                    self.beta = tmp_beta
                    self.bphi = tmp_bphi
                    self.EV1 = tmp_EV1
                    self.EV2 = tmp_EV2
                    self.EV3 = tmp_EV3
                    self.eV0x = tmp_eV0x
                    self.eV0y = tmp_eV0y
                    self.eV0z = tmp_eV0z
                    self.sigma1 = tmp_sigma1
                    self.sigma2 = tmp_sigma2
                    self.sigma3 = tmp_sigma3
                    self.assoc = tmp_array
                    self.pt = tmp_pt
                    self.vt = tmp_vt
                else:
                    self.time = ak.concatenate((self.time, tmp_time))
                    self.raw_energy = ak.concatenate((self.raw_energy, tmp_raw_energy))
                    self.bx = ak.concatenate((self.bx, tmp_bx))
                    self.by = ak.concatenate((self.by, tmp_by))
                    self.bz = ak.concatenate((self.bz, tmp_bz))
                    self.beta = ak.concatenate((self.beta, tmp_beta))
                    self.bphi = ak.concatenate((self.bphi, tmp_bphi))
                    self.EV1 = ak.concatenate((self.EV1, tmp_EV1))
                    self.EV2 = ak.concatenate((self.EV2, tmp_EV2))
                    self.EV3 = ak.concatenate((self.EV3, tmp_EV3))
                    self.eV0x = ak.concatenate((self.eV0x, tmp_eV0x))
                    self.eV0y = ak.concatenate((self.eV0y, tmp_eV0y))
                    self.eV0z = ak.concatenate((self.eV0z, tmp_eV0z))
                    self.sigma1 = ak.concatenate((self.sigma1, tmp_sigma1))
                    self.sigma2 = ak.concatenate((self.sigma2, tmp_sigma2))
                    self.sigma3 = ak.concatenate((self.sigma3, tmp_sigma3))
                    self.assoc = ak.concatenate((self.assoc, tmp_array))
                    self.pt = ak.concatenate((self.pt, tmp_pt))
                    self.vt = ak.concatenate((self.vt, tmp_vt))

                counter += len(tmp_bx)
                if counter >= max_events:
                    print(f"Reached {max_events} events!")
                    break
            if counter >= max_events:
                break

    def download(self):
        raise RuntimeError(
            f'Dataset not found. Please download it from {self.url} and move all '
            f'*.root files to {self.raw_dir}')

    def len(self):
        return len(self.time)

    @property
    def raw_file_names(self):
        raw_files = sorted(glob.glob(osp.join(self.raw_dir, '*.root')))
        return raw_files

    @property
    def processed_file_names(self):
        return []



    def get(self, idx):
        
        def reconstruct_array(grouped_indices):
            # Find the maximum index to determine the array length
            max_index = max(max(indices) for indices in grouped_indices.values())

            # Initialize an array with the correct size, filled with a placeholder (e.g., -1)
            reconstructed = [-1] * (max_index + 1)

            # Populate the array based on the dictionary
            for value, indices in grouped_indices.items():
                for idx in indices:
                    reconstructed[idx] = value

            return reconstructed
        # Extract per-event arrays
        event_time = self.time[idx]
        event_raw_energy = self.raw_energy[idx]
        event_bx = self.bx[idx]
        event_by = self.by[idx]
        event_bz = self.bz[idx]
        event_beta = self.beta[idx]
        event_bphi = self.bphi[idx]
        event_EV1 = self.EV1[idx]
        event_EV2 = self.EV2[idx]
        event_EV3 = self.EV3[idx]
        event_eV0x = self.eV0x[idx]
        event_eV0y = self.eV0y[idx]
        event_eV0z = self.eV0z[idx]
        event_sigma1 = self.sigma1[idx]
        event_sigma2 = self.sigma2[idx]
        event_sigma3 = self.sigma3[idx]
        event_assoc = self.assoc[idx]  # Each is now an array (e.g., [0, 1, 2]) indicating the pion group
        event_pt = self.pt[idx]
        event_vt = self.vt[idx]

        # Convert each to numpy arrays if needed
        event_time = np.array(event_time)
        event_raw_energy = np.array(event_raw_energy)
        event_bx = np.array(event_bx)
        event_by = np.array(event_by)
        event_bz = np.array(event_bz)
        event_beta = np.array(event_beta)
        event_bphi = np.array(event_bphi)
        event_EV1 = np.array(event_EV1)
        event_EV2 = np.array(event_EV2)
        event_EV3 = np.array(event_EV3)
        event_eV0x = np.array(event_eV0x)
        event_eV0y = np.array(event_eV0y)
        event_eV0z = np.array(event_eV0z)
        event_sigma1 = np.array(event_sigma1)
        event_sigma2 = np.array(event_sigma2)
        event_sigma3 = np.array(event_sigma3)
        event_assoc = np.array(event_assoc)  # e.g. [[0,1,2], [2,0,1], [1,0,2], ...]
        event_pt = np.array(event_pt)
        
        avg_vt = []
        for vt in event_vt:
            vt_arr = np.array(vt)  # Convert list of vertex times to a numpy array
            valid_times = vt_arr[vt_arr != -99]  # Filter out invalid times (-99)
            if valid_times.size > 0:
                avg_vt.append(valid_times.mean())
            else:
                avg_vt.append(0)  # or np.nan if you prefer to denote missing values
        avg_vt = np.array(avg_vt)

        # Combine features
        flat_feats = np.column_stack((
            event_bx, event_by, event_bz, event_raw_energy,
            event_beta, event_bphi,
            event_EV1, event_EV2, event_EV3,
            event_eV0x, event_eV0y, event_eV0z,
            event_sigma1, event_sigma2, event_sigma3, event_pt
        ))
        x = torch.from_numpy(flat_feats).float()
        assoc = event_assoc

        total_tracksters = len(event_time)

        # --------------------------------------------------------------------
        # Group tracksters by their association tuple. Two tracksters belong
        # to the same pion group if their association arrays (converted to tuples)
        # match.
        # --------------------------------------------------------------------
        # Group tracksters by their association tuple
         # Group tracksters by the first element of event_assoc
        assoc_groups = {}
        for i, assoc in enumerate(event_assoc):
            key = assoc[0]  # Only use the first element as the key
            if key not in assoc_groups:
                assoc_groups[key] = []
            assoc_groups[key].append(i)
        assoc_array = reconstruct_array(assoc_groups)
        pos_edges = []
        neg_edges = []
        # Ensure positive edges always connect to another trackster in the same group if possible
        for i in range(total_tracksters):
            key = event_assoc[i][0]  # Get first element as group identifier
            same_group = assoc_groups[key]

            # --- Positive edge ---
            if len(same_group) > 1:
                # Always select another trackster from the same group
                pos_target = random.choice([j for j in same_group if j != i])
            else:
                # No other trackster in the group, form a self-loop
                pos_target = i
            pos_edges.append([i, pos_target])

            # --- Negative edge ---
            neg_candidates = [j for j in range(total_tracksters) if event_assoc[j][0] != key]
            if neg_candidates:
                neg_target = random.choice(neg_candidates)
            else:
                neg_target = i
            neg_edges.append([i, neg_target])

        x_pos_edge = torch.tensor(pos_edges, dtype=torch.long)
        x_neg_edge = torch.tensor(neg_edges, dtype=torch.long)

        return Data(x=x, x_pe=x_pos_edge, x_ne=x_neg_edge, assoc = assoc_array)

In [2]:

vpath = "/vols/cms/mm1221/Data/100k/5e/val/"
data_val = CCV1(vpath, max_events=100, inp='val')

### Loading tracksters data


  0%|                                                     | 0/1 [00:00<?, ?it/s]

/vols/cms/mm1221/Data/100k/5e/val/raw/val.root


  0%|                                                     | 0/1 [00:10<?, ?it/s]

Reached 100 events!





In [3]:
import numpy as np

# Convert features and associations to numpy arrays
features = data_val[0].x.numpy()
assoc = np.array(data_val[0].assoc)

# Compute mean per group
unique_groups = np.unique(assoc)
means = {group: np.mean(features[assoc == group], axis=0) for group in unique_groups}

# Print means for each feature per group
for group, mean_values in means.items():
    print(f"Group {group} Means: {mean_values}")


Group 0 Means: [-3.68170047e+00 -5.29421043e+01  3.35287048e+02  1.19959915e+02
  2.54272842e+00 -1.64022648e+00  3.84049339e+01  7.47460783e-01
  3.40673774e-01  4.04848438e-03 -1.45035386e-01  9.89418209e-01
  6.13216829e+00  9.62122023e-01  9.81783867e-01  1.74052086e+01]
Group 1 Means: [ 2.6506069e+00 -4.8803196e+01  3.4861060e+02  5.6607277e+01
  2.6621523e+00 -1.5166208e+00  2.3438560e+01  2.1770787e-01
  1.0208924e-01  1.5691033e-01 -3.7340480e-01  8.6298758e-01
  4.2646890e+00  6.2055349e-01  1.1281140e+00  8.1703224e+00]
Group 2 Means: [-4.6648107e+00 -4.9610035e+01  3.4267044e+02  3.9465828e+01
  2.6250632e+00 -1.6644167e+00  2.9590332e+01  9.2051888e-01
  4.2493060e-01  5.4838147e-02 -1.4676648e-01  9.7702831e-01
  4.9075565e+00  1.1576029e+00  9.8233032e-01  4.8005271e+00]
Group 3 Means: [-8.3863287e+00 -5.3563755e+01  3.3538666e+02  1.0880519e+02
  2.5219197e+00 -1.7261027e+00  4.2169136e+01  3.2079679e-01
  5.1162496e-02 -3.8178496e-02 -1.6212319e-01  9.8603165e-01
  6.48

In [4]:
print(data_val[0].x)

tensor([[-4.2841e+00, -4.8579e+01,  3.3563e+02,  1.1379e+02,  2.6273e+00,
         -1.6588e+00,  3.9528e+01,  4.6912e-01,  1.1979e-01, -1.3397e-02,
         -1.1855e-01,  9.9286e-01,  6.2841e+00,  6.2077e-01,  4.9127e-01,
          1.3576e+01],
        [-8.3863e+00, -5.3564e+01,  3.3539e+02,  1.0881e+02,  2.5219e+00,
         -1.7261e+00,  4.2169e+01,  3.2080e-01,  5.1162e-02, -3.8178e-02,
         -1.6212e-01,  9.8603e-01,  6.4884e+00,  5.1264e-01,  4.2261e-01,
          1.8122e+01],
        [ 2.2171e+00, -4.8338e+01,  3.3668e+02,  1.1233e+02,  2.6381e+00,
         -1.5250e+00,  4.0239e+01,  2.3432e-01,  2.0418e-01,  1.4504e-02,
         -1.3940e-01,  9.9013e-01,  6.2117e+00,  9.6617e-01,  1.0766e+00,
          1.5744e+01],
        [-3.6817e+00, -5.2942e+01,  3.3529e+02,  1.1996e+02,  2.5427e+00,
         -1.6402e+00,  3.8405e+01,  7.4746e-01,  3.4067e-01,  4.0485e-03,
         -1.4504e-01,  9.8942e-01,  6.1322e+00,  9.6212e-01,  9.8178e-01,
          1.7405e+01],
        [-8.2142e+00

In [5]:
print(data_val[0].assoc)

[2, 3, 1, 0, 2, 2, 1]
