# TrackMLReader Testing

**Goal**: Build and test an AthenaReader class

In [125]:
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
import pandas as pd
import yaml
from itertools import chain, product, combinations
import torch

from time import time as tt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
sys.path.append("../../../")
from gnn4itk_cf.stages.data_reading.models.trackml_utils import *

## Build TrackML Reader

In [3]:
from gnn4itk_cf.stages.data_reading.data_reading_stage import EventReader

In [121]:
from gnn4itk_cf.stages.data_reading.models.trackml_reader import TrackMLReader

In [122]:
config_file = "trackml_reader_config.yaml"
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)

In [123]:
reader = TrackMLReader.infer(config)

  Extracting thicknesses...
  Done.
  Extracting rotations...
  Done.
  Extracting thicknesses...
  Done.


Building trainset CSV files:   0%|          | 0/8 [00:00<?, ?it/s]

   hit_id         particle_id         tx        ty      tz       tpx  \
0       1  616997409557315584 -85.034897  -2.41392 -1502.5 -0.208887   
1       2  716072478190862336 -68.635300  -8.34244 -1502.5 -0.374129   
2       3  189154482884444160 -32.215599   2.22252 -1502.5  0.665540   
3       4  481893818782711808 -75.292099  -3.60834 -1502.5 -0.490148   
4       5  526924867254091776 -49.505001 -10.29500 -1502.5  0.194509   

        tpy      tpz    weight          x         y       z  volume_id  \
0  0.026238 -3.67787  0.000010 -85.049698  -2.40534 -1502.5          7   
1 -0.066007 -8.68605  0.000008 -68.623199  -8.33481 -1502.5          7   
2  0.018533 -9.69207  0.000012 -32.212898   2.20440 -1502.5          7   
3  0.002686 -9.60058  0.000010 -75.308403  -3.59471 -1502.5          7   
4 -0.011123 -3.15250  0.000009 -49.510799 -10.29130 -1502.5          7   

   layer_id  module_id  
0         2          1  
1         2          1  
2         2          1  
3         2          1

Building trainset CSV files:  12%|█▎        | 1/8 [00:01<00:09,  1.30s/it]

   hit_id         particle_id         tx         ty      tz       tpx  \
0       1  207172317367762944 -57.984600   3.262830 -1502.5 -1.486190   
1       2  450386007418732544 -85.290802  -6.018620 -1502.5 -0.747519   
2       3   67558598615498752 -97.424599  -6.647630 -1502.5 -0.633927   
3       4  585468295155548160 -79.817902  -0.893004 -1502.5 -0.676604   
4       5  878208043370676224 -78.528000 -13.592100 -1502.5 -0.336242   

        tpy        tpz    weight          x          y       z  volume_id  \
0  0.097822 -35.055099  0.000026 -57.971001   3.286780 -1502.5          7   
1 -0.078641 -12.739100  0.000010 -85.279999  -6.034600 -1502.5          7   
2 -0.014421  -9.730320  0.000010 -97.419998  -6.638950 -1502.5          7   
3 -0.035643 -13.760800  0.000010 -79.816803  -0.890101 -1502.5          7   
4 -0.085959  -6.456910  0.000007 -78.536102 -13.578700 -1502.5          7   

   layer_id  module_id  
0         2          1  
1         2          1  
2         2          1 

Building trainset CSV files:  25%|██▌       | 2/8 [00:02<00:08,  1.43s/it]

   hit_id         particle_id         tx        ty      tz       tpx  \
0       1  445859867802992640 -61.585999 -12.69010 -1502.5 -0.260413   
1       2  499900245832892416 -90.883499  -1.73100 -1502.5 -0.725090   
2       3  207180220107587584 -61.738300  -7.70857 -1502.5 -0.431008   
3       4  756608242091556864 -76.674103  -5.84760 -1502.5 -0.404668   
4       5  481887771468759040 -80.884903  -9.04320 -1502.5 -0.999074   

        tpy       tpz    weight          x         y       z  volume_id  \
0 -0.033349  -6.50986  0.000008 -61.565701 -12.69450 -1502.5          7   
1  0.013067 -11.81480  0.000009 -90.867798  -1.70968 -1502.5          7   
2 -0.155117  -9.30086  0.000000 -61.733601  -7.69228 -1502.5          7   
3 -0.058954  -7.94034  0.000009 -76.661102  -5.85782 -1502.5          7   
4 -0.133594 -17.82250  0.000015 -80.867500  -9.04768 -1502.5          7   

   layer_id  module_id  
0         2          1  
1         2          1  
2         2          1  
3         2     

Building trainset CSV files:  38%|███▊      | 3/8 [00:04<00:07,  1.47s/it]

   hit_id         particle_id         tx       ty      tz       tpx       tpy  \
