# Evaluation of Graph Neural Network Tracker 

In [None]:
import os
import numpy as np
import yaml

from nb_utils import (compute_metrics, plot_metrics, draw_sample_xy, draw_sample, load_summaries)

from xtracker.gnn_tracking.ImTracker import ImTracker 
from xtracker.gnn_tracking.TrackingSolver import TrackingSolver
from xtracker.gnn_tracking.TrackingGame import TrackingGame as Game
from xtracker.gnn_tracking.pytorch.NNet import NNetWrapper 
from xtracker.utils import dotdict
from itertools import cycle
from xtracker.datasets import get_data_loaders

In [None]:
# Limit CPU usage on Jupyter
os.environ['OMP_NUM_THREADS'] = '4'

In [None]:
%pwd 

In [None]:
%matplotlib notebook

## Load the data

In [None]:
with open('./../examples/configs/belle2_vtx.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

data_args = dotdict( config['data'] )
train_data_loader, valid_data_loader = get_data_loaders(**data_args,  input_dir=config['global']['graph_dir'])
assert valid_data_loader is not None
assert train_data_loader is not None

valid_data_loader = cycle(valid_data_loader)
train_data_loader = cycle(train_data_loader)


game = Game(train_data_loader, valid_data_loader)

## Length scales for plots

In [None]:
n_phi_sections = 1
feature_scale_r = config['selection']['feature_scale_r']
feature_scale_phi = config['selection']['feature_scale_phi']
feature_scale_z = config['selection']['feature_scale_z']

feature_scale = np.array([feature_scale_r, np.pi / n_phi_sections / feature_scale_phi, feature_scale_z])

## Load the trained model

In [None]:
# Load trained neural net
n1 = NNetWrapper()
checkpoint_dir = os.path.expandvars(config['training']['checkpoint'])
n1.load_checkpoint(checkpoint_dir, 'best.pth.tar')

# Built a tracker
tracker_args = dotdict(config['model'])
tracker = ImTracker(game, n1, tracker_args)

## Training history

In [None]:
summaries = load_summaries(checkpoint_dir)

print('\nTraining summaries:')
summaries

In [None]:
# Find the best epoch
best_idx = summaries.pit_nnet_score.idxmax()
summaries.loc[[best_idx]]

In [None]:
# Outcomment this line to use the mc solution

#tracker =  TrackingSolver(game)

# Evaluate tracker on individual events 

In [None]:
%%time

graph = next(valid_data_loader)  
board = game.getInitBoardFromBatch(graph)


pred, score, trig = tracker.process(board)   
print('score=', score)
print('pred trigger ', trig)
print('true trigger ', board.trig)

draw_sample_xy(board.x * feature_scale, board.edge_index, pred, board.y, cut=0.5, mconly=False, fullonly=False,
              figsize=(9, 9)
)
draw_sample(board.x * feature_scale, board.edge_index, pred, board.y, cut=0.5, mconly=False, fullonly=False, 
           figsize=(9, 6))




In [None]:
for i in range(5):
    graph = next(valid_data_loader)  
    board = game.getInitBoardFromBatch(graph)

    pred, score, trig = tracker.process(board)   
    print('score=', score)

# Evaluate tracker with statistics

In [None]:
def predict_sample(data_loader, game, tracker, verbose=False, n=12):
    
    preds, targets = [], []
    i = 0
   
    for _ in range(n):
      
        graph = next(data_loader)  
        board = game.getInitBoardFromBatch(graph)
        
        pred, score, trig = tracker.process(board)   
         
        if verbose:     
            test_metrics = compute_metrics([pred], [board.y], threshold=0.5)
            print('Accuracy:  %.4f' % test_metrics.accuracy)
            print('Precision: %.4f' % test_metrics.precision)
            print('Recall:    %.4f' % test_metrics.recall)    
            
                
        preds.append(pred)
        targets.append(board.y)
        i = i + 1
    return preds, targets

In [None]:
%%time
# Apply the model
test_preds, test_targets = predict_sample(valid_data_loader, game=game, tracker=tracker,
                                                              verbose=False, n=32) 


In [None]:
threshold = 0.5 
test_metrics = compute_metrics(test_preds, test_targets, threshold=threshold)

print('Faster Test set results with threshold of', threshold)
print('Accuracy:  %.4f' % test_metrics.accuracy)
print('Precision: %.4f' % test_metrics.precision)
print('Recall:    %.4f' % test_metrics.recall)