## GNN Stage

In [1]:
import glob, os, sys, yaml

In [2]:
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
import pprint
import seaborn as sns
import trackml.dataset

In [4]:
import torch
from torch_geometric.data import Data
import itertools

In [5]:
from src import draw_event
from src import compose_event

In [6]:
# set EXATRKX_DATA env variable
os.environ['EXATRKX_DATA'] = os.path.abspath(os.curdir)

In [7]:
# select a device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
# load processing config file (trusted source)
config_file = os.path.join(os.curdir, 'LightningModules/GNN/train_quickstart_GNN.yaml')
with open(config_file) as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader) # equiv: yaml.full_load(f)
    except yaml.YAMLError as e:
        print(e)

## _Input Data_
---

In [9]:
from LightningModules.Processing.utils.event_utils import graph_intersection
from LightningModules.Processing.utils.draw_utils import draw_proc_event, cylindrical_to_cartesian

In [10]:
# Path to feature_store, use os.path.expandvars for ${HOME}
FEATURE_DATA = os.path.expandvars(config['input_dir'])
print("FEATURE_DATA: {}".format(os.path.basename(FEATURE_DATA)))

FEATURE_DATA: feature_store


In [11]:
# examine an event, give integer value to event_id
event_id = 1
feature_data = torch.load(os.path.join(FEATURE_DATA, str(event_id)), map_location=device)
print("Length of Data: {}".format(len(feature_data)))

Length of Data: 10


In [12]:
feature_data

Data(x=[164, 3], pid=[164], layers=[164], event_file='/home/adeel/current/3_deeptrkx/stttrkx-hsf/train_all/event0000000001', hid=[164], pt=[164], modulewise_true_edges=[2, 154], layerwise_true_edges=[2, 171], edge_index=[2, 896], y_pid=[896])

In [13]:
feature_data.x[:10]

tensor([[ 1.6627e-01,  9.9033e-01,  2.8715e-03],
        [ 1.6627e-01, -9.9033e-01,  3.4859e-04],
        [ 1.6810e-01,  4.7995e-02,  4.9377e-03],
        [ 1.7229e-01,  7.5164e-01,  2.6239e-03],
        [ 1.7523e-01, -1.0268e-01,  4.6810e-03],
        [ 1.7523e-01,  9.8164e-01,  3.1453e-03],
        [ 1.7523e-01, -9.8164e-01,  1.1627e-03],
        [ 1.7523e-01, -7.6935e-01,  1.1195e-04],
        [ 1.7610e-01,  3.6593e-02,  1.4773e-05],
        [ 1.7869e-01,  4.5309e-01,  7.3398e-04]])

In [14]:
feature_data.y_pid.shape

torch.Size([896])

In [15]:
type(feature_data.y_pid)

torch.Tensor

## _GNN Stage_

- For GNN stage, I will run the `InteractionGNN` model

In [16]:
from LightningModules.GNN.Models.interaction_gnn import InteractionGNN

In [17]:
# see params used in this stage
pp = pprint.PrettyPrinter(indent=2)
pp.pprint(config)

{ 'aggregation': 'sum_max',
  'callbacks': ['GNNTelemetry'],
  'cell_channels': 0,
  'datatype_names': [['train', 'val', 'test']],
  'datatype_split': [[800, 100, 100]],
  'edge_cut': 0.5,
  'emb_channels': 0,
  'factor': 0.3,
  'hidden': 128,
  'hidden_activation': 'ReLU',
  'input_dir': '${EXATRKX_DATA}/run/feature_store',
  'layernorm': True,
  'lr': 0.001,
  'max_epochs': 1,
  'n_graph_iters': 8,
  'nb_edge_layer': 3,
  'nb_node_layer': 3,
  'noise': False,
  'output_dir': '${EXATRKX_DATA}/run/gnn_processed/quickstart_example',
  'patience': 10,
  'project': 'GNNStudy',
  'pt_background_min': 0.0,
  'pt_signal_min': 0.0,
  'regime': [['pid']],
  'spatial_channels': 3,
  'warmup': 200,
  'weight': 2}


In [18]:
# change some params here
config['datatype_names'] = [['train', 'val', 'test']]
config['datatype_split'] = [[800, 100, 100]]
config['input_dir']  = os.path.join(os.environ['EXATRKX_DATA'],'run/feature_store')
config['output_dir'] = os.path.join(os.environ['EXATRKX_DATA'],'run/gnn_processed/quickstart_example')

In [19]:
# see params used in this stage
pp = pprint.PrettyPrinter(indent=2)
pp.pprint(config)

