## Calculate ERPs

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from os import path
from matplotlib.ticker import ScalarFormatter, MaxNLocator
#%matplotlib notebook
import seaborn as sns

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.visual import *
from brainpipe.statistics import *

from mne.baseline import rescale
from mne.filter import filter_data
import scipy.stats as stats


## Bootstrap statistics

In [2]:
def boot_matrix(z, B):
    """Bootstrap sample
    Returns all bootstrap samples in a matrix, random trials (some repeated)
    Do not change the size of the matrix always len(z) » len(z)*B"""
    n = len(z)  # sample size
    idz = np.random.randint(0, n, size=(B, n))  # indices to pick for all boostrap samples
    return z[idz]

def bootstrap_mean(x, B=10000, alpha=0.05):
    """Bootstrap standard error and (1-alpha)*100% c.i. for the population mean
    Returns bootstrapped standard error and different types of confidence intervals"""
    # 2D array, loop for each time points
    x_boot_wins, x_sd_boot, x_ci_boot = np.array([]),np.array([]),np.array([])
    for i in range(x.shape[1]):
        # Deterministic things
        x_win = x[:,i]
        n = len(x_win)  # sample size
        orig = x_win.mean()  # sample mean
        se_mean = x_win.std()/np.sqrt(n) # standard error of the mean
        qt = stats.t.ppf(q=1 - alpha/2, df=n - 1) # Student quantile
        # Generate boostrap distribution of sample mean
        xboot = boot_matrix(x_win, B=B)
        xboot = xboot[:,:,np.newaxis]
        sampling_distribution = xboot.mean(axis=1)
        # Standard error and sample quantiles
        se_mean_boot = sampling_distribution.std()
        quantile_boot = np.percentile(sampling_distribution, q=(100*alpha/2, 100*(1-alpha/2)))
        # Concatenate all values
        x_boot_wins = np.concatenate((x_boot_wins,xboot), axis=2) if np.size(x_boot_wins) else xboot
        x_sd_boot = np.vstack((x_sd_boot,se_mean_boot)) if np.size(x_sd_boot) else se_mean_boot
        x_ci_boot = np.vstack((x_ci_boot,quantile_boot)) if np.size(x_ci_boot) else quantile_boot
    print(x_boot_wins.shape, x_sd_boot.shape, x_ci_boot.shape)
    return x_boot_wins, x_sd_boot, x_ci_boot

## User variables

In [3]:
# where to find data
st = study('Olfacto')
score = 'Epi' #'Rec'
if score == 'Epi':
    path_data = path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_5s_Good_Bad_EpiScore/')
    save_path = path.join(st.path, 'feature/ERP_Good_Bad_100ms_rescale_filtered_stats_EM/Bootstrap_10000_STER_balanced/')
if score == 'Rec':
    path_data =path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_Good_Bad_RecScore/')
    save_path = path.join(st.path, 'feature/ERP_Good_Bad_100ms_rescale_filtered_stats_Rec/')
# ANALYSIS PARAMETERS
low_pass_filter = 10.
sf = 512.
norm_mode = 'mean' #'ratio' 'mean' 'percent' 
baseline = [973 , 1024] #100ms before odor perception
data_to_use = [973, 1536] #1000ms after odor
time_points = data_to_use[1]-data_to_use[0]
winSample = 10 #in samples = 20ms
alpha = 0.05
n_rep = 1000
minsucc = 5
statistic = 'boot' #'perm'

-> Olfacto loaded


## Plot ERPs for Odor groups

In [5]:
test = False

if test == True:
    n_elec = {'PIRJ' :2}
    subjects = ['PIRJ']
else :
    subjects = ['VACJ','SEMC','PIRJ','LEFC','MICP','CHAF'] 
    n_elec = {
    'CHAF' : 81,
    'VACJ' : 91, 
    'SEMC' : 81,
    'PIRJ' : 62,
    'LEFC' : 160,
    'MICP' : 79,
        }

for su in subjects:
    #Load files
    data_bad = np.load(path.join(path_data, su+'_concat_odor_bad_bipo_new.npz'))
    data_good = np.load(path.join(path_data, su+'_concat_odor_good_bipo_new.npz'))
    data_bad, channels, labels, data_good = data_bad['x'], data_bad['channel'], data_bad['label'], data_good['x']
    
