In [1]:
import numpy as np
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import os
import pickle
import torch

import sklearn.metrics as metrics
from scipy.sparse import coo_matrix
from scipy.stats import mode


import matplotlib.cm as cm
from main_script.utils import load_config
from main_script.utils import load_checkpoint

In [2]:
config_file = '/home/tingting/trigger_detection/results/noise_weight/config.pkl'
config = pickle.load(open(config_file, 'rb'))

In [3]:
if config['model_name'] == 'GNN_ip':
    from models.ip_GNN import IpGNN
    model = IpGNN(**config['model'])
if config['model_name'] == 'GNN_vp':
    from models.vp_GNN import VpGNN
    model = VpGNN(**config['model'])
if config['model_name'] == 'GNN_Diffpool' or config['model_name'] == 'GNN_Diffpool_trackinfo':
    from models.GNN_diffpool import GNNDiffpool
    model = GNNDiffpool(**config['model'])
if config['model_name'] == 'GNNPairDiffpool' or config['model_name'] == 'GNNPairDiffpool_affinityloss':
    from models.GNN_pair_diffpool import GNNPairDiffpool
    model = GNNPairDiffpool(**config['model'])
if config['model_name'] == 'Diffpool':
    from models.Diffpool import Diffpool
    model = Diffpool(**config['model'])
if config['model_name'] == 'Dense_GNN_Diffpool':
    from models.DenseGNNDiffpool import DenseGNNDiffpool
    model = DenseGNNDiffpool(**config['model'])

  m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  m.bias.data = init.constant(m.bias.data, 0.0)
  m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  m.bias.data = init.constant(m.bias.data, 0.0)


In [4]:
result_dir = '/home/tingting/trigger_detection/results/noise_weight'
checkpoint_dir = os.path.join(result_dir, 'checkpoints')
checkpoint_file = sorted([os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.startswith('model_checkpoint')])
# checkpoint_file = checkpoint_file[207]
# checkpoint_file = checkpoint_file[330]
checkpoint_file = checkpoint_file[197]
print(checkpoint_file)
model = load_checkpoint(checkpoint_file, model)
print('Successfully reloaded!')

/home/tingting/trigger_detection/results/noise_weight/checkpoints/model_checkpoint_198.pth.tar
Successfully reloaded!


In [20]:
# Test Settings
#test_dir1 = os.path.expandvars('physics_data/nontrigger_event/NN')
test_dir1 = '/home/tingting/tracking/tracking_result_with_noise_In'
test_dir2 = '/home/tingting/tracking/tracking_result_with_noise_D0'
#test_dir1 = os.path.expandvars('physics_data/trigger_event')
test_samples = 20000
batch_size = 1

# Load testing data
from dataloaders.dataloader_for_tracking_result import HitGraphDataset
from dataloaders.dataloader_for_tracking_result import JetsBatchSampler
from torch.utils.data import DataLoader
from torch_geometric.data import Batch

test_dataset = HitGraphDataset(input_dir=test_dir1, n_samples=test_samples, n_input_dir=2, input_dir2=test_dir2, n_samples2=test_samples)
test_batch_sampler = JetsBatchSampler(test_dataset.n_hits, batch_size)
collate_fn = Batch.from_data_list
test_data_loader = DataLoader(test_dataset, batch_sampler=test_batch_sampler, collate_fn=collate_fn)
print('Loaded %g inference samples' % len(test_data_loader.dataset))

Loaded 40000 inference samples


In [21]:
DEVICE = 'cuda'
model.to(DEVICE)
# model_track.to(DEVICE)
test_loss = 0

preds = []
labels = []
event_labels = []
count = 0
for batch in test_data_loader:
    count += 1
    batch_size = batch.batch[-1]+1
    hits = batch.x#hits_info. : (N1+ N2 + ...)*input_features 
    edge_index = batch.edge_index #GNN-edge
    e = batch.e
    trig = batch.trigger #trigger or not, 0 for NN, 1 for trigger, 2 for ND
    event_labels.append(trig.long())
    trig = (trig == 1)
    trig = trig.to(DEVICE, torch.float)
    labels.append(trig.long().cpu().numpy())
    

    # One Train step on the current batch
    hits = hits.to(DEVICE, torch.float)
    edge_index = edge_index.to(DEVICE, torch.long)
    e = e.to(DEVICE, torch.float)
    batch.batch = batch.batch.to(DEVICE, torch.long)



    with torch.set_grad_enabled(False):
        ip_pred = model(hits, edge_index, batch.batch, batch_size, e)
        ip_pred = ip_pred.squeeze(1)
        # ip_pred_track = ip_pred_track.squeeze(1)
        preds.append((ip_pred).cpu().data.numpy())     
        loss = model.get_loss(ip_pred, trig)
    test_loss += loss.item() * batch_size


labels = np.hstack(labels)
preds = np.hstack(preds)
event_labels = np.hstack(event_labels)
result = {'prec': metrics.precision_score(labels, preds>0),
            'recall': metrics.recall_score(labels, preds>0),
            'acc': metrics.accuracy_score(labels, preds>0),
            'F1': metrics.f1_score(labels, preds>0)}
print(result)

{'prec': 0.6993868064484225, 'recall': 0.70715, 'acc': 0.7016, 'F1': 0.7032469792650788}


In [20]:
model

Diffpool(
  (input_network): Sequential(
    (0): Linear(in_features=4, out_features=60, bias=True)
    (1): Tanh()
  )
  (loss_func): BCELoss()
  (ip_pred_diffpool): SoftPoolingGcnEncoder(
    (conv_first): GraphConv()
    (conv_block): ModuleList()
    (conv_last): GraphConv()
    (act): Tanh()
    (pred_model): Sequential(
      (0): Linear(in_features=48, out_features=50, bias=True)
      (1): Tanh()
      (2): Linear(in_features=50, out_features=1, bias=True)
    )
    (conv_first_after_pool): ModuleList(
      (0): GraphConv()
      (1): GraphConv()
    )
    (conv_block_after_pool): ModuleList(
      (0): ModuleList()
      (1): ModuleList()
    )
    (conv_last_after_pool): ModuleList(
      (0): GraphConv()
      (1): GraphConv()
    )
    (assign_conv_first_modules): ModuleList(
      (0): GraphConv()
      (1): GraphConv()
    )
    (assign_conv_block_modules): ModuleList(
      (0): ModuleList()
      (1): ModuleList()
    )
    (assign_conv_last_modules): ModuleList(
     