# Evaluation of Graph Neural Network Tracker 

In [None]:
# System imports
import os
import sys
import pprint
from collections import deque

# External imports
import matplotlib.pyplot as plt
import numpy as np

# Limit CPU usage on Jupyter
os.environ['OMP_NUM_THREADS'] = '4'

# Local imports
from nb_utils import (compute_metrics, plot_metrics, draw_sample_xy, draw_sample, load_summaries)

n_phi_sections = 1
feature_scale_r = 15.0
feature_scale_phi =  1.0
feature_scale_z = 50.0

feature_scale = np.array([feature_scale_r, np.pi / n_phi_sections / feature_scale_phi, feature_scale_z])


import yaml
with open('./../examples/configs/belle2_vtx.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


In [None]:
%pwd 

In [None]:
%matplotlib notebook

## Load the trained model

In [None]:
# Need to adjust this to output of graph generation
data_dir = './../examples/data/hitgraphs_belle2_vtx'  

# Need to adjust this to checkpoints from training on graphs
checkpoint_dir = './../examples/data/model_vtx/'

In [None]:
from xtracker.gnn_tracking.ImTracker import ImTracker 


from xtracker.gnn_tracking.TrackingSolver import TrackingSolver
from xtracker.gnn_tracking.TrackingGame import TrackingGame as Game
from xtracker.gnn_tracking.pytorch.NNet import NNetWrapper as NNet
import numpy as np
from xtracker.utils import dotdict
from torch.utils.data import DataLoader
from itertools import cycle
from xtracker.datasets import get_data_loaders
from xtracker.gnn_tracking.TrackingLogic import Board



data_args = dotdict( config['data'] )
tracker_args = dotdict( config['model'] )   

    
# Data loaders for training and validation 
train_data_loader, valid_data_loader = get_data_loaders(**data_args,  input_dir=config['global']['graph_dir'])
assert valid_data_loader is not None
assert train_data_loader is not None

valid_data_loader = cycle(valid_data_loader)
train_data_loader = cycle(train_data_loader)


game = Game(train_data_loader, valid_data_loader)

In [None]:
summaries = load_summaries(checkpoint_dir)

print('\nTraining summaries:')
summaries

In [None]:
# Find the best epoch
best_idx = summaries.pit_nnet_score.idxmax()
summaries.loc[[best_idx]]

## Use neural net tracker 

In [None]:
# Load neural net
n1 = NNet()
n1.load_checkpoint(checkpoint_dir,'best.pth.tar')

# Built a tracker
tracker = ImTracker(game, n1, tracker_args)

## User MC tracker

In [None]:
# Outcomment this line to use the mc solution

tracker =  TrackingSolver(game)

# Evaluate tracker on individual events 

In [None]:
%%time

graph = next(valid_data_loader)  
board = game.getInitBoardFromBatch(graph)


pred, score, trig = tracker.process(board)   
print('score=', score)
print('pred trigger ', trig)
print('true trigger ', board.trig)

draw_sample_xy(board.x * feature_scale, board.edge_index, pred, board.y, cut=0.5, mconly=False, fullonly=False,
              figsize=(9, 9)
)
draw_sample(board.x * feature_scale, board.edge_index, pred, board.y, cut=0.5, mconly=False, fullonly=False, 
           figsize=(9, 6))




In [None]:
for i in range(5):
    graph = next(valid_data_loader)  
    board = game.getInitBoardFromBatch(graph)

    pred, score, trig = tracker.process(board)   
    print('score=', score)

# Evaluate tracker with statistics

In [None]:
def predict_sample(data_loader, game, tracker, verbose=False, n=12):
    
    preds, targets = [], []
    i = 0
   
    for _ in range(n):
      
        graph = next(data_loader)  
        board = game.getInitBoardFromBatch(graph)
        
        pred, score, trig = tracker.process(board)   
         
        if verbose:     
            test_metrics = compute_metrics([pred], [board.y], threshold=0.5)
            print('Accuracy:  %.4f' % test_metrics.accuracy)
            print('Precision: %.4f' % test_metrics.precision)
            print('Recall:    %.4f' % test_metrics.recall)    
            
                
        preds.append(pred)
        targets.append(board.y)
        i = i + 1
    return preds, targets

In [None]:
%%time
# Apply the model
test_preds, test_targets = predict_sample(valid_data_loader, game=game, tracker=tracker,
                                                              verbose=False, n=32) 


In [None]:
threshold = 0.5 
test_metrics = compute_metrics(test_preds, test_targets, threshold=threshold)

print('Faster Test set results with threshold of', threshold)
print('Accuracy:  %.4f' % test_metrics.accuracy)
print('Precision: %.4f' % test_metrics.precision)
print('Recall:    %.4f' % test_metrics.recall)