In this notebook, we evaluate the 1D-CNN ROI finder at different ADC ranges

In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras 
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import load_model
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

In [2]:
wireplane = 'U'
np.random.seed(42)

def filter_signal_ADC(data, clean_data, roi_targets, adc_value, filter_by_max):
    data_wf = []
    clean_wf = []
    roi_tar = []
    for i in range(clean_data.shape[0]):
        if filter_by_max:
            if max(clean_data[i]) < adc_value or sum(clean_data[i]) == 0:
                data_wf.append(data[i])
                clean_wf.append(clean_data[i])
                roi_tar.append(roi_targets[i])
        else:
            if max(clean_data[i]) > adc_value or sum(clean_data[i]) == 0:
                data_wf.append(data[i])
                clean_wf.append(clean_data[i])
                roi_tar.append(roi_targets[i])
    data_wf = np.array(data_wf)
    clean_wf = np.array(clean_wf)
    roi_tar = np.array(roi_tar)
    return data_wf, clean_wf, roi_tar

Load testing set

In [3]:
x_test = np.load('../processed_data/x_test_' + wireplane + '.npy')
y_test = np.load('../processed_data/y_test_ROI_' + wireplane + '.npy')
mean = np.load('../latest_models/mean_' + wireplane + '_nu.npy')
std = np.load('../latest_models/scale_' + wireplane + '_nu.npy')

In [4]:
print(x_test.shape, type(x_test))

(100000, 200) <class 'numpy.ndarray'>


Load trained model

In [5]:
model = load_model('../latest_models/model_' + wireplane + 'plane_nu.h5')

## Below we evaluate testing set with ADC cuts. Note all ADC at the moment are > 3

In [6]:
# need to load clean data test set meant for the AE in order to perform ADC CUTS
full_test_clean = np.load('../processed_data/y_test_AE_' + wireplane + '.npy')

New development

In [17]:
# adc_max = 0 means no max cut is applied
def eval_cut_(full_test_clean, x_test, adc_min, adc_max):
    full_test = x_test
    print(x_test.shape, y_test.shape)

    test_, clean_, y_test_ =  filter_signal_ADC(full_test, full_test_clean, y_test, adc_min, False)
    print(test_.shape, y_test_.shape)
    if adc_max != 0:
        test_, clean_, y_test_ =  filter_signal_ADC(test_, clean_, y_test_, adc_max, True)
    print(test_.shape, y_test_.shape)
    
    x_test_scaled = (test_-mean)/std
    all_infer = model.predict(x_test_scaled, batch_size=4096)
    all_y_test = y_test_
    fpr_keras, tpr_keras, thresholds_keras = roc_curve(all_y_test, all_infer)

    fpr_keras, tpr_keras, thresholds_keras = roc_curve(all_y_test, all_infer)
    plt.figure(figsize=(8, 6))  
    plt.plot(fpr_keras, tpr_keras, label='auc: ' + str(round(auc(fpr_keras, tpr_keras), 3)))
    
    if adc_max == 0:
        plt.title("ROC Curve - Test Dataset Plane " + wireplane + ' (ADC > ' + str(adc_min) + ')' )
    else:
        plt.title("ROC Curve - Test Dataset Plane " + wireplane +  ' (' + str(adc_min) + ' < ADC < ' + str(adc_max) + ')' )
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc='center')
    
    if adc_max == 0:
        filename = '../roc_curves/'+ wireplane + '/plane_' + wireplane + '_roc_adc_gt_'+str(adc_min)+'.png'
    else:
        filename = '../roc_curves/plane_' + wireplane + '_roc_adc_' + str(adc_min) + '-' + str(adc_max) + '.png'
    plt.savefig(filename, facecolor='w')
    #plt.show()
    plt.close()

In [11]:
for i in range(3, 16):
    eval_cut_(full_test_clean, x_test, i, 0)

(100000, 200) (100000,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(96875, 200) (96875,)
(96875, 200) (96875,)
(100000, 200) (100000,)
(94438, 200) (94438,)
(94438, 200) (94438,)
(100000, 200) (100000,)
(92520, 200) (92520,)
(92520, 200) (92520,)
(100000, 200) (100000,)
(90830, 200) (90830,)
(90830, 200) (90830,)
(100000, 200) (100000,)
(89352, 200) (89352,)
(89352, 200) (89352,)
(100000, 200) (100000,)
(88136, 200) (88136,)
(88136, 200) (88136,)
(100000, 200) (100000,)
(87020, 200) (87020,)
(87020, 200) (87020,)
(100000, 200) (100000,)
(86007, 200) (86007,)
(86007, 200) (86007,)
(100000, 200) (100000,)
(85114, 200) (85114,)
(85114, 200) (85114,)
(100000, 200) (100000,)
(84265, 200) (84265,)
(84265, 200) (84265,)
(100000, 200) (100000,)
(83521, 200) (83521,)
(83521, 200) (83521,)
(100000, 200) (100000,)
(82757, 200) (82757,)
(82757, 200) (82757,)


In [18]:
for i in range(12):
    eval_cut_(full_test_clean, x_test, 3, 4 + i)

(100000, 200) (100000,)
(100000, 200) (100000,)
(50097, 200) (50097,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(53222, 200) (53222,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(55659, 200) (55659,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(57577, 200) (57577,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(59267, 200) (59267,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(60745, 200) (60745,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(61961, 200) (61961,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(63077, 200) (63077,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(64090, 200) (64090,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(64983, 200) (64983,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(65832, 200) (65832,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(66576, 200) (66576,)


In [19]:
for i in range(12):
    eval_cut_(full_test_clean, x_test, 3+i, 15)

(100000, 200) (100000,)
(100000, 200) (100000,)
(66576, 200) (66576,)
(100000, 200) (100000,)
(96875, 200) (96875,)
(63451, 200) (63451,)
(100000, 200) (100000,)
(94438, 200) (94438,)
(61014, 200) (61014,)
(100000, 200) (100000,)
(92520, 200) (92520,)
(59096, 200) (59096,)
(100000, 200) (100000,)
(90830, 200) (90830,)
(57406, 200) (57406,)
(100000, 200) (100000,)
(89352, 200) (89352,)
(55928, 200) (55928,)
(100000, 200) (100000,)
(88136, 200) (88136,)
(54712, 200) (54712,)
(100000, 200) (100000,)
(87020, 200) (87020,)
(53596, 200) (53596,)
(100000, 200) (100000,)
(86007, 200) (86007,)
(52583, 200) (52583,)
(100000, 200) (100000,)
(85114, 200) (85114,)
(51690, 200) (51690,)
(100000, 200) (100000,)
(84265, 200) (84265,)
(50841, 200) (50841,)
(100000, 200) (100000,)
(83521, 200) (83521,)
(50097, 200) (50097,)
