# AthenaReader Testing

**Goal**: Build and test an AthenaReader class

In [1]:
%load_ext autoreload
%autoreload 2

import os

import numpy as np
import pandas as pd
import yaml

from time import time as tt

### Roadmap

1. Inspect a set of raw Athena files (i.e. hits, clusters, particles)
2. Run the functional implementation on these files and inspect the results
3. Build a class that performs the same functionality
4. Test the class on the same files
5. Find ways to refactor and simplify for maximum user-friendliness!

### 1. Inspect raw Athena files

In [3]:
from processing_utils import read_particles, read_spacepoints, convert_barcodes, read_clusters, get_detectable_particles, get_truth_spacepoints

In [3]:
config_file = "datatype_information.yaml"
with open(config_file, "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

FileNotFoundError: [Errno 2] No such file or directory: 'datatype_information.yaml'

In [25]:
input_dir = "./data/"
input_event = "evt1.txt"

In [26]:
clusters_file = os.path.join(input_dir, f"clusters_{input_event}")
particles_file = os.path.join(input_dir, f"particles_{input_event}")
spacepoints_file = os.path.join(input_dir, f"spacepoints_{input_event}")
subevents_file = os.path.join(input_dir, f"subevents_{input_event}")

In [27]:
# Read particles
particles = read_particles(particles_file)
particles = convert_barcodes(particles)

particles_datatypes = config["particles_datatypes"]
particles = particles.astype(particles_datatypes)

In [28]:
particles

Unnamed: 0,particle_id,subevent,barcode,px,py,pz,pt,eta,vx,vy,vz,radius,status,charge,pdgId,pass,vProdNIn,vProdNOut,vProdStatus,vProdBarcode
0,1,0,1,0.000000,0.000000,0.000000,0.000000,inf,-1.000000,-1.000000,-1.000000,999.000000,4,1.000000,2212,NO,0,0,-999,999
1,2,0,2,0.000000,0.000000,0.000000,0.000000,-inf,-1.000000,-1.000000,-1.000000,999.000000,4,1.000000,2212,NO,0,0,-999,999
2,3,0,3,0.000000,0.000000,0.000000,0.000000,inf,-0.012448,0.002296,-36.654202,0.012658,21,1.000000,21,NO,1,2,0,-1
3,4,0,4,0.000000,0.000000,0.000000,0.000000,-inf,-0.012448,0.002296,-36.654202,0.012658,21,1.000000,21,NO,1,1,0,-2
4,5,0,5,11332.000000,33428.800781,33428.800781,35297.300781,2.679480,-0.012448,0.002296,-36.654202,0.012658,22,0.666667,6,NO,2,3,0,-3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
217928,69941201275,6994,1201275,-266.597992,247.266006,247.266006,363.614014,0.592441,-195.190994,116.448997,19902.199219,227.287994,1,0.000000,2112,NO,1,1,1111,-200841
217929,69941201282,6994,1201282,-429.354004,191.046997,191.046997,469.941010,2.215460,-104.683998,-16.695000,19745.400391,106.007004,1,-1.000000,-211,NO,1,1,1111,-200850
217930,69941201298,6994,1201298,7.258230,-2.475790,-2.475790,7.668860,1.716160,-91.080597,-20.040501,19745.699219,93.259300,1,1.000000,-11,NO,1,2,1003,-200853
217931,69942200680,6994,2200680,-199.509995,-108.432999,-108.432999,227.072006,-0.636485,429.341003,-934.872986,212.710007,1028.750000,1,1.000000,2212,NO,1,1,1111,-200550


In [29]:
pixel_spacepoints, strip_spacepoints = read_spacepoints(spacepoints_file)

In [30]:
strip_spacepoints

Unnamed: 0,hit_id,x,y,z,cluster_index_1,cluster_index_2
256608,256608,-385.432,14.13330,-1501.25,256608,256639.0
256609,256609,-400.494,-31.56590,-1501.25,256610,256636.0
256610,256610,-383.252,-30.95050,-1501.25,256611,256636.0
256611,256611,-404.481,13.22680,-1501.25,256614,256647.0
256612,256612,-419.850,5.37124,-1501.25,256615,256646.0
...,...,...,...,...,...,...
340463,340463,909.729,-113.10100,2854.25,475402,475425.0
340464,340464,915.800,-129.20300,2854.25,475404,475423.0
340465,340465,896.554,-141.31400,2854.25,475405,475422.0
340466,340466,951.864,-162.12100,2854.25,475406,475421.0


The hard work begins now: We need to read in the raw cluster information, which have variable lengths depending on whether the cluster is noise (no particle ID), a simple hit (one particle ID) or a shared hit (multiple particle IDs).

In [31]:
clusters = read_clusters(clusters_file, particles, config["column_lookup"])

### 2. Process these raw inputs

In [33]:
num_clusters =  clusters.groupby("particle_id")["cluster_id"].count().reset_index(name="num_clusters")

particles = pd.merge(particles, num_clusters, on='particle_id').fillna(method='ffill')
reco_particles = get_detectable_particles(particles)

In [37]:
truth = get_truth_spacepoints(pixel_spacepoints, strip_spacepoints, clusters, config)

### 3. Build AthenaReader class

In [430]:
from typing import Union
import glob
import re
from itertools import chain, product
from torch_geometric.data import Data
import torch

class EventReader:
    """ 
    A general class for reading simulated particle collision events from a set of files. Several convenience utilities are built in,
    and conversion to CSV and Pytorch Geometric data objects is enforced. However the reading of the input files is left up to the user.
    It is expected that general usage is, e.g.
    AthenaReader(path/to/files), which consists of:
    1. Raw files -> CSV
    2. CSV -> PyG data objects    
    """
    def __init__(self, config):
        self.files = None
        if isinstance(config, dict):
            self.config = config
        elif isinstance(config, str):
            self.config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
        else:
            raise NotImplementedError

    @classmethod
    def infer(cls, config):
        """ 
        The gateway for the inference stage. This class method is called from the infer_stage.py script.
        It assumes a set of basic steps. These are:
        1. Convert to CSV - This should be implemented by the user
        2. Convert to PyG
        """
        reader = cls(config)
        reader.convert_to_csv()
        reader._test_csv_conversion()
        reader._convert_to_pyg()
        return reader 

    def convert_to_csv(self):
        raise NotImplementedError

    def _convert_to_pyg(self):

        self.csv_events = self.get_file_names(self.config["output_dir"], filename_terms = ["particles", "truth"])

        for event in self.csv_events:
            particles = pd.read_csv(event["particles"])
            hits = pd.read_csv(event["truth"])
            hits, particles = self._select_hits(hits, particles)
            hits = self._add_all_features(hits)
            hits = self._clean_noise_duplicates(hits)

            tracks, track_features, hits = self._build_true_tracks(hits, particles)
            hits, particles, tracks = self._custom_processing(hits, particles, tracks)
            graph = self._build_graph(hits, tracks, track_features)
            self._save_pyg_data(graph, event["event_id"])

    def _build_graph(self, hits, tracks, track_features):
        """
        Builds a PyG data object from the hits, particles and tracks.
        """
        
        graph = Data()
        for feature in self.config["feature_sets"]["hit_features"]:
            graph[feature] = torch.from_numpy(hits[feature].values)
        
        graph.track_edges = torch.from_numpy(tracks)
        for feature in self.config["feature_sets"]["track_features"]:
            graph[feature] = torch.from_numpy(track_features[feature])

        return graph

    def _save_pyg_data(self, graph, event_id):
        """
        Save the PyG constructed graph
        """
        torch.save(graph, os.path.join(self.config["output_dir"], f"event{event_id}-graph.pyg"))


    @staticmethod
    def calc_eta(r, z):
        theta = np.arctan2(r, z)
        return -1.0 * np.log(np.tan(theta / 2.0))  

    def _select_hits(self, hits, particles):

        """ 
        Takes a set of hits and particles and applies hard cuts to the list of hits and particles. These should be defined
        in a `hard_cuts` key in the config file. E.g.
        hard_cuts: 
            pt: [1000, inf]
            barcode: [0, 200000]
        """

        particles = particles.assign(primary=(particles.barcode < 200000).astype(int))

        if "hard_cuts" in self.config and self.config["hard_cuts"] is not None:
            raise NotImplementedError("Hard cuts not implemented yet")
            for cut in self.config["hard_cuts"]:
                pass
            hits = hits.merge(
                particles[["particle_id", "pt", "vx", "vy", "vz", "primary"]],
                on="particle_id",
            )

        else:
            hits = hits.merge(
                particles[["particle_id", "pt", "vx", "vy", "vz", "primary", "pdgId", "radius"]],
                on="particle_id",
                how="left",
            )

        hits["nhits"] = hits.groupby("particle_id")["particle_id"].transform("count")
        hits["particle_id"] = hits["particle_id"].fillna(0).astype(int)
        hits.loc[hits.particle_id == 0, "nhits"] = -1

        return hits, particles

    @staticmethod
    def _clean_noise_duplicates(hits):
        """
        This handles the case where a hit is assigned to both a particle and noise (e.g. in a ghost spacepoint). 
        This is not sensible, so we remove those duplicated noise hits.
        """

        noise_hits = hits[hits.particle_id == 0].drop_duplicates(subset="hit_id")
        signal_hits = hits[hits.particle_id != 0]

        non_duplicate_noise_hits = noise_hits[~noise_hits.hit_id.isin(signal_hits.hit_id)]
        hits = pd.concat([signal_hits, non_duplicate_noise_hits], ignore_index=True)
        # Sort hits by hit_id for ease of processing
        hits = hits.sort_values("hit_id").reset_index(drop=True)

        return hits

    def _add_all_features(self, hits):
        assert all([col in hits.columns for col in ["x","y","z"]]), "Need to add (x,y,z) features"
        
        r = np.sqrt(hits.x**2 + hits.y**2)
        phi = np.arctan2(hits.y, hits.x)
        eta = self.calc_eta(r, hits.z)
        hits = hits.assign(r=r, phi=phi, eta=eta)

        return hits

    def _build_true_tracks(self, hits, particles):

        assert all([col in hits.columns for col in ["particle_id", "hit_id", "x","y","z","vx","vy","vz"]]), "Need to add (particle_id, hit_id), (x,y,z) and (vx,vy,vz) features to hits dataframe in custom EventReader class"

        signal = hits[(hits.particle_id != 0)]

        # Sort by increasing distance from production
        signal = signal.assign(
            R=np.sqrt(
                (signal.x - signal.vx) ** 2
                + (signal.y - signal.vy) ** 2
                + (signal.z - signal.vz) ** 2
            )
        )

        signal = signal.sort_values("R").reset_index(drop=False)

        # Group by particle ID
        if "module_columns" not in self.config or self.config["module_columns"] is None:
            module_columns = ["barrel_endcap", "hardware", "layer_disk", "eta_module", "phi_module"]
        
        signal_index_list = (signal.groupby(
                ["particle_id"] + module_columns,
                sort=False,
            )["index"]
            .agg(lambda x: list(x))
            .groupby(level=0)
            .agg(lambda x: list(x)))

        track_index_edges = []
        for row in signal_index_list.values:
            for i, j in zip(row[:-1], row[1:]):
                track_index_edges.extend(list(product(i, j)))

        track_index_edges = np.array(track_index_edges).T
        track_edges = hits.hit_id.values[track_index_edges]

        assert (hits[hits.hit_id.isin(track_edges.flatten())].particle_id == 0).sum() == 0, "There are hits in the track edges that are noise"

        track_features = self._get_track_features(hits, track_index_edges)

        # Remap
        track_edges, hits = self.remap_edges(track_edges, track_features, hits)

        return track_edges, track_features, hits

    def _get_track_features(self, hits, track_index_edges):
        track_features = {}
        for track_feature in self.config["feature_sets"]["track_features"]:
            assert (hits[track_feature].values[track_index_edges][0] == hits[track_feature].values[track_index_edges][1]).all()
            track_features[track_feature] = hits[track_feature].values[track_index_edges[0]]
    
        return track_features

    @staticmethod
    def remap_edges(track_edges, track_features, hits):

        unique_hid = np.unique(hits.hit_id)
        hid_mapping = np.zeros(unique_hid.max() + 1).astype(int)
        hid_mapping[unique_hid] = np.arange(len(unique_hid))

        hits = hits.drop_duplicates(subset="hit_id").sort_values("hit_id")
        assert (hits.hit_id == unique_hid).all(), "If hit IDs are not sequential, this will mess up graph structure!"

        track_edges = hid_mapping[track_edges]

        # This test imposes a limit to how we simplify the graph: We don't allow shared EDGES (i.e. two different particles can share a hit, but not an edge between the same two hits). We want to ensure these are in a tiny minority
        assert ((hits.particle_id.values[track_edges[0]] != track_features["particle_id"]) & (hits.particle_id.values[track_edges[1]] != track_features["particle_id"])).sum() < 50, "The number of shared EDGES is unusually high!"

        return track_edges, hits

    def get_file_names(self, inputdir, filename_terms : Union[str, list] = None):
        """
        Takes a list of filename terms and searches for all files containing those terms AND a number. Returns the files and numbers.
        For the list of numbers, search for each of the matching terms and files containing that number AND ONLY THAT NUMBER.
        """

        if isinstance(filename_terms, str):
            filename_terms = [filename_terms]
        elif filename_terms is None:
            filename_terms = ["*"]

        all_files_in_template = [ glob.glob(os.path.join(inputdir, f"*{template}*")) for template in filename_terms ]
        all_files_in_template = list(chain.from_iterable(all_files_in_template))
        all_event_ids = sorted(list(set([re.findall("[0-9]+", file)[-1] for file in all_files_in_template])))

        all_events = []
        for event_id in all_event_ids:
            event = {"event_id": event_id}
            for term in filename_terms:
                # Search for a file containing the term and EXACTLY the event id (i.e. no other numbers)
                template_file = [file for file in all_files_in_template if term in os.path.basename(file) and re.findall("[0-9]+", file)[-1] == event_id]
                if len(template_file) == 0:
                    print(f"Could not find file for term {term} and event id {event_id}")
                    break
                else:
                    event[term] = template_file[0]
            else:
                all_events.append(event)

        return all_events

    def _custom_processing(self, hits, particles, tracks):
        """
        This is called after the base class has finished processing the hits, particles and tracks.
        """
        pass

    def _test_csv_conversion(self):
        
        self.csv_events = self.get_file_names(self.config["output_dir"], filename_terms=["truth", "particles"])
        assert len(self.csv_events) > 0, "No CSV files found in output directory matching the formats (event[eventID]-truth.csv, event[eventID]-particles.csv). Please check that the conversion to CSV was successful."

        # Load the first event
        event = self.csv_events[0]
        truth = pd.read_csv(event["truth"])
        particles = pd.read_csv(event["particles"])
        assert len(truth) > 0, "No truth spacepoints found in CSV file. Please check that the conversion to CSV was successful."
        assert len(particles) > 0, "No particles found in CSV file. Please check that the conversion to CSV was successful."

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        return self.files[idx]
    
    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

In [431]:
class AthenaReader(EventReader):
    def __init__(self, config):
        super().__init__(config)
        
    def _custom_processing(self, hits, particles, tracks):
        """
        This is called after the base class has finished processing the hits, particles and tracks.
        In Athena, we will use it for some fine-tuning of the final outputs, including adding region labels to the hits (for heteroGNN stage).
        """
        # Add region labels to hits
        hits = self._add_region_labels(hits)
        
        return hits, particles, tracks
        
    def _add_region_labels(self, hits):
        
        for region_label, conditions in config["region_labels"].items():
            for condition_column, condition in conditions.items():
                condition_mask = np.logical_and.reduce([hits[condition_column] == condition for condition_column, condition in conditions.items()])
                hits.loc[condition_mask, "region"] = region_label

        assert (hits.region.isna()).sum() == 0, "There are hits that do not belong to any region!"

        return hits

    def convert_to_csv(self):
        
        input_dir = self.config["input_dir"]
        self.raw_events = self.get_file_names(input_dir, filename_terms = ["clusters", "particles", "spacepoints"])

        for event in self.raw_events:

            clusters_file = event["clusters"]
            particles_file = event["particles"]
            spacepoints_file = event["spacepoints"]
            event_id = event["event_id"]

            # Check if file already exists
            if os.path.exists(os.path.join(self.config["output_dir"], "event{:09}-particles.csv".format(int(event_id)))) and os.path.exists(os.path.join(self.config["output_dir"], "event{:09}-truth.csv".format(int(event_id)))):
                print("File already exists, skipping...")
                continue

            # Read particles
            particles = read_particles(particles_file)
            particles = convert_barcodes(particles)
            particles = particles.astype(self.config["particles_datatypes"])

            # Read spacepoints
            pixel_spacepoints, strip_spacepoints = read_spacepoints(spacepoints_file)

            # Read clusters
            clusters = read_clusters(clusters_file, particles, self.config["column_lookup"])

            # Get truth spacepoints
            truth = get_truth_spacepoints(pixel_spacepoints, strip_spacepoints, clusters, self.config["spacepoints_datatypes"])

            # Get detectable particles
            detectable_particles = get_detectable_particles(particles, clusters)

            # Save to CSV
            os.makedirs(self.config["output_dir"], exist_ok=True)
            truth.to_csv(os.path.join(self.config["output_dir"], "event{:09}-truth.csv".format(int(event_id))), index=False)
            detectable_particles.to_csv(os.path.join(self.config["output_dir"], "event{:09}-particles.csv".format(int(event_id))), index=False)


In [2]:
config_file = "athena_reader_config.yaml"
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)

