## GNN Stage

In [None]:
import glob, os, sys, yaml

In [None]:
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
import seaborn as sns
import trackml.dataset

In [None]:
import torch
from torch_geometric.data import Data
import itertools

In [None]:
from src import draw_event
from src import compose_event

In [None]:
# set EXATRKX_DATA env variable
os.environ['EXATRKX_DATA'] = os.path.abspath(os.curdir)

In [None]:
# select a device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# load processing config file (trusted source)
config_file = os.path.join(os.curdir, 'LightningModules/GNN/train_quickstart_GNN.yaml')
with open(config_file) as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader) # equiv: yaml.full_load(f)
    except yaml.YAMLError as e:
        print(e)

## _Input Data_
---

In [None]:
from LightningModules.Processing.utils.event_utils import graph_intersection
from LightningModules.Processing.utils.draw_utils import draw_proc_event, cylindrical_to_cartesian

In [None]:
# Path to feature_store, use os.path.expandvars for ${HOME}
FEATURE_DATA = os.path.expandvars(config['input_dir'])
print("FEATURE_DATA: {}".format(os.path.basename(FEATURE_DATA)))

In [None]:
# examine an event, give integer value to event_id
event_id = 1
feature_data = torch.load(os.path.join(FEATURE_DATA, str(event_id)), map_location=device)
print("Length of Data: {}".format(len(feature_data)))

In [None]:
feature_data

In [None]:
feature_data.x[:10]

In [None]:
feature_data.y_pid.shape

In [None]:
type(feature_data.y_pid)

## _GNN Stage_

- For GNN stage, I will run the `InteractionGNN` model

In [None]:
from LightningModules.GNN.Models.interaction_gnn import InteractionGNN

In [None]:
# see params used in this stage
pp = pprint.PrettyPrinter(indent=2)
pp.pprint(config)

In [None]:
# change some params here
config['datatype_names'] = [['train', 'val', 'test']]
config['datatype_split'] = [[800, 100, 100]]
config['input_dir']  = os.path.join(os.environ['EXATRKX_DATA'],'run/feature_store')
config['output_dir'] = os.path.join(os.environ['EXATRKX_DATA'],'run/gnn_processed/quickstart_example')

In [None]:
# see params used in this stage
pp = pprint.PrettyPrinter(indent=2)
pp.pprint(config)

- _EDA_ :: Let's test `load_event()`, it gave error

In [None]:
def select_data(events, pt_background_cut, pt_signal_cut, noise):
    """Select data after applying pt cuts OR return without applying it if pt's set to zero."""
    
    # Handle event in batched form
    if type(events) is not list:
        events = [events]

    # NOTE: Cutting background by pT BY DEFINITION removes noise
    if (pt_background_cut > 0) | (pt_signal_cut > 0):
        for event in events:

            edge_mask = (event.pt[event.edge_index] > pt_background_cut).all(0)
            event.edge_index = event.edge_index[:, edge_mask]
            event.y = event.y[edge_mask]

            if "weights" in event.__dict__.keys():
                if event.weights.shape[0] == edge_mask.shape[0]:
                    event.weights = event.weights[edge_mask]

            if (pt_signal_cut > pt_background_cut) and (
                "signal_true_edges" in event.__dict__.keys()
            ):
                signal_mask = (event.pt[event.signal_true_edges] > pt_signal_cut).all(0)
                event.signal_true_edges = event.signal_true_edges[:, signal_mask]

    return events

In [None]:
def load_dataset(input_dir, num, pt_background_cut, pt_signal_cut, noise):
    """Load data and apply pt cuts."""
    if input_dir is not None:
        all_events = os.listdir(input_dir)
        all_events = sorted([os.path.join(input_dir, event) for event in all_events])
        loaded_events = [
            torch.load(event, map_location=torch.device("cpu"))
            for event in all_events[:num]
        ]
        loaded_events = select_data(
            loaded_events, pt_background_cut, pt_signal_cut, noise
        )
        return loaded_events
    else:
        return None

    return included_edges, included_edges_mask # FIXME::ADAK: This will never execute.

In [None]:
all_events = os.listdir(config['input_dir'])

In [None]:
# init the InteractionGNN
model = InteractionGNN(config)

In [None]:
model.summarize()

In [None]:
# dataset as accessed in model
model.setup(stage="fit")

In [None]:
trainset = model.trainset

In [None]:
example_data = trainset[0]
r, phi, ir = example_data.x.T

In [None]:
x, y = r * np.cos(phi * np.pi), r * np.sin(phi * np.pi)

In [None]:
plt.figure(figsize=(8, 8))
plt.scatter(x, y, s=2)
plt.title("Azimuthal View of Detector", fontsize=24), plt.xlabel(
    "x", fontsize=18
), plt.ylabel("y", fontsize=18)

In [None]:
e = example_data.edge_index
pid = example_data.pid
true_edges = pid[e[0]] == pid[e[1]]

In [None]:
plt.figure(figsize=(8,8))
# plt.plot(x[e[:, ~true_edges]], y[e[:, ~true_edges]], c="r")
plt.plot(x[e[:, true_edges]], y[e[:, true_edges]], c="k")
plt.scatter(x, y, s=5)
plt.title("Azimuthal View of Detector", fontsize=24), plt.xlabel(
    "x", fontsize=18
), plt.ylabel("y", fontsize=18)

In [None]:
plt.figure(figsize=(8,8))
plt.plot(x[e[:, (~true_edges)][:, 0:-1:5]], y[e[:, (~true_edges)][:, 0:-1:5]], c="r")
plt.scatter(x, y, s=5)
plt.title("Azimuthal View of Detector", fontsize=24), plt.xlabel(
    "x", fontsize=18
), plt.ylabel("y", fontsize=18)

In [None]:
from pytorch_lightning import Trainer

In [None]:
trainer = Trainer(max_epochs=1)

In [None]:
trainer.fit(model)