## _Inference after DNN Stage_

**_Inference_** is done using callbacks defined in the **_LightningModules/GNN/Models/inference.py_**. The callbacks run during the **_test_step()_** _a.k.a_ model _**evalution**_.

### How to run _Inference_?

1. _`traintrack config/pipeline_quickstart.yaml`_. One can use `--inference` flag to run only the `test_step()` (Should work, but failed.)
2. _`infer.ipynb`_ notebook runs the _pl.Trainer().test()_

In [None]:
import sys, os, glob, yaml

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
from tqdm import tqdm
import trackml.dataset

In [None]:
import torch
import torchmetrics
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import itertools

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
os.environ['EXATRKX_DATA'] = os.path.abspath(os.curdir)

In [None]:
from LightningModules.DNN import EdgeClassifier, EdgeClassifier_BN, EdgeClassifier_LN
from LightningModules.DNN import GNNBuilder, GNNMetrics
from LightningModules.DNN.Models.infer import GNNTelemetry

## _Classifier Evaluation_

Metrics to evaluate the GNN networks:

- Accuracy/ACC = $TP+TN/TP+TN+FP+FN$
- sensitivity, recall, hit rate, or true positive rate ($TPR = 1 - FNR$)
- specificity, selectivity or true negative rate ($TNR = 1 - FPR$)
- miss rate or false negative rate ($FNR = 1 - TPR$)
- fall-out or false positive rate ($FPR = 1 - TNR$)
- F1-score = $2 \times (\text{PPV} \times \text{TPR})/(\text{PPV} + \text{TPR})$
- Efficiency/Recall/Sensitivity/Hit Rate: $TPR = TP/(TP+FN)$
- Purity/Precision/Positive Predictive Value: $PPV = TP/(TP+FP$
- AUC-ROC Curve $\equiv$ FPR ($x-$axis) v.s. TPR ($y-$axis) plot
- AUC-PRC Curve $\equiv$ TPR ($x-$axis) v.s. PPV ($y-$axis) plot


Use `tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()` to directly access TN, FP, FN and TP using Scikit-learn.

### _Test Dataset_

### _Load Checkpoint_

Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. We have checkpoint stored after training is finished.

```
# load a LightningModule along with its weights & hyperparameters from a checkpoint
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.input_dir)
```

Note that we have saved our hyperparameters when our **LightningModule** was initialized i.e. `self.save_hyperparameters(hparams)`

```
# hyperparameters are saved to the “hyper_parameters” key in the checkpoint, to access them
checkpoint = torch.load(path/to/checkpoint, map_location=device)
print(checkpoint["hyper_parameters"])
```

One can also initialize the model with different hyperparameters (if they are saved).


For more details, consult [Lighting Checkpointing](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing.html).

### _Get Checkpoint Hparams_

- Either from the configs folder 
- Or extract it from the checkpoint, favoured if model is trained and evaluated on two different machines.

In [None]:
# load processing config file (trusted source)
config = None
config_file = os.path.join(os.curdir, 'LightningModules/DNN/configs/train_alldata_DNN.yaml')
with open(config_file) as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader) # equiv: yaml.full_load(f)
    except yaml.YAMLError as e:
        print(e)

In [None]:
# print(config)

In [None]:
# Load Model Checkpoint

# Dense with LayerNorm: [1000,2000,2000,2000,1000,1]
# ckpnt_path = "run_all/lightning_models/lightning_checkpoints/DNNStudy/version_0/checkpoints/last.ckpt"

# Dense with BatchNorm: [128,512,128,1024,512,1]
ckpnt_path = "run_all/lightning_models/lightning_checkpoints/DNNStudy/fh48qczt/checkpoints/last.ckpt"

# Dense with LayerNorm: [128,128,1024,1024,128,1]
# ckpnt_path = "run_all/lightning_models/lightning_checkpoints/DNNStudy/avvy18g5/checkpoints/last.ckpt"

checkpoint = torch.load(ckpnt_path, map_location=device)
config = checkpoint["hyper_parameters"]

In [None]:
print(checkpoint["hyper_parameters"])

In [None]:
# View Hyperparameters
# print(config)

In [None]:
# One Can Modify Hyperparameters
config["checkpoint_path"] = ckpnt_path
config["input_dir"] = "run_quick/feature_store"
config["output_dir"] = "run_quick/dnn_processed_bn"
config["artifact_library"] = "lightning_models/lightning_checkpoints"
config["train_split"] = [0, 0, 20000]
config["map_location"] = device

In [None]:
# View Hyperparameters (Modified)
print(config)

### _Get Checkpoint Model_

In [None]:
# Init EdgeClassifier with New Config
model = EdgeClassifier_BN(config)

In [None]:
# model.hparams

In [None]:
# Load Checkpoint with New Config (It will Provide Path and Other Parameters, Most will be Overwritten)
model = model.load_from_checkpoint(**config)

### _(1) - Inference: Callbacks_

* _Test with LightingModule_

In [None]:
# Lightning Trainer
trainer = pl.Trainer(callbacks=[GNNBuilder()])

In [None]:
# Run TestStep
trainer.test(model=model, verbose=True)

* _Test with LightningDataModule_

In [None]:
# from Predict import SttDataModule

In [None]:
# Prepare LightningDataModule
# dm = SttDataModule(config)

In [None]:
# dm.setup(stage='test')
# test_dataloaders = dm.test_dataloader

In [None]:
# Run TestStep with LightningDataModule
# trainer.test(model=model, dataloaders=None, ckpt_path=None, verbose=True, datamodule=dm)

### _(2) - Inference: Manual_

In [None]:
# from Predict import eval_model

- How to get data using LightningModuel?

In [None]:
# run setup() for datasets
# model.setup(stage="fit")

- Run _`eval_model()`_ on _`test_dataloader()`_

### _(3) - Inference: BNNBuilder_

If _**GNNBuilder**_ callback has been run during training, just load data from `dnn_processed/test` and extract `scores` and `y_pid ~ truth` and simply run the following metrics.

- _Load all `truth` and `scores` from the `testset` from the `DNN` stage_

In [None]:
# save scores and truths as .npy files
np.save("bn_dnn_scores.npy", scores.numpy())
np.save("bn_dnn_truths.npy", truths.numpy())

In [None]:
# torch to numpy
scores = scores.numpy()
truths = truths.numpy()

In [None]:
metrics = compute_metrics(scores,truths,threshold=0.5)

In [None]:
# Curves
# metrics.prc_precision, metrics.prc_recall, metrics.prc_thresh
# metrics.roc_tpr, metrics.roc_fpr, metrics.roc_thresh

In [None]:
metrics.accuracy

In [None]:
metrics.recall

In [None]:
metrics.precision

In [None]:
metrics.f1

### _Plot Test Event_

### _TensorBoard Logger_