In [433]:
reader = AthenaReader.infer(config)

Could not find file for term particles and event id 2
File already exists, skipping...


In [434]:
graph = torch.load("output_data/event000000001-graph.pyg")

In [435]:
graph

Data(hit_id=[340468], x=[340468], y=[340468], z=[340468], r=[340468], phi=[340468], eta=[340468], cluster_x_1=[340468], cluster_y_1=[340468], cluster_z_1=[340468], cluster_x_2=[340468], cluster_y_2=[340468], cluster_z_2=[340468], norm_x=[340468], norm_y=[340468], norm_z_1=[340468], eta_angle_1=[340468], phi_angle_1=[340468], eta_angle_2=[340468], phi_angle_2=[340468], norm_z_2=[340468], track_edges=[2, 137544], particle_id=[137544], pt=[137544], radius=[137544], primary=[137544], nhits=[137544], pdgId=[137544])

In [4]:
hits = pd.read_csv("/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/valset/event000000102-truth.csv")
particles = pd.read_csv("/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/valset/event000000102-particles.csv")

In [392]:
particles = particles.assign(primary=(particles.barcode < 200000).astype(int))
hits = hits.merge(
                particles[["particle_id", "pt", "vx", "vy", "vz", "primary", "pdgId", "radius"]],
                on="particle_id",
                how="left",
            )

