## _Inference after GNN Stage_

**_Inference_** is done using callbacks defined in the **_LightningModules/GNN/Models/inference.py_**. The callbacks run during the **_test_step()_** _a.k.a_ model _**evalution**_.

### How to run _Inference_?

1. _`traintrack config/pipeline_quickstart.yaml`_. One can use `--inference` flag to run only the `test_step()` (Should work, but failed.)
2. _`infer.ipynb`_ notebook runs the _pl.Trainer().test()_

In [1]:
import sys, os, glob, yaml

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
import pprint
from tqdm import tqdm
import trackml.dataset

In [4]:
import torch
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import itertools

In [5]:
from LightningModules.GNN import InteractionGNN
from LightningModules.GNN import GNNBuilder, GNNMetrics
from LightningModules.GNN.Models.infer import GNNTelemetry

In [6]:
pp = pprint.PrettyPrinter(indent=2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
os.environ['EXATRKX_DATA'] = os.path.abspath(os.curdir)

## _Classifier Evaluation_

Metrics to evaluate the GNN networks:

- Accuracy/ACC = $TP+TN/TP+TN+FP+FN$
- sensitivity, recall, hit rate, or true positive rate ($TPR = 1 - FNR$)
- specificity, selectivity or true negative rate ($TNR = 1 - FPR$)
- miss rate or false negative rate ($FNR = 1 - TPR$)
- fall-out or false positive rate ($FPR = 1 - TNR$)
- F1-score = $2 \times (\text{PPV} \times \text{TPR})/(\text{PPV} + \text{TPR})$
- Efficiency/Recall/Sensitivity/Hit Rate: $TPR = TP/(TP+FN)$
- Purity/Precision/Positive Predictive Value: $PPV = TP/(TP+FP$
- AUC-ROC Curve $\equiv$ FPR ($x-$axis) v.s. TPR ($y-$axis) plot
- AUC-PRC Curve $\equiv$ TPR ($x-$axis) v.s. PPV ($y-$axis) plot

### _Test Dataset_

### _Load Checkpoint_

Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. We have checkpoint stored after training is finished.

```
# load a LightningModule along with its weights & hyperparameters from a checkpoint
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.input_dir)
```

Note that we have saved our hyperparameters when our **LightningModule** was initialized i.e. `self.save_hyperparameters(hparams)`

```
# hyperparameters are saved to the “hyper_parameters” key in the checkpoint, to access them
checkpoint = torch.load(path/to/checkpoint, map_location=device)
print(checkpoint["hyper_parameters"])
```

One can also initialize the model with different hyperparameters (if they are saved).


For more details, consult [Lighting Checkpointing](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing.html).

### _Get Hyperparameter Config File_

- Either from the configs folder 
- Or extract it from the checkpoint, favoured if model is trained and evaluated on two different machines.

In [8]:
# load processing config file (trusted source)
config = None
config_file = os.path.join(os.curdir, 'LightningModules/GNN/configs/train_alldata_GNN.yaml')
with open(config_file) as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader) # equiv: yaml.full_load(f)
    except yaml.YAMLError as e:
        print(e)

In [9]:
# pp.pprint(config)

In [10]:
# Load Model Checkpoint
ckpnt_path = "run_all/lightning_models/lightning_checkpoints/GNNStudy/version_1/checkpoints/last.ckpt"
checkpoint = torch.load(ckpnt_path, map_location=device)
pp.pprint(checkpoint.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])


In [11]:
# View Hyperparameters
hparams = checkpoint["hyper_parameters"]

In [12]:
# One Can Modify Hyperparameters
hparams["checkpoint_path"] = ckpnt_path
hparams["input_dir"] = "run/feature_store"
hparams["output_dir"] = "run/gnn_processed"
hparams["artifact_library"] = "lightning_models/lightning_checkpoints"
hparams["train_split"] = [0, 0, 10000]
hparams["map_location"] = device

In [13]:
# View Hyperparameters
# pp.pprint(hparams)

In [14]:
# Init InteractionGNN
model = InteractionGNN(hparams)
model = model.load_from_checkpoint(**hparams)

In [15]:
model.state_dict

