In [2]:
import os
import sys
import random
import time
import argparse
sys.path.append('../')

import pickle
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from sklearn.cluster import DBSCAN

from models.dataset import Dataset
from models.interaction_network import InteractionNetwork
from models.graph import Graph, save_graphs, load_graph

In [3]:
def calc_dphi(phi1, phi2):
    """Computes phi2-phi1 given in range [-pi,pi]"""
    dphi = phi2 - phi1
    dphi[dphi > np.pi] -= 2*np.pi
    dphi[dphi < -np.pi] += 2*np.pi
    return dphi

def calc_eta(r, z):
    theta = np.arctan2(r, z)
    return -1. * np.log(np.tan(theta / 2.))

In [4]:
pt_cut = 1
use_cuda = False
construction = 'heptrkx_plus'
epoch = 48
disc = 0.3
model = "../trained_models/train1_40hu_{}_epoch{}_{}GeV.pt".format(construction, epoch, pt_cut)
print("model={0}".format(model))

# load in test graph paths
graph_indir = "../../hitgraphs_2/{}_{}/".format(construction, pt_cut)
graph_files = np.array(os.listdir(graph_indir))

device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {'batch_size': 1}
test_kwargs = {'batch_size': 1}

n_graphs = len(graph_files)
IDs = np.arange(n_graphs)
#np.random.shuffle(IDs)
partition = {'test':  graph_files[IDs[1000:1400]]}

params = {'batch_size': 1, 'shuffle': True, 'num_workers': 6}
test_set = Dataset(graph_indir, partition['test'])
test_loader = torch.utils.data.DataLoader(test_set, **params)

interaction_network = InteractionNetwork(3, 4, 4)
interaction_network.load_state_dict(torch.load(model, map_location=torch.device('cpu')))
interaction_network.eval()

model=../trained_models/train1_40hu_heptrkx_plus_epoch48_1GeV.pt


InteractionNetwork(
  (phi_R1): RelationalModel(
    (layers): Sequential(
      (0): Linear(in_features=10, out_features=40, bias=True)
      (1): ReLU()
      (2): Linear(in_features=40, out_features=40, bias=True)
      (3): ReLU()
      (4): Linear(in_features=40, out_features=4, bias=True)
    )
  )
  (phi_R2): RelationalModel(
    (layers): Sequential(
      (0): Linear(in_features=10, out_features=40, bias=True)
      (1): ReLU()
      (2): Linear(in_features=40, out_features=40, bias=True)
      (3): ReLU()
      (4): Linear(in_features=40, out_features=1, bias=True)
    )
  )
  (phi_O): ObjectModel(
    (layers): Sequential(
      (0): Linear(in_features=7, out_features=40, bias=True)
      (1): ReLU()
      (2): Linear(in_features=40, out_features=40, bias=True)
      (3): ReLU()
      (4): Linear(in_features=40, out_features=3, bias=True)
    )
  )
)