hits["nhits"] = hits.groupby("particle_id")["particle_id"].transform("count")
hits["particle_id"] = hits["particle_id"].fillna(0).astype(int)
hits.loc[hits.particle_id == 0, "nhits"] = -1

In [393]:
noise_hits = hits[hits.particle_id == 0].drop_duplicates(subset="hit_id")
signal_hits = hits[hits.particle_id != 0]

non_duplicate_noise_hits = noise_hits[~noise_hits.hit_id.isin(signal_hits.hit_id)]
hits = pd.concat([signal_hits, non_duplicate_noise_hits], ignore_index=True)

# Sort hits by hit_id for ease of processing
hits = hits.sort_values("hit_id").reset_index(drop=True)

In [394]:
signal = hits[hits.particle_id != 0]

# Sort by increasing distance from production
signal = signal.assign(
    R=np.sqrt(
        (signal.x - signal.vx) ** 2
        + (signal.y - signal.vy) ** 2
        + (signal.z - signal.vz) ** 2
    )
)

signal = signal.sort_values("R").reset_index(drop=False)

# Group by particle ID
module_columns = ["barrel_endcap", "hardware", "layer_disk", "eta_module", "phi_module"]

In [395]:
signal_index_list = (signal.groupby(
        ["particle_id"] + module_columns,
        sort=False,
    )["index"]
    .agg(lambda x: list(x))
    .groupby(level=0)
    .agg(lambda x: list(x)))

