# Example of Graph Neural Network

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

sys.path.append('..')
device = "cuda" if torch.cuda.is_available() else "cpu"

## Attention Mechanism

In [3]:
from LightningModules.GNN.Models.agnn import ResAGNN
from LightningModules.GNN.Models.checkpoint_agnn import CheckpointedResAGNN
from LightningModules.GNN.Models.interaction_multistep_gnn import CheckpointedInteractionMultistepGNN

In [4]:
with open("example_gnn.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
model = ResAGNN(hparams)

In [5]:
model = CheckpointedResAGNN(hparams)

In [5]:
model = CheckpointedInteractionMultistepGNN(hparams)

### Dataset

In [10]:
%%time
model.setup(stage="fit")

CPU times: user 26.6 s, sys: 2.29 s, total: 28.9 s
Wall time: 17.7 s


In [12]:
sample = model.trainset[0]

In [13]:
sample = torch.load("/project/projectdirs/m3443/data/ITk-upgrade/processed/filter_processed/1_GeV_unweighted_high_eff/train/0001")

In [13]:
sample

Data(cell_data=[341982, 11], edge_index=[2, 34700], event_file="/project/projectdirs/m3443/data/ITk-upgrade/processed/full_events_v3/event000000001", hid=[341982], modulewise_true_edges=[2, 15408], nhits=[341982], pid=[341982], primary=[341982], pt=[341982], x=[341982, 3], y=[34700], y_pid=[1558445])

In [14]:
sample.y.sum()/sample.modulewise_true_edges.shape[1]

tensor(0.9523)

In [15]:
sample.y.sum()/sample.edge_index.shape[1]

tensor(0.4229)

In [16]:
edges = sample.edge_index

In [17]:
pid = sample.pid

In [19]:
edges.shape

torch.Size([2, 34700])

In [20]:
(sample.pid[edges[0]] == sample.pid[edges[1]]).sum()

tensor(27367)

### Memory Test

In [6]:
%%time
model.setup(stage="fit")

CPU times: user 2.34 s, sys: 289 ms, total: 2.63 s
Wall time: 1.84 s


In [7]:
sample = model.trainset[0].to(device)

In [8]:
model = model.to(device)

In [9]:
torch.cuda.reset_peak_memory_stats()
output = model(sample.x.to(device), sample.edge_index.to(device))

In [10]:
print(torch.cuda.max_memory_allocated()/1024**3, "Gb")

12.287028789520264 Gb


### Train GNN

In [None]:
logger = WandbLogger(project="ITk_1GeV_GNN", group="InitialTest")
trainer = Trainer(gpus=1, max_epochs=50, logger=logger)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Set SLURM handle signals.

  | Name                   | Type       | Params
------------------------------------------------------
0 | node_encoder           | Sequential | 7.1 K 
1 | edge_encoder           | Sequential | 19.3 K
2 | edge_network           | Sequential | 31.9 K
3 | node_network           | Sequential | 31.9 K
4 | output_edge_classifier | Sequential | 25.6 K
------------------------------------------------------
115 K     Trainable params
0         Non-trainable params
115 K     Total params
0.463     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  eff = torch.tensor(edge_true_positive / edge_true)
  pur = torch.tensor(edge_true_positive / edge_positive)
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]