In [1]:
import numpy as np
import pandas as pd

import os
import pickle
import glob
import matplotlib.pyplot as plt
import time
import keras
import imageio
from scipy import signal
from keras.models import load_model
from sklearn.metrics import mean_squared_error as MSE
from sklearn.metrics import mean_absolute_error as MAE
from sklearn.metrics import explained_variance_score
from sklearn.metrics import roc_curve, auc
from skimage.transform import resize
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes



In [2]:
save_figures = False
file_ending = '.png'
model_string = 'NMDA'

dataset_folder = 'data'

models_folder     = os.path.join(dataset_folder, 'Models')
morphology_folder = os.path.join(dataset_folder, 'Morphology')
test_data_folder  = os.path.join(dataset_folder, 'Data_test')
auxiliary_folder  = os.path.join(dataset_folder, 'Auxiliary')

model_filename           = os.path.join(models_folder, 'NMDA_TCN__DWT_7_128_153__model.h5')
model_metadata_filename  = os.path.join(models_folder, 'NMDA_TCN__DWT_7_128_153__training.pickle')
morphology_filename      = os.path.join(morphology_folder, 'morphology_dict.pickle')
NN_illustration_filename = os.path.join(auxiliary_folder, 'TCN_7_layers.png')
test_files               = sorted(glob.glob(os.path.join(test_data_folder, '*_128_simulationRuns*_6_secDuration_*')))

print('-----------------------------------------------')
print('finding files: model, morphology and test data')
print('-----------------------------------------------')
print('model found          : "%s"' %(model_filename.split('/')[-1]))
print('model metadata found : "%s"' %(model_metadata_filename.split('/')[-1]))
print('morphology found     : "%s"' %(morphology_filename.split('/')[-1]))
print('number of test files is %d' %(len(test_files)))
print('-----------------------------------------------')

-----------------------------------------------
finding files: model, morphology and test data
-----------------------------------------------
model found          : "NMDA_TCN__DWT_7_128_153__model.h5"
model metadata found : "NMDA_TCN__DWT_7_128_153__training.pickle"
morphology found     : "morphology_dict.pickle"
number of test files is 12
-----------------------------------------------


In [3]:
experiment_dict = pickle.load(open(test_files[0], "rb" ), encoding='latin1')
parse_multiple_sim_experiment_files(experiment_dict)

NameError: name 'parse_multiple_sim_experiment_files' is not defined

In [4]:
##%% helper functions

def bin2dict(bin_spikes_matrix):
    spike_row_inds, spike_times = np.nonzero(bin_spikes_matrix)
    row_inds_spike_times_map = {}
    for row_ind, syn_time in zip(spike_row_inds,spike_times):
        if row_ind in row_inds_spike_times_map.keys():
            row_inds_spike_times_map[row_ind].append(syn_time)
        else:
            row_inds_spike_times_map[row_ind] = [syn_time]

    return row_inds_spike_times_map


def dict2bin(row_inds_spike_times_map, num_segments, sim_duration_ms):
    
    bin_spikes_matrix = np.zeros((num_segments, sim_duration_ms), dtype='bool')
    for row_ind in row_inds_spike_times_map.keys():
        for spike_time in row_inds_spike_times_map[row_ind]:
            bin_spikes_matrix[row_ind,spike_time] = 1.0
    
    return bin_spikes_matrix

