## _HitPairs using ExaTrkX Pipeline_

In [None]:
import glob, os, sys, yaml

In [None]:
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
pp = pprint.PrettyPrinter(indent=2)
import seaborn as sns
import trackml.dataset

In [None]:
import torch
from torch_geometric.data import Data
import itertools

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.environ['EXATRKX_DATA'] = os.path.abspath(os.curdir)

### _2.1 - Config File_

In [None]:
# load processing config file (trusted source)
config_file = os.path.join(os.curdir, 'LightningModules/GNN/configs/train_quickstart_DNN.yaml')
with open(config_file) as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader) # equiv: yaml.full_load(f)
    except yaml.YAMLError as e:
        print(e)

In [None]:
# pp.pprint(config)

In [None]:
# Modifications
config['project'] = 'DNNStudy'
config['datatype_names'] = ['train', 'val', 'test']
config['datatype_split'] = [800, 100, 100]
config['input_dir']  = os.path.join(os.environ['EXATRKX_DATA'],'run/feature_store')
config['output_dir'] = os.path.join(os.environ['EXATRKX_DATA'],'run/dnn_processed')

In [None]:
# pp.pprint(config)

### _2.2 - Input Data_

In [None]:
# Read Event from the Testset
inputdir=os.path.expandvars(config['input_dir']+"/train")
outputdir=os.path.expandvars(config['output_dir'])
os.makedirs(outputdir, exist_ok=True)

In [None]:
all_files = glob.glob(os.path.join(inputdir, "*"))
all_files = sorted(all_files)
print("Total Test Events: ", len(all_files))

In [None]:
all_files[:10]

In [None]:
filename = all_files[5]

In [None]:
# os.path.split(path) := os.path.dirname(path) + os.path.basename(path)
print("os.path.dirname(path) : ", os.path.split(filename)[0])
print("os.path.basename(path): ", os.path.split(filename)[1])

In [None]:
# load a file
feature_data = torch.load(filename, map_location=device)
print("Length of Data: {}".format(len(feature_data)))

In [None]:
print(feature_data.keys)

In [None]:
print(feature_data)

In [None]:
x = feature_data.x
edge_index = feature_data.edge_index

### _Input Features to Network_

- The `forward()` function gets `x, edge_index` from outside. Where `x = [r, phi, z]` is node feature and `edge_index` contains _edges_ (_aka node/hit pairs)
- However, `EdgeClassifier` needs $x_i, x_j$ for each edge. So one needs to concatenate features of nodes in each edge.

Let's see how it can be achieved.

In [None]:
# start = edge_index[0]
# end = edge_index[1]
start, end = edge_index

In [None]:
# This yield true
# start == edge_index[0]

In [None]:
# This yield true
# end == edge_index[1]

In [None]:
edge_index.shape

In [None]:
x.shape

In [None]:
x[start].shape

In [None]:
x[end].shape

In [None]:
x[0]

In [None]:
x[4]

In [None]:
x[start][0]

In [None]:
x[end][0]

In [None]:
edge_inputs = torch.cat([x[start], x[end]], dim=1)

In [None]:
edge_inputs[0]

### _2.3 - Network Model_

In [None]:
from LightningModules.GNN.gnn_base import GNNBase

In [None]:
from LightningModules.GNN.Models.dense_network import EdgeClassifier

In [None]:
model = EdgeClassifier(config)

In [None]:
print(model)

### _2.4 - Training_

In [None]:
from pytorch_lightning import Trainer

In [None]:
# dataset as accessed in model
model.setup('fit')

In [None]:
trainer = Trainer(max_epochs=10)

In [None]:
trainer.fit(model)

### _Test Training Set_

In [None]:
trainset = model.trainset

In [None]:
example_data = trainset[0]
r, phi, ir = example_data.x.T

In [None]:
x, y = r * np.cos(phi * np.pi), r * np.sin(phi * np.pi)

In [None]:
plt.figure(figsize=(8, 8))
plt.scatter(x, y, s=2)
plt.title("Azimuthal View of Detector", fontsize=24)
plt.xlabel("x", fontsize=18)
plt.ylabel("y", fontsize=18)

In [None]:
e = example_data.edge_index
pid = example_data.pid
true_edges = pid[e[0]] == pid[e[1]]

In [None]:
plt.figure(figsize=(8,8))
plt.plot(x[e[:, true_edges]], y[e[:, true_edges]], c="k")
plt.scatter(x, y, s=5)
plt.title("Azimuthal View of Detector", fontsize=24)
plt.xlabel("x", fontsize=18)
plt.ylabel("y", fontsize=18)

In [None]:
plt.figure(figsize=(8,8))
# plt.plot(x[e[:, ~true_edges]], y[e[:, ~true_edges]], c="r")
plt.plot(x[e[:, (~true_edges)][:, 0:-1:5]], y[e[:, (~true_edges)][:, 0:-1:5]], c="r")
plt.scatter(x, y, s=5)
plt.title("Azimuthal View of Detector", fontsize=24)
plt.xlabel("x", fontsize=18)
plt.ylabel("y", fontsize=18)