# Example of Track Parameter Regression with GNN

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning import Trainer

sys.path.append("..")
device = "cuda" if torch.cuda.is_available() else "cpu"

# THIS NOTEBOOK IS NOT YET FINISHED! BEWARE

Example notebook for running GNN on Track Parameter regression task

- [X] Load node regression model
- [X] This model should have attention, with output of fully connected layers per node
- [X] Model should take `regression_base` as base class, which runs (e.g.) MSE loss on each node vs. pT
- [X] Train

This may not train well, so we can tweak with some more sophistication:

- [ ] Load node+edge regression model
- [ ] This model is as above, but includes edge classification to stabilise training
- [ ] Model should take `regression_base` as base class, where now we can turn on/off the edge class. stabilisation with a hyperparameter
- [ ] Train

## Pytorch Lightning Model

As in the case of metric learning, we store all of the model logic in Pytorch Lightning modules. We import this class.

In [3]:
from LightningModules.GNN.Models.agnn_regression import AGNNRegression

### Construct PyLightning model

An ML model typically has many knobs to turn, as well as locations of data, some training preferences, and so on. For convenience, let's put all of these parameters into a YAML file and load it.

In [4]:
with open("example_gnn.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

We plug these parameters into a constructor of the `LayerlessEmbedding` Lightning Module. This doesn't **do** anything yet - merely creates the object.

In [5]:
model = AGNNRegression(hparams)

## Train Node-only GNN

We train in exactly the same way as we trained the metric learning model. The only difference is that this trainer may take a little longer, since it is spreading information through a graph N times in each training step.

In [6]:
trainer = Trainer(gpus=1, max_epochs=30)
trainer.fit(model)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Set SLURM handle signals.

  | Name           | Type        | Params
-----------------------------------------------
0 | input_network  | Sequential  | 9.0 K 
1 | edge_network   | EdgeNetwork | 18.6 K
2 | node_network   | NodeNetwork | 17.2 K
3 | output_network | Sequential  | 12.7 K
-----------------------------------------------
57.5 K    Trainable params
0         Non-trainable params
57.5 K    Total params
0.230     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  edge_eff = torch.tensor(edge_true_positive / edge_true)
  edge_pur = torch.tensor(edge_true_positive / edge_positive)






HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

### Test GNN

In [7]:
trainer.test()



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

Step: {'loss': tensor(1.1561, device='cuda:0'), 'edge_preds': array([False, False, False, ..., False, False, False]), 'edge_truth': array([0., 1., 0., ..., 1., 1., 1.], dtype=float32), 'node_accuracy': tensor(0.1650, device='cuda:0')}
Step: {'loss': tensor(5.7185, device='cuda:0'), 'edge_preds': array([False, False, False, ..., False, False, False]), 'edge_truth': array([0., 1., 1., ..., 1., 1., 0.], dtype=float32), 'node_accuracy': tensor(0.1651, device='cuda:0')}
Step: {'loss': tensor(2.2994, device='cuda:0'), 'edge_preds': array([False, False, False, ..., False, False, False]), 'edge_truth': array([0., 0., 0., ..., 1., 0., 0.], dtype=float32), 'node_accuracy': tensor(0.1649, device='cuda:0')}
Step: {'loss': tensor(4.2551, device='cuda:0'), 'edge_preds': array([False, False, False, ..., False, False, False]), 'edge_truth': array([1., 1., 0., ..., 1., 0., 0.], dtype=float32), 'node_accuracy': tensor(0.1576, device='cuda:0')}
Step: {'loss': tensor(1.7970, device='cuda:0'), 'edge_preds'

[{'val_loss': 3.139658212661743,
  'edge_eff': 0.0,
  'edge_pur': nan,
  'node_accuracy': 0.16130855679512024,
  'current_lr': 2.700000004551839e-05}]

In [8]:
test_results = trainer.test(ckpt_path=None)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

Step: {'loss': tensor(1.1561, device='cuda:0'), 'edge_preds': array([False, False, False, ..., False, False, False]), 'edge_truth': array([0., 1., 0., ..., 1., 1., 1.], dtype=float32), 'node_accuracy': tensor(0.1650, device='cuda:0')}
Step: {'loss': tensor(5.7185, device='cuda:0'), 'edge_preds': array([False, False, False, ..., False, False, False]), 'edge_truth': array([0., 1., 1., ..., 1., 1., 0.], dtype=float32), 'node_accuracy': tensor(0.1651, device='cuda:0')}
Step: {'loss': tensor(2.2994, device='cuda:0'), 'edge_preds': array([False, False, False, ..., False, False, False]), 'edge_truth': array([0., 0., 0., ..., 1., 0., 0.], dtype=float32), 'node_accuracy': tensor(0.1649, device='cuda:0')}
Step: {'loss': tensor(4.2551, device='cuda:0'), 'edge_preds': array([False, False, False, ..., False, False, False]), 'edge_truth': array([1., 1., 0., ..., 1., 0., 0.], dtype=float32), 'node_accuracy': tensor(0.1576, device='cuda:0')}
Step: {'loss': tensor(1.7970, device='cuda:0'), 'edge_preds'

## Train Node+Edge GNN