## _Inference after DNN Stage_

**_Inference_** is done using callbacks defined in the **_LightningModules/GNN/Models/inference.py_**. The callbacks run during the **_test_step()_** _a.k.a_ model _**evalution**_.

### How to run _Inference_?

1. _`traintrack config/pipeline_quickstart.yaml`_. One can use `--inference` flag to run only the `test_step()` (Should work, but failed.)
2. _`infer.ipynb`_ notebook runs the _pl.Trainer().test()_

In [None]:
import sys, os, glob, yaml

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
from tqdm import tqdm
import trackml.dataset

In [None]:
import torch
import torchmetrics
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import itertools

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
os.environ['EXATRKX_DATA'] = os.path.abspath(os.curdir)

In [None]:
from LightningModules.GNN import InteractionGNN
from LightningModules.GNN import GNNBuilder, GNNMetrics
from LightningModules.GNN.Models.infer import GNNTelemetry

## _Classifier Evaluation_

Metrics to evaluate the GNN networks:

- Accuracy/ACC = $TP+TN/TP+TN+FP+FN$
- sensitivity, recall, hit rate, or true positive rate ($TPR = 1 - FNR$)
- specificity, selectivity or true negative rate ($TNR = 1 - FPR$)
- miss rate or false negative rate ($FNR = 1 - TPR$)
- fall-out or false positive rate ($FPR = 1 - TNR$)
- F1-score = $2 \times (\text{PPV} \times \text{TPR})/(\text{PPV} + \text{TPR})$
- Efficiency/Recall/Sensitivity/Hit Rate: $TPR = TP/(TP+FN)$
- Purity/Precision/Positive Predictive Value: $PPV = TP/(TP+FP$
- AUC-ROC Curve $\equiv$ FPR ($x-$axis) v.s. TPR ($y-$axis) plot
- AUC-PRC Curve $\equiv$ TPR ($x-$axis) v.s. PPV ($y-$axis) plot


Use `tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()` to directly access TN, FP, FN and TP using Scikit-learn.

### _Test Dataset_

### _Load Checkpoint_

Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. We have checkpoint stored after training is finished.

```
# load a LightningModule along with its weights & hyperparameters from a checkpoint
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.input_dir)
```

Note that we have saved our hyperparameters when our **LightningModule** was initialized i.e. `self.save_hyperparameters(hparams)`

```
# hyperparameters are saved to the “hyper_parameters” key in the checkpoint, to access them
checkpoint = torch.load(path/to/checkpoint, map_location=device)
print(checkpoint["hyper_parameters"])
```

One can also initialize the model with different hyperparameters (if they are saved).


For more details, consult [Lighting Checkpointing](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing.html).

### _Get Checkpoint Hparams_

- Either from the configs folder 
- Or extract it from the checkpoint, favoured if model is trained and evaluated on two different machines.

In [None]:
# load processing config file (trusted source)
config = None
config_file = os.path.join(os.curdir, 'LightningModules/DNN/configs/train_alldata_DNN.yaml')
with open(config_file) as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader) # equiv: yaml.full_load(f)
    except yaml.YAMLError as e:
        print(e)

In [None]:
# print(config)

In [None]:
# Load Model Checkpoint
ckpnt_path = "run_all/lightning_models/lightning_checkpoints/GNNStudy/a58b2mlx/checkpoints/last.ckpt"
# ckpnt_path = "run_all/lightning_models/lightning_checkpoints/HypGNN/uibb0ir9/checkpoints/last.ckpt"

checkpoint = torch.load(ckpnt_path, map_location=device)
config = checkpoint["hyper_parameters"]

In [None]:
print(checkpoint["hyper_parameters"])

In [None]:
# View Hyperparameters
# print(config)

In [None]:
# One Can Modify Hyperparameters
config["checkpoint_path"] = ckpnt_path
config["input_dir"] = "run_quick/feature_store"
config["output_dir"] = "run_quick/gnn_processed"
config["artifact_library"] = "lightning_models/lightning_checkpoints"
config["train_split"] = [0, 0, 10000]
config["map_location"] = device

In [None]:
# View Hyperparameters (Modified)
# print(config)

### _Get Checkpoint Model_

In [None]:
# Init EdgeClassifier with New Config
model = InteractionGNN(config)

In [None]:
# model.hparams

In [None]:
# Load Checkpoint with New Config (It will Provide Path and Other Parameters, Most will be Overwritten)
model = model.load_from_checkpoint(**config)

### _(1) - Inference: Callbacks_

* _Test with LightingModule_

In [None]:
# Lightning Trainer
trainer = pl.Trainer(callbacks=[GNNBuilder()])

In [None]:
# Run TestStep
# trainer.test(model=model, verbose=True)

* _Test with LightningDataModule_

In [None]:
# from Predict import SttDataModule

In [None]:
# Prepare LightningDataModule
# dm = SttDataModule(config)

In [None]:
# dm.setup(stage='test')
# test_dataloaders = dm.test_dataloader

In [None]:
# Run TestStep with LightningDataModule
# trainer.test(model=model, dataloaders=None, ckpt_path=None, verbose=True, datamodule=dm)

### _(2) - Inference: Manual_

In [None]:
# from Predict import eval_model

