[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HSF-reco-and-software-triggers/Tracking-ML-Exa.TrkX/blob/master/Examples/TrackML_Quickstart/colab_quickstart.ipynb)

# TrackML Quickstart

## Install Libraries

In [None]:
!git clone https://github.com/HSF-reco-and-software-triggers/Tracking-ML-Exa.TrkX.git
!pip install -e ./Tracking-ML-Exa.TrkX

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()

In [None]:
!conda install pandas scipy matplotlib cupy cudatoolkit=11.3 pytorch=1.10.2 pytorch-lightning pyg faiss-gpu cudf=21.12 cugraph=21.12 -c rapidsai -c nvidia -c pytorch -c pyg -c condaforge
!pip install seaborn bokeh 

# Import libraries

In [10]:
from trackml.dataset import load_event
from trackml.utils import add_momentum_quantities, add_position_quantities

# visualization packages
import matplotlib.pyplot as plt, seaborn as sns, bokeh
import plotly.express as px
from bokeh.io import output_notebook, show
from bokeh.plotting import figure, row
from bokeh.models import ColumnDataSource, HoverTool
output_notebook()

import numpy as np, pandas as pd
import os
from IPython.display import clear_output

## Download Data

In [None]:
!wget https://portal.nersc.gov/cfs/m3443/dtmurnane/TrackML_Example/train_sample.zip -O datasets/train_sample.zip
!unzip datasets/train_sample.zip -d datasets/

clear_output()

## Visualize the Data

The TrackML dataset contains simulated indepedent proton-proton collision events, each generating hundreds of particles, each of which hits cells and layers of the detector layers multiple times. The detector records the spatial coordinates and other auxillary information of these hits which, if properly connected, form tracks associated with the parent particle and the collision event from which it originates. The challenge and goal of this project is to associate each and every hit to one single track with optimal purity and efficiency, whose precise definition will be given later.

Each event in the dataset is represented by 4 f

In [40]:
DATA_DIR = '/global/cfs/cdirs/m3443/usr/pmtuan/ExaTrk-data/train_100_events'

In [41]:
# load data

hits, cells, particles, truth = load_event( os.path.join(DATA_DIR, 'event000001000') )

print(f'Number of particles in this event: {particles.shape[0]}')
print(f'Nunber of truth tracks: {truth["particle_id"].unique().shape[0]}')
print(f'Number of particles that exit the detector without a hit: {particles.shape[0] - truth["particle_id"].unique().shape[0]}')
print(f'Number of particles that exit the detector without a hit: {particles[ particles["nhits"]==0].shape[0]}')

Number of particles in this event: 12263
Nunber of truth tracks: 10566
Number of particles that exit the detector without a hit: 1697
Number of particles that exit the detector without a hit: 1698


Each entry in the particles data frame contains a unique identifier of the particle (particle_id), its charge (q), its initial position or vertex $(v_x, v_y, v_z)$, its initial momentum in GeV/c $(p_x, p_y, p_z)$ and its associated number of detector hits. 

Many particles do not leave behind any detector hits and obviously cannot be associated to any track. This is called "detector inefficiency". They are among "uninterested particles" and will be mostly filtered out by a simple momentum cut.

In [4]:
particles = add_momentum_quantities(particles)

particles.head()

Unnamed: 0,particle_id,vx,vy,vz,px,py,pz,q,nhits,pt,pphi,peta,p
0,4503668346847232,-0.009288,0.009861,-0.077879,-0.055269,0.323272,-0.203492,-1,8,0.327963,1.740126,-0.586301,0.385964
1,4503737066323968,-0.009288,0.009861,-0.077879,-0.948125,0.470892,2.01006,1,11,1.058622,2.680624,1.397417,2.271788
2,4503805785800704,-0.009288,0.009861,-0.077879,-0.886484,0.105749,0.683881,-1,0,0.892769,3.022863,0.705916,1.124602
3,4503874505277440,-0.009288,0.009861,-0.077879,0.257539,-0.676718,0.991616,1,12,0.724067,-1.207151,1.12013,1.227834
4,4503943224754176,-0.009288,0.009861,-0.077879,16.4394,-15.5489,-39.824902,1,3,22.627907,-0.757567,-1.330844,45.80442


