In [1]:
import os
import sys
import random
import time
import argparse
sys.path.append('../')

import pickle
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from matplotlib import pyplot as plt
from sklearn.cluster import DBSCAN

from models.dataset import Dataset
from models.interaction_network import InteractionNetwork
from models.graph import Graph, save_graphs, load_graph

In [2]:
plt.rc('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans'],'size':10})

# Set the font used for MathJax - more on this later
plt.rc('mathtext',**{'default':'regular'})
colors = ['#377eb8', '#ff7f00', '#4daf4a',
          '#f781bf', '#a65628', '#984ea3',
          '#999999', '#e41a1c', '#dede00']

In [3]:
# trained models for the paper 
model_dir = '/tigress/jdezoort/IN_paper_models/'
models = os.listdir('/tigress/jdezoort/IN_paper_models/')
model_paths = [model_dir+model for model in models]

# initial discriminants (won't matter in the end)
discs = {'2': 0.346, '1p5': 0.38, '1': 0.225, '0p75': 0.203, '0p6': 0.2}
device = "cpu"

# track the losses and accuracies per model per training set
# 5 trained models for 2, 1.5, 1, 0.75, and 0.6 GeV graphs 
# 6 train sets for 2, 1.5, 1, 0.75, and 0.6 GeV graphs 
overall_losses = np.zeros((5,6))
overall_losses_std = np.zeros((5,6))
overall_accs = np.zeros((5,6))
overall_accs_std = np.zeros((5,6))
overall_tpr = np.zeros((5,6))
overall_tpr_std = np.zeros((5,6))
overall_tnr = np.zeros((5,6))
overall_tnr_std = np.zeros((5,6))

