In this notebook, we evaluate the 1D-CNN ROI finder at different ADC ranges

In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras 
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import load_model
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

In [2]:
wireplane = 'Z'
np.random.seed(42)

def filter_signal_ADC(data, clean_data, roi_targets, adc_value, filter_by_max):
    data_wf = []
    clean_wf = []
    roi_tar = []
    for i in range(clean_data.shape[0]):
        if filter_by_max:
            if max(clean_data[i]) < adc_value or sum(clean_data[i]) == 0:
                data_wf.append(data[i])
                clean_wf.append(clean_data[i])
                roi_tar.append(roi_targets[i])
        else:
            if max(clean_data[i]) > adc_value or sum(clean_data[i]) == 0:
                data_wf.append(data[i])
                clean_wf.append(clean_data[i])
                roi_tar.append(roi_targets[i])
    data_wf = np.array(data_wf)
    clean_wf = np.array(clean_wf)
    roi_tar = np.array(roi_tar)
    return data_wf, clean_wf, roi_tar

Load testing set

In [3]:
x_test = np.load('../processed_data/x_test_' + wireplane + '.npy')
y_test = np.load('../processed_data/y_test_ROI_' + wireplane + '.npy')
mean = np.load('../latest_models/mean_' + wireplane + '_nu.npy')
std = np.load('../latest_models/scale_' + wireplane + '_nu.npy')

In [4]:
print(x_test.shape, type(x_test))

(100000, 200) <class 'numpy.ndarray'>


Load trained model

In [5]:
model = load_model('../latest_models/model_' + wireplane + 'plane_nu.h5')

Evaluate model with full testing set

## Below we evaluate testing set with ADC cuts. Note all ADC at the moment are > 3

In [6]:
# need to load clean data test set meant for the AE in order to perform ADC CUTS
full_test_clean = np.load('../processed_data/y_test_AE_' + wireplane + '.npy')

New development

In [7]:
# adc_max = 0 means no max cut is applied
def eval_cut_(full_test_clean, x_test, adc_min, adc_max):
    full_test = x_test
    print(x_test.shape, y_test.shape)

    test_, clean_, y_test_ =  filter_signal_ADC(full_test, full_test_clean, y_test, adc_min, False)
    print(test_.shape, y_test_.shape)
    if adc_max != 0:
        test_, clean_, y_test_ =  filter_signal_ADC(test_, clean_, y_test_, adc_max, True)
    print(test_.shape, y_test_.shape)
    
    x_test_scaled = (test_-mean)/std
    all_infer = model.predict(x_test_scaled, batch_size=4096)
    all_y_test = y_test_
    fpr_keras, tpr_keras, thresholds_keras = roc_curve(all_y_test, all_infer)

    fpr_keras, tpr_keras, thresholds_keras = roc_curve(all_y_test, all_infer)
    plt.figure(figsize=(8, 6))  
    plt.plot(fpr_keras, tpr_keras, label='auc: ' + str(round(auc(fpr_keras, tpr_keras), 3)))
    
    if adc_max == 0:
        plt.title("ROC Curve - Test Dataset Plane " + wireplane + ' (ADC > ' + str(adc_min) + ')' )
    else:
        plt.title("ROC Curve - Test Dataset Plane " + wireplane +  ' (' + str(adc_min) + ' < ADC < ' + str(adc_max) + ')' )
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc='center')
    
    if adc_max == 0:
        filename = '../roc_curves/'+ wireplane + '/plane_' + wireplane + '_roc_adc_gt_'+str(adc_min)+'.png'
    else:
        filename = '../roc_curves/'+ wireplane + '/plane_' + wireplane + '_roc_adc_' + str(adc_min) + '-' + str(adc_max) + '.png'
    plt.savefig(filename, facecolor='w')
    #plt.show()
    plt.close()

In [8]:
for i in range(3, 16):
    eval_cut_(full_test_clean, x_test, i, 0)

(100000, 200) (100000,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(96230, 200) (96230,)
(96230, 200) (96230,)
(100000, 200) (100000,)
(93169, 200) (93169,)
(93169, 200) (93169,)
(100000, 200) (100000,)
(90669, 200) (90669,)
(90669, 200) (90669,)
(100000, 200) (100000,)
(88534, 200) (88534,)
(88534, 200) (88534,)
(100000, 200) (100000,)
(86753, 200) (86753,)
(86753, 200) (86753,)
(100000, 200) (100000,)
(85312, 200) (85312,)
(85312, 200) (85312,)
(100000, 200) (100000,)
(84016, 200) (84016,)
(84016, 200) (84016,)
(100000, 200) (100000,)
(82873, 200) (82873,)
(82873, 200) (82873,)
(100000, 200) (100000,)
(81924, 200) (81924,)
(81924, 200) (81924,)
(100000, 200) (100000,)
(81081, 200) (81081,)
(81081, 200) (81081,)
(100000, 200) (100000,)
(80322, 200) (80322,)
(80322, 200) (80322,)
(100000, 200) (100000,)
(79681, 200) (79681,)
(79681, 200) (79681,)


In [9]:
for i in range(12):
    eval_cut_(full_test_clean, x_test, 3, 4 + i)

(100000, 200) (100000,)
(100000, 200) (100000,)
(50092, 200) (50092,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(53862, 200) (53862,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(56923, 200) (56923,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(59423, 200) (59423,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(61558, 200) (61558,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(63339, 200) (63339,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(64780, 200) (64780,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(66076, 200) (66076,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(67219, 200) (67219,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(68168, 200) (68168,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(69011, 200) (69011,)
(100000, 200) (100000,)
(100000, 200) (100000,)
(69770, 200) (69770,)


In [10]:
for i in range(12):
    eval_cut_(full_test_clean, x_test, 3+i, 15)

(100000, 200) (100000,)
(100000, 200) (100000,)
(69770, 200) (69770,)
(100000, 200) (100000,)
(96230, 200) (96230,)
(66000, 200) (66000,)
(100000, 200) (100000,)
(93169, 200) (93169,)
(62939, 200) (62939,)
(100000, 200) (100000,)
(90669, 200) (90669,)
(60439, 200) (60439,)
(100000, 200) (100000,)
(88534, 200) (88534,)
(58304, 200) (58304,)
(100000, 200) (100000,)
(86753, 200) (86753,)
(56523, 200) (56523,)
(100000, 200) (100000,)
(85312, 200) (85312,)
(55082, 200) (55082,)
(100000, 200) (100000,)
(84016, 200) (84016,)
(53786, 200) (53786,)
(100000, 200) (100000,)
(82873, 200) (82873,)
(52643, 200) (52643,)
(100000, 200) (100000,)
(81924, 200) (81924,)
(51694, 200) (51694,)
(100000, 200) (100000,)
(81081, 200) (81081,)
(50851, 200) (50851,)
(100000, 200) (100000,)
(80322, 200) (80322,)
(50092, 200) (50092,)
