## _Building Graphs: Input Edges_

In [None]:
import glob, os, sys, yaml

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
import seaborn as sns
import trackml.dataset

In [None]:
import torch
from torch_geometric.data import Data
import itertools

In [None]:
# append parent dir
sys.path.append('..')

# local imports
from src import Compose_Event, Draw_Compose_Event

### _(+) - Input Data_

In [None]:
# input data
input_dir = '../train_all'

In [None]:
# Find All Input Data Files (hits.csv, cells.csv, particles.csv, truth.csv)
all_files = os.listdir(input_dir)

# Extract File Prefixes (use e.g. xxx-hits.csv)
suffix = '-hits.csv'
file_prefixes = sorted(os.path.join(input_dir, f.replace(suffix, ''))
                       for f in all_files if f.endswith(suffix))

print("Number of Files: ", len(file_prefixes))

In [None]:
# file_prefixes[:10]

In [None]:
event_id = 95191
event_prefix = file_prefixes[event_id]

In [None]:
# load an event
hits, tubes, particles, truth = trackml.dataset.load_event(event_prefix)

In [None]:
# hits.head()
# tubes.head()
# particles.head()
# truth.head()

### _(+) - Build Event_

- functions from _event_utils.py_

In [None]:
event = Compose_Event(event_prefix,skewed=False)
Draw_Compose_Event(event,figsize=(10,10));

## _Build Graphs_

### _(A) - True Edges (Layerwise)_

**True Graph** is the ground truth for GNN. It is built from creating edges from _`hits`_ from the same particle but in adjacent layers. 

For this purpose one has _`true_edges, hits = get_layerwise_edges(event)`_ function in the _`event_util.py`_.

In [None]:
from LightningModules.Processing.utils.event_utils import get_layerwise_edges

In [None]:
true_edges, hits = get_layerwise_edges(event)

### _(B) - Input Edges (Layerwise)_

**Input Graph** is the training input to GNN. It is build from edges from hits from all particles but in adjacent layers.

- use same `hits` from `get_layerwise_edges()`
- make `get_input_graph()` function similar to `get_layerwise_edges()`
- add to Data variable.

In [None]:
# layer_groups.size()
# layer_groups.groups
# layer_groups.first()
# layer_groups.last()
# layer_groups.ngroups
# layer_groups.groups.keys()

In [None]:
def select_segments(hits1, hits2, filtering=True):
    
    # TODO: Impelement filtering flag
    # Start with all possible pairs of hits
    keys = ['event_id', 'r', 'phi', 'isochrone', 'sector_id']
    hit_pairs = hits1[keys].reset_index().merge(hits2[keys].reset_index(), on='event_id', suffixes=('_1', '_2'))
    
    if filtering:
        dSector = (hit_pairs['sector_id_1'] - hit_pairs['sector_id_2'])
        sector_mask = ((dSector.abs() < 2) | (dSector.abs() == 5))
        segments = hit_pairs[['index_1', 'index_2']][sector_mask]
    else:
        segments = hit_pairs[['index_1', 'index_2']]
        
    return segments

def construct_graph(hits, layer_pairs, filtering=True):
    """Construct one graph (e.g. from one event)"""

    # Loop over layer pairs and construct segments
    layer_groups = hits.groupby('layer_id')
    segments = []
    for (layer1, layer2) in layer_pairs:
        
        # Find and join all hit pairs
        try:
            hits1 = layer_groups.get_group(layer1)
            hits2 = layer_groups.get_group(layer2)
        # If an event has no hits on a layer, we get a KeyError.
        # In that case we just skip to the next layer pair
        except KeyError as e:
            logging.info('skipping empty layer: %s' % e)
            continue
        
        # Construct the segments
        segments.append(select_segments(hits1, hits2, filtering))
    
    # Combine segments from all layer pairs
    # segments = pd.concat(segments)
    return segments

In [None]:
# lets get unique pids with freq (~ hits).
sel_pids, sel_pids_fr = np.unique(hits.particle_id, return_counts=True)
print(sel_pids)

In [None]:
# get number of layers, without skewed layers its just 18
n_layers = hits.layer_id.unique().shape[0]
print("total number of layers (w/o skewed): {}".format(n_layers))

In [None]:
# lets get pairs to adjacent layers
layers = np.arange(n_layers)
layer_pairs = np.stack([layers[:-1], layers[1:]], axis=1)
layer_pairs

In [None]:
# returns a list of indices from layer pairs.
segments = construct_graph(hits, layer_pairs, filtering=True)

In [None]:
# Combine segments from all layer pairs
# segments = pd.concat(segments)
# segments.describe()

In [None]:
# let's see the first layer pair (0th element)
# with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
#    print(segments[0][["index_1", "index_2"]])

In [None]:
# get the layer pari [0,1] from segments
edge_index = segments[0].to_numpy().T

In [None]:
edge_index.shape

In [None]:
edge_index[0]

In [None]:
edge_index[1]

In [None]:
# hits.query("layer==0")

In [None]:
# hits.query("layer==1")

In [None]:
edge_index.shape[1] == len(segments[0])

### _(+) - Plotting Input Edges_

In [None]:
from src.drawing import detector_layout
from src.utils_math import polar_to_cartesian

In [None]:
# plotting input_edges
plt.close('all')
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot()

p_ids = np.unique(event.particle_id.values)
det = pd.read_csv("../src/stt.csv")
skw = det.query('skewed==0')
nkw = det.query('skewed==1') # one may look for +ve/-ve polarity
    
# detector layout
plt.scatter(skw.x.values, skw.y.values, s=44, facecolors='none', edgecolors='lightgreen')
plt.scatter(nkw.x.values, nkw.y.values, s=44, facecolors='none', edgecolors='coral')