0       1  729588293594775552 -55.112999 -2.31708 -1502.5 -0.297104  0.003612   
1       2  441367538169806848 -79.754204 -9.90477 -1502.5 -0.514688 -0.042898   
2       3  756608310811033600 -66.508400 -8.08303 -1502.5 -0.292262 -0.012456   
3       4   18014742106865664 -54.464001 -8.21359 -1502.5 -0.537134 -0.099158   
4       5  635015793796448256 -56.381401 -1.06888 -1502.5 -0.352365  0.011734   

        tpz    weight          x        y       z  volume_id  layer_id  \
0  -7.89829  0.000011 -55.104599 -2.30556 -1502.5          7         2   
1  -9.54239  0.000007 -79.751099 -9.91275 -1502.5          7         2   
2  -6.43837  0.000008 -66.500198 -8.06741 -1502.5          7         2   
3 -14.93740  0.000013 -54.469501 -8.22398 -1502.5          7         2   
4  -8.96996  0.000011 -56.388100 -1.05241 -1502.5          7         2   

   module_id  
0          1  
1          1  
2          1  
3       

Building trainset CSV files:  50%|█████     | 4/8 [00:05<00:05,  1.47s/it]

   hit_id         particle_id         tx         ty      tz            tpx  \
0       1  567460356276879360 -93.806297   0.323528 -1502.5      -0.523865   
1       2                   0 -53.281300  -6.947830 -1502.5  398562.000000   
2       3  797141051054751744 -68.739899 -13.692200 -1502.5      -0.253448   
3       4  351295683061350400 -79.322502 -10.611700 -1502.5      -0.620186   
4       5  342275289667076096 -90.384697  -6.882020 -1502.5      -0.442011   

            tpy           tpz    weight          x          y       z  \
0     -0.129955      -7.34948  0.000000 -93.792099   0.317131 -1502.5   
1 -52633.699219 -915630.00000  0.000000 -53.282501  -6.945030 -1502.5   
2     -0.070047      -5.55139  0.000008 -68.747002 -13.711100 -1502.5   
3     -0.056223     -11.27790  0.000008 -79.332199 -10.632100 -1502.5   
4     -0.003902      -7.91382  0.000008 -90.403801  -6.889240 -1502.5   

   volume_id  layer_id  module_id  
0          7         2          1  
1          7        

Building trainset CSV files:  62%|██████▎   | 5/8 [00:07<00:04,  1.54s/it]

   hit_id         particle_id         tx       ty      tz       tpx       tpy  \
0       1  261212076922372096 -56.515598 -9.48530 -1502.5 -0.245977 -0.022694   
1       2   13512379430076416 -59.658199 -7.91937 -1502.5 -0.504114 -0.045715   
2       3  243206611944865792 -55.727901 -9.75980 -1502.5 -0.443663 -0.102055   
3       4  171144688579903488 -53.071499 -2.10350 -1502.5 -0.364669 -0.033409   
4       5  148637754278805504 -44.052101 -9.91345 -1502.5 -0.520255 -0.237466   

        tpz    weight          x        y       z  volume_id  layer_id  \
0  -6.67421  0.000011 -56.514099 -9.48830 -1502.5          7         2   
1 -12.23440  0.000012 -59.662601 -7.93052 -1502.5          7         2   
2 -11.84720  0.000013 -55.716202 -9.77658 -1502.5          7         2   
3 -10.01990  0.000013 -53.089802 -2.09684 -1502.5          7         2   
4 -15.35500  0.000027 -44.067402 -9.91305 -1502.5          7         2   

   module_id  
0          1  
1          1  
2          1  
3       

Building trainset CSV files:  75%|███████▌  | 6/8 [00:08<00:02,  1.50s/it]

   hit_id         particle_id         tx        ty      tz       tpx  \
0       1  261227057768300544 -91.832397 -14.82700 -1502.5 -0.529816   
1       2  743108163447816192 -80.046204  -6.74737 -1502.5 -0.427951   
2       3  198163125248196608 -75.078598  -2.82754 -1502.5 -0.483707   
3       4  801641489586192384 -74.884499  -9.24541 -1502.5 -0.518539   
4       5  689057683654836224 -71.709297   1.99082 -1502.5 -0.298442   

        tpy      tpz    weight          x         y       z  volume_id  \
0 -0.057568 -8.46056  0.000007 -91.821701 -14.82490 -1502.5          7   
1 -0.059080 -8.35119  0.000007 -80.034698  -6.72518 -1502.5          7   
2  0.004474 -8.91173  0.000007 -75.094200  -2.82553 -1502.5          7   
3 -0.089987 -9.81322  0.000008 -74.872398  -9.22787 -1502.5          7   
4 -0.009785 -6.11734  0.000007 -71.694000   2.00613 -1502.5          7   

   layer_id  module_id  
0         2          1  
1         2          1  
2         2          1  
3         2          1

Building trainset CSV files:  88%|████████▊ | 7/8 [00:10<00:01,  1.55s/it]

   hit_id         particle_id         tx        ty      tz       tpx  \