In [None]:
good_eff, tight_eff = [], []
with torch.no_grad():
    counter = 0
    for data, target in test_loader:
        
        # grab data and targets 
        X, Ra = data['X'].float().to(device), data['Ra'].float().to(device)
        Ri, Ro = data['Ri'].float().to(device), data['Ro'].float().to(device)
        pids = target['pid'][0].to(device)
        target = target['y'].to(device)
        
        # inference, loss calculation
        output = interaction_network(X, Ra.float(), Ri.float(), Ro.float())
        test_loss = F.binary_cross_entropy(output.squeeze(2), target,
                                           reduction='mean').item()
        accuracy = torch.sum(((target==1).squeeze() &
                              (output>disc).squeeze()) |
                             ((target==0).squeeze() &
                              (output<disc).squeeze())).float()/target.shape[1]

        print('loss={}, acc={}'.format(test_loss, accuracy))

        # shape up output, calculate true edges 
        output = output.squeeze()
        true_edges = (output>disc).squeeze()
        probs = torch.transpose(torch.unsqueeze(output, dim=0), 0, 1)
        
        # count hits per pid in each event, add indices to hits
        n_particles = len(np.unique(pids))
        pid_counts = {p.item(): torch.sum(pids==p).item() for p in pids}
        pid_label_map = {p.item(): -5 for p in pids}
        hit_idx = torch.unsqueeze(torch.arange(X[0].shape[1]), dim=0)
        X = torch.cat((hit_idx.float(), X[0]), dim=0)
    
        # separate segments into incoming and outgoing hit positions 
        feats_o = torch.matmul(Ro[0], torch.transpose(X, dim0=0, dim1=1))
        feats_o = torch.cat((feats_o, probs), dim=1)
        feats_o = feats_o[true_edges]
        feats_i = torch.matmul(Ri[0], torch.transpose(X, dim0=0, dim1=1))
        feats_i = torch.cat((feats_i, probs), dim=1)
        feats_i = feats_i[true_edges]
        
        # geometric quantities --> distance calculation 
        r_o, phi_o, z_o = feats_o[:,1], feats_o[:,2], feats_o[:,3]
        eta_o = calc_eta(r_o, z_o)
        r_i, phi_i, z_i = feats_i[:,1], feats_i[:,2], feats_i[:,3]
        eta_i = calc_eta(r_i, z_i)
        dphi, deta = calc_dphi(phi_o, phi_i), eta_i-eta_o
        distances = torch.sqrt((r_i*torch.cos(np.pi*phi_i) - r_o*torch.cos(np.pi*phi_o))**2 +
                               (r_i*torch.sin(np.pi*phi_i) - r_i*torch.sin(np.pi*phi_i))**2 +
                               (z_i-z_o)**2)
        #distances = torch.sqrt(dphi**2 + deta**2)
        
        dist_matrix = 10*torch.ones(X.shape[1], X.shape[1])
        for h in range(len(feats_i)):
            dist_matrix[int(feats_o[h][0])][int(feats_i[h][0])] = distances[h]
            
        # run DBScan
        eps, min_pts = 0.38, 1
        clustering = DBSCAN(eps=eps, min_samples=min_pts,
                            metric='precomputed').fit(dist_matrix)
        labels = clustering.labels_
        
        # count reconstructed particles from hit clusters 
        good_clusters, tight_clusters = 0, 0
        for label in np.unique(labels):  
            if label==-1: continue # ignore noise 
                
            # grab pids corresponding to hit cluster labels
            label_pids = pids[labels==label]
            main_pid = np.bincount(label_pids).argmax() # most frequent pid in cluster
            
            # fraction of hits with the most common pid 
            hit_fraction = len(label_pids[label_pids==main_pid])/len(label_pids)
            
            # check if pid has been mapped to cluster before 
            previous_label = pid_label_map[main_pid]
            if (previous_label > -1):
                print('label', label, 'mapped to duplicate particle')
                hit_fraction = 0
                
            pid_label_map[main_pid] = label
            if hit_fraction > 0.99:
                good_clusters+=1
                pid = label_pids[0].item()
                if pid_counts[pid] == len(label_pids):
                    tight_clusters += 1

        good_eff.append(good_clusters/n_particles)
        tight_eff.append(tight_clusters/n_particles)
        print("GOOD: {}/{}={}".format(good_clusters, n_particles, good_clusters/n_particles))
        print("TIGHT: {}/{}={}".format(tight_clusters, n_particles, tight_clusters/n_particles))
        
        counter += 1
        if (counter > 10): break

print("Good Eff: {}+/-{}", np.mean(good_eff), np.std(good_eff))
print("Tight Eff: {}+/-{}", np.mean(tight_eff), np.std(tight_eff))


loss=0.007400191389024258, acc=0.9971281290054321
label 74 mapped to duplicate particle
label 144 mapped to duplicate particle
label 153 mapped to duplicate particle
label 165 mapped to duplicate particle
label 215 mapped to duplicate particle
label 276 mapped to duplicate particle
label 337 mapped to duplicate particle
label 359 mapped to duplicate particle
label 389 mapped to duplicate particle
label 421 mapped to duplicate particle
label 426 mapped to duplicate particle
label 485 mapped to duplicate particle
label 522 mapped to duplicate particle
label 540 mapped to duplicate particle
label 588 mapped to duplicate particle
label 591 mapped to duplicate particle
label 650 mapped to duplicate particle
label 692 mapped to duplicate particle
label 708 mapped to duplicate particle
label 724 mapped to duplicate particle
label 787 mapped to duplicate particle
label 801 mapped to duplicate particle
label 824 mapped to duplicate particle
label 847 mapped to duplicate particle
label 850 mappe

loss=0.0052935294806957245, acc=0.9980215430259705
label 64 mapped to duplicate particle
label 99 mapped to duplicate particle
label 108 mapped to duplicate particle
label 133 mapped to duplicate particle
label 243 mapped to duplicate particle
label 250 mapped to duplicate particle
label 330 mapped to duplicate particle
label 332 mapped to duplicate particle
label 367 mapped to duplicate particle
label 376 mapped to duplicate particle
label 435 mapped to duplicate particle
label 457 mapped to duplicate particle
label 462 mapped to duplicate particle
label 499 mapped to duplicate particle
label 536 mapped to duplicate particle
label 572 mapped to duplicate particle
label 592 mapped to duplicate particle
label 648 mapped to duplicate particle
label 740 mapped to duplicate particle
label 816 mapped to duplicate particle
label 856 mapped to duplicate particle
label 889 mapped to duplicate particle
label 932 mapped to duplicate particle
label 1058 mapped to duplicate particle
label 1079 map