# Evaluation of Graph Neural Network Trigger

In [None]:
import os
import numpy as np
import yaml

from nb_utils import (compute_metrics, plot_metrics, draw_sample_xy, draw_sample, load_summaries)

from xtracker.utils import dotdict
from itertools import cycle
from xtracker.datasets import get_data_loaders
from xtracker.gnn_tracking.TrackingGame import TrackingGame as Game
from xtracker.gnn_tracking.pytorch.NNet import NNetWrapper
from xtracker.gnn_tracking.pytorch.NNet import NNetWrapperTrigger
from xtracker.gnn_tracking.ImTracker import ImTracker
from xtracker.gnn_tracking.TrackingSolver import TrackingSolver

In [None]:
# Limit CPU usage on Jupyter
os.environ['OMP_NUM_THREADS'] = '4'

In [None]:
%pwd 

In [None]:
%matplotlib notebook

## Load the trained model

In [None]:
def getFeatureScales(config_path):
    
    config_path = os.path.expandvars(config_path)
    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    n_phi_sections = 1
    feature_scale_r = config['selection']['feature_scale_r']
    feature_scale_phi = config['selection']['feature_scale_phi']
    feature_scale_z = config['selection']['feature_scale_z']

    feature_scale = np.array([feature_scale_r, np.pi / n_phi_sections / feature_scale_phi, feature_scale_z])
    return feature_scale 

def getGame(config_path):
    
    config_path = os.path.expandvars(config_path)
    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    data_args = dotdict( config['data'] )
        
    train_data_loader, valid_data_loader = get_data_loaders(**data_args,  input_dir=config['global']['graph_dir'])
    assert valid_data_loader is not None
    assert train_data_loader is not None

    valid_data_loader = cycle(valid_data_loader)
    train_data_loader = cycle(train_data_loader)

    game = Game(train_data_loader, valid_data_loader)
    
    return game 

def setupTracker(game, config_path):
    
    config_path_tracker = os.path.expandvars(config_path)
    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    # Load trained neural net
    n1 = NNetWrapper()
    checkpoint_dir = os.path.expandvars(config['training']['checkpoint'])
    n1.load_checkpoint(checkpoint_dir, 'best.pth.tar')

    # Built a tracker
    tracker_args = dotdict(config['model'])
    tracker = ImTracker(game, n1, tracker_args)

    return tracker


def setupTrigger(config_path):
    
    config_path = os.path.expandvars(config_path)
    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    
    # Load neural net
    trigger = NNetWrapperTrigger()
    checkpoint_dir = os.path.expandvars(config['training']['checkpoint'])
    trigger.load_checkpoint(checkpoint_dir, 'best.pth.tar')
    
    return trigger

In [None]:
config_path_tracker = './../examples/configs/belle2_vtx.yaml'
config_path_trigger = './../examples/configs/belle2_vtx_trigger.yaml'

game = getGame(config_path_trigger)
feature_scale = getFeatureScales(config_path_trigger)
tracker = setupTracker(game, config_path_tracker)
trigger = setupTrigger(config_path_trigger)

# Evaluate tracker and tracker on individual events 

In [None]:
%%time

board = game.getInitBoard(training=False)

pred, score, _ = tracker.process(board)   
x = tracker.embed_hits(board)
pred_trig = trigger.predict(x)[0, 0]
true_trig = board.trig.numpy()[0,0]

print('score=', score)
print('pred trigger ', pred_trig)
print('true trigger ', true_trig)

draw_sample_xy(board.x * feature_scale, board.edge_index, pred, board.y, cut=0.5, mconly=False, fullonly=False,
              figsize=(9, 9)
)
draw_sample(board.x * feature_scale, board.edge_index, pred, board.y, cut=0.5, mconly=False, fullonly=False, 
           figsize=(9, 6))


# Evaluate trigger with statistics

In [None]:
def predict_sample(n=200):
    preds, targets = [], []
    for _ in range(n):  
        board = game.getInitBoard(training=False)
        x = tracker.embed_hits(board)
        pred_trig = trigger.predict(x)[0, 0]
        true_trig = board.trig.numpy()[0, 0]
              
        preds.append(pred_trig)
        targets.append(true_trig)
        
    return preds, targets

In [None]:
%%time
# Apply the model
test_preds, test_targets = predict_sample(n=200) 

test_metrics = compute_metrics([test_preds], [test_targets], threshold=0.5)
print('Accuracy:  %.4f' % test_metrics.accuracy)
print('Precision: %.4f' % test_metrics.precision)
print('Recall:    %.4f' % test_metrics.recall)    

In [None]:
plot_metrics([test_preds], [test_targets], test_metrics, figsize=(9, 6))