## _Evaluation Metrics_

_If **GNNBuilder** callback has been run during training, just load data from `dnn_processed/test` and extract `scores` and `y_pid ~ truth` and simply run the following metrics_.

In [None]:
import sys, os, glob, yaml

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
from tqdm import tqdm
import trackml.dataset

In [None]:
import torch
import torchmetrics
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import itertools

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# append parent dir
sys.path.append('..')

### Evaluation Definitions

Metrics to evaluate the GNN networks:

- Accuracy/ACC = $TP+TN/TP+TN+FP+FN$
- sensitivity, recall, hit rate, or true positive rate ($TPR = 1 - FNR$)
- specificity, selectivity or true negative rate ($TNR = 1 - FPR$)
- miss rate or false negative rate ($FNR = 1 - TPR$)
- fall-out or false positive rate ($FPR = 1 - TNR$)
- F1-score = $2 \times (\text{PPV} \times \text{TPR})/(\text{PPV} + \text{TPR})$
- Efficiency/Recall/Sensitivity/Hit Rate: $TPR = TP/(TP+FN)$
- Purity/Precision/Positive Predictive Value: $PPV = TP/(TP+FP$
- AUC-ROC Curve $\equiv$ FPR ($x-$axis) v.s. TPR ($y-$axis) plot
- AUC-PRC Curve $\equiv$ TPR ($x-$axis) v.s. PPV ($y-$axis) plot


Use _`tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()`_ to directly access TN, FP, FN and TP using Scikit-learn.

### Classifier Evaluation

In [None]:
# fetch all files
# inputdir = "run_all/gnn_processed/test"
# inputdir = "run_all/dnn_processed_bn/test"
# inputdir = "run_all/dnn_processed_ln/test"

# HypGNN (FWP + Filtering)
# inputdir = "run_all/fwp_gnn_processed_nf/pred"

# HypGNN (FWP + No Filtering)
inputdir = "run_all/fwp_gnn_processed/pred"

In [None]:
test_files = sorted(glob.glob(os.path.join(inputdir, "*")))

In [None]:
# Let's test a event
torch.load(test_files[0], map_location=device)

### _Append Scores and Truths_
- _Load all `truth` and `scores` from the `testset` from the `DNN` stage_

In [None]:
scoresl, truthsl = [], []

for e in range(len(test_files)):
    
    # read test events e.g. gnn_processed/test
    data = torch.load(test_files[e], map_location=device)
    
    # get truths and scores
    truth = data.y_pid
    score = data.scores
    score = score[:truth.size(0)]
    
    # logging
    if e !=0 and (e)%1000==0:
        print("Processed Batches: ", e)
        
    # append each batch
    scoresl.append(score)
    truthsl.append(truth)

In [None]:
scores = torch.cat(scoresl)
truths = torch.cat(truthsl)

In [None]:
# save scores and truths as .npy files
# np.save("scores.npy", scores.numpy())
# np.save("truths.npy", truths.numpy())

### _Compute Metrics_

In [None]:
from src.metric_utils import compute_metrics, plot_metrics
from src.metric_utils import plot_roc, plot_prc, plot_prc_thr, plot_epc, plot_epc_cut, plot_output

In [None]:
# torch to numpy
scores = scores.numpy()
truths = truths.numpy()

In [None]:
metrics = compute_metrics(scores,truths,threshold=0.5)

In [None]:
metrics.accuracy

In [None]:
metrics.recall

In [None]:
metrics.precision

In [None]:
metrics.f1

### _(a) - Plot Metrics_

In [None]:
outname = "fwp"

In [None]:
# plot_metrics(scores,truths, metrics, name=outname)

In [None]:
# ROC Curve
# plot_roc(metrics, name=outname)

In [None]:
# PR Curve
# plot_prc(metrics, name=outname)

In [None]:
# Built from PRC Curve
# plot_prc_thr(metrics, name=outname)

In [None]:
# EP Curve from ROC
plot_epc(metrics, name=outname)

In [None]:
# Built from ROC Curve
plot_epc_cut(metrics, name=outname)

In [None]:
# Model output: True and False
plot_output(scores, truths, threshold=0.9, name=outname)

- _filter the **bumpy** region_

In [None]:
preds, targets = scores, truths

In [None]:
# define a mask around this region
mask = np.where((preds < 0.8) & (preds > 0.6))[0]

In [None]:
# filter, preds and targets
preds = preds[mask]
labels = targets[mask]

In [None]:
# Figure & Axes
fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))

# Ploting
binning = dict(bins=25, range=(0, 1), histtype='step', log=True)
axs.hist(preds[labels == True], label='Real', **binning)  # True Class
axs.hist(preds[labels == False], label='Fake', **binning)  # False Class