0       1  238697583478833152 -86.806602 -14.74600 -1502.5 -0.247412   
1       2  671044522095935488 -68.705399 -10.91280 -1502.5 -0.526919   
2       3  774632811083595776 -56.143101  -9.52382 -1502.5 -1.053290   
3       4  945760182355361792 -60.311501  -4.98039 -1502.5 -1.179080   
4       5  837680800685096960 -89.389702  -8.82342 -1502.5 -1.395490   

        tpy        tpz    weight          x         y       z  volume_id  \
0 -0.068921  -4.570100  0.000006 -86.796204 -14.73040 -1502.5          7   
1 -0.108128 -11.914000  0.000007 -68.703499 -10.89900 -1502.5          7   
2 -0.156064 -28.250999  0.000017 -56.138302  -9.50887 -1502.5          7   
3 -0.120229 -30.253500  0.000016 -60.311401  -4.97231 -1502.5          7   
4 -0.106162 -23.752899  0.000018 -89.405701  -8.81687 -1502.5          7   

   layer_id  module_id  
0         2          1  
1         2          1  
2         2          1  
3         

Building trainset CSV files: 100%|██████████| 8/8 [00:12<00:00,  1.52s/it]
Building valset CSV files:   0%|          | 0/1 [00:00<?, ?it/s]

   hit_id         particle_id         tx        ty      tz       tpx  \
0       1  648522400790478848 -78.850998  -2.87520 -1502.5 -1.000040   
1       2  571962993831575552 -54.059601 -12.50020 -1502.5 -0.296292   
2       3  616996997240455168 -38.350201  -7.97174 -1502.5  0.185413   
3       4  225193038069104640 -61.355701  -8.06029 -1502.5 -0.194794   
4       5  261216543688359936 -66.956299  -2.61544 -1502.5 -0.384786   

        tpy        tpz    weight          x         y       z  volume_id  \
0 -0.012708 -19.087999  0.000013 -78.870903  -2.87199 -1502.5          7   
1 -0.085277  -8.503490  0.000009 -54.076099 -12.50630 -1502.5          7   
2 -0.012252  -3.274800  0.000007 -38.353199  -7.95869 -1502.5          7   
3 -0.047217  -4.974530  0.000009 -61.365799  -8.06457 -1502.5          7   
4 -0.036082  -8.718730  0.000008 -66.949799  -2.63594 -1502.5          7   

   layer_id  module_id  
0         2          1  
1         2          1  
2         2          1  
3         

Building valset CSV files: 100%|██████████| 1/1 [00:01<00:00,  1.43s/it]
Building testset CSV files:   0%|          | 0/1 [00:00<?, ?it/s]

   hit_id         particle_id         tx        ty      tz           tpx  \
0       1    4507997673881600 -59.765900 -10.77760 -1502.5     -0.311024   
1       2  779124109924630528 -94.742996  -3.75661 -1502.5     -0.614702   
2       3                   0 -51.960701  -1.83278 -1502.5 -13986.400391   
3       4  639514583060447232 -86.284103   1.14292 -1502.5     -0.237660   
4       5  418835727418130432 -54.114300 -10.76580 -1502.5     -0.312541   

           tpy           tpz    weight          x         y       z  \
0    -0.036885      -7.93309  0.000009 -59.747799 -10.79600 -1502.5   
1    -0.054572      -9.76703  0.000010 -94.744400  -3.77019 -1502.5   
2 -2798.229980 -999898.00000  0.000000 -51.980000  -1.85903 -1502.5   
3     0.034980      -3.86216  0.000007 -86.297501   1.15775 -1502.5   
4    -0.040851      -8.43179  0.000010 -54.101200 -10.75290 -1502.5   

   volume_id  layer_id  module_id  
0          7         2          1  
1          7         2          1  
2       

Building testset CSV files: 100%|██████████| 1/1 [00:01<00:00,  1.52s/it]
Building trainset graphs: 100%|██████████| 8/8 [00:17<00:00,  2.24s/it]
Building valset graphs: 100%|██████████| 1/1 [00:02<00:00,  2.05s/it]
Building testset graphs: 100%|██████████| 1/1 [00:02<00:00,  2.40s/it]


## Debug CSV build

In [94]:
# Load csv files
base_path = "/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_3_Dev/feature_store/valset/"
base_path_file = os.path.join(base_path, os.listdir(base_path)[0].split("-")[0])
hits, particles = pd.read_csv(base_path_file + "-truth.csv"), pd.read_csv(base_path_file + "-particles.csv")

In [95]:
hits