In [4]:
for i, model in enumerate(model_paths): # loop over each of the 5 trained models 
    print("Model:", model)
    
    # for each model, loop over graphs at each pt_min
    for j, test_pt in enumerate(['0p6']):
        print("Testing pt={} GeV".format(test_pt))
        
        # load up the correct model 
        model_pt = model.split('.')[0].split('_')[-1].strip('GeV')
        if model_pt != '1p5': continue
        disc = discs[model_pt]
        interaction_network = InteractionNetwork(3, 4, 4)
        interaction_network.load_state_dict(torch.load(model, map_location=torch.device('cpu')))
        interaction_network.eval()
        
        # load the graphs, making sure they belong to train_2 
        construction = 'heptrkx_plus'
        graph_indir = "../../hitgraphs_2/{}_{}/".format(construction, test_pt)
        print('Sampling graphs from:', graph_indir)
        graph_files = np.array(os.listdir(graph_indir))
        train_2_mask = [(int(graph_file.split("00000")[1].split("_")[0]) > 2820)
                        for graph_file in graph_files]
        graph_files = graph_files[train_2_mask]
        n_graphs = len(graph_files)
    
        # randomly partition the graphs into 100 validation and 600 testing graphs 
        IDs = np.arange(n_graphs)
        np.random.shuffle(IDs)
        
        
        partition = {'val': graph_files[IDs[500:600]],
                     'test':  graph_files[IDs[1000:1500]]}
        
        # create a validation dataloader 
        params = {'batch_size': 1, 'shuffle': True, 'num_workers': 6}
        if (test_pt=='0p5'): params['num_workers'] = 1
        val_set = Dataset(graph_indir, partition['val']) 
        val_loader = torch.utils.data.DataLoader(val_set, **params)
        
        with torch.no_grad():
            print("...running validation")
           
            # run validation procedure to determine the best discriminant to use
            best_discs=[]
            for data, target in val_loader:
                # grab data and targets 
                X, Ra = data['X'].float().to(device), data['Ra'].float().to(device)
                Ri, Ro = data['Ri'].float().to(device), data['Ro'].float().to(device)
                #pids = target['pid'][0].to(device)
                target = target['y'].to(device)
        
                # inference, loss calculation
                output = interaction_network(X, Ra.float(), Ri.float(), Ro.float())
                test_loss = F.binary_cross_entropy(output.squeeze(2), target,
                                           reduction='mean').item()
                
                # the best discriminant balances true positives and true negatives
                diff, best_disc = 100, 0
                best_tpr, best_tnr = 0, 0
                for disc in np.arange(0, 0.6, 0.001):
                    true_pos = ((target==1).squeeze() & (output>disc).squeeze())
                    true_neg = ((target==0).squeeze() & (output<disc).squeeze())
                    false_pos = ((target==0).squeeze() & (output>disc).squeeze())
                    false_neg = ((target==1).squeeze() & (output<disc).squeeze())
                    N_tp, N_tn = torch.sum(true_pos).item(), torch.sum(true_neg).item()
                    N_fp, N_fn = torch.sum(false_pos).item(), torch.sum(false_neg).item()
                    true_pos_rate = N_tp/(N_tp + N_fn)
                    true_neg_rate = N_tn/(N_tn + N_fp)
                    delta = abs(true_pos_rate - true_neg_rate)
                    if (delta < diff):
                        diff, best_disc = delta, disc
                        
                best_discs.append(best_disc)
                del X
                del Ra
                del Ri
                del Ro
                del target
            disc = np.mean(best_discs)
            print("...validation produces best disc {:.3f}+-{:.3f}".format(disc, np.std(best_discs)))

        # create a test dataloader 
        params = {'batch_size': 1, 'shuffle': True, 'num_workers': 6}
        if (test_pt=='0p5'): params['num_workers'] = 1
        test_set = Dataset(graph_indir, partition['test']) 
        test_loader = torch.utils.data.DataLoader(test_set, **params)
    
        # for each training graph, track loss, accuracy, true positive rate, and true negative rate
        losses, accs = [], []
        true_pos_rates, true_neg_rates = [], []
        with torch.no_grad():
            print("...testing performance")
            
            for data, target in test_loader:
        
                # grab data and targets 
                X, Ra = data['X'].float().to(device), data['Ra'].float().to(device)
                Ri, Ro = data['Ri'].float().to(device), data['Ro'].float().to(device)
                #pids = target['pid'][0].to(device)
                target = target['y'].to(device)
        
                # inference, loss and acc calculations
                output = interaction_network(X, Ra.float(), Ri.float(), Ro.float())
                test_loss = F.binary_cross_entropy(output.squeeze(2), target,
                                           reduction='mean').item()
                accuracy = torch.sum(((target==1).squeeze() &
                                      (output>disc).squeeze()) |
                                     ((target==0).squeeze() &
                                      (output<disc).squeeze())).float()/target.shape[1]
                losses.append(test_loss)
                accs.append(accuracy)
                
                # true positive rate, true negative rate calculations 
                true_pos = ((target==1).squeeze() & (output>disc).squeeze())
                true_neg = ((target==0).squeeze() & (output<disc).squeeze())
                false_pos = ((target==0).squeeze() & (output>disc).squeeze())
                false_neg = ((target==1).squeeze() & (output<disc).squeeze())
                N_tp, N_tn = torch.sum(true_pos).item(), torch.sum(true_neg).item()
                N_fp, N_fn = torch.sum(false_pos).item(), torch.sum(false_neg).item()
                true_pos_rate = N_tp/(N_tp + N_fn)
                true_neg_rate = N_tn/(N_tn + N_fp)
                
                true_pos_rates.append(true_pos_rate)
                true_neg_rates.append(true_neg_rate)
                
                del X
                del Ra
                del Ri
                del Ro
                del target
                
        # fill the global tables 
        overall_losses[i][j] = np.mean(losses)
        overall_losses_std[i][j] = np.std(losses)
        overall_accs[i][j] = np.mean(accs)
        overall_accs_std[i][j] = np.std(accs)
        overall_tpr[i][j] = np.mean(true_pos_rates)
        overall_tpr_std[i][j] = np.std(true_pos_rates)
        overall_tnr[i][j] = np.mean(true_neg_rates)
        overall_tnr_std[i][j] = np.std(true_neg_rates)
        
        print("\n --> Results:")
        print(" --> overall loss {:.4f}+-{:4f}, overall acc {:.4f}+-{:.4f}"
              .format(np.mean(losses), np.std(losses), np.mean(accs), np.std(accs)))
        print(" --> overall tpr {:.4f}+-{:4f}, overall tnr {:.4f}+-{:.4f}\n"
              .format(np.mean(true_pos_rates), np.std(true_pos_rates), 
                      np.mean(true_neg_rates), np.std(true_neg_rates)))
        

Model: /tigress/jdezoort/IN_paper_models/train1_40hu_heptrkx_plus_epoch30_0p75GeV.pt
Testing pt=0p6 GeV
Model: /tigress/jdezoort/IN_paper_models/train1_40hu_heptrkx_plus_epoch60_2GeV.pt
Testing pt=0p6 GeV
Model: /tigress/jdezoort/IN_paper_models/train1_40hu_heptrkx_plus_epoch48_1GeV.pt
Testing pt=0p6 GeV
Model: /tigress/jdezoort/IN_paper_models/train1_40hu_heptrkx_plus_epoch9_0p6GeV.pt
Testing pt=0p6 GeV
Model: /tigress/jdezoort/IN_paper_models/train1_40hu_heptrkx_plus_epoch60_1p5GeV.pt
Testing pt=0p6 GeV
Sampling graphs from: ../../hitgraphs_2/heptrkx_plus_0p6/
...running validation
...validation produces best disc 0.008+-0.004
...testing performance

 --> Results:
 --> overall loss 0.0360+-0.001792, overall acc 0.9799+-0.0018
 --> overall tpr 0.9792+-0.005641, overall tnr 0.9799+-0.0021



In [5]:
# models in order 2, 0.75, 1