## _Inference after GNN Stage_

**_Inference_** is done using callbacks defined in the **_LightningModules/GNN/Models/inference.py_**. The callbacks run during the **_test_step()_** _a.k.a_ model _**evalution**_.

### How to run _Inference_?

1. _`traintrack config/pipeline_quickstart.yaml`_. One can use `--inference` flag to run only the `test_step()` (Should work, but failed.)
2. _`infer.ipynb`_ notebook runs the _pl.Trainer().test()_

In [None]:
import sys, os, glob, yaml

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
from tqdm import tqdm
import trackml.dataset

In [None]:
import torch
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import itertools

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
os.environ['EXATRKX_DATA'] = os.path.abspath(os.curdir)

In [None]:
from LightningModules.GNN import InteractionGNN
from LightningModules.GNN import GNNBuilder, GNNMetrics

In [None]:
from LightningModules.GNN.Models.infer import GNNTelemetry

## _Classifier Evaluation_

Metrics to evaluate the GNN networks:

- Accuracy/ACC = $TP+TN/TP+TN+FP+FN$
- sensitivity, recall, hit rate, or true positive rate ($TPR = 1 - FNR$)
- specificity, selectivity or true negative rate ($TNR = 1 - FPR$)
- miss rate or false negative rate ($FNR = 1 - TPR$)
- fall-out or false positive rate ($FPR = 1 - TNR$)
- F1-score = $2 \times (\text{PPV} \times \text{TPR})/(\text{PPV} + \text{TPR})$
- Efficiency/Recall/Sensitivity/Hit Rate: $TPR = TP/(TP+FN)$
- Purity/Precision/Positive Predictive Value: $PPV = TP/(TP+FP$
- AUC-ROC Curve $\equiv$ FPR ($x-$axis) v.s. TPR ($y-$axis) plot
- AUC-PRC Curve $\equiv$ TPR ($x-$axis) v.s. PPV ($y-$axis) plot

### _Test Dataset_

### _Load Checkpoint_

Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. We have checkpoint stored after training is finished.

```
# load a LightningModule along with its weights & hyperparameters from a checkpoint
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.input_dir)
```

Note that we have saved our hyperparameters when our **LightningModule** was initialized i.e. `self.save_hyperparameters(hparams)`

```
# hyperparameters are saved to the “hyper_parameters” key in the checkpoint, to access them
checkpoint = torch.load(path/to/checkpoint, map_location=device)
print(checkpoint["hyper_parameters"])
```

One can also initialize the model with different hyperparameters (if they are saved).


For more details, consult [Lighting Checkpointing](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing.html).

### _Get Hyperparameter Config File_

- Either from the configs folder 
- Or extract it from the checkpoint, favoured if model is trained and evaluated on two different machines.

In [None]:
# load processing config file (trusted source)
config = None
config_file = os.path.join(os.curdir, 'LightningModules/GNN/configs/train_alldata_GNN.yaml')
with open(config_file) as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader) # equiv: yaml.full_load(f)
    except yaml.YAMLError as e:
        print(e)

In [None]:
# print(config)

In [None]:
# Load Model Checkpoint
ckpnt_path = "run_all/lightning_models/lightning_checkpoints/GNNStudy/version_1/checkpoints/last.ckpt"
checkpoint = torch.load(ckpnt_path, map_location=device)
config = checkpoint["hyper_parameters"]

In [None]:
# View Hyperparameters
# print(config)

In [None]:
# One Can Modify Hyperparameters
config["checkpoint_path"] = ckpnt_path
config["input_dir"] = "run/feature_store"
config["output_dir"] = "run/gnn_processed"
config["artifact_library"] = "lightning_models/lightning_checkpoints"
config["datatype_split"] = [0, 0, 10000]
config["map_location"] = device

In [None]:
# View Hyperparameters (New)
# print(config)

In [None]:
# Init InteractionGNN with New Config
model = InteractionGNN(config)

In [None]:
# Load Checkpoint with New Config
model = model.load_from_checkpoint(**config)

In [None]:
# Init Lightning Trainer
trainer = pl.Trainer(callbacks=[GNNBuilder()])

In [None]:
# Run Test Loop
trainer.test(model=model, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None)

### _Test with LightningDataModule_

In [None]:
from LightningModules.GNN.utils.data_utils import split_datasets, load_dataset

In [None]:
class SttDataModule(pl.LightningDataModule):
    def __init__(self, hparams):
        super().__init__()

        # Set workers from hparams
        self.n_workers = (
            self.hparams["n_workers"]
            if "n_workers" in self.hparams
            else len(os.sched_getaffinity(0))
        )

        # Instance Variables
        self.train_split = self.hparams["train_split"]
        self.trainset, self.valset, self.testset = None, None, None

    def setup(self, stage: str):

        if stage == "fit":
            self.trainset, self.valset, self.testset = split_datasets(**self.hparams)

        if stage == "test":
            print("Number of Test Events: ", self.hparams["train_split"][2])
            self.testset = load_dataset(self.hparams["input_dir"], self.train_split[2])

    def train_dataloader(self):
        if self.trainset is not None:
            return DataLoader(
                self.trainset, batch_size=1, num_workers=self.n_workers
            )  # , pin_memory=True, persistent_workers=True)
        else:
            return None

    def val_dataloader(self):
        if self.valset is not None:
            return DataLoader(
                self.valset, batch_size=1, num_workers=self.n_workers
            )  # , pin_memory=True, persistent_workers=True)
        else:
            return None

    def test_dataloader(self):
        if self.testset is not None:
            return DataLoader(
                self.testset, batch_size=1, num_workers=self.n_workers
            )  # , pin_memory=True, persistent_workers=True)
        else:
            return None

In [None]:
# Prepare LightningDataModule (Error)
# dm = SttDataModule(config)

In [None]:
# dm.setup(stage='test')

In [None]:
# test_dataloaders = dm.test_dataloader

In [None]:
# trainer.test(model=model, dataloaders=None, ckpt_path=None, verbose=True, datamodule=dm)

### _TensorBoard Logger_

In [None]:
# Load TensorBoard notebook extension
%load_ext tensorboard

In [None]:
%tensorboard --logdir=run_all/lightning_models/lightning_checkpoints/GNNStudy