Unnamed: 0,hit_id,particle_id,tx,ty,tz,tpx,tpy,tpz,weight,x,...,region,cell_count,cell_val,leta,lphi,lx,ly,lz,geta,gphi
0,0,4507997673881600,-59.7659,-10.77760,-1502.5,-0.311024,-0.036885,-7.93309,0.000009,-59.7478,...,1.0,2.0,0.297907,1.623512,1.152572,0.05,0.11250,0.3,-1.623512,-2.644828
1,1,779124109924630528,-94.7430,-3.75661,-1502.5,-0.614702,-0.054572,-9.76703,0.000010,-94.7444,...,1.0,2.0,0.314380,1.623512,1.152572,0.05,0.11250,0.3,-1.623512,-2.644828
2,2,0,-51.9607,-1.83278,-1502.5,-13986.400000,-2798.230000,-999898.00000,0.000000,-51.9800,...,1.0,1.0,0.272378,2.091356,0.844154,0.05,0.05625,0.3,-2.091356,-2.336410
3,3,639514583060447232,-86.2841,1.14292,-1502.5,-0.237660,0.034980,-3.86216,0.000007,-86.2975,...,1.0,1.0,0.288525,2.091356,0.844154,0.05,0.05625,0.3,-2.091356,-2.336410
4,4,418835727418130432,-54.1143,-10.76580,-1502.5,-0.312541,-0.040851,-8.43179,0.000010,-54.1012,...,1.0,1.0,0.280084,2.091356,0.844154,0.05,0.05625,0.3,-2.091356,-2.336410
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
118236,118236,653027443526860800,-884.7880,86.35920,2952.5,-0.269840,0.131374,1.00313,0.000008,-886.3220,...,6.0,2.0,2.000000,0.067239,1.547723,0.24,10.40000,0.7,0.067239,3.055688
118237,118237,121632443030568960,-939.4860,65.12530,2952.5,-0.262143,-0.078602,1.01966,0.000007,-939.7640,...,6.0,2.0,2.000000,0.067239,1.547723,0.24,10.40000,0.7,0.067239,3.055688
118238,118238,535928493096042496,-908.9590,91.80800,2952.5,-1.354840,0.034403,4.57033,0.000021,-906.8360,...,6.0,1.0,1.000000,0.067253,1.559258,0.12,10.40000,0.7,0.067252,3.067223
118239,118239,0,-946.3050,114.64000,2952.5,-115535.000000,19081.700000,993120.00000,0.000000,-947.0660,...,6.0,1.0,1.000000,0.067253,1.559258,0.12,10.40000,0.7,0.067252,3.067223


In [96]:
particles

Unnamed: 0,particle_id,particle_type,vx,vy,vz,px,py,pz,q,nhits,radius,pt
0,4503943224754176,211,-0.005737,-0.002144,47.4336,0.987541,0.447020,-0.519561,1,13,0.006125,1.084004
1,4504011944230912,211,-0.005737,-0.002144,47.4336,-0.320215,0.657894,-0.139844,1,13,0.006125,0.731684
2,4504149383184384,211,-0.005737,-0.002144,47.4336,0.306570,-0.039088,3.046840,1,11,0.006125,0.309052
3,4504218102661120,211,-0.005737,-0.002144,47.4336,-0.905416,0.232695,0.045861,1,12,0.006125,0.934840
4,4504286822137856,-211,-0.005737,-0.002144,47.4336,-1.348460,0.113065,0.082340,-1,12,0.006125,1.353192
...,...,...,...,...,...,...,...,...,...,...,...,...
11863,968277149716848642,2212,-184.650000,-160.423000,-2948.5000,0.262202,0.184957,0.191671,1,1,244.604100,0.320872
11864,968277905614307328,211,0.008127,0.020551,46.8981,0.079867,0.170747,-2.449680,1,0,0.022099,0.188503
11865,968277905631088641,2212,8.568010,16.958500,-199.8530,-0.064976,-0.113261,-0.152094,1,1,19.000040,0.130575
11866,968277905631092738,-211,8.568010,16.958500,-199.8530,0.184261,-0.162827,0.042949,-1,9,19.000040,0.245896


## Debug Graph build

In [131]:
base_path = "/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_3_Dev/feature_store/valset/"
base_path_file = os.path.join(base_path, os.listdir(base_path)[0].split("-")[0])
graph = torch.load(base_path_file + "-graph.pyg")

In [138]:
for key, value in graph.items():
    print(key, value.type())

hit_id torch.LongTensor
x torch.DoubleTensor
y torch.DoubleTensor
z torch.DoubleTensor
r torch.DoubleTensor
phi torch.DoubleTensor
eta torch.DoubleTensor
region torch.DoubleTensor
module_index torch.LongTensor
weight torch.DoubleTensor
cell_count torch.DoubleTensor
cell_val torch.DoubleTensor
leta torch.DoubleTensor
lphi torch.DoubleTensor
lx torch.DoubleTensor
ly torch.DoubleTensor
lz torch.DoubleTensor
geta torch.DoubleTensor
gphi torch.DoubleTensor
track_edges torch.LongTensor
particle_id torch.LongTensor
pt torch.DoubleTensor
radius torch.DoubleTensor
nhits torch.DoubleTensor