def parse_sim_experiment_file(sim_experiment_file):
    
    print('-----------------------------------------------------------------')
    print("loading file: '" + sim_experiment_file.split("\\")[-1] + "'")
    loading_start_time = time.time()
    experiment_dict = pickle.load(open(sim_experiment_file, "rb" ), encoding='latin1')
    
    # gather params
    num_simulations = len(experiment_dict['Results']['listOfSingleSimulationDicts'])
    num_segments    = len(experiment_dict['Params']['allSegmentsType'])
    sim_duration_ms = experiment_dict['Params']['totalSimDurationInSec'] * 1000
    num_ex_synapses  = num_segments
    num_inh_synapses = num_segments
    num_synapses = num_ex_synapses + num_inh_synapses
    
    # collect X, y_spike, y_soma
    X = np.zeros((num_synapses,sim_duration_ms,num_simulations), dtype='bool')
    y_spike = np.zeros((sim_duration_ms,num_simulations), dtype=np.float32)
    y_soma  = np.zeros((sim_duration_ms,num_simulations), dtype=np.float32)
    for k, sim_dict in enumerate(experiment_dict['Results']['listOfSingleSimulationDicts']):
        X_ex  = dict2bin(sim_dict['exInputSpikeTimes'] , num_segments, sim_duration_ms)
        X_inh = dict2bin(sim_dict['inhInputSpikeTimes'], num_segments, sim_duration_ms)
        X[:,:,k] = np.vstack((X_ex,X_inh))
        spike_times = (sim_dict['outputSpikeTimes'].astype(float) - 0.5).astype(int)
        y_spike[spike_times,k] = 1.0
        y_soma[:,k] = sim_dict['somaVoltageLowRes']

    loading_duration_sec = time.time() - loading_start_time 
    print('loading took %.3f seconds' %(loading_duration_sec))
    print('-----------------------------------------------------------------')

    return X, y_spike, y_soma


def parse_multiple_sim_experiment_files(sim_experiment_files):
    
    for k, sim_experiment_file in enumerate(sim_experiment_files):
        X_curr, y_spike_curr, y_soma_curr = parse_sim_experiment_file(sim_experiment_file)
        
        if k == 0:
            X       = X_curr
            y_spike = y_spike_curr
            y_soma  = y_soma_curr
        else:
            X       = np.dstack((X,X_curr))
            y_spike = np.hstack((y_spike,y_spike_curr))
            y_soma  = np.hstack((y_soma,y_soma_curr))

    return X, y_spike, y_soma


def calc_AUC_at_desired_FP(y_test, y_test_hat, desired_false_positive_rate=0.01):
    fpr, tpr, thresholds = roc_curve(y_test.ravel(), y_test_hat.ravel())

    linear_spaced_FPR = np.linspace(0,1,num=20000)
    linear_spaced_TPR = np.interp(linear_spaced_FPR, fpr, tpr)
    
    desired_fp_ind = min(max(1,np.argmin(abs(linear_spaced_FPR-desired_false_positive_rate))),linear_spaced_TPR.shape[0]-1)
    
    return linear_spaced_TPR[:desired_fp_ind].mean()


def calc_TP_at_desired_FP(y_test, y_test_hat, desired_false_positive_rate=0.0025):
    fpr, tpr, thresholds = roc_curve(y_test.ravel(), y_test_hat.ravel())
    
    desired_fp_ind = np.argmin(abs(fpr-desired_false_positive_rate))
    if desired_fp_ind == 0:
        desired_fp_ind = 1

    return tpr[desired_fp_ind]