- _How to get data using LightningModuel?_

In [None]:
# run setup() for datasets
# model.setup(stage="fit")

In [None]:
# Method 1: Directly Get Test Dataset
# testset = model.testset

# Get singel Batch
# batch = testset[0]

# OR, loop over
# for index, batch in enumerate(testset):
# for batch in testset:
#    print(index, batch)

In [None]:
# Method 2: Directly Get Test Dataloader
# test_dataloader = model.test_dataloader()

# Get singel Batch
# batch = next(iter(test_dataloader))

# OR, loop over
# for batch_idx, batch in enumerate(test_dataloader):
# for batch in test_dataloader:
#    print(batch)

- _Run `eval_model()` on `test_dataloader()`_

In [None]:
# get testset or test_dataloader
# testset = model.testset
# test_dataloader = model.test_dataloader()

In [None]:
# evaluate model, returns torch tensors
# scores, truths = eval_model(model, test_dataloader)

### _(3) - Inference: BNNBuilder_

_If **GNNBuilder** callback has been run during training, just load data from `dnn_processed/test` and extract `scores` and `y_pid ~ truth` and simply run the following metrics_.

In [None]:
# fetch all files
inputdir = "run_all/gnn_processed/test"
gnn_files = sorted(glob.glob(os.path.join(inputdir, "*")))

- _Load all `truth` and `scores` from the `testset` from the `DNN` stage_

In [None]:
scoresl, truthsl = [], []

for e in range(len(gnn_files)):
    
    # logging
    if e !=0 and e%1000==0:
        print("Processed Batches: ", e)
    
    gnn_data = torch.load(gnn_files[e], map_location=device)
    
    truth = gnn_data.y_pid
    score = gnn_data.scores
    score = score[:truth.size(0)]
    
    # append each batch
    scoresl.append(score)
    truthsl.append(truth)

In [None]:
scores = torch.cat(scoresl)
truths = torch.cat(truthsl)

### _Evaluation Metrics_

In [None]:
from src.metric_utils import compute_metrics, plot_metrics
from src.metric_utils import plot_roc, plot_prc, plot_prc_thr, plot_epc, plot_epc_cut, plot_output

In [None]:
# save scores and truths as .npy files
# np.save("gnn_scores.npy", scores.numpy())
# np.save("gnn_truths.npy", truths.numpy())

In [None]:
# torch to numpy
scores = scores.numpy()
truths = truths.numpy()

In [None]:
metrics = compute_metrics(scores,truths,threshold=0.5)

In [None]:
# Curves
# metrics.prc_precision, metrics.prc_recall, metrics.prc_thresh
# metrics.roc_tpr, metrics.roc_fpr, metrics.roc_thresh

In [None]:
metrics.accuracy

In [None]:
metrics.recall

In [None]:
metrics.precision

In [None]:
metrics.f1

### _(a) - Plot Metrics_

In [None]:
outname = "gnn"

In [None]:
plot_metrics(scores,truths, metrics, name=outname)

In [None]:
# ROC Curve
# plot_roc(metrics, name=outname)

In [None]:
# PR Curve
# plot_prc(metrics, name=outname)

In [None]:
# Built from PRC Curve
# plot_prc_thr(metrics, name=outname)

In [None]:
# EP Curve from ROC
# plot_epc(metrics, name=outname)

In [None]:
# Built from ROC Curve
# plot_epc_cut(metrics, name=outname)

In [None]:
# Model output: True and False
# plot_output(scores, truths, threshold=0.5, name=outname)

### _(b) - S/B Suppression_

Background rejection rate (1/FPR) is given as $1/\epsilon_{bkg}$ where $\epsilon_{bkg}$ is the fraction of fake edges that pass the classification requirement. Signal efficiency (TPR ~ Recall) ($\epsilon_{sig}$) is defined as the number of true edges above a given classification score cut over the total number of true edges. What we have?

- Signal Efficiency = $\epsilon_{sig}$ = TPR ~ Recall 
- Background Rejection = $1 - \epsilon_{bkg}$ ???
- Background Rejection Rate = $1/\epsilon_{bkg}$ = 1/FPR


First apply a edge score cut to binarized the `scores`, we will call it `preds`. The count number of false or true edges that pass this cut. Then calculated background rejection rate and signal efficiency. For making a plot one can do calculations in batch by batch mode on the test dataset.

In [None]:
sig = metrics.roc_tpr

In [None]:
fpr = metrics.roc_fpr

In [None]:
bkg_rejection = 1/fpr

In [None]:
# cut off eff < 0.2 or 0.5
sig_mask = sig > 0.3

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8,6))
ax.plot(sig[sig_mask], bkg_rejection[sig_mask], color="blue")

# Axes Params
ax.set_xlabel("Signal Efficiency ($\epsilon_{sig}$)", fontsize=16)
ax.set_ylabel("Background Rejection ($1/\epsilon_{bkg}$)", fontsize=16)
ax.set_yscale('log')
ax.tick_params(axis='both', which='major', labelsize=12)
ax.tick_params(axis='both', which='minor', labelsize=12)
ax.grid(True)

# Figure Params
fig.tight_layout()
fig.savefig(outname+"_SB.pdf")

### _Plot Test Event_

### _TensorBoard Logger_