In this notebook, we evaluate the 1D-CNN ROI finder at different ADC ranges

In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras 
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import load_model
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

In [2]:
wireplane = 'V'
np.random.seed(42)

def filter_signal_ADC(data, clean_data, roi_targets, adc_value, filter_by_max):
    data_wf = []
    clean_wf = []
    roi_tar = []
    for i in range(clean_data.shape[0]):
        if filter_by_max:
            if max(clean_data[i]) < adc_value or sum(clean_data[i]) == 0:
                data_wf.append(data[i])
                clean_wf.append(clean_data[i])
                roi_tar.append(roi_targets[i])
        else:
            if max(clean_data[i]) > adc_value or sum(clean_data[i]) == 0:
                data_wf.append(data[i])
                clean_wf.append(clean_data[i])
                roi_tar.append(roi_targets[i])
    data_wf = np.array(data_wf)
    clean_wf = np.array(clean_wf)
    roi_tar = np.array(roi_tar)
    return data_wf, clean_wf, roi_tar

Load testing set

In [3]:
x_test = np.load('../processed_data/x_test_' + wireplane + '.npy')
y_test = np.load('../processed_data/y_test_ROI_' + wireplane + '.npy')
mean = np.load('../latest_models/mean_' + wireplane + '_nu.npy')
std = np.load('../latest_models/scale_' + wireplane + '_nu.npy')

In [4]:
print(x_test.shape, type(x_test))

(100000, 200) <class 'numpy.ndarray'>


Load trained model

In [5]:
model = load_model('../latest_models/model_' + wireplane + 'plane_nu.h5')

Evaluate model with full testing set

## Below we evaluate testing set with ADC cuts. Note all ADC at the moment are > 3

In [6]:
# need to load clean data test set meant for the AE in order to perform ADC CUTS
full_test_clean = np.load('../processed_data/y_test_AE_' + wireplane + '.npy')

New development

In [7]:
# adc_max = 0 means no max cut is applied
def eval_cut_(full_test_clean, x_test, adc_min, adc_max):
    full_test = x_test
    print(x_test.shape, y_test.shape)

    test_, clean_, y_test_ =  filter_signal_ADC(full_test, full_test_clean, y_test, adc_min, False)
    print(test_.shape, y_test_.shape)
    if adc_max != 0:
        test_, clean_, y_test_ =  filter_signal_ADC(test_, clean_, y_test_, adc_max, True)
    print(test_.shape, y_test_.shape)
    
    x_test_scaled = (test_-mean)/std
    all_infer = model.predict(x_test_scaled, batch_size=4096)
    all_y_test = y_test_
    fpr_keras, tpr_keras, thresholds_keras = roc_curve(all_y_test, all_infer)

    fpr_keras, tpr_keras, thresholds_keras = roc_curve(all_y_test, all_infer)
    plt.figure(figsize=(8, 6))  
    plt.plot(fpr_keras, tpr_keras, label='auc: ' + str(round(auc(fpr_keras, tpr_keras), 3)))
    
    if adc_max == 0:
        plt.title("ROC Curve - Test Dataset Plane " + wireplane + ' (ADC > ' + str(adc_min) + ')' )
    else:
        plt.title("ROC Curve - Test Dataset Plane " + wireplane +  ' (' + str(adc_min) + ' < ADC < ' + str(adc_max) + ')' )
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc='center')
    
    if adc_max == 0:
        filename = '../roc_curves/'+ wireplane + '/plane_' + wireplane + '_roc_adc_gt_'+str(adc_min)+'.png'
    else:
        filename = '../roc_curves/'+ wireplane + '/plane_' + wireplane + '_roc_adc_' + str(adc_min) + '-' + str(adc_max) + '.png'
    plt.savefig(filename, facecolor='w')
    #plt.show()
    plt.close()

In [8]:
for i in range(3, 16):
    eval_cut_(full_test_clean, x_test, i, 0)

(100000, 200) (100000,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(96954, 200) (96954,)
(96954, 200) (96954,)
(100000, 200) (100000,)
(94692, 200) (94692,)
(94692, 200) (94692,)
(100000, 200) (100000,)
(92725, 200) (92725,)
(92725, 200) (92725,)
(100000, 200) (100000,)
(91086, 200) (91086,)
(91086, 200) (91086,)
(100000, 200) (100000,)
(89633, 200) (89633,)
(89633, 200) (89633,)
(100000, 200) (100000,)
(88363, 200) (88363,)
(88363, 200) (88363,)
(100000, 200) (100000,)
(87284, 200) (87284,)
(87284, 200) (87284,)
(100000, 200) (100000,)
(86265, 200) (86265,)
(86265, 200) (86265,)
(100000, 200) (100000,)
(85322, 200) (85322,)
(85322, 200) (85322,)
(100000, 200) (100000,)
(84470, 200) (84470,)
(84470, 200) (84470,)
(100000, 200) (100000,)
(83666, 200) (83666,)
(83666, 200) (83666,)
(100000, 200) (100000,)
(82888, 200) (82888,)
(82888, 200) (82888,)


In [9]:
for i in range(12):
    eval_cut_(full_test_clean, x_test, 3, 4 + i)

(100000, 200) (100000,)
(100000, 200) (100000,)
(50396, 200) (50396,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(53442, 200) (53442,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(55704, 200) (55704,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(57671, 200) (57671,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(59310, 200) (59310,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(60763, 200) (60763,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(62033, 200) (62033,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(63112, 200) (63112,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(64131, 200) (64131,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(65074, 200) (65074,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(65926, 200) (65926,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(66730, 200) (66730,)


In [10]:
for i in range(12):
    eval_cut_(full_test_clean, x_test, 3+i, 15)

(100000, 200) (100000,)
(100000, 200) (100000,)
(66730, 200) (66730,)
(100000, 200) (100000,)
(96954, 200) (96954,)
(63684, 200) (63684,)
(100000, 200) (100000,)
(94692, 200) (94692,)
(61422, 200) (61422,)
(100000, 200) (100000,)
(92725, 200) (92725,)
(59455, 200) (59455,)
(100000, 200) (100000,)
(91086, 200) (91086,)
(57816, 200) (57816,)
(100000, 200) (100000,)
(89633, 200) (89633,)
(56363, 200) (56363,)
(100000, 200) (100000,)
(88363, 200) (88363,)
(55093, 200) (55093,)
(100000, 200) (100000,)
(87284, 200) (87284,)
(54014, 200) (54014,)
(100000, 200) (100000,)
(86265, 200) (86265,)
(52995, 200) (52995,)
(100000, 200) (100000,)
(85322, 200) (85322,)
(52052, 200) (52052,)
(100000, 200) (100000,)
(84470, 200) (84470,)
(51200, 200) (51200,)
(100000, 200) (100000,)
(83666, 200) (83666,)
(50396, 200) (50396,)
