# Code for Ice-Cube 3D CNN

- Oct 29, 2018: This code just makes plots for previously trained CNNs

In [1]:
import sys
import os

import matplotlib.pyplot as plt
import numpy as np
import glob
import pickle
import time

In [2]:
%matplotlib widget
# %matplotlib inline

Useful blog for keras conv3D: http://learnandshare645.blogspot.com/2016/06/3d-cnn-in-keras-action-recognition.html

In [3]:
# keras modules
import tensorflow.keras as keras
from tensorflow.keras import layers, models, optimizers, callbacks  # or tensorflow.keras as keras
import tensorflow as tf
from sklearn.utils import shuffle
from sklearn.metrics import roc_curve, auc, roc_auc_score
from tensorflow.keras.models import load_model



In [4]:
print(tf.__version__)
print(keras.__version__)

1.12.0
2.1.6-tf


## Modules

In [5]:
####################################
### Modules for viewing predicted data ###
def f_plt_hist(ypred):
    ''' Plot a histogram of predictions'''
    print(ypred.shape)
    plt.figure()
    n,bins,patches=plt.hist(ypred, density=None, bins=300)
    plt.xlim(0,1)
    plt.show()
#     plt.close()

def f_get_prediction_info(y_pred,plot=False):
    ''' Function that prints info on predicted data.
        For example, number of zeroes and ones, plots, etc.
    '''
    
    # Condition for the case when the prediction is a 2 column array 
    arr=y_pred[:,1] if len(y_pred.shape)==2 else y_pred
    
    # Print info and plot
    num_total=arr.shape[0]
    num_zeros=arr[arr==0.0].shape[0]
    num_ones=arr[arr==1.0].shape[0]
    print("Pred 0's:\t%s,\tPred 1's:\t%s,Total:\t %s" %(num_zeros,num_ones,num_total))
    if plot:
        ### Plot histogram ###
        f_plt_hist(arr)


def f_get_ydata_and_wts(data_dir,f1,f2):
    ''' Load extracted data from files. Just extracting ydata and weights
    returns : inpy,weights as arrays
    '''

    inpy=np.loadtxt(data_dir+f1)
    wts=np.loadtxt(data_dir+f2)
    
    return inpy,wts
    
def f_plot_learning(history):
    '''Plot learning curves : Accuracy and Validation'''
    fig=plt.figure()
    # Plot training & validation accuracy values
    fig.add_subplot(2,1,1)
    xlim=len(history['acc'])
    
    plt.plot(history['acc'],label='Train',marker='o')
    plt.plot(history['val_acc'],label='Validation',marker='*')
#     plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xticks(np.arange(0,xlim,2))
    
    # Plot loss values
    fig.add_subplot(2,1,2)
    plt.plot(history['loss'],label='Train',marker='o')
    plt.plot(history['val_loss'],label='Validation',marker='*')
#     plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.xticks(np.arange(0,xlim,2))

    plt.legend(loc='best')


def f_plot_roc_curve(fpr,tpr):
    '''
    Module for roc plot and printing AUC
    '''
    plt.figure()
    # plt.plot(fpr,tpr)
#     plt.scatter(fpr,tpr)
    plt.semilogx(fpr, tpr)
  # Zooms
    plt.xlim([10**-7,1.0])
    plt.ylim([0,1.0])
    # y=x line for comparison
    x=np.linspace(0,1,num=500)
    plt.plot(x,x)
#     plt.xscale('log')
#     plt.xlim(1e-10,1e-5)
    plt.show()

    # AUC 
    auc_val = auc(fpr, tpr)
    print("AUC: ",auc_val)


def f_plot_fit(inpy,wts,model_dict,model_loc,plt_roc=True,plt_learning=True,model_desc=True,plt_pred=True):
    '''
    Plot fit results.
    Steps:
    - Read model from .h5 files
    - Describle model structure
    - Plot learning history
    - Read predictions from previous test
    - Get prediction information: number of zeros and ones. Plot of histogram.
    - Roc curve: Compute and plot roc curve. Store results in dictionary.
    '''
    
    model_save_dir=model_loc
    model_name=model_dict['name'] # string for the model
    fname_model,fname_history='model_{0}.h5'.format(model_name),'history_{0}.pickle'.format(model_name)
    
        
    ########################
    ## Read model and history
    
    ### Check if files exist
    assert os.path.exists(model_save_dir+fname_model),"Model not saved"
    assert os.path.exists(model_save_dir+fname_history),"History not saved"
    
    model=load_model(model_save_dir+fname_model)
    with open(model_save_dir+fname_history,'rb') as f:
        history= pickle.load(f)
    
    ########################
    if model_desc: model.summary()
    # Plot learning of trained model
    if plt_learning: f_plot_learning(history)
    
    ########################
    # Get test predictions
    
    test_file_name=model_save_dir+'y-predict_model-'+str(model_name)+'.pred'
    test_y_file_name=model_save_dir+'y-test_model-'+str(model_name)+'.test'
    test_weights_file_name=model_save_dir+'wts-test_model-'+str(model_name)+'.test'    
    print("Using test prediction from previous test",test_file_name)

    assert os.path.exists(test_file_name),"y-predictions not saved"
    y_pred=np.loadtxt(test_file_name)
    ydata=np.loadtxt(test_y_file_name)
    wts=np.loadtxt(test_weights_file_name)
    assert(test_y.shape[0]==y_pred.shape[0]),"Data %s and prediction arrays %s are not of the same size"%(test_y.shape,y_pred.shape)
    
    # Condition for the case when the prediction is a 2column array 
    if len(y_pred.shape)==2: y_pred=y_pred[:,1]
    
    ## Prints details of predictions
    f_get_prediction_info(y_pred,plot=plt_pred)
    
    fpr,tpr,threshold=roc_curve(ydata,y_pred,sample_weight=wts)
    model_dict['fpr'], model_dict['tpr'], model_dict['threshold']=fpr,tpr,threshold
    print(fpr.shape,tpr.shape,threshold.shape)
    if plt_roc: f_plot_roc_curve(fpr,tpr)
    
    model_dict['model'],model_dict['history']=model,history
    
    return model_dict