### Visualize tracks

In [None]:
# inner join truth and huts on hit_id
truth_hits = pd.merge(truth,hits, how='inner', on='hit_id')

# get particles with relatively long tracks
long_track_particles = particles[ particles['nhits'] > 15 ].sample(20)
truth_hits = truth_hits[ truth_hits['particle_id'].isin(long_track_particles['particle_id']) ]

size = 800
# tools = 'box_select,lasso_select,box_zoom,reset,wheel_zoom'
source = ColumnDataSource(hits)

cmap = viridis(20)

p1 = figure(plot_height = size, plot_width = size, 
           title = 'Sample Tracks in Longitudinal View',
          x_axis_label = 'z', 
           y_axis_label = 'x', 
            # tools=tools
        )

p1.scatter('z', 'x', source=source, size=1, alpha=0.1, color='gray' )

p2 = figure(plot_height = size, plot_width = size, 
           title = 'Sample Tracks in Down-beampipe View',
          x_axis_label = 'z', 
           y_axis_label = 'x', 
            # tools=tools
        )

p2.scatter('x', 'y', source=source, size=1, alpha=0.1, color='gray' )

for idx, pid in enumerate( truth_hits['particle_id'].unique()):
    track = truth_hits[ truth_hits['particle_id'] == pid ].copy()
    track['r2'] = track['tpx']**2 + track['tpy']**2 + track['tpz']**2
    track['color']=idx
    track.sort_values(axis=0, by='r2', inplace=True)
    source = ColumnDataSource(track)
    p2.scatter( 'x', 'y', source=source, color=cmap[idx], size=5, line_dash='solid')
    p2.line( 'x', 'y', source=source, color=cmap[idx], line_width=2)
    p1.scatter( 'z', 'x', source=source, color=cmap[idx], size=5, line_dash='solid')
    p1.line( 'z', 'x', source=source, color=cmap[idx], line_width=2)

    
p2.add_tools(BoxSelectTool())
p1.add_tools(BoxSelectTool())
p2.add_tools(LassoSelectTool())
p1.add_tools(LassoSelectTool())

show(row([p1, p2]))


### Particle kinematic distribution

In [None]:
size = 400
p=[]
titles = [
    'Momentum',
    'Transverse Momentum',
    'Pseudo Rapidity',
    'Azimuthal Angle', 
    'Hits'
]
variables = [
    'p', 'pt', 'peta', 'pphi', 'nhits'
]
x_labels = [
    r'$$p \, (GeV/c)$$',
    r'$$p_T \, (GeV/c)$$',
    r'$$\eta$$',
    r'$$\phi $$',
    r'$$n_{hit}$$'
]
y_label = 'Count'
ranges = [ [0, 15], [0,4], [-4,4], [-np.pi, np.pi], [0, 20] ]
dfs = []
for title, variable, x_label, r in zip( titles, variables, x_labels, ranges ):
    p.append(figure(plot_height = size, plot_width = size, 
           title = title,
          x_axis_label = x_label, 
           y_axis_label = y_label)
    )
    hist, edges = np.histogram(particles[variable], bins=20 if variable=='nhits' else 50, range=r)
    dfs.append( pd.DataFrame( { 'count': hist, 'left': edges[:-1], 'right': edges[1:] } ) )
    source = ColumnDataSource(dfs[-1])
    p[-1].quad(bottom=0., top='count', left='left', right='right', source=source)
    p[-1].add_tools(
        HoverTool(tooltips = [('Count', '@count'),
                                 (f'Range', '(@left{2.2f}, @right{2.2f})')])
    )