def exctract_key_results(y_spikes_GT, y_spikes_hat, y_soma_GT, y_soma_hat, desired_FP_list=[0.0025,0.0100]):
    
    # evaluate the model and save the results
    print('----------------------------------------------------------------------------------------')
    print('calculating key results...')
    
    evaluation_start_time = time.time()
    
    # store results in the hyper param dict and return it
    evaluations_results_dict = {}
    
    for desired_FP in desired_FP_list:
        TP_at_desired_FP  = calc_TP_at_desired_FP(y_spikes_GT, y_spikes_hat, desired_false_positive_rate=desired_FP)
        AUC_at_desired_FP = calc_AUC_at_desired_FP(y_spikes_GT, y_spikes_hat, desired_false_positive_rate=desired_FP)
        print('-----------------------------------')
        print('TP  at %.4f FP rate = %.4f' %(desired_FP, TP_at_desired_FP))
        print('AUC at %.4f FP rate = %.4f' %(desired_FP, AUC_at_desired_FP))
        TP_key_string = 'TP @ %.4f FP' %(desired_FP)
        evaluations_results_dict[TP_key_string] = TP_at_desired_FP
    
        AUC_key_string = 'AUC @ %.4f FP' %(desired_FP)
        evaluations_results_dict[AUC_key_string] = AUC_at_desired_FP
    
    print('--------------------------------------------------')
    fpr, tpr, thresholds = roc_curve(y_spikes_GT.ravel(), y_spikes_hat.ravel())
    AUC_score = auc(fpr, tpr)
    print('AUC = %.4f' %(AUC_score))
    print('--------------------------------------------------')
    
    soma_explained_variance_percent = 100.0*explained_variance_score(y_soma_GT.ravel(),y_soma_hat.ravel())
    soma_RMSE = np.sqrt(MSE(y_soma_GT.ravel(),y_soma_hat.ravel()))
    soma_MAE  = MAE(y_soma_GT.ravel(),y_soma_hat.ravel())
    
    print('--------------------------------------------------')
    print('soma explained_variance percent = %.2f%s' %(soma_explained_variance_percent, '%'))
    print('soma RMSE = %.3f [mV]' %(soma_RMSE))
    print('soma MAE = %.3f [mV]' %(soma_MAE))
    print('--------------------------------------------------')
    
    evaluations_results_dict['AUC'] = AUC_score
    evaluations_results_dict['soma_explained_variance_percent'] = soma_explained_variance_percent
    evaluations_results_dict['soma_RMSE'] = soma_RMSE
    evaluations_results_dict['soma_MAE'] = soma_MAE
    
    evaluation_duration_min = (time.time() - evaluation_start_time)/60
    print('finished evaluation. time took to evaluate results is %.2f minutes' %(evaluation_duration_min))
    print('----------------------------------------------------------------------------------------')
    
    return evaluations_results_dict


def filter_and_exctract_key_results(y_spikes_GT, y_spikes_hat, y_soma_GT, y_soma_hat, desired_FP_list=[0.0025,0.0100], 
                                                                                      ignore_time_at_start_ms=500, 
                                                                                      num_spikes_per_sim=[0,24]):

    time_points_to_eval = np.arange(y_spikes_GT.shape[1]) >= ignore_time_at_start_ms
    simulations_to_eval = np.logical_and((y_spikes_GT.sum(axis=1) >= num_spikes_per_sim[0]),(y_spikes_GT.sum(axis=1) <= num_spikes_per_sim[1]))
    
    print('total amount of simualtions is %d' %(y_spikes_GT.shape[0]))
    print('percent of simulations kept = %.2f%s' %(100*simulations_to_eval.mean(),'%'))
    
    y_spikes_GT_to_eval  = y_spikes_GT[simulations_to_eval,:][:,time_points_to_eval]
    y_spikes_hat_to_eval = y_spikes_hat[simulations_to_eval,:][:,time_points_to_eval]
    y_soma_GT_to_eval    = y_soma_GT[simulations_to_eval,:][:,time_points_to_eval]
    y_soma_hat_to_eval   = y_soma_hat[simulations_to_eval,:][:,time_points_to_eval]
    
    return exctract_key_results(y_spikes_GT_to_eval, y_spikes_hat_to_eval, y_soma_GT_to_eval, y_soma_hat_to_eval, desired_FP_list=desired_FP_list)