# ==========================  BALANCED CONDITIONS =====================================
    if data_bad.shape[2] > data_good.shape[2]:
        bad_sel = data_bad[:,:,np.random.randint(data_bad.shape[2], size=data_good.shape[2])] #reshape bad_data to fit good_data shape
        good_sel = data_good
    elif data_bad.shape[2] < data_good.shape[2]:
        bad_sel = data_bad
        good_sel = data_good[:,:,np.random.randint(data_good.shape[2], size=data_bad.shape[2])]
    else:
        bad_sel, good_sel = data_bad, data_good
    print ('balanced data : ', bad_sel.shape, good_sel.shape)

# ========================= DATA for 1 elec PREPROCESS ====================================
    for elec in range(n_elec[su]):
        print(elec, bad_sel.shape)
        # Select data for one elec + name :
        data_elec_bad = bad_sel[elec,:,:]
        data_elec_good = good_sel[elec,:,:]
        ntrials = str(data_elec_bad.shape[1])+'/'+ str(data_elec_good.shape[1]) #to be displayed on figures
        print (su, 'Channel : ', channels[elec], 'Label : ', labels[elec], 'N_trials :', ntrials, 
               'Bad shape : ', data_elec_bad.shape, 'Good shape : ', data_elec_good.shape)
        #Filter data for one elec (all trials):
        data_elec_bad = np.array(data_elec_bad, dtype='float64')
        data_elec_good = np.array(data_elec_good, dtype='float64')
        data_bad_to_filter = np.swapaxes(data_elec_bad, 0, 1)
        data_good_to_filter = np.swapaxes(data_elec_good, 0, 1)
        filtered_data_bad = filter_data(data_bad_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        filtered_data_good = filter_data(data_good_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        print ('Size of filtered data bad :', filtered_data_bad.shape, 'filtered data good : ', filtered_data_good.shape,)
        # Normalize the non-averaged data (all trials)
        times = np.arange(filtered_data_bad.shape[1])
        print ('time points : ', times.shape)
        norm_filtered_data_bad = rescale(filtered_data_bad, times=times, baseline=baseline, mode=norm_mode)
        norm_filtered_data_good = rescale(filtered_data_good, times=times, baseline=baseline, mode=norm_mode)
        print ('Size norm & filtered data 0 : ', norm_filtered_data_bad.shape, norm_filtered_data_good.shape,)
        # Range of the data to compute
        data_range = range(data_to_use[0], data_to_use[1])
        # Select a time window in the data
        bad_win = norm_filtered_data_bad[:, data_range]
        good_win = norm_filtered_data_good[:, data_range,]
        # Average the signal on consecutive windows
        n_pts = bad_win.shape[1]
        rmPoints = n_pts % winSample # Points to remove before splitting
        shapeRmPoints = np.arange(n_pts-rmPoints).astype(int) # Number of points for round division
        n_win = int(n_pts / winSample) # Number of segments
        # Split and average data (trials, n_pts)
        bad_split = np.array(np.split(bad_win[:, shapeRmPoints], n_win, axis=1)) # n_win n_trials n_pts
        bad_split = np.mean(bad_split, axis=2).swapaxes(0,1) # n-trials n_win
        good_split = np.array(np.split(good_win[:, shapeRmPoints], n_win, axis=1))
        good_split = np.mean(good_split, axis=2).swapaxes(0,1)
        print('final data to use : ', bad_split.shape, good_split.shape)
        
#================================ BOOTSTRAP AND STATISTICS ====================================        
        # Bootstrap and Welsh t-test
        bad_boot, sd_bad_boot, ci_bad_boot = bootstrap_mean(bad_split, B=10000)
        good_boot, sd_good_boot, ci_good_boot = bootstrap_mean(good_split, B=10000)
        # Data to average
        bad_mean = np.mean(bad_split, axis=0)
        good_mean = np.mean(good_split, axis=0)
        # Create a threshold vector
        th_vals = []
        for t in range(bad_split.shape[1]):
            if ci_bad_boot[t,0] >= good_mean[t]:
                th_vals.append(0.04)
            elif ci_bad_boot[t,1] <= good_mean[t]:
                th_vals.append(0.04)
            else:
                th_vals.append(1)
        print(len(th_vals), th_vals)
          
# ========================== PREPARE PLOTS AND SAVE STATS =========================================
        # Plot and figure parameters
        xfmt = ScalarFormatter(useMathText=True)
        xfmt.set_powerlimits((0,3))
        fig = plt.figure(1,figsize=(9,6))
        rmaxis(plt.gca(), ['right', 'top'])
        title = 'Bootstrap STER for '+su+' '+score+' Good & Bad '+ channels[elec] +' '+labels[elec]+' ('+str(elec)+') ntrials:'+str(ntrials)
        fig.suptitle(title, fontsize=12)
        times_plot = 1000 * np.arange((baseline[0] - baseline[1]), len(shapeRmPoints)-baseline[1]+baseline[0],winSample) / sf
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))
        # Plot the Bootstrap STER with confidence intervals
        plt.plot(times_plot, bad_mean, '-', color='b', label='bad')
        plt.plot(times_plot, good_mean, '-', color='m', label='good')
        plt.legend(loc=0, handletextpad=0.1, frameon=False) 
        plt.xlabel('Time (ms)')
        plt.ylabel('Single trial evoked response (mV)')
        plt.fill_between(times_plot, ci_bad_boot[:,0],ci_bad_boot[:,1], alpha=0.2, color='b')
        plt.fill_between(times_plot, ci_good_boot[:,0],ci_good_boot[:,1], alpha=0.2, color='m')
        addLines(plt.gca(), vLines=[0], vWidth=[2], vColor=['#000000'], hLines=[0], 
                 hColor=['#000000'], hWidth=[2])
        addPval(plt.gca(), th_vals, p=0.05, x=times_plot, y=5, color='orange', lw=2, minsucc=minsucc)
        # Significance criteria to reach      
        pvals = np.ravel(th_vals)
        underp = np.where(pvals < 1)[0]
        pvsplit = np.split(underp, np.where(np.diff(underp) != 1)[0]+1)
        signif = [True for k in pvsplit if len(k) >= minsucc ]
        #Save plots and stats
        if len(signif) >=1:
            name_bad_boot = (save_path +'Significant/data/'+su +'_bad_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_sd_bad_boot = (save_path +'Significant/data/'+su +'_sd_bad_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_ci_bad_boot = (save_path +'Significant/data/'+su +'_ci_bad_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_good_boot = (save_path +'Significant/data/'+su +'_good_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_sd_good_boot = (save_path +'Significant/data/'+su +'_sd_good_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_ci_good_boot = (save_path +'Significant/data/'+su +'_ci_good_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            plot_name = (save_path +'Significant/fig/'+su +'_STER_' + score +'_'+labels[elec]+'_('+str(elec)+').png')
        else:
            name_bad_boot = (save_path +'Not_Significant/data/'+su +'_bad_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_sd_bad_boot = (save_path +'Not_Significant/data/'+su +'_sd_bad_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_ci_bad_boot = (save_path +'Not_Significant/data/'+su +'_ci_bad_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_good_boot = (save_path +'Not_Significant/data/'+su +'_good_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_sd_good_boot = (save_path +'Not_Significant/data/'+su +'_sd_good_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            name_ci_good_boot = (save_path +'Not_Significant/data/'+su +'_ci_good_boot_' + score +'_'+labels[elec]+'_('+str(elec)+').npy')
            plot_name = (save_path +'/Not_Significant/fig/'+su +'_STER_' + score +'_'+labels[elec]+'_('+str(elec)+').png')
        
        np.save(name_bad_boot, bad_boot)
        np.save(name_sd_bad_boot, sd_bad_boot)
        np.save(name_ci_bad_boot, ci_bad_boot)
        np.save(name_good_boot, good_boot)
        np.save(name_sd_good_boot, sd_good_boot)
        np.save(name_ci_good_boot, ci_good_boot)
        plt.savefig(plot_name, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close()


balanced data :  (91, 2560, 13) (91, 2560, 13)
0 (91, 2560, 13)
VACJ Channel :  b3-b2 Label :  mHC-PHG&mHC-Ent N_trials : 13/13 Bad shape :  (2560, 13) Good shape :  (2560, 13)
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Size of filtered data bad : (13, 2560) filtered data good :  (13, 2560)
time points :  (2560,)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Size norm & filtered data 0 :  (13, 2560) (13, 2560)
final data to use :  (13, 56) (13, 56)
(10000, 13, 56) (56, 1) (56, 2)
(10000, 13, 56) (56, 1) (56, 2)
56 [0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.04, 0.04, 0.04, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.04, 0.04, 

In [None]:
print(good_boot.shape,bad_boot.shape)
bad_boot_to_plot = bad_boot.reshape((bad_boot.shape[-1], -1))
good_boot_to_plot = good_boot.reshape((good_boot.shape[-1], -1))
print(good_boot_to_plot.shape,bad_boot_to_plot.shape)
bad_boot_to_plot = bad_boot.reshape((bad_boot.shape[-1], -1)).mean(axis=1)
good_boot_to_plot = good_boot.reshape((good_boot.shape[-1], -1)).mean(axis=1)
print(good_boot_to_plot.shape,bad_boot_to_plot.shape)