AttributeError: 'list' object has no attribute 'type'

In [106]:
# Load csv files
base_path = "/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_3_Dev/feature_store/valset/"
base_path_file = os.path.join(base_path, os.listdir(base_path)[0].split("-")[0])
hits, particles = pd.read_csv(f"{base_path_file}-truth.csv"), pd.read_csv(f"{base_path_file}-particles.csv")

In [107]:
hits

Unnamed: 0,hit_id,particle_id,tx,ty,tz,tpx,tpy,tpz,weight,x,...,region,cell_count,cell_val,leta,lphi,lx,ly,lz,geta,gphi
0,0,4507997673881600,-59.7659,-10.77760,-1502.5,-0.311024,-0.036885,-7.93309,0.000009,-59.7478,...,1.0,2.0,0.297907,1.623512,1.152572,0.05,0.11250,0.3,-1.623512,-2.644828
1,1,779124109924630528,-94.7430,-3.75661,-1502.5,-0.614702,-0.054572,-9.76703,0.000010,-94.7444,...,1.0,2.0,0.314380,1.623512,1.152572,0.05,0.11250,0.3,-1.623512,-2.644828
2,2,0,-51.9607,-1.83278,-1502.5,-13986.400000,-2798.230000,-999898.00000,0.000000,-51.9800,...,1.0,1.0,0.272378,2.091356,0.844154,0.05,0.05625,0.3,-2.091356,-2.336410
3,3,639514583060447232,-86.2841,1.14292,-1502.5,-0.237660,0.034980,-3.86216,0.000007,-86.2975,...,1.0,1.0,0.288525,2.091356,0.844154,0.05,0.05625,0.3,-2.091356,-2.336410
4,4,418835727418130432,-54.1143,-10.76580,-1502.5,-0.312541,-0.040851,-8.43179,0.000010,-54.1012,...,1.0,1.0,0.280084,2.091356,0.844154,0.05,0.05625,0.3,-2.091356,-2.336410
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
118236,118236,653027443526860800,-884.7880,86.35920,2952.5,-0.269840,0.131374,1.00313,0.000008,-886.3220,...,6.0,2.0,2.000000,0.067239,1.547723,0.24,10.40000,0.7,0.067239,3.055688
118237,118237,121632443030568960,-939.4860,65.12530,2952.5,-0.262143,-0.078602,1.01966,0.000007,-939.7640,...,6.0,2.0,2.000000,0.067239,1.547723,0.24,10.40000,0.7,0.067239,3.055688
118238,118238,535928493096042496,-908.9590,91.80800,2952.5,-1.354840,0.034403,4.57033,0.000021,-906.8360,...,6.0,1.0,1.000000,0.067253,1.559258,0.12,10.40000,0.7,0.067252,3.067223
118239,118239,0,-946.3050,114.64000,2952.5,-115535.000000,19081.700000,993120.00000,0.000000,-947.0660,...,6.0,1.0,1.000000,0.067253,1.559258,0.12,10.40000,0.7,0.067252,3.067223


In [108]:
if "barcode" in particles.columns:
    particles = particles.assign(primary=(particles.barcode < 200000).astype(int))

hits["nhits"] = hits.groupby("particle_id")["particle_id"].transform("count")
assert all(vertex in particles.columns for vertex in ["vx", "vy", "vz"]), "Particles must have vertex information!"
particle_features = config["feature_sets"]["track_features"] + ["vx", "vy", "vz"]

assert "particle_id" in hits.columns and "particle_id" in particles.columns, "Hits and particles must have a particle_id column!"
hits = hits.merge(
    particles[particle_features],
    on="particle_id",
    how="left",
)

hits["particle_id"] = hits["particle_id"].fillna(0).astype(int)
hits.loc[hits.particle_id == 0, "nhits"] = -1

In [109]:
def calc_eta(r, z):
    theta = np.arctan2(r, z)
    return -1.0 * np.log(np.tan(theta / 2.0))  

assert all(col in hits.columns for col in ["x", "y", "z"]), "Need to add (x,y,z) features"
r = np.sqrt(hits.x**2 + hits.y**2)
phi = np.arctan2(hits.y, hits.x)
eta = calc_eta(r, hits.z)
hits = hits.assign(r=r, phi=phi, eta=eta)

In [110]:
noise_hits = hits[hits.particle_id == 0].drop_duplicates(subset="hit_id")
signal_hits = hits[hits.particle_id != 0]

non_duplicate_noise_hits = noise_hits[~noise_hits.hit_id.isin(signal_hits.hit_id)]
hits = pd.concat([signal_hits, non_duplicate_noise_hits], ignore_index=True)
# Sort hits by hit_id for ease of processing
hits = hits.sort_values("hit_id").reset_index(drop=True)

In [111]:
hits