track_index_edges = []
for row in signal_index_list.values:
    for i, j in zip(row[:-1], row[1:]):
        track_index_edges.extend(list(product(i, j)))

track_index_edges = np.array(track_index_edges).T

In [396]:
track_edges = hits.hit_id.values[track_index_edges]

In [397]:
assert (hits[hits.hit_id.isin(track_edges.flatten())].particle_id == 0).sum() == 0, "There are hits in the track edges that are noise"

In [398]:
assert (hits.pdgId.values[track_index_edges][0] == hits.pdgId.values[track_index_edges][1]).all()

In [399]:
track_features = {}
for track_feature in ["particle_id", "pt", "radius", "primary", "nhits", "pdgId"]:
    assert (hits[track_feature].values[track_index_edges][0] == hits[track_feature].values[track_index_edges][1]).all()
    track_features[track_feature] = hits[track_feature].values[track_index_edges[0]]

In [401]:
unique_hid = np.unique(hits.hit_id)
hid_mapping = np.zeros(unique_hid.max() + 1).astype(int)
hid_mapping[unique_hid] = np.arange(len(unique_hid))
hits = hits.drop_duplicates(subset="hit_id").sort_values("hit_id")
assert (hits.hit_id == unique_hid).all(), "If hit IDs are not sequential, this will mess up graph structure!"

track_edges = hid_mapping[track_edges]

In [402]:
assert ((hits.particle_id.values[track_edges[0]] != track_features["particle_id"]) & (hits.particle_id.values[track_edges[1]] != track_features["particle_id"])).sum() < 100, "The number of shared EDGES is unusually high!"

In [407]:
for region_label, conditions in config["region_labels"].items():
    for condition_column, condition in conditions.items():
        condition_mask = np.logical_and.reduce([hits[condition_column] == condition for condition_column, condition in conditions.items()])
        hits.loc[condition_mask, "region"] = region_label

assert (hits.region.isna()).sum() == 0, "There are hits that do not belong to any region!"

### 4. Inspect Outputs

In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
test_csv_path = "/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/testset/event000000194-truth.csv"
test_hits = pd.read_csv(test_csv_path)

In [5]:
test_hits

Unnamed: 0,hit_id,x,y,z,cluster_index_1,cluster_index_2,particle_id,particle_id_1,particle_id_2,hardware,...,eta_angle_1,phi_angle_1,cluster_x_2,cluster_y_2,cluster_z_2,eta_angle_2,phi_angle_2,norm_z_2,region,ID
0,0,-36.6581,-4.29661,-263.00,0,-1,67480000184,67480000184,-1,PIXEL,...,0.982794,0.982794,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
1,1,-50.8359,-18.85710,-263.00,1,-1,0,0,-1,PIXEL,...,1.249050,1.249050,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
2,2,-43.2095,-13.32110,-263.00,2,-1,0,0,-1,PIXEL,...,0.321751,0.291457,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
3,3,-38.5501,-7.39981,-263.00,3,-1,67520000493,67520000493,-1,PIXEL,...,1.249050,0.982794,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
4,4,-41.2780,-10.57310,-263.00,4,-1,69550000406,69550000406,-1,PIXEL,...,0.982794,1.249050,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
361778,360259,954.5330,-138.13500,2854.25,496951,496965,67850001058,67850001058,67850001058,STRIP,...,1.324640,0.006379,927.959,-134.836,2860.75,1.32464,0.006379,-1.0,6.0,1927426291305283584
361779,360260,950.3910,-170.52700,2854.25,496953,496964,0,0,0,STRIP,...,1.115860,0.006380,922.790,-166.575,2860.75,1.32464,0.006379,-1.0,6.0,1927426291305283584
361780,360261,905.5270,-164.65300,2854.25,496954,496963,0,0,0,STRIP,...,0.311485,0.006384,922.625,-167.485,2860.75,1.11611,0.006379,-1.0,6.0,1927426291305283584
361781,360262,893.0260,-162.12200,2854.25,496954,496964,0,0,0,STRIP,...,0.311485,0.006384,922.790,-166.575,2860.75,1.32464,0.006379,-1.0,6.0,1927426291305283584