{ 'aggregation': 'sum_max',
  'callbacks': ['GNNTelemetry'],
  'cell_channels': 0,
  'datatype_names': [['train', 'val', 'test']],
  'datatype_split': [[800, 100, 100]],
  'edge_cut': 0.5,
  'emb_channels': 0,
  'factor': 0.3,
  'hidden': 128,
  'hidden_activation': 'ReLU',
  'input_dir': '/home/adeak977/current/3_deeptrkx/stttrkx-hsf/run/feature_store',
  'layernorm': True,
  'lr': 0.001,
  'max_epochs': 1,
  'n_graph_iters': 8,
  'nb_edge_layer': 3,
  'nb_node_layer': 3,
  'noise': False,
  'output_dir': '/home/adeak977/current/3_deeptrkx/stttrkx-hsf/run/gnn_processed/quickstart_example',
  'patience': 10,
  'project': 'GNNStudy',
  'pt_background_min': 0.0,
  'pt_signal_min': 0.0,
  'regime': [['pid']],
  'spatial_channels': 3,
  'warmup': 200,
  'weight': 2}


- _EDA_ :: Let's test `load_event()`, it gave error

In [None]:
def select_data(events, pt_background_cut, pt_signal_cut, noise):
    """Select data after applying pt cuts OR return without applying it if pt's set to zero."""
    
    # Handle event in batched form
    if type(events) is not list:
        events = [events]

    # NOTE: Cutting background by pT BY DEFINITION removes noise
    if (pt_background_cut > 0) | (pt_signal_cut > 0):
        for event in events:

            edge_mask = (event.pt[event.edge_index] > pt_background_cut).all(0)
            event.edge_index = event.edge_index[:, edge_mask]
            event.y = event.y[edge_mask]

            if "weights" in event.__dict__.keys():
                if event.weights.shape[0] == edge_mask.shape[0]:
                    event.weights = event.weights[edge_mask]

            if (pt_signal_cut > pt_background_cut) and (
                "signal_true_edges" in event.__dict__.keys()
            ):
                signal_mask = (event.pt[event.signal_true_edges] > pt_signal_cut).all(0)
                event.signal_true_edges = event.signal_true_edges[:, signal_mask]

    return events

In [None]:
def load_dataset(input_dir, num, pt_background_cut, pt_signal_cut, noise):
    """Load data and apply pt cuts."""
    if input_dir is not None:
        all_events = os.listdir(input_dir)
        all_events = sorted([os.path.join(input_dir, event) for event in all_events])
        loaded_events = [
            torch.load(event, map_location=torch.device("cpu"))
            for event in all_events[:num]
        ]
        loaded_events = select_data(
            loaded_events, pt_background_cut, pt_signal_cut, noise
        )
        return loaded_events
    else:
        return None

    return included_edges, included_edges_mask # FIXME::ADAK: This will never execute.

In [23]:
all_events = os.listdir(config['input_dir'])

In [20]:
# init the InteractionGNN
model = InteractionGNN(config)

In [21]:
model.summarize()

  model.summarize()
  rank_zero_deprecation(


  | Name                   | Type       | Params
------------------------------------------------------
0 | node_encoder           | Sequential | 34.0 K
1 | edge_encoder           | Sequential | 66.4 K
2 | edge_network           | Sequential | 82.8 K
3 | node_network           | Sequential | 82.8 K
4 | output_edge_classifier | Sequential | 83.2 K
------------------------------------------------------
349 K     Trainable params
0         Non-trainable params
349 K     Total params
1.397     Total estimated model params size (MB)

In [22]:
# dataset as accessed in model
model.setup(stage="fit")

TypeError: join() argument must be str, bytes, or os.PathLike object, not 'list'

In [None]:
trainset = model.trainset

In [None]:
example_data = trainset[0]
r, phi, ir = example_data.x.T

In [None]:
x, y = r * np.cos(phi * np.pi), r * np.sin(phi * np.pi)

In [None]:
plt.figure(figsize=(8, 8))
plt.scatter(x, y, s=2)
plt.title("Azimuthal View of Detector", fontsize=24), plt.xlabel(
    "x", fontsize=18
), plt.ylabel("y", fontsize=18)

In [None]:
e = example_data.edge_index
pid = example_data.pid
true_edges = pid[e[0]] == pid[e[1]]

In [None]:
plt.figure(figsize=(8,8))
# plt.plot(x[e[:, ~true_edges]], y[e[:, ~true_edges]], c="r")
plt.plot(x[e[:, true_edges]], y[e[:, true_edges]], c="k")
plt.scatter(x, y, s=5)
plt.title("Azimuthal View of Detector", fontsize=24), plt.xlabel(
    "x", fontsize=18
), plt.ylabel("y", fontsize=18)

In [None]:
plt.figure(figsize=(8,8))
plt.plot(x[e[:, (~true_edges)][:, 0:-1:5]], y[e[:, (~true_edges)][:, 0:-1:5]], c="r")
plt.scatter(x, y, s=5)
plt.title("Azimuthal View of Detector", fontsize=24), plt.xlabel(
    "x", fontsize=18
), plt.ylabel("y", fontsize=18)

In [None]:
from pytorch_lightning import Trainer

In [None]:
trainer = Trainer(max_epochs=1)

In [None]:
trainer.fit(model)