Unnamed: 0,hit_id,particle_id,tx,ty,tz,tpx,tpy,tpz,weight,x,...,pt,radius,nhits_y,vx,vy,vz,nhits,r,phi,eta
0,0,4507997673881600,-59.7659,-10.77760,-1502.5,-0.311024,-0.036885,-7.93309,0.000009,-59.7478,...,0.310680,0.006125,12.0,-0.005737,-0.002144,47.4336,,60.715346,-2.962829,-3.902244
1,1,779124109924630528,-94.7430,-3.75661,-1502.5,-0.614702,-0.054572,-9.76703,0.000010,-94.7444,...,0.623976,0.041959,13.0,0.027341,-0.031828,-12.7654,,94.819384,-3.101820,-3.457053
2,2,0,-51.9607,-1.83278,-1502.5,-13986.400000,-2798.230000,-999898.00000,0.000000,-51.9800,...,,,,,,,-1.0,52.013233,-3.105844,-4.056834
3,3,639514583060447232,-86.2841,1.14292,-1502.5,-0.237660,0.034980,-3.86216,0.000007,-86.2975,...,0.235532,0.019557,15.0,0.018649,-0.005888,-100.0070,,86.305266,3.128178,-3.550966
4,4,418835727418130432,-54.1143,-10.76580,-1502.5,-0.312541,-0.040851,-8.43179,0.000010,-54.1012,...,0.311082,0.024466,10.0,-0.024296,0.002882,-8.8052,,55.159448,-2.945394,-3.998142
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
118236,118236,653027443526860800,-884.7880,86.35920,2952.5,-0.269840,0.131374,1.00313,0.000008,-886.3220,...,0.328799,0.036514,13.0,0.033733,-0.013978,-37.2638,,890.525495,3.044392,1.913747
118237,118237,121632443030568960,-939.4860,65.12530,2952.5,-0.262143,-0.078602,1.01966,0.000007,-939.7640,...,0.356721,28.582443,14.0,-25.846100,12.203900,54.3581,,942.018672,3.072391,1.860059
118238,118238,535928493096042496,-908.9590,91.80800,2952.5,-1.354840,0.034403,4.57033,0.000021,-906.8360,...,1.398116,0.010199,12.0,0.010198,0.000182,-64.5827,,911.452703,3.040900,1.891531
118239,118239,0,-946.3050,114.64000,2952.5,-115535.000000,19081.700000,993120.00000,0.000000,-947.0660,...,,,,,,,-1.0,953.987988,3.021055,1.848038


In [112]:
signal = hits[(hits.particle_id != 0)]

# Sort by increasing distance from production
signal = signal.assign(
    R=np.sqrt(
        (signal.x - signal.vx) ** 2
        + (signal.y - signal.vy) ** 2
        + (signal.z - signal.vz) ** 2
    )
)

signal = signal.sort_values("R").reset_index(drop=False)

# Group by particle ID
if "module_columns" not in config or config["module_columns"] is None:
    module_columns = ["barrel_endcap", "hardware", "layer_disk", "eta_module", "phi_module"]
else:
    module_columns = config["module_columns"]

In [113]:
signal_index_list = (signal.groupby(
        ["particle_id"] + module_columns,
        sort=False,
    )["index"]
    .agg(lambda x: list(x))
    .groupby(level=0)
    .agg(lambda x: list(x)))

track_index_edges = []
for row in signal_index_list.values:
    for i, j in zip(row[:-1], row[1:]):
        track_index_edges.extend(list(product(i, j)))

track_index_edges = np.array(track_index_edges).T

In [114]:
track_edges = hits.hit_id.values[track_index_edges]

In [116]:
track_edges.shape

(2, 90611)

In [117]:
track_edges

array([[21971, 21986, 29518, ..., 72404, 79462, 86008],
       [21986, 29518, 29529, ..., 79462, 86008, 92141]])

In [119]:
hits.iloc[track_edges[0]]