In [2]:
test_pyg_path = "/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/testset/event000000194-graph.pyg"
test_graph = torch.load(test_pyg_path)

In [3]:
test_graph

Data(hit_id=[360264], x=[360264], y=[360264], z=[360264], r=[360264], phi=[360264], eta=[360264], region=[360264], cluster_x_1=[360264], cluster_y_1=[360264], cluster_z_1=[360264], cluster_x_2=[360264], cluster_y_2=[360264], cluster_z_2=[360264], norm_x=[360264], norm_y=[360264], norm_z_1=[360264], eta_angle_1=[360264], phi_angle_1=[360264], eta_angle_2=[360264], phi_angle_2=[360264], norm_z_2=[360264], track_edges=[2, 145405], particle_id=[145405], pt=[145405], radius=[145405], primary=[145405], nhits=[145405], pdgId=[145405], config=[1])

In [12]:
test_graph = test_graph.to("cuda")

In [16]:
print(test_graph.config)

[{'stage': 'data_reading', 'model': 'AthenaReader', 'input_dir': '/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/athena_100_events', 'stage_dir': '/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/', 'module_lookup_path': '/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/Modules_geo_10events.txt', 'feature_sets': {'hit_features': ['hit_id', 'x', 'y', 'z', 'r', 'phi', 'eta', 'region', 'cluster_x_1', 'cluster_y_1', 'cluster_z_1', 'cluster_x_2', 'cluster_y_2', 'cluster_z_2', 'norm_x', 'norm_y', 'norm_z_1', 'eta_angle_1', 'phi_angle_1', 'eta_angle_2', 'phi_angle_2', 'norm_z_2'], 'track_features': ['particle_id', 'pt', 'radius', 'primary', 'nhits', 'pdgId']}, 'region_labels': {1: {'hardware': 'PIXEL', 'barrel_endcap': -2}, 2: {'hardware': 'STRIP', 'barrel_endcap': -2}, 3: {'hardware': 'PIXEL', 'barrel_endcap': 0}, 4: {'hardware': 'STRIP', 'barrel_endcap': 0}, 5: {'hardware': 'PIXEL', 'barrel_endcap'

## Test Redundant Split Edge Feature

In [39]:
from gnn4itk_cf.stages.data_reading import AthenaReader
from itertools import chain, product, combinations

In [13]:
config_file = "athena_reader_config.yaml"
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)

In [22]:
reader = AthenaReader(config)

In [95]:
hits = pd.read_csv("/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/valset/event000000102-truth.csv")
particles = pd.read_csv("/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/valset/event000000102-particles.csv")

In [96]:
hits, particles = reader._merge_particles_to_hits(hits, particles)
hits = reader._add_handengineered_features(hits)
hits = reader._clean_noise_duplicates(hits)

In [97]:
hits.columns

Index(['hit_id', 'x', 'y', 'z', 'cluster_index_1', 'cluster_index_2',
       'particle_id', 'particle_id_1', 'particle_id_2', 'hardware',
       'cluster_x_1', 'cluster_y_1', 'cluster_z_1', 'barrel_endcap',
       'layer_disk', 'eta_module', 'phi_module', 'norm_x', 'norm_y',
       'norm_z_1', 'eta_angle_1', 'phi_angle_1', 'cluster_x_2', 'cluster_y_2',
       'cluster_z_2', 'eta_angle_2', 'phi_angle_2', 'norm_z_2', 'region', 'ID',
       'pt', 'radius', 'primary', 'nhits', 'pdgId', 'vx', 'vy', 'vz', 'r',
       'phi', 'eta'],
      dtype='object')

In [98]:
# Sort by increasing distance from production
hits = hits.assign(
    R=np.sqrt(
        (hits.x - hits.vx) ** 2
        + (hits.y - hits.vy) ** 2
        + (hits.z - hits.vz) ** 2
    )
)

signal = hits[(hits.particle_id != 0)]     
signal = signal.sort_values("R").reset_index(drop=False)

In [99]:
signal.duplicated(subset=["ID", "particle_id"]).sum()

892

In [100]:
# Group by particle ID
if "module_columns" not in config or config["module_columns"] is None:
    module_columns = ["barrel_endcap", "hardware", "layer_disk", "eta_module", "phi_module"]
else:
    module_columns = config["module_columns"]

signal_index_list = (signal.groupby(
        ["particle_id"] + module_columns,
        sort=False,
    )["index"]
    .agg(lambda x: list(x))
    .groupby(level=0)
    .agg(lambda x: list(x)))

track_index_edges = []
for row in signal_index_list.values:
    for i, j in zip(row[:-1], row[1:]):
        track_index_edges.extend(list(product(i, j)))

track_index_edges = np.array(track_index_edges).T

In [101]:
track_edges = hits.hit_id.values[track_index_edges]
assert (hits[hits.hit_id.isin(track_edges.flatten())].particle_id == 0).sum() == 0, "There are hits in the track edges that are noise"
track_features = reader._get_track_features(hits, track_index_edges)

In [102]:
track_features

{'particle_id': array([        181,         181,         181, ..., 62280201002,
        62280201002, 62280201002]),
 'pt': array([60739.2  , 60739.2  , 60739.2  , ...,   330.628,   330.628,
          330.628]),
 'radius': array([1.93578e-02, 1.93578e-02, 1.93578e-02, ..., 5.57083e+02,
        5.57083e+02, 5.57083e+02]),
 'primary': array([1., 1., 1., ..., 0., 0., 0.]),
 'nhits': array([1.00000e+01, 1.00000e+01, 1.00000e+01, ..., 1.49991e+05,
        1.49991e+05, 1.49991e+05]),
 'pdgId': array([  11.,   11.,   11., ..., -211., -211., -211.])}

In [103]:
truth_track_df = pd.concat([pd.DataFrame(track_edges.T, columns=["hit_id_0", "hit_id_1"]), pd.DataFrame(track_features)], axis=1)

In [106]:
hits.drop_duplicates(subset="hit_id").ID

0             158329674399744
1             158329674399744
2             158329674399744
3             158329674399744
4             158329674399744
                 ...         
278460    1927426291305283584
278461    1927426291305283584
278462    1927426291305283584
278463    1927426291305283584
278464    1927426291305283584
Name: ID, Length: 278183, dtype: int64

In [110]:
truth_track_df.shape

(115312, 8)

In [138]:
hits_unique = hits.drop_duplicates(subset="hit_id")[["hit_id", "ID", "R"] + module_columns]
truth_track_df = truth_track_df[["hit_id_0", "hit_id_1", "particle_id"]].merge(hits_unique, left_on="hit_id_0", right_on="hit_id", how="left").drop(columns=["hit_id"]).merge(hits_unique, left_on="hit_id_1", right_on="hit_id", how="left").drop(columns=["hit_id"])


In [139]:
truth_track_df[truth_track_df.duplicated(subset=["ID_y", "particle_id"])]

Unnamed: 0,hit_id_0,hit_id_1,particle_id,ID_x,R_x,barrel_endcap_x,hardware_x,layer_disk_x,eta_module_x,phi_module_x,ID_y,R_y,barrel_endcap_y,hardware_y,layer_disk_y,eta_module_y,phi_module_y
60,80473,88504,225,293420071034814464,113.917780,0,PIXEL,0,-4,9,331120125727997952,275.807582,0,PIXEL,1,-6,12
62,88504,18203,225,331120125727997952,275.807582,0,PIXEL,1,-6,12,36750076646785024,307.133043,-2,PIXEL,1,0,1
303,76686,34681,353,290015983035219968,170.633432,0,PIXEL,0,-7,3,42968914413486080,383.124251,-2,PIXEL,1,3,12
426,131410,247319,384,456253345061928960,561.971057,0,PIXEL,4,12,42,1485554558334664704,1071.755886,0,STRIP,1,10,29
428,247319,255552,384,1485554558334664704,1071.755886,0,STRIP,1,10,29,1524433289492824064,1495.754639,0,STRIP,2,14,39
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
114774,231833,242446,62280001100,1442770361874644992,1000.222905,0,STRIP,0,10,5,1479397293219119104,1417.732382,0,STRIP,1,14,7
114776,242446,265305,62280001100,1479397293219119104,1417.732382,0,STRIP,1,14,7,1768938287231139840,1795.494585,2,STRIP,1,3,12
114780,76511,155720,62280001101,289593770570153984,128.937767,0,PIXEL,0,9,2,613791371089674240,276.831507,2,PIXEL,1,2,2
114902,132765,132542,62280001176,459507899480145920,313.423391,0,PIXEL,4,-2,48,458936153433702400,321.784722,0,PIXEL,4,-3,47


In [159]:
truth_track_df[truth_track_df.duplicated(subset=["ID_x", "particle_id"])]

Unnamed: 0,hit_id_0,hit_id_1,particle_id,ID_x,R_x,barrel_endcap_x,hardware_x,layer_disk_x,eta_module_x,phi_module_x,ID_y,R_y,barrel_endcap_y,hardware_y,layer_disk_y,eta_module_y,phi_module_y
60,80473,88504,225,293420071034814464,113.917780,0,PIXEL,0,-4,9,331120125727997952,275.807582,0,PIXEL,1,-6,12
62,88504,18203,225,331120125727997952,275.807582,0,PIXEL,1,-6,12,36750076646785024,307.133043,-2,PIXEL,1,0,1
303,76686,34681,353,290015983035219968,170.633432,0,PIXEL,0,-7,3,42968914413486080,383.124251,-2,PIXEL,1,3,12
426,131410,247319,384,456253345061928960,561.971057,0,PIXEL,4,12,42,1485554558334664704,1071.755886,0,STRIP,1,10,29
428,247319,255552,384,1485554558334664704,1071.755886,0,STRIP,1,10,29,1524433289492824064,1495.754639,0,STRIP,2,14,39
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
114774,231833,242446,62280001100,1442770361874644992,1000.222905,0,STRIP,0,10,5,1479397293219119104,1417.732382,0,STRIP,1,14,7
114776,242446,265305,62280001100,1479397293219119104,1417.732382,0,STRIP,1,14,7,1768938287231139840,1795.494585,2,STRIP,1,3,12
114780,76511,155720,62280001101,289593770570153984,128.937767,0,PIXEL,0,9,2,613791371089674240,276.831507,2,PIXEL,1,2,2
114902,132765,132542,62280001176,459507899480145920,313.423391,0,PIXEL,4,-2,48,458936153433702400,321.784722,0,PIXEL,4,-3,47


In [136]:
signal_index_list = (signal.groupby(
        ["particle_id"] + module_columns,
        sort=False,
    )["index"]
    .agg(lambda x: list(x))
    .groupby(level=0)
    .agg(lambda x: list(x)))

In [144]:
R_y_min = truth_track_df.loc[truth_track_df.groupby(["ID_y", "particle_id"])["R_y"].idxmin()]

In [145]:
R_y_min

Unnamed: 0,hit_id_0,hit_id_1,particle_id,ID_x,R_x,barrel_endcap_x,hardware_x,layer_disk_x,eta_module_x,phi_module_x,ID_y,R_y,barrel_endcap_y,hardware_y,layer_disk_y,eta_module_y,phi_module_y
6336,15977,21,60430000113,10854378789404672,236.308941,-2,PIXEL,0,0,19,158329674399744,240.365010,-2,PIXEL,0,0,0
16277,80346,0,60620000748,293376090569703424,211.951872,0,PIXEL,0,-9,9,158329674399744,293.247020,-2,PIXEL,0,0,0
17359,15968,12,60630001115,10854378789404672,255.201796,-2,PIXEL,0,0,19,158329674399744,259.280577,-2,PIXEL,0,0,0
24826,15995,37,60750000259,10854378789404672,257.026768,-2,PIXEL,0,0,19,158329674399744,261.084194,-2,PIXEL,0,0,0
27528,16002,4,60790000076,10854378789404672,247.319669,-2,PIXEL,0,0,19,158329674399744,251.387901,-2,PIXEL,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4915,273645,278169,60400000364,1882363906752512000,2570.391006,2,STRIP,4,2,31,1927408699119239168,2895.009515,2,STRIP,5,3,63
109286,273650,278170,62230200007,1882363906752512000,1304.858321,2,STRIP,4,2,31,1927408699119239168,1632.637992,2,STRIP,5,3,63
95539,271208,278175,61940200156,1873646978567503872,1814.804453,2,STRIP,4,3,0,1927417495212261376,2167.985602,2,STRIP,5,4,63
114793,271211,278174,62280001101,1873646978567503872,2577.615761,2,STRIP,4,3,0,1927417495212261376,2922.021775,2,STRIP,5,4,63


In [162]:
R_y_min_2 = truth_track_df.sort_values(by=["ID_y", "R_x", "R_y"]).drop_duplicates(subset=["ID_y", "particle_id"], keep="first")

In [163]:
dropped_rows = ~truth_track_df.index.isin(R_y_min_2.index)

In [167]:
dropped_rows

array([False, False, False, ..., False, False, False])

In [158]:
truth_track_df[dropped_rows]

Unnamed: 0,hit_id_0,hit_id_1,particle_id,ID_x,R_x,barrel_endcap_x,hardware_x,layer_disk_x,eta_module_x,phi_module_x,ID_y,R_y,barrel_endcap_y,hardware_y,layer_disk_y,eta_module_y,phi_module_y
60,80473,88504,225,293420071034814464,113.917780,0,PIXEL,0,-4,9,331120125727997952,275.807582,0,PIXEL,1,-6,12
62,88504,18203,225,331120125727997952,275.807582,0,PIXEL,1,-6,12,36750076646785024,307.133043,-2,PIXEL,1,0,1
303,76686,34681,353,290015983035219968,170.633432,0,PIXEL,0,-7,3,42968914413486080,383.124251,-2,PIXEL,1,3,12
426,131410,247319,384,456253345061928960,561.971057,0,PIXEL,4,12,42,1485554558334664704,1071.755886,0,STRIP,1,10,29
428,247319,255552,384,1485554558334664704,1071.755886,0,STRIP,1,10,29,1524433289492824064,1495.754639,0,STRIP,2,14,39
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
114774,231833,242446,62280001100,1442770361874644992,1000.222905,0,STRIP,0,10,5,1479397293219119104,1417.732382,0,STRIP,1,14,7
114776,242446,265305,62280001100,1479397293219119104,1417.732382,0,STRIP,1,14,7,1768938287231139840,1795.494585,2,STRIP,1,3,12
114780,76511,155720,62280001101,289593770570153984,128.937767,0,PIXEL,0,9,2,613791371089674240,276.831507,2,PIXEL,1,2,2
114902,132765,132542,62280001176,459507899480145920,313.423391,0,PIXEL,4,-2,48,458936153433702400,321.784722,0,PIXEL,4,-3,47


In [164]:
truth_track_df[dropped_rows]

Unnamed: 0,hit_id_0,hit_id_1,particle_id,ID_x,R_x,barrel_endcap_x,hardware_x,layer_disk_x,eta_module_x,phi_module_x,ID_y,R_y,barrel_endcap_y,hardware_y,layer_disk_y,eta_module_y,phi_module_y
60,80473,88504,225,293420071034814464,113.917780,0,PIXEL,0,-4,9,331120125727997952,275.807582,0,PIXEL,1,-6,12
62,88504,18203,225,331120125727997952,275.807582,0,PIXEL,1,-6,12,36750076646785024,307.133043,-2,PIXEL,1,0,1
303,76686,34681,353,290015983035219968,170.633432,0,PIXEL,0,-7,3,42968914413486080,383.124251,-2,PIXEL,1,3,12
426,131410,247319,384,456253345061928960,561.971057,0,PIXEL,4,12,42,1485554558334664704,1071.755886,0,STRIP,1,10,29
428,247319,255552,384,1485554558334664704,1071.755886,0,STRIP,1,10,29,1524433289492824064,1495.754639,0,STRIP,2,14,39
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
114774,231833,242446,62280001100,1442770361874644992,1000.222905,0,STRIP,0,10,5,1479397293219119104,1417.732382,0,STRIP,1,14,7
114776,242446,265305,62280001100,1479397293219119104,1417.732382,0,STRIP,1,14,7,1768938287231139840,1795.494585,2,STRIP,1,3,12
114780,76511,155720,62280001101,289593770570153984,128.937767,0,PIXEL,0,9,2,613791371089674240,276.831507,2,PIXEL,1,2,2
114902,132765,132542,62280001176,459507899480145920,313.423391,0,PIXEL,4,-2,48,458936153433702400,321.784722,0,PIXEL,4,-3,47


In [161]:
R_y_min_2.duplicated(subset=["ID_x", "particle_id"]).sum()

0

In [156]:
truth_track_df

Unnamed: 0,hit_id_0,hit_id_1,particle_id,ID_x,R_x,barrel_endcap_x,hardware_x,layer_disk_x,eta_module_x,phi_module_x,ID_y,R_y,barrel_endcap_y,hardware_y,layer_disk_y,eta_module_y,phi_module_y
0,76902,85968,181,290086351779397632,41.896161,0,PIXEL,0,1,3,326660506565738496,105.623154,0,PIXEL,1,-1,4
1,85968,85657,181,326660506565738496,105.623154,0,PIXEL,1,-1,4,326097556612317184,112.916188,0,PIXEL,1,-1,3
2,85657,94016,181,326097556612317184,112.916188,0,PIXEL,1,-1,3,363243457445101568,179.972384,0,PIXEL,2,-2,5
3,94016,110156,181,363243457445101568,179.972384,0,PIXEL,2,-2,5,400961104324329472,246.317601,0,PIXEL,3,-2,8
4,110156,109818,181,400961104324329472,246.317601,0,PIXEL,3,-2,8,400398154370908160,254.975531,0,PIXEL,3,-2,7
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
115307,269123,271776,62280201001,1839588506385514496,2151.606966,2,STRIP,3,3,7,1875626099497500672,2461.954309,2,STRIP,4,4,7
115308,271776,275041,62280201001,1875626099497500672,2461.954309,2,STRIP,4,4,7,1911382217632776192,2794.787157,2,STRIP,5,5,6
115309,264335,266241,62280201002,1740790789560074240,425.613454,2,STRIP,0,3,40,1776819586579038208,624.307007,2,STRIP,1,3,40
115310,266241,268362,62280201002,1776819586579038208,624.307007,2,STRIP,1,3,40,1813138654667735040,884.224167,2,STRIP,2,4,41


In [151]:
R_y_min[R_y_min.particle_id == 225]

Unnamed: 0,hit_id_0,hit_id_1,particle_id,ID_x,R_x,barrel_endcap_x,hardware_x,layer_disk_x,eta_module_x,phi_module_x,ID_y,R_y,barrel_endcap_y,hardware_y,layer_disk_y,eta_module_y,phi_module_y
61,88505,18203,225,331120125727997952,275.631079,0,PIXEL,1,-6,12,36750076646785024,307.133043,-2,PIXEL,1,0,1
63,18203,18278,225,36750076646785024,307.133043,-2,PIXEL,1,0,1,36758872739807232,331.647759,-2,PIXEL,1,1,1
59,80473,88505,225,293420071034814464,113.91778,0,PIXEL,0,-4,9,331120125727997952,275.631079,0,PIXEL,1,-6,12
65,101457,101439,225,371617338002243584,468.110222,0,PIXEL,2,-10,20,371608541909221376,540.658879,0,PIXEL,2,-11,20
64,18278,101457,225,36758872739807232,331.647759,-2,PIXEL,1,1,1,371617338002243584,468.110222,0,PIXEL,2,-10,20
66,101439,116355,225,371608541909221376,540.658879,0,PIXEL,2,-11,20,412123346369511424,677.618091,0,PIXEL,3,-13,28
67,116355,116640,225,412123346369511424,677.618091,0,PIXEL,3,-13,28,412677500229910528,745.942278,0,PIXEL,3,-14,29
69,130040,130030,225,453192304690200576,897.431858,0,PIXEL,4,-16,37,453183508597178368,987.048103,0,PIXEL,4,-17,37
68,116640,130040,225,412677500229910528,745.942278,0,PIXEL,3,-14,29,453192304690200576,897.431858,0,PIXEL,4,-16,37
70,130030,227733,225,453183508597178368,987.048103,0,PIXEL,4,-17,37,1339944034445033472,2896.947255,-2,STRIP,5,0,24


In [131]:
sample = truth_track_df[truth_track_df.particle_id == 225]

In [152]:
sample

Unnamed: 0,hit_id_0,hit_id_1,particle_id,ID_x,R_x,ID_y,R_y
59,80473,88505,225,293420071034814464,113.91778,331120125727997952,275.631079
60,80473,88504,225,293420071034814464,113.91778,331120125727997952,275.807582
61,88505,18203,225,331120125727997952,275.631079,36750076646785024,307.133043
62,88504,18203,225,331120125727997952,275.807582,36750076646785024,307.133043
63,18203,18278,225,36750076646785024,307.133043,36758872739807232,331.647759
64,18278,101457,225,36758872739807232,331.647759,371617338002243584,468.110222
65,101457,101439,225,371617338002243584,468.110222,371608541909221376,540.658879
66,101439,116355,225,371608541909221376,540.658879,412123346369511424,677.618091
67,116355,116640,225,412123346369511424,677.618091,412677500229910528,745.942278
68,116640,130040,225,412677500229910528,745.942278,453192304690200576,897.431858


## Truth Track Mapping Refactor

In [68]:
track_edges_pre = hits.hit_id.values[track_index_edges]

In [69]:
track_edges_pre

array([[ 76902,  85968,  85657, ..., 264335, 266241, 268362],
       [ 85968,  85657,  94016, ..., 266241, 268362, 270759]])

In [70]:
assert (hits[hits.hit_id.isin(track_edges.flatten())].particle_id == 0).sum() == 0, "There are hits in the track edges that are noise"

In [71]:
track_features = reader._get_track_features(hits, track_index_edges)

In [72]:
unique_hid = np.unique(hits.hit_id)
hid_mapping = np.zeros(unique_hid.max() + 1).astype(int)
hid_mapping[unique_hid] = np.arange(len(unique_hid))

hits_post = hits.drop_duplicates(subset="hit_id").sort_values("hit_id")
assert (hits_post.hit_id == unique_hid).all(), "If hit IDs are not sequential, this will mess up graph structure!"

track_edges_post = hid_mapping[track_edges_pre]

In [77]:
track_edges_pre.shape

(2, 115312)

In [78]:
track_edges_post.shape

(2, 115312)

In [59]:
track_edges_post, track_features, hits = reader.remap_edges(track_edges_pre, track_features, hits)

In [60]:
track_edges_post

array([[     0,      3,      4, ..., 274275, 274286, 274290],
       [ 17022,    856,  16034, ..., 278165, 278164, 278178]])

In [None]:
# Now, want to remove track edges that are duplicated by particle ID and module ID

# First, build a dataframe of track_edges, with the particle ID and module ID
track_edges_df = pd.DataFrame(track_edges, columns=["hit_id_1", "hit_id_2"])