<bound method Module.state_dict of InteractionGNN(
  (node_encoder): Sequential(
    (0): Linear(in_features=3, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
  )
  (edge_encoder): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
  )
  (edge_network): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_featu

In [16]:
# Init Lightning Trainer
# trainer = pl.Trainer(callbacks=[GNNTelemetry()])

In [17]:
# Run Test Loop
# trainer.test(model=model, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None)

## _Prediction_

- Create a DataModule
- Get Predict Dataloader
- Load a Checkpoint and Predict (Alternatively use `predict_step()` in LightningModule, )

In [72]:
import typing
from LightningModules.GNN.utils.data_utils import split_datasets, load_dataset

In [74]:
class SttDataModule(pl.LightningDataModule):
    """"DataModules are a way of decoupling data-related hooks from the LightningModule"""
    def __init__(self, hparams):
        super().__init__()
        
        # Save hyperparameters
        self.save_hyperparameters(hparams)
        
        # Set workers from hparams
        self.n_workers = (
            self.hparams["n_workers"]
            if "n_workers" in self.hparams
            else len(os.sched_getaffinity(0))
        )
        
        self.data_split = (
            self.hparams["train_split"]
            if "train_split" in self.hparams
            else [0,0,5000]
        )
        
        self.trainset, self.valset, self.testset = None, None, None
        self.predset = None
        
        
    def print_params(self):
        pp.pprint(self.hparams)
        
    def setup(self, stage=None):
        
        if stage == "fit" or stage is None:
            self.trainset, self.valset, self.testset = split_datasets(**self.hparams)

        if stage == "test" or stage is None:
            print("Number of Test Events: ", self.hparams['train_split'][2])
            self.testset = load_dataset(self.hparams["input_dir"], self.data_split[2])
            
        if stage == "pred" or stage is None:
            print("Number of Pred Events: ", self.hparams['train_split'][2])
            self.predset = load_dataset(self.hparams["input_dir"], self.data_split[2])

    def train_dataloader(self):
        if self.trainset is not None:
            return DataLoader(
                self.trainset, batch_size=1, num_workers=self.n_workers
            )  # , pin_memory=True, persistent_workers=True)
        else:
            return None

    def val_dataloader(self):
        if self.valset is not None:
            return DataLoader(
                self.valset, batch_size=1, num_workers=self.n_workers
            )  # , pin_memory=True, persistent_workers=True)
        else:
            return None

    def test_dataloader(self):
        if self.testset is not None:
            return DataLoader(
                self.testset, batch_size=1, num_workers=self.n_workers
            )  # , pin_memory=True, persistent_workers=True)
        else:
            return None

    #def predict_dataloader(self):
    #    if self.predset is not None:
    #        return DataLoader(
    #            self.predset, batch_size=1, num_workers=self.n_workers
    #        )  # , pin_memory=True, persistent_workers=True)
    #    else:
    #        return None

In [70]:
# Prepare LightningDataModule (Error)
dm = SttDataModule(hparams)

In [38]:
# dm.print_params()

In [39]:
dm.setup(stage="test")

Number of Test Events:  10000


In [43]:
# loop over test_dataloader
# for batch in dm.test_dataloader():
#    print(batch)

In [44]:
# get test_dataloader
test_dataloader = dm.test_dataloader()

In [45]:
# get one batch
batch = next(iter(test_dataloader))

In [46]:
batch

DataBatch(x=[170, 3], pid=[170], layers=[170], event_file=[1], hid=[170], pt=[170], modulewise_true_edges=[2, 160], layerwise_true_edges=[2, 162], edge_index=[2, 1039], y_pid=[1039], batch=[170], ptr=[2])

In [48]:
# 1 - Helper Function
def get_input_data(batch):
    input_data = batch.x
    input_data[input_data != input_data] = 0

    return input_data


# 2 - Helper Function
def handle_directed(batch, edge_sample, truth_sample, directed=False):

    edge_sample = torch.cat([edge_sample, edge_sample.flip(0)], dim=-1)
    truth_sample = truth_sample.repeat(2)

    if directed:
        direction_mask = batch.x[edge_sample[0], 0] < batch.x[edge_sample[1], 0]
        edge_sample = edge_sample[:, direction_mask]
        truth_sample = truth_sample[direction_mask]

    return edge_sample, truth_sample

In [57]:
batch

DataBatch(x=[170, 3], pid=[170], layers=[170], event_file=[1], hid=[170], pt=[170], modulewise_true_edges=[2, 160], layerwise_true_edges=[2, 162], edge_index=[2, 1039], y_pid=[1039], batch=[170], ptr=[2])

In [50]:
# prepare model input
batch = next(iter(test_dataloader))
truth = batch.y_pid
edge_sample, truth_sample = handle_directed(batch, batch.edge_index, truth)
input_data = get_input_data(batch)

In [52]:
model.eval();

In [55]:
with torch.no_grad():
    output = model(input_data, edge_sample).squeeze()

In [56]:
output.shape

torch.Size([2078])

In [58]:
score = torch.sigmoid(output)

In [65]:
score

tensor([9.9993e-01, 1.1721e-04, 9.6153e-05,  ..., 8.4563e-05, 8.0483e-05,
        7.9189e-05])

In [66]:
batch.scores = score

In [60]:
preds = score > 0.5

In [61]:
preds

tensor([ True, False, False,  ..., False, False, False])

In [67]:
batch

DataBatch(x=[170, 3], pid=[170], layers=[170], event_file=[1], hid=[170], pt=[170], modulewise_true_edges=[2, 160], layerwise_true_edges=[2, 162], edge_index=[2, 1039], y_pid=[1039], batch=[170], ptr=[2], scores=[2078])

### _TensorBoard Logger_

In [69]:
# Load TensorBoard notebook extension
# %load_ext tensorboard

In [68]:
# %tensorboard --logdir=run_all/lightning_models/lightning_checkpoints/GNNStudy