show(row(p))

### Momentum cut

In [55]:
particles = particles[ (particles['pt'] > 1. ) ]

print(f'Number of particles after momentum cut: {particles.shape[0]}')
print(f'Number of particles that exit the detector without a hit: {particles[ particles["nhits"]==0].shape[0]}')

Number of particles after momentum cut: 1383
Number of particles that exit the detector without a hit: 36


In [133]:
print(f'Number of hits recorded: {hits.shape[0]}')
print(f'Average number of hits per track: {hits.shape[0]/truth["particle_id"].unique().shape[0]: .2f}')

Number of hits recorded: 120939
Average number of hits per track:  11.45


### Detector geometry

In [5]:
hits = add_position_quantities(hits)
hits['abs_z'] = hits['z'].abs()

In [54]:
from bokeh.palettes import Plasma256, Spectral, brewer, d3, Plasma, inferno, viridis
from bokeh.models import ColorMapper, ColorBar, LinearColorMapper, BoxSelectTool, LassoSelectTool
from bokeh.transform import linear_cmap

In [None]:
size = 800
# tools = 'box_select,lasso_select,box_zoom,reset,wheel_zoom'
source = ColumnDataSource(hits)
cmap = LinearColorMapper(palette='Plasma256', low=0, high=hits['abs_z'].max()+300)
# linear_cmap(field_name = "abs_z", palette='Spectral11', low=0, high=hits['z'].max())

p1 = figure(plot_height = size, plot_width = size, 
           title = 'Longitudinal Spacepoint Distribution',
          x_axis_label = 'z', 
           y_axis_label = 'x', 
            # tools=tools
        )

p1.scatter('z', 'x', source=source, size=1, color={'field': 'abs_z', 'transform': cmap} )
p1.add_tools(BoxSelectTool())

p2 = figure(plot_height = size, plot_width = size + 90, 
           title = 'Down-beampipe Spacepoint Distribution',
          x_axis_label = 'x', 
           y_axis_label = 'y',
            # tools=tools
           )

p2.scatter('x', 'y', source=source, size=1, color={'field': 'abs_z', 'transform': cmap})

color_bar = ColorBar(color_mapper=cmap, title='|z|',
                     location=(0,0))
p2.add_layout(color_bar, 'right')
p2.add_tools(BoxSelectTool())
show(row([p1,p2]))

In [None]:
# 3D visualization of detector 

import plotly.express as px
import plotly.graph_objects as go

# inner join truth and huts on hit_id
truth_hits = pd.merge(truth,hits, how='inner', on='hit_id')

# get particles with relatively long tracks
long_track_particles = particles[ particles['nhits'] > 15 ].sample(20)
truth_hits = truth_hits[ truth_hits['particle_id'].isin(long_track_particles['particle_id']) ]

fig = px.scatter_3d(hits, x='x', y='y', z='z', opacity=0.2,
              # color='abs_z',
                    width=1800, height=800
    )
fig.update_traces(marker_size=1)

cmap = inferno(20)

for idx, pid in enumerate( truth_hits['particle_id'].unique()):
    track = truth_hits[ truth_hits['particle_id'] == pid ].copy()
    track['r2'] = track['tpx']**2 + track['tpy']**2 + track['tpz']**2
    track['color']=idx
    track.sort_values(axis=0, by='r2', inplace=True)
    fig.add_trace(
        go.Scatter3d(
            x=track['x'],
            y=track['y'],
            z=track['z'],
            mode='markers',
            marker=dict(
                size=3,
                color=cmap[idx]
            ),
            showlegend=False
        )
    )
    fig.add_trace(
        go.Scatter3d(
            x=track['x'],
            y=track['y'],
            z=track['z'],
            mode='lines',
            line=dict(
                dash='solid',
                width=2,
                color=cmap[idx]
            ),
            showlegend=False
        )
    )

fig.show()