# Example of Graph Neural Network

In [1]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

# import seaborn as sns
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import wandb

import warnings

warnings.filterwarnings("ignore")
sys.path.append("../../..")
device = "cuda" if torch.cuda.is_available() else "cpu"

## Setup

In [2]:
from LightningModules.GNN.Models.interaction_gnn import InteractionGNN
from LightningModules.GNN.Models.interaction_gnn_PyG import PyG_GNN

In [3]:
with open("hetero_gnn.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

In [4]:
model = PyG_GNN(hparams)

## Heterogeneous Data Structure

In [5]:
model.setup(stage="fit")

In [6]:
from torch_geometric.data import HeteroData

In [8]:
sample = model.trainset[0]

In [9]:
sample

Data(x=[305249, 3], cell_data=[305249, 11], pid=[305249], event_file='/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/full_events_v4/event000016762', hid=[305249], pt=[305249], primary=[305249], nhits=[305249], modules=[305249], modulewise_true_edges=[2, 120909], signal_true_edges=[2, 13450], edge_index=[2, 25187], y=[25187], y_pid=[25187], pid_signal=[25187])

In [11]:
csv_event_file = sample.event_file
particles = pd.read_csv(csv_event_file + "-particles.csv")
truth = pd.read_csv(csv_event_file + "-truth.csv")


In [28]:
pixel_hits = truth[truth["hardware"] == "PIXEL"][["hit_id", "x", "y", "z"]]
strip_hits = truth[truth["hardware"] == "STRIP"][["hit_id", "x", "y", "z", "cluster_x_1", "cluster_y_1", "cluster_z_1", "cluster_x_2", "cluster_y_2", "cluster_z_2"]]

In [25]:
hid_df = pd.DataFrame({"hit_id": sample.hid})

In [29]:
pixel_hits = hid_df.merge(pixel_hits, on="hit_id", how="inner").values
strip_hits = hid_df.merge(strip_hits, on="hit_id", how="inner").values

In [30]:
pixel_hits

array([[ 0.00000e+00, -3.75019e+01, -3.16355e+00, -2.63000e+02],
       [ 1.00000e+00, -5.33225e+01, -1.68997e+01, -2.63000e+02],
       [ 2.00000e+00, -3.95529e+01, -1.55590e+01, -2.63000e+02],
       ...,
       [ 2.33297e+05,  2.11053e+01,  2.94327e+02,  2.84200e+03],
       [ 2.33298e+05,  3.62114e+01,  3.01760e+02,  2.84200e+03],
       [ 2.33299e+05,  2.65350e+01,  2.89890e+02,  2.84200e+03]])

In [31]:
strip_hits

array([[ 2.33300e+05, -4.00329e+02,  3.40820e+01, ..., -4.13910e+02,
         3.59886e+01, -1.50775e+03],
       [ 2.33301e+05, -3.89551e+02,  2.32969e+01, ..., -3.93287e+02,
         2.35296e+01, -1.50775e+03],
       [ 2.33302e+05, -3.84950e+02,  1.96286e+01, ..., -3.93471e+02,
         2.02203e+01, -1.50775e+03],
       ...,
       [ 3.05247e+05,  9.25272e+02, -1.78305e+02, ...,  9.20740e+02,
        -1.77556e+02,  2.86075e+03],
       [ 3.05247e+05,  9.25272e+02, -1.78305e+02, ...,  9.20740e+02,
        -1.77556e+02,  2.86075e+03],
       [ 3.05248e+05,  8.91353e+02, -1.71082e+02, ...,  9.21013e+02,
        -1.76135e+02,  2.86075e+03]])

In [32]:
data = HeteroData()

# 0. Load the event_file

# 1. Divide data into pixel and strip volumes

# 2. For pixel volumes, attach data structure
data["pixel"].x = pixel_hits[:, 1:]

# 3. For strip volumes, gather cluster features

# 4. For strip volumes, attach data structure
data["strip"].x = strip_hits[:, 1:]

# 

In [None]:
data["pixel", "connects", "pixel"].edge_index = 

In [33]:
data

HeteroData(
  [1mpixel[0m={ x=[233754, 3] },
  [1mstrip[0m={ x=[72737, 9] }
)

## Train GNN

In [5]:
logger = WandbLogger(
    project=hparams["project"], group="InitialTest", save_dir=hparams["artifacts"]
)
trainer = Trainer(
    gpus=1, max_epochs=hparams["max_epochs"], logger=logger
)  # , precision=16)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type       | Params
------------------------------------------------------
0 | node_encoder           | Sequential | 33.5 K
1 | node_network           | Sequential | 49.5 K
2 | edge_conv              | GINConv    | 49.5 K
3 | output_edge_classifier | Sequential | 66.0 K
------------------------------------------------------
149 K     Trainable params
0         Non-trainable params
149 K     Total params
0.596     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]