Unnamed: 0,hit_id,particle_id,tx,ty,tz,tpx,tpy,tpz,weight,x,...,pt,radius,nhits_y,vx,vy,vz,nhits,r,phi,eta
21971,21971,4503943224754176,28.8417,12.76730,32.31800,0.994373,0.430884,-0.519385,0.000026,28.8319,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,31.543632,0.417680,0.898516
21986,21986,4503943224754176,30.2185,13.36330,31.59900,0.994718,0.429804,-0.519498,0.000022,30.2130,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,33.038900,0.416606,0.850229
29518,29518,4503943224754176,65.3056,28.05650,13.44260,1.004610,0.408024,-0.516908,0.000019,65.2978,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,71.076889,0.406039,0.188066
29529,29529,4503943224754176,66.9907,28.73980,12.57580,1.004790,0.407196,-0.517115,0.000016,66.9846,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,72.893042,0.405403,0.171608
35511,35511,4503943224754176,107.9010,44.71120,-8.37895,1.014040,0.381551,-0.517537,0.000014,107.8970,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,116.796670,0.392902,-0.071705
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
40796,40796,968277905631092738,160.0290,-65.62190,-167.44300,0.231874,-0.067544,0.047333,0.000000,160.0290,...,0.245896,19.000040,9.0,8.568010,16.958500,-199.8530,,172.961362,-0.389156,-0.858595
72410,72410,968277905631092738,243.1100,-80.66200,-152.23800,0.236887,-0.018097,0.044495,0.000000,243.1010,...,0.245896,19.000040,9.0,8.568010,16.958500,-199.8530,,256.140563,-0.320455,-0.565862
72404,72404,968277905631092738,249.5710,-81.10240,-151.02500,0.237135,-0.013266,0.044210,0.000000,249.5560,...,0.245896,19.000040,9.0,8.568010,16.958500,-199.8530,,262.411392,-0.314308,-0.548305
79462,79462,968277905631092738,348.7990,-74.11990,-132.47600,0.232537,0.046194,0.040148,0.000000,348.7940,...,0.245896,19.000040,9.0,8.568010,16.958500,-199.8530,,356.585797,-0.209433,-0.363257


In [120]:
hits.iloc[track_edges[1]]

Unnamed: 0,hit_id,particle_id,tx,ty,tz,tpx,tpy,tpz,weight,x,...,pt,radius,nhits_y,vx,vy,vz,nhits,r,phi,eta
21986,21986,4503943224754176,30.2185,13.36330,31.59900,0.994718,0.429804,-0.519498,0.000022,30.2130,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,33.038900,0.416606,0.850229
29518,29518,4503943224754176,65.3056,28.05650,13.44260,1.004610,0.408024,-0.516908,0.000019,65.2978,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,71.076889,0.406039,0.188066
29529,29529,4503943224754176,66.9907,28.73980,12.57580,1.004790,0.407196,-0.517115,0.000016,66.9846,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,72.893042,0.405403,0.171608
35511,35511,4503943224754176,107.9010,44.71120,-8.37895,1.014040,0.381551,-0.517537,0.000014,107.8970,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,116.796670,0.392902,-0.071705
41955,41955,4503943224754176,159.8280,63.31700,-34.83370,1.024010,0.349530,-0.518727,0.000011,159.8230,...,1.084004,0.006125,13.0,-0.005737,-0.002144,47.4336,,171.911234,0.377244,-0.201250
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
72410,72410,968277905631092738,243.1100,-80.66200,-152.23800,0.236887,-0.018097,0.044495,0.000000,243.1010,...,0.245896,19.000040,9.0,8.568010,16.958500,-199.8530,,256.140563,-0.320455,-0.565862
72404,72404,968277905631092738,249.5710,-81.10240,-151.02500,0.237135,-0.013266,0.044210,0.000000,249.5560,...,0.245896,19.000040,9.0,8.568010,16.958500,-199.8530,,262.411392,-0.314308,-0.548305
79462,79462,968277905631092738,348.7990,-74.11990,-132.47600,0.232537,0.046194,0.040148,0.000000,348.7940,...,0.245896,19.000040,9.0,8.568010,16.958500,-199.8530,,356.585797,-0.209433,-0.363257
86008,86008,968277905631092738,500.2760,-8.75188,-104.35900,0.191771,0.138197,0.038414,0.000000,500.2690,...,0.245896,19.000040,9.0,8.568010,16.958500,-199.8530,,500.346218,-0.017569,-0.207953


In [397]:
assert (hits[hits.hit_id.isin(track_edges.flatten())].particle_id == 0).sum() == 0, "There are hits in the track edges that are noise"

In [398]:
assert (hits.pdgId.values[track_index_edges][0] == hits.pdgId.values[track_index_edges][1]).all()

In [399]:
track_features = {}
for track_feature in ["particle_id", "pt", "radius", "primary", "nhits", "pdgId"]:
    assert (hits[track_feature].values[track_index_edges][0] == hits[track_feature].values[track_index_edges][1]).all()
    track_features[track_feature] = hits[track_feature].values[track_index_edges[0]]

In [401]:
unique_hid = np.unique(hits.hit_id)
hid_mapping = np.zeros(unique_hid.max() + 1).astype(int)
hid_mapping[unique_hid] = np.arange(len(unique_hid))
hits = hits.drop_duplicates(subset="hit_id").sort_values("hit_id")
assert (hits.hit_id == unique_hid).all(), "If hit IDs are not sequential, this will mess up graph structure!"

track_edges = hid_mapping[track_edges]

In [402]:
assert ((hits.particle_id.values[track_edges[0]] != track_features["particle_id"]) & (hits.particle_id.values[track_edges[1]] != track_features["particle_id"])).sum() < 100, "The number of shared EDGES is unusually high!"