## Read part of test data

In [6]:
if __name__=='__main__':
    
    model_loc='/global/project/projectdirs/dasrepo/vpa/ice_cube/data_for_cnn/results_data/final_7models_march17_2019/'
    ###Extract data : Only extract y-data and weights for tests, which has been saved already along with the model.
    ### Note!: the test file data is the same for all models, so just pick the first one. ###
    f1,f2='y-test_model-1.test','wts-test_model-1.test'
#     f1,f2='y-test_model-15.test','wts-test_model-15.test'

    inpy,wts=f_get_ydata_and_wts(model_loc,f1,f2)
    test_y,test_wts=inpy[:],wts[:]
    

In [7]:
print(inpy.shape,wts.shape,test_y.shape,test_wts.shape)

(368857,) (368857,) (368857,) (368857,)


## Plot fits

In [10]:
dict_list=[]
# for i in range(1,16):
for i in [1,4]:
    model_dict=dict.fromkeys(['name','description','model','history','fpr','tpr','threshold'],None)
    model_dict['name']=str(i)
    print(i,model_dict)
    model_dict=f_plot_fit(test_y,test_wts,model_dict,model_loc,plt_roc=True,plt_learning=True,model_desc=True,plt_pred=True)
    dict_list.append(model_dict)

1 {'name': '1', 'description': None, 'model': None, 'history': None, 'fpr': None, 'tpr': None, 'threshold': None}
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 10, 20, 60, 1)     0         
_________________________________________________________________
conv3d (Conv3D)              (None, 10, 20, 60, 10)    280       
_________________________________________________________________
max_pooling3d (MaxPooling3D) (None, 5, 10, 30, 10)     0         
_________________________________________________________________
dropout (Dropout)            (None, 5, 10, 30, 10)     0         
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 5, 10, 30, 10)     2710      
_________________________________________________________________
max_pooling3d_1 (MaxPooling3 (None, 2, 5, 15, 10)      0         
____________________________

FigureCanvasNbAgg()

Using test prediction from previous test /global/project/projectdirs/dasrepo/vpa/ice_cube/data_for_cnn/results_data/final_7models_march17_2019/y-predict_model-1.pred
Pred 0's:	2,	Pred 1's:	34,Total:	 368857
(368857,)


FigureCanvasNbAgg()

(363142,) (363142,) (363142,)


FigureCanvasNbAgg()

AUC:  0.9214806038496348
4 {'name': '4', 'description': None, 'model': None, 'history': None, 'fpr': None, 'tpr': None, 'threshold': None}
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 10, 20, 60, 1)     0         
_________________________________________________________________
conv3d_9 (Conv3D)            (None, 10, 20, 60, 40)    3880      
_________________________________________________________________
batch_normalization (BatchNo (None, 10, 20, 60, 40)    160       
_________________________________________________________________
max_pooling3d_9 (MaxPooling3 (None, 10, 10, 20, 40)    0         
_________________________________________________________________
dropout_7 (Dropout)          (None, 10, 10, 20, 40)    0         
_________________________________________________________________
conv3d_10 (Conv3D)           (None, 10, 10, 20, 40)    153640    
___

FigureCanvasNbAgg()

Using test prediction from previous test /global/project/projectdirs/dasrepo/vpa/ice_cube/data_for_cnn/results_data/final_7models_march17_2019/y-predict_model-4.pred
Pred 0's:	0,	Pred 1's:	50,Total:	 368857
(368857,)


FigureCanvasNbAgg()

(365936,) (365936,) (365936,)


FigureCanvasNbAgg()

AUC:  0.9596425513880648


## -----------------------------------------------

In [9]:
def f_plot_roc_curve_all(model_dict):
    '''
    Module for roc plot and printing AUC
    '''
    
    fpr,tpr=model_dict['fpr'],model_dict['tpr']
    auc_val = auc(fpr, tpr)

#     plt.scatter(fpr,tpr,marker='o',label='Model='+str(model_dict['name'])+"\nAUC: "+str(auc_val))
    plt.plot(fpr,tpr,label='Model='+str(model_dict['name'])+"\nAUC: "+str(auc_val))
    plt.xscale('log')
  # Zooms
    plt.xlim([10**-8,1.0])
    plt.ylim([0,1.0])
    ### Plot Physics benchmark point
#   ##Physics benchmark [1.44576*10**-6 FPR, 0.04302 TPR]
    phys_x,phys_y=1.44576*10**-6,0.04302
    plt.scatter(phys_x,phys_y,color='red')
#     plt.xscale('log')
#     plt.xlim(1e-10,1e-5)
#     plt.show()
    plt.legend()
    plt.xlabel("False postive rate (1- Background rejection)")
    plt.ylabel("True postive rate (Signal Efficiency)")
plt.figure()
# f_plot_roc_curve_all(model)
for i in dict_list:
    if int(i['name']) in [1,3,7]:
        print(i['name'])
        f_plot_roc_curve_all(i)

FigureCanvasNbAgg()

2
3
4
5


## Trying widgets

In [None]:
# interact(f_text_plot, a=RadioButtons(
#     options=['1','2','3'],
#     value='3',
#     description='Plot type:',
#     disabled=False
# ));