# particle tracks
for pid in sel_pids:
    idx = hits.particle_id == pid
    ax.scatter(hits[idx].x.values, hits[idx].y.values, label='particle_id: %d' %pid)
    
# input edges
for iedge in range(edge_index.shape[1]):
    pt1 = hits.iloc[edge_index[0][iedge]]
    pt2 = hits.iloc[edge_index[1][iedge]]
    ax.plot([pt1.x, pt2.x], [pt1.y, pt2.y], color='k', alpha=0.3, lw=1.5)

# plotting params
ax.set_xlabel('x [cm]', fontsize=20)
ax.set_ylabel('y [cm]', fontsize=20)
# ax.set_title('Event ID # %d' % event_id)
ax.set_xlim(-41, 41)
ax.set_ylim(-41, 41)
ax.grid(False)
ax.legend(fontsize=11, loc='best')
fig.tight_layout()
# fig.savefig("input_edges.pdf")

In [None]:
# New Plotting Scheme
fig, ax = detector_layout(figsize=(10,10))

# particle tracks
for pid in sel_pids:
    idx = hits.particle_id == pid
    ax.scatter(hits[idx].x.values, hits[idx].y.values, label='particle_id: %d' %pid)
    
# input edges
for iedge in range(edge_index.shape[1]):
    pt1 = hits.iloc[edge_index[0][iedge]]
    pt2 = hits.iloc[edge_index[1][iedge]]
    ax.plot([pt1.x, pt2.x], [pt1.y, pt2.y], color='k', alpha=0.3, lw=1.5)

# axis params
ax.legend(fontsize=12, loc='best')
fig.tight_layout()
# fig.savefig("input_edges.pdf")

### _(+) Sector-wise Filtering_

* build edges only in neighouring sectors _i.e._ `|sector_id_i - sector_id_j| < 2`

In [None]:
# lets take first layer_pair and corresponding hits
layer_pairs[0]

In [None]:
layer_groups = hits.groupby('layer_id')

In [None]:
hits1 = layer_groups.get_group(0)
hits2 = layer_groups.get_group(1)

In [None]:
keys = ['event_id', 'r', 'phi', 'isochrone', 'sector_id']
hit_pairs = hits1[keys].reset_index().merge(hits2[keys].reset_index(), on='event_id', suffixes=('_1', '_2'))
hit_pairs

In [None]:
sector_mask = ((hit_pairs['sector_id_1'] - hit_pairs['sector_id_2']).abs() < 2)

In [None]:
sector_mask

In [None]:
hit_pairs[['index_1', 'index_2']].head()

In [None]:
hit_pairs[['index_1', 'index_2']][sector_mask].head()

### _(C) - Input Edges (Modulewise)_

**Input Graph** is the training input to GNN. It is build from edges from hits from all particles but in adjacent layers.

- use same `hits` from `get_modulewise_ordered_edges()`
- make `get_input_modulewise_edges()` function similar to `get_input_edges()`
- add to Data variable.

In [None]:
from LightningModules.Processing.utils.event_utils import select_hits
from LightningModules.Processing.utils.event_utils import get_modulewise_ordered_edges

In [None]:
kwargs = {"selection": False}

In [None]:
# select hits
hits = select_hits(event_file=event_prefix, noise=False, skewed=False, **kwargs)

In [None]:
hits.head()

In [None]:
# Handle NaN and Null Values
signal = hits[
    ((~hits.particle_id.isna()) & (hits.particle_id != 0)) & (~hits.vx.isna())
]
signal = signal.drop_duplicates(
    subset=["particle_id", "volume_id", "layer_id", "module_id"]
)

# Handle Indexing (Keep order of occurrence)
signal = signal.reset_index()

# Rename 'index' column to 'unsorted_index'
signal = signal.rename(columns={"index": "unsorted_index"}).reset_index(drop=False)

# Handle Particle_id 0
signal.loc[signal["particle_id"] == 0, "particle_id"] = np.nan

In [None]:
signal.head()

In [None]:
pid_groups = hits.groupby("particle_id", sort=False)

In [None]:
pid_groups.groups

In [None]:
n_pids = signal.particle_id.unique().shape[0]
pids = np.arange(n_pids)
pid_pairs = np.stack([pids[:-1], pids[1:]], axis=1)

In [None]:
pid_pairs

In [None]:
edges = []
for (g1, g2) in pid_pairs:
    hits1 = layer_groups.get_group(g2)
    hits2 = layer_groups.get_group(g2)
    
    keys = ['event_id', 'r', 'phi', 'isochrone', 'sector_id']
    hit_pairs = hits1[keys].reset_index().merge(hits2[keys].reset_index(), on='event_id', suffixes=('_1', '_2'))
    
    dSector = (hit_pairs['sector_id_1'] - hit_pairs['sector_id_2'])
    sector_mask = ((dSector.abs() < 2) | (dSector.abs() == 5))
    e = hit_pairs[['index_1', 'index_2']][sector_mask]
        
    edges.append(e)    

In [None]:
edges = pd.concat(edges)

In [None]:
edge_index = edges.to_numpy().T

In [None]:
# New Plotting Scheme
fig, ax = detector_layout(figsize=(10,10))

# particle tracks
for pid in sel_pids:
    idx = hits.particle_id == pid
    ax.scatter(hits[idx].x.values, hits[idx].y.values, label='particle_id: %d' %pid)
    
# Plot input edges
for i, j in edge_index.T:
    pt1 = hits.iloc[i]
    pt2 = hits.iloc[j]
    ax.plot([pt1.x, pt2.x], [pt1.y, pt2.y], color='gray', linewidth=0.5)


# axis params
ax.legend(fontsize=12, loc='best')
fig.tight_layout()
# fig.savefig("input_edges.pdf")