def draw_weights(first_layer_weights, selected_filter_ind, set_ylabel, ax00,ax10, ax01,ax11, ax02,ax12):

    time_span, _, num_filters = first_layer_weights.shape
    
    weight_granularity = 0.06
    time_granularity = 20
    max_time_to_show = 40
    
    if use_filtered:
        first_layer_weights_filtered = signal.convolve(first_layer_weights, (1.0/filter_size)*np.ones((filter_size,1,1)), mode='valid')
        first_layer_weights = first_layer_weights_filtered
    
    if first_layer_weights.shape[0] >= max_time_to_show:
        first_layer_weights = first_layer_weights[:max_time_to_show]
    
    num_segments =  639
    basal_cutoff =  262
    tuft_cutoff  = [366,559]

    # invert if needed
    exc_sum = first_layer_weights[:12,:num_segments,selected_filter_ind].sum()
    inh_sum = first_layer_weights[:12,num_segments:,selected_filter_ind].sum()
    exc_minus_inh = exc_sum - inh_sum
    
    if exc_minus_inh < 0:
        first_layer_weights = -first_layer_weights
    
    upper_limit = max(np.percentile(abs(first_layer_weights[:,:,selected_filter_ind]),99.95),np.percentile(abs(first_layer_weights[:,:,selected_filter_ind]),0.05))
    xlims = [-5*int(first_layer_weights.shape[0]/5),0]
    
    ex_basal_syn_inds    = np.arange(basal_cutoff)
    ex_oblique_syn_inds  = np.hstack((np.arange(basal_cutoff,tuft_cutoff[0]),np.arange(tuft_cutoff[1],num_segments)))
    ex_tuft_syn_inds     = np.arange(tuft_cutoff[0],tuft_cutoff[1])
    inh_basal_syn_inds   = num_segments + ex_basal_syn_inds
    inh_oblique_syn_inds = num_segments + ex_oblique_syn_inds
    inh_tuft_syn_inds    = num_segments + ex_tuft_syn_inds
    
    basal_weights_example_filter_ex  = np.fliplr(first_layer_weights[:,ex_basal_syn_inds,selected_filter_ind].T)
    basal_weights_example_filter_inh = np.fliplr(first_layer_weights[:,inh_basal_syn_inds,selected_filter_ind].T)
    basal_weights_example_filter     = np.concatenate((basal_weights_example_filter_ex,basal_weights_example_filter_inh),axis=0)
    oblique_weights_example_filter_ex  = np.fliplr(first_layer_weights[:,ex_oblique_syn_inds,selected_filter_ind].T)
    oblique_weights_example_filter_inh = np.fliplr(first_layer_weights[:,inh_oblique_syn_inds,selected_filter_ind].T)
    oblique_weights_example_filter     = np.concatenate((oblique_weights_example_filter_ex, oblique_weights_example_filter_inh),axis=0)
    tuft_weights_example_filter_ex  = np.fliplr(first_layer_weights[:,ex_tuft_syn_inds,selected_filter_ind].T)
    tuft_weights_example_filter_inh = np.fliplr(first_layer_weights[:,inh_tuft_syn_inds,selected_filter_ind].T)
    tuft_weights_example_filter     = np.concatenate((tuft_weights_example_filter_ex,tuft_weights_example_filter_inh),axis=0)
    
    time_axis = -np.arange(first_layer_weights.shape[0])
    
    ##%% create nice figure
    figure_xlims = xlims
    figure_xlims[0] = max(-40, figure_xlims[0])
    
    ex_basal_syn_inds    = np.arange(basal_cutoff)
    ex_oblique_syn_inds  = np.hstack((np.arange(basal_cutoff,tuft_cutoff[0]),np.arange(tuft_cutoff[1],num_segments)))
    ex_tuft_syn_inds     = np.arange(tuft_cutoff[0],tuft_cutoff[1])
    inh_basal_syn_inds   = num_segments + ex_basal_syn_inds
    inh_oblique_syn_inds = num_segments + ex_oblique_syn_inds
    inh_tuft_syn_inds    = num_segments + ex_tuft_syn_inds
    
    basal_weights_example_filter_ex  = np.fliplr(first_layer_weights[:,ex_basal_syn_inds,selected_filter_ind].T)
    basal_weights_example_filter_inh = np.fliplr(first_layer_weights[:,inh_basal_syn_inds,selected_filter_ind].T)
    basal_weights_example_filter     = np.concatenate((basal_weights_example_filter_ex,basal_weights_example_filter_inh),axis=0)
    oblique_weights_example_filter_ex  = np.fliplr(first_layer_weights[:,ex_oblique_syn_inds,selected_filter_ind].T)
    oblique_weights_example_filter_inh = np.fliplr(first_layer_weights[:,inh_oblique_syn_inds,selected_filter_ind].T)
    oblique_weights_example_filter     = np.concatenate((oblique_weights_example_filter_ex, oblique_weights_example_filter_inh),axis=0)
    tuft_weights_example_filter_ex  = np.fliplr(first_layer_weights[:,ex_tuft_syn_inds,selected_filter_ind].T)
    tuft_weights_example_filter_inh = np.fliplr(first_layer_weights[:,inh_tuft_syn_inds,selected_filter_ind].T)
    tuft_weights_example_filter     = np.concatenate((tuft_weights_example_filter_ex,tuft_weights_example_filter_inh),axis=0)
    
    combined_filter = np.concatenate((basal_weights_example_filter_ex,oblique_weights_example_filter_ex,tuft_weights_example_filter_ex,
                                      basal_weights_example_filter_inh,oblique_weights_example_filter_inh,tuft_weights_example_filter_inh),axis=0)
    
    ##%% draw 2 x 3 (basal,oblique,tuft) matrix
    ex_basal_syn_inds    = np.arange(basal_cutoff)
    ex_oblique_syn_inds  = np.hstack((np.arange(basal_cutoff,tuft_cutoff[0]),np.arange(tuft_cutoff[1],num_segments)))
    ex_tuft_syn_inds     = np.arange(tuft_cutoff[0],tuft_cutoff[1])
    inh_basal_syn_inds   = num_segments + ex_basal_syn_inds
    inh_oblique_syn_inds = num_segments + ex_oblique_syn_inds
    inh_tuft_syn_inds    = num_segments + ex_tuft_syn_inds
    
    basal_weights_example_filter_ex  = np.fliplr(first_layer_weights[:,ex_basal_syn_inds,selected_filter_ind].T)
    basal_weights_example_filter_inh = np.fliplr(first_layer_weights[:,inh_basal_syn_inds,selected_filter_ind].T)
    basal_weights_example_filter     = np.concatenate((basal_weights_example_filter_ex,basal_weights_example_filter_inh),axis=0)
    oblique_weights_example_filter_ex  = np.fliplr(first_layer_weights[:,ex_oblique_syn_inds,selected_filter_ind].T)
    oblique_weights_example_filter_inh = np.fliplr(first_layer_weights[:,inh_oblique_syn_inds,selected_filter_ind].T)
    oblique_weights_example_filter     = np.concatenate((oblique_weights_example_filter_ex, oblique_weights_example_filter_inh),axis=0)
    tuft_weights_example_filter_ex  = np.fliplr(first_layer_weights[:,ex_tuft_syn_inds,selected_filter_ind].T)
    tuft_weights_example_filter_inh = np.fliplr(first_layer_weights[:,inh_tuft_syn_inds,selected_filter_ind].T)
    tuft_weights_example_filter     = np.concatenate((tuft_weights_example_filter_ex,tuft_weights_example_filter_inh),axis=0)
    
    time_axis = -np.arange(first_layer_weights.shape[0])
    
    upper_limit = max(np.percentile(abs(first_layer_weights[:,:,selected_filter_ind]),99.8),np.percentile(abs(first_layer_weights[:,:,selected_filter_ind]),0.2))
    weights_ylims = np.array([-1.08,1.08]) * upper_limit
    
    weight_ticks_lims = (np.array(weights_ylims)/weight_granularity).astype(int) * weight_granularity
    
    ax00.axis('off')
    ax01.axis('off')
    ax02.axis('off')
    
    # basal
    weights_images = ax00.imshow(resize(basal_weights_example_filter, (combined_filter.shape[0], 200)),
                                 cmap='jet', vmin=weights_ylims[0],vmax=weights_ylims[1], aspect='auto')
    ax00.set_xticks([])
    ax00.set_ylabel('Synaptic index', fontsize=xylabels_fontsize)
    for ytick_label in ax00.get_yticklabels():
        ytick_label.set_fontsize(xytick_labels_fontsize)
    
    ax_colorbar = inset_axes(ax00, width="67%", height="6%", loc=2)
    cbar = plt.colorbar(weights_images, cax=ax_colorbar, orientation="horizontal", ticks=[weight_ticks_lims[0], 0, weight_ticks_lims[1]])
    ax_colorbar.xaxis.set_ticks_position("bottom")
    cbar.ax.tick_params(labelsize=xytick_labels_fontsize-2)
    ax00.text(10, 190, 'Weight (A.U)', color='k', fontsize=xytick_labels_fontsize+1, ha='left', va='top', rotation='horizontal')
    
    ax10.plot(time_axis, np.fliplr(basal_weights_example_filter_ex).T , c='r', alpha=all_traces_alpha)
    ax10.plot(time_axis, np.mean(np.fliplr(basal_weights_example_filter_ex).T, axis=1) , c='r', lw=mean_linewidth)
    ax10.plot(time_axis, np.fliplr(basal_weights_example_filter_inh).T, c='b', alpha=all_traces_alpha)
    ax10.plot(time_axis, np.mean(np.fliplr(basal_weights_example_filter_inh).T, axis=1) , c='b', lw=mean_linewidth)
    
    ax10.set_xlim(time_axis.min(),time_axis.max())
    ax10.set_ylim(weights_ylims[0],weights_ylims[1])
    if set_ylabel:
        ax10.set_ylabel('Weight (A.U)', fontsize=xylabels_fontsize)
    
    time_ticks_to_show = np.unique((np.array(time_axis)/time_granularity).astype(int) * time_granularity)
    ax10.set_xticks(time_ticks_to_show)
    
    weights_axis = np.linspace(weights_ylims[0],weights_ylims[1],10)
    weight_ticks_to_show = np.unique((np.array(weights_axis)/weight_granularity).astype(int) * weight_granularity)
    ax10.set_yticks(weight_ticks_to_show)
    
    
    ax10.spines['top'].set_visible(False)
    ax10.spines['right'].set_visible(False)
    
    for ytick_label in ax10.get_yticklabels():
        ytick_label.set_fontsize(xytick_labels_fontsize)
    for xtick_label in ax10.get_xticklabels():
        xtick_label.set_fontsize(xytick_labels_fontsize)
    
    # oblique
    weights_images = ax01.imshow(resize(oblique_weights_example_filter, (combined_filter.shape[0], 200)),
                                 cmap='jet', vmin=weights_ylims[0],vmax=weights_ylims[1], aspect='auto')
    ax01.set_xticks([])
    ax01.set_ylabel('Synaptic index', fontsize=xylabels_fontsize)
    for ytick_label in ax01.get_yticklabels():
        ytick_label.set_fontsize(xytick_labels_fontsize)
    
    ax11.plot(time_axis, np.fliplr(oblique_weights_example_filter_ex).T , c='r', alpha=all_traces_alpha)
    ax11.plot(time_axis, np.mean(np.fliplr(oblique_weights_example_filter_ex).T, axis=1) , c='r', lw=mean_linewidth)
    ax11.plot(time_axis, np.fliplr(oblique_weights_example_filter_inh).T, c='b', alpha=all_traces_alpha)
    ax11.plot(time_axis, np.mean(np.fliplr(oblique_weights_example_filter_inh).T, axis=1) , c='b', lw=mean_linewidth)
    
    ax11.set_xlim(time_axis.min(),time_axis.max())
    ax11.set_xlabel('Time before $t_0$ (ms)', fontsize=xylabels_fontsize);
    ax11.set_ylim(weights_ylims[0],weights_ylims[1])
    
    time_ticks_to_show = np.unique((np.array(time_axis)/time_granularity).astype(int) * time_granularity)
    ax11.set_xticks(time_ticks_to_show)
    
    ax11.spines['top'].set_visible(False)
    ax11.spines['right'].set_visible(False)
    ax11.spines['left'].set_visible(False)
    
    ax11.set_yticks([])
    for ytick_label in ax11.get_yticklabels():
        ytick_label.set_fontsize(xytick_labels_fontsize)
    for xtick_label in ax11.get_xticklabels():
        xtick_label.set_fontsize(xytick_labels_fontsize)
        
    # tuft
    #weights_images = ax02.imshow(tuft_weights_example_filter,cmap='jet', aspect='auto')
    weights_images = ax02.imshow(resize(tuft_weights_example_filter, (combined_filter.shape[0], 200)),
                                 cmap='jet', vmin=weights_ylims[0],vmax=weights_ylims[1], aspect='auto')
    ax02.set_xticks([])
    ax02.set_ylabel('Synaptic index', fontsize=xylabels_fontsize)
    for ytick_label in ax02.get_yticklabels():
        ytick_label.set_fontsize(xytick_labels_fontsize)
    
    ax12.plot(time_axis, np.fliplr(tuft_weights_example_filter_ex).T , c='r', alpha=all_traces_alpha)
    ax12.plot(time_axis, np.mean(np.fliplr(tuft_weights_example_filter_ex).T, axis=1) , c='r', lw=mean_linewidth)
    ax12.plot(time_axis, np.fliplr(tuft_weights_example_filter_inh).T, c='b', alpha=all_traces_alpha)
    ax12.plot(time_axis, np.mean(np.fliplr(tuft_weights_example_filter_inh).T, axis=1) , c='b', lw=mean_linewidth)
    
    ax12.set_xlim(time_axis.min(),time_axis.max())
    ax12.set_ylim(weights_ylims[0],weights_ylims[1])
    ax12.set_yticks([])
    
    time_ticks_to_show = np.unique((np.array(time_axis)/time_granularity).astype(int) * time_granularity)
    ax12.set_xticks(time_ticks_to_show)
    
    ax12.spines['top'].set_visible(False)
    ax12.spines['right'].set_visible(False)
    ax12.spines['left'].set_visible(False)
    
    for ytick_label in ax12.get_yticklabels():
        ytick_label.set_fontsize(xytick_labels_fontsize)
    for xtick_label in ax12.get_xticklabels():
        xtick_label.set_fontsize(xytick_labels_fontsize)