# Axes Params
# axs.set_title("Classifier Output", fontsize=15)
axs.set_xlabel('Model Output', size=20)
axs.set_ylabel('Counts', size=20)
axs.tick_params(axis='both', which='major', labelsize=12)
axs.tick_params(axis='both', which='minor', labelsize=12)
# axs.set_ylim(ymin=.005)
axs.legend(fontsize=14, loc='upper center')
fig.tight_layout()

### _(b) - S/B Suppression_

Background rejection rate (1/FPR) is given as $1/\epsilon_{bkg}$ where $\epsilon_{bkg}$ is the fraction of fake edges that pass the classification requirement. Signal efficiency (TPR ~ Recall) ($\epsilon_{sig}$) is defined as the number of true edges above a given classification score cut over the total number of true edges. What we have?

- Signal Efficiency = $\epsilon_{sig}$ = TPR ~ Recall 
- Background Rejection = $1 - \epsilon_{bkg}$ ???
- Background Rejection Rate = $1/\epsilon_{bkg}$ = 1/FPR


First apply a edge score cut to binarized the `scores`, we will call it `preds`. The count number of false or true edges that pass this cut. Then calculated background rejection rate and signal efficiency. For making a plot one can do calculations in batch by batch mode on the test dataset.

In [None]:
sig = metrics.roc_tpr

In [None]:
bkg_rejection = 1/metrics.roc_fpr

In [None]:
# cut off eff < 0.2 or 0.5
sig_mask = sig > 0.3

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8,6))
ax.plot(sig[sig_mask], bkg_rejection[sig_mask], label="Interaction GNN", color="blue")

# Axes Params
ax.set_xlabel("Signal Efficiency", fontsize=16)
ax.set_ylabel("Background Rejection Rate", fontsize=16)
ax.set_yscale('log')
ax.tick_params(axis='both', which='major', labelsize=12)
ax.tick_params(axis='both', which='minor', labelsize=12)
ax.grid(True)
ax.legend(fontsize=14, loc='upper right')
    
# Figure Params
fig.tight_layout()
# fig.savefig(outname+"_SB.pdf")

### _(c) - Visualize Model Output_

In [None]:
from src.drawing import detector_layout
from src.utils_math import polar_to_cartesian

In [None]:
e = 15

In [None]:
# load graph
graph = torch.load(test_files[e], map_location=device)

# get truths and scores
truth = graph.y_pid
scores = graph.scores[:truth.size(0)]
edges = graph.edge_index
eid = e

In [None]:
truth.shape, scores.shape, edges.shape

In [None]:
preds, labels = score.numpy(), truth.numpy()

In [None]:
# extract hit information
r, phi, ir = graph.x.T
ir = ir.detach().numpy()*100
x, y = polar_to_cartesian(r.detach().numpy(), phi.detach().numpy())

In [None]:
def draw_sample_xy(hits, edges, preds, labels, cut=0.5, figsize=(16, 16)):
    """"Draw Sample with True and False Edges"""
    
    # coordinate transformation
    r, phi, ir = hits.T
    x, y = polar_to_cartesian(r, phi)
    
    # detector layout
    fig, ax = detector_layout(figsize=figsize)
    
    # Draw the segments
    for j in range(labels.shape[0]):
        
        ptx1 = x[edges[0,j]]
        ptx2 = x[edges[1,j]]
        pty1 = y[edges[0,j]]
        pty2 = y[edges[1,j]]
        
        # False Negatives
        if preds[j] < cut and labels[j] > cut:
            # ax.plot([x[edges[0,j]], x[edges[1,j]]], [y[edges[0,j]], y[edges[1,j]]], '--', c='b')
            ax.plot([ptx1, ptx2], [pty1, pty2], '--', color='b', lw=1.5, alpha=0.9)

        # False Positives
        if preds[j] > cut and labels[j] < cut:
            # ax.plot([x[edges[0,j]], x[edges[1,j]]], [y[edges[0,j]], y[edges[1,j]]], '-', c='r', alpha=preds[j])
            ax.plot([ptx1, ptx2], [pty1, pty2], '-', color='r', lw=1.5, alpha=0.15)

        # True Positives
        if preds[j] > cut and labels[j] > cut:
            # ax.plot([x[edges[0,j]], x[edges[1,j]]], [y[edges[0,j]], y[edges[1,j]]], '-', c='k', alpha=preds[j])
            ax.plot([ptx1, ptx2], [pty1, pty2], '-', color='k', lw=1.5, alpha=0.9)

    return fig, ax

In [None]:
draw_sample_xy(graph.x.detach().numpy(), edges, preds, labels, cut=0.9);