In [407]:
for region_label, conditions in config["region_labels"].items():
    for condition_column, condition in conditions.items():
        condition_mask = np.logical_and.reduce([hits[condition_column] == condition for condition_column, condition in conditions.items()])
        hits.loc[condition_mask, "region"] = region_label

assert (hits.region.isna()).sum() == 0, "There are hits that do not belong to any region!"

### 4. Inspect Outputs

In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
test_csv_path = "/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/testset/event000000194-truth.csv"
test_hits = pd.read_csv(test_csv_path)

In [5]:
test_hits

Unnamed: 0,hit_id,x,y,z,cluster_index_1,cluster_index_2,particle_id,particle_id_1,particle_id_2,hardware,...,eta_angle_1,phi_angle_1,cluster_x_2,cluster_y_2,cluster_z_2,eta_angle_2,phi_angle_2,norm_z_2,region,ID
0,0,-36.6581,-4.29661,-263.00,0,-1,67480000184,67480000184,-1,PIXEL,...,0.982794,0.982794,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
1,1,-50.8359,-18.85710,-263.00,1,-1,0,0,-1,PIXEL,...,1.249050,1.249050,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
2,2,-43.2095,-13.32110,-263.00,2,-1,0,0,-1,PIXEL,...,0.321751,0.291457,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
3,3,-38.5501,-7.39981,-263.00,3,-1,67520000493,67520000493,-1,PIXEL,...,1.249050,0.982794,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
4,4,-41.2780,-10.57310,-263.00,4,-1,69550000406,69550000406,-1,PIXEL,...,0.982794,1.249050,-1.000,-1.000,-1.00,-1.00000,-1.000000,-1.0,1.0,158329674399744
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
361778,360259,954.5330,-138.13500,2854.25,496951,496965,67850001058,67850001058,67850001058,STRIP,...,1.324640,0.006379,927.959,-134.836,2860.75,1.32464,0.006379,-1.0,6.0,1927426291305283584
361779,360260,950.3910,-170.52700,2854.25,496953,496964,0,0,0,STRIP,...,1.115860,0.006380,922.790,-166.575,2860.75,1.32464,0.006379,-1.0,6.0,1927426291305283584
361780,360261,905.5270,-164.65300,2854.25,496954,496963,0,0,0,STRIP,...,0.311485,0.006384,922.625,-167.485,2860.75,1.11611,0.006379,-1.0,6.0,1927426291305283584
361781,360262,893.0260,-162.12200,2854.25,496954,496964,0,0,0,STRIP,...,0.311485,0.006384,922.790,-166.575,2860.75,1.32464,0.006379,-1.0,6.0,1927426291305283584


In [2]:
test_pyg_path = "/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/testset/event000000194-graph.pyg"
test_graph = torch.load(test_pyg_path)

In [3]:
test_graph

Data(hit_id=[360264], x=[360264], y=[360264], z=[360264], r=[360264], phi=[360264], eta=[360264], region=[360264], cluster_x_1=[360264], cluster_y_1=[360264], cluster_z_1=[360264], cluster_x_2=[360264], cluster_y_2=[360264], cluster_z_2=[360264], norm_x=[360264], norm_y=[360264], norm_z_1=[360264], eta_angle_1=[360264], phi_angle_1=[360264], eta_angle_2=[360264], phi_angle_2=[360264], norm_z_2=[360264], track_edges=[2, 145405], particle_id=[145405], pt=[145405], radius=[145405], primary=[145405], nhits=[145405], pdgId=[145405], config=[1])

In [12]:
test_graph = test_graph.to("cuda")

In [16]:
print(test_graph.config)

[{'stage': 'data_reading', 'model': 'AthenaReader', 'input_dir': '/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/athena_100_events', 'stage_dir': '/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/feature_store/', 'module_lookup_path': '/global/cfs/cdirs/m3443/data/GNN4ITK/CommonFrameworkExamples/Example_1_Dev/Modules_geo_10events.txt', 'feature_sets': {'hit_features': ['hit_id', 'x', 'y', 'z', 'r', 'phi', 'eta', 'region', 'cluster_x_1', 'cluster_y_1', 'cluster_z_1', 'cluster_x_2', 'cluster_y_2', 'cluster_z_2', 'norm_x', 'norm_y', 'norm_z_1', 'eta_angle_1', 'phi_angle_1', 'eta_angle_2', 'phi_angle_2', 'norm_z_2'], 'track_features': ['particle_id', 'pt', 'radius', 'primary', 'nhits', 'pdgId']}, 'region_labels': {1: {'hardware': 'PIXEL', 'barrel_endcap': -2}, 2: {'hardware': 'STRIP', 'barrel_endcap': -2}, 3: {'hardware': 'PIXEL', 'barrel_endcap': 0}, 4: {'hardware': 'STRIP', 'barrel_endcap': 0}, 5: {'hardware': 'PIXEL', 'barrel_endcap'