In [None]:
#%% load test dataset

print('------------------------------------------------------------------------------------')
print('loading testing files...')
test_file_loading_start_time = time.time()

v_threshold = -55

# kaggle RAM only permits 6 files to be loaded at the same time, so we select only 6 
# (the 6th file was slected specifically to contain the trace from Figure 2 that was used in the paper)
test_files = test_files[:5] + [test_files[10]]

print('-------------')
print('will be loading the following files:')
[print(x) for x in test_files]
print('-------------')

# load test data
X_test , y_spike_test , y_soma_test  = parse_multiple_sim_experiment_files(test_files)
y_soma_test[y_soma_test > v_threshold] = v_threshold

test_file_loading_duration_min = (time.time() - test_file_loading_start_time)/60
print('time took to load data is %.3f minutes' %(test_file_loading_duration_min))
print('------------------------------------------------------------------------------------')

------------------------------------------------------------------------------------
loading testing files...
-------------
will be loading the following files:
data/Data_test/sim__saved_InputSpikes_DVTs__561_outSpikes__128_simulationRuns__6_secDuration__randomSeed_100520.p
data/Data_test/sim__saved_InputSpikes_DVTs__596_outSpikes__128_simulationRuns__6_secDuration__randomSeed_100511.p
data/Data_test/sim__saved_InputSpikes_DVTs__632_outSpikes__128_simulationRuns__6_secDuration__randomSeed_100525.p
data/Data_test/sim__saved_InputSpikes_DVTs__647_outSpikes__128_simulationRuns__6_secDuration__randomSeed_100519.p
data/Data_test/sim__saved_InputSpikes_DVTs__655_outSpikes__128_simulationRuns__6_secDuration__randomSeed_100518.p
data/Data_test/sim__saved_InputSpikes_DVTs__901_outSpikes__128_simulationRuns__6_secDuration__randomSeed_100523.p
-------------
-----------------------------------------------------------------
loading file: 'data/Data_test/sim__saved_InputSpikes_DVTs__561_outSpikes__1