In [1]:
import torch
import argparse
import ast
import array
import numpy as np
import os
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import roc_curve, accuracy_score, auc
from sklearn.cluster import AgglomerativeClustering

import ROOT
from ROOT import TFile
from ROOT import vector

from dataset import WaveformSliceDataset
from model import DnnModel, RnnModel
from utility import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def process_waveform(

    #- Read the input file in ROOT and get key variables (waveform sample i, time, tag)
    infile='source_test.root', 
    intree='sim', 
    outfile='pred', 
    mfile='dnn.pth', 
    start=0, 
    length=-1, 
    sign=1, 
    threshold=0.95, 
    with_tag=True, 
    clustering_cut=-1.):

    f = TFile(infile)
    t = f.Get(intree)
    wf = ROOT.std.vector['double'](0)
    time = ROOT.std.vector['double'](0)
    tag = ROOT.std.vector['int'](0)
    t.SetBranchAddress('wf_i', wf)
    if with_tag:
        t.SetBranchAddress('time', time)
        t.SetBranchAddress('tag', tag)

    #- Set up the output file and structure 
    os.makedirs('results', exist_ok=True)
    fout = TFile(outfile, 'recreate')
    tout = t.CloneTree(0)
    ncount = array.array('i', [-1])
    xcount = vector['double'](0)
    ncount_cls = array.array('i', [-1])
    xcount_cls = vector['double'](0)
    tout.Branch('ncount', ncount, 'ncount/I')
    tout.Branch('xcount', xcount)
    tout.Branch('ncount_cls', ncount_cls, 'ncount_cls/I')
    tout.Branch('xcount_cls', xcount_cls)

    wf_slice_dataset = WaveformSliceDataset(infile, intree, start, length, nleft=5, nright=9, with_tag=with_tag, tag_method='default', sign=sign, debug=True)
    evtno2idx_dict = wf_slice_dataset.GetEventNoToIndexDict()

    model = DnnModel(embedding=True)
    model.load_state_dict(torch.load(mfile))
    model.eval()

    num_processed = 0
    labels = []
    predictions = []
    for evtno in tqdm(evtno2idx_dict, desc='Making predictions'):
        t.GetEntry(evtno)
        ncount[0] = 0
        xcount.clear()
        ncount_cls[0] = 0
        xcount_cls.clear()
        # det_time = []
        truth_time = None
        truth_tag = None
        if with_tag:
            truth_time = [t for i, t in enumerate(time) if tag[i] > 0]
            truth_tag = [t for i, t in enumerate(tag) if tag[i] > 0]

        index_list = evtno2idx_dict[evtno]
        for idx in index_list:
            with torch.no_grad():
                x, y = wf_slice_dataset[idx]
                x = x.to(device)
                pred, _ = model(x)
                if pred.item() > threshold:
                    ncount[0] += 1
                    xcount.push_back(wf_slice_dataset.GetWaveformSliceTime(idx))
                    # det_time.append(wf_slice_dataset.GetWaveformSliceTime(idx))
                if with_tag:
                    labels.append(y.item())
                    predictions.append(pred.item())
                # if idx == 0:
                #     writer.add_graph(model, x)
        
        if clustering_cut > 0:
            detX = np.array(xcount)
            cluster_id = AgglomerativeClustering(n_clusters=None, distance_threshold=clustering_cut).fit_predict(detX.reshape(-1, 1))
            cluster_map = {k:[] for k in cluster_id}
            for idx, k in enumerate(cluster_id):
                cluster_map[k].append(detX[idx])

            for k in cluster_map:
                xcount_cls.push_back(np.mean(cluster_map[k]))
            ncount_cls[0] = len(cluster_map)
        
        if num_processed < 10:
            plot_waveform(wf, sign=sign, time=xcount, truth_time=truth_time, truth_tag=truth_tag, filename='results/wf_{}.png'.format(num_processed))
        tout.Fill()
        num_processed += 1

    if with_tag:
        fpr, tpr, thr = roc_curve(labels, predictions)
        roc_auc = auc(fpr, tpr)
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc:.3f})')
        ax.plot([0, 1], [0, 1], 'k--')
        ax.set_xlabel('FPR')
        ax.set_ylabel('TPR')
        ax.set_title('ROC Curve')
        plt.savefig('results/roc.png')

        fig, ax = plt.subplots(figsize=(16, 8))
        ax.plot(thr, fpr, color='b', label='FPR')
        ax.set_xlabel('THR')
        ax.set_ylabel('FPR', color='b')
        ax.set_xlim((0, 1))
        ax.semilogy()
        ax.grid(True)
        ax2 = ax.twinx()
        ax2.plot(thr, tpr, color='r', label='TPR')
        ax2.set_ylabel('TPR', color='r')
        plt.savefig('results/thr.png')

        def find_threshold(cut = 0.01):
            index = 0
            for i, f in enumerate(fpr):
                if fpr[i] < cut and fpr[i+1] > cut:
                    index = i
            return thr[index], tpr[index]
        print('FPR = 0.01: {}'.format(find_threshold(0.01)))
        print('FPR = 0.001: {}'.format(find_threshold(0.001)))

    fout.WriteTObject(tout)
    # writer.add_pr_curve('PRC', np.array(labels), np.array(predictions), 0)
    # writer.close()



process_waveform(
        'source_test.root', 'sim', 
        'results/pred.root', 'dnn.pth', 
        0, 100, 1, 0.95, 1, -1.)

: 