## _Graph Construction_

- _Heuristic Method_

In [None]:
import glob, os, sys, yaml

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
import seaborn as sns
import trackml.dataset

In [None]:
import torch
from torch_geometric.data import Data
import itertools

In [None]:
# append parent dir
sys.path.append('..')

In [None]:
# get cuda gpus if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# local imports
from src import SttCSVDataReader, SttTorchDataReader
from src import detector_layout
from src import Build_Event, Build_Event_Viz, Visualize_Edges
from src.math_utils import polar_to_cartesian

### _Input Data_

In [None]:
# input data
input_dir = '../data_all'

In [None]:
# Find All Input Data Files (hits.csv, cells.csv, particles.csv, truth.csv)
all_files = os.listdir(input_dir)

# Extract File Prefixes (use e.g. xxx-hits.csv)
suffix = '-hits.csv'
file_prefixes = sorted(os.path.join(input_dir, f.replace(suffix, ''))
                       for f in all_files if f.endswith(suffix))

print("Number of Files: ", len(file_prefixes))

In [None]:
# file_prefixes[:10]

In [None]:
# load an event
# hits, tubes, particles, truth = trackml.dataset.load_event(file_prefixes[0])

In [None]:
# hits.head()
# tubes.head()
# particles.head()
# truth.head()

### _Visualize Event_

In [None]:
# select event
event_id = 95191

In [None]:
# compose event is exactly the same as select_hits()
# event = Build_Event(input_dir, event_id, noise=False, skewed=False, selection=False)

In [None]:
# visualize event
# Build_Event_Viz(event, figsize=(10,10), fig_type="pdf", save_fig=False)

## _Graph Construction: Heuristic Method_

Input graphs are input to a neural network, so they contain both _`True`_ and _`False`_ edges constructed by either a _heuristic method_ or _metric learning_. For supervised learning, we need node features (_`x`_), edge index (_`edge_index`_) and corresponding groud truth (_`y`_). 

Here we will explore a **_Heuristic Method_** to contruct input graphs.

In [None]:
from LightningModules.Processing.utils.event_utils import select_hits
from LightningModules.Processing.utils.event_utils import get_layerwise_edges
from LightningModules.Processing.utils.event_utils import get_modulewise_edges

### _(A) - Layerwise True Edges_

- layerwise true edges works for high momentum particles but fails when particle re-enter the detector
- layerwise input edges are bit inconsistent with layerwise true edges

In [None]:
# get event prefix using event_id
event_prefix = file_prefixes[event_id]

In [None]:
# select hits
kwargs = {"selection": False}
hits = select_hits(event_file=event_prefix, noise=False, skewed=False, **kwargs)

In [None]:
# layerwise true edges & new hits dataframe
true_edges, hits = get_layerwise_edges(hits)

In [None]:
# visualize nodes and edges
Visualize_Edges (hits, true_edges, figsize=(10,10), fig_type="pdf", save_fig=False)

### _(B) - Layerwise Input Edges_

**Input Graph** is the training input to GNN. It is build from edges from hits from all particles but in adjacent layers.

- use same `hits` from `get_layerwise_edges()`
- make `get_input_graph()` function similar to `get_layerwise_edges()`
- add to PyG `Data` object.

In [None]:
import logging

In [None]:
# select hits
kwargs = {"selection": False}
hits = select_hits(event_file=event_prefix, noise=False, skewed=False, **kwargs)

In [None]:
def select_segments(hits1, hits2, filtering=True):
    
    # TODO: Impelement filtering flag
    # Start with all possible pairs of hits
    keys = ['event_id', 'r', 'phi', 'isochrone', 'sector_id']
    hit_pairs = hits1[keys].reset_index().merge(hits2[keys].reset_index(), on='event_id', suffixes=('_1', '_2'))
    
    if filtering:
        dSector = (hit_pairs['sector_id_1'] - hit_pairs['sector_id_2'])
        sector_mask = ((dSector.abs() < 2) | (dSector.abs() == 5))
        segments = hit_pairs[['index_1', 'index_2']][sector_mask]
    else:
        segments = hit_pairs[['index_1', 'index_2']]
        
    return segments

def construct_graph(hits, layer_pairs, filtering=True):
    """Construct one graph (e.g. from one event)"""

    # Loop over layer pairs and construct segments
    layer_groups = hits.groupby('layer_id')
    segments = []
    for (layer1, layer2) in layer_pairs:
        
        # Find and join all hit pairs
        try:
            hits1 = layer_groups.get_group(layer1)
            hits2 = layer_groups.get_group(layer2)
        # If an event has no hits on a layer, we get a KeyError.
        # In that case we just skip to the next layer pair
        except KeyError as e:
            logging.info('skipping empty layer: %s' % e)
            continue
        
        # Construct the segments
        segments.append(select_segments(hits1, hits2, filtering))
    
    # Combine segments from all layer pairs
    # segments = pd.concat(segments)
    return segments

In [None]:
# layer_groups = hits.groupby('layer_id')
# layer_groups.size()
# layer_groups.groups
# layer_groups.first()
# layer_groups.last()
# layer_groups.ngroups
# layer_groups.groups.keys()

In [None]:
# get number of layers, without skewed layers its just 18
n_layers = hits.layer_id.unique().shape[0]
print("total number of layers (w/o skewed): {}".format(n_layers))

In [None]:
# lets get pairs to adjacent layers
layers = np.arange(n_layers)
layer_pairs = np.stack([layers[:-1], layers[1:]], axis=1)
print("total number of layer pairs (w/o skewed): {}".format(layer_pairs.shape[0]))

In [None]:
# get a list of indices (DataFrame) for each layer pairs.
segments = construct_graph(hits, layer_pairs, filtering=False)
len(segments)

In [None]:
# Combine segments from all layer pairs
combined_segments = pd.concat(segments)
combined_segments.head()

In [None]:
# get the layer pari [0,1] from segments
input_graph = segments[0].to_numpy().T

In [None]:
# input_graph.shape

In [None]:
# input_graph[0]

In [None]:
# input_graph[1]

In [None]:
# hits.query("layer==0")

In [None]:
# hits.query("layer==1")

In [None]:
# input_graph.shape[1] == len(segments[0])

### _(+) - Plotting Input Edges_

In [None]:
# visualize nodes and edges
Visualize_Edges (hits, input_graph, figsize=(10,10), fig_type="pdf", save_fig=False)

### _(+) Sector-wise Filtering_

* build edges only in neighouring sectors _i.e._ `|sector_id_i - sector_id_j| < 2`

In [None]:
# lets take first layer_pair and corresponding hits
layer_pairs[0]

In [None]:
layer_groups = hits.groupby('layer_id')

In [None]:
hits1 = layer_groups.get_group(0)
hits2 = layer_groups.get_group(1)

In [None]:
keys = ['event_id', 'r', 'phi', 'isochrone', 'sector_id']
hit_pairs = hits1[keys].reset_index().merge(hits2[keys].reset_index(), on='event_id', suffixes=('_1', '_2'))
hit_pairs

In [None]:
sector_mask = ((hit_pairs['sector_id_1'] - hit_pairs['sector_id_2']).abs() < 2)

In [None]:
sector_mask

In [None]:
hit_pairs[['index_1', 'index_2']].head()

In [None]:
hit_pairs[['index_1', 'index_2']][sector_mask].head()

### _(C) - Modulewise True Edges_

In [None]:
# select hits
kwargs = {"selection": False}
hits = select_hits(event_file=event_prefix, noise=False, skewed=False, **kwargs)

In [None]:
# layerwise true edges & new hits dataframe
true_edges = get_modulewise_edges(hits)

In [None]:
# check dimensions
# true_edges.shape

In [None]:
# visualize nodes and edges
Visualize_Edges (hits, true_edges, figsize=(10,10), fig_type="pdf", save_fig=False)

### _(D) - Modulewise Input Edges_

**Input Graph** is the training input to GNN. It is build from edges from hits from all particles but in adjacent layers.

- use same `hits` from `get_modulewise_ordered_edges()`
- make `get_input_modulewise_edges()` function similar to `get_input_edges()`
- add to Data variable.

In [None]:
# select hits
kwargs = {"selection": False}
hits = select_hits(event_file=event_prefix, noise=False, skewed=False, **kwargs)

In [None]:
hits.head()

In [None]:
# Handle NaN and Null Values
signal = hits[
    ((~hits.particle_id.isna()) & (hits.particle_id != 0)) & (~hits.vx.isna())
]
signal = signal.drop_duplicates(
    subset=["particle_id", "volume_id", "layer_id", "module_id"]
)

# Handle Indexing (Keep order of occurrence)
signal = signal.reset_index()

# Rename 'index' column to 'unsorted_index'
signal = signal.rename(columns={"index": "unsorted_index"}).reset_index(drop=False)

# Handle Particle_id 0
signal.loc[signal["particle_id"] == 0, "particle_id"] = np.nan

In [None]:
signal.head()

In [None]:
pid_groups = hits.groupby("particle_id", sort=False)

In [None]:
pid_groups.groups

In [None]:
n_pids = signal.particle_id.unique().shape[0]
pids = np.arange(n_pids)
pid_pairs = np.stack([pids[:-1], pids[1:]], axis=1)

In [None]:
pid_pairs

In [None]:
layer_groups = hits.groupby('layer_id')

In [None]:
edges = []
for (g1, g2) in pid_pairs:
    hits1 = layer_groups.get_group(g1)
    hits2 = layer_groups.get_group(g2)
    
    keys = ['event_id', 'r', 'phi', 'isochrone', 'sector_id']
    hit_pairs = hits1[keys].reset_index().merge(hits2[keys].reset_index(), on='event_id', suffixes=('_1', '_2'))
    
    dSector = (hit_pairs['sector_id_1'] - hit_pairs['sector_id_2'])
    sector_mask = ((dSector.abs() < 2) | (dSector.abs() == 5))
    e = hit_pairs[['index_1', 'index_2']][sector_mask]
        
    edges.append(e)

In [None]:
input_edges = pd.concat(edges)

In [None]:
input_edges = input_edges.to_numpy().T

In [None]:
input_edges.shape

In [None]:
# visualize nodes and edges
# Visualize_Edges (hits, input_edges, figsize=(10,10), fig_type="pdf", save_fig=False)