## Calculate ERPs

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from os import path
from matplotlib.ticker import ScalarFormatter, MaxNLocator
%matplotlib notebook
import seaborn as sns

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.visual import *
from brainpipe.statistics import *

from mne.baseline import rescale
from mne.filter import filter_data
import scipy.stats as stats
from scipy.stats import *
# from mne.stats import *
import time

## Bootstrap statistics

In [18]:
def boot_matrix(z, B):
    """Bootstrap sample
    Returns all bootstrap samples in a matrix, random trials (some repeated)
    Do not change the size of the matrix always len(z) » B*len(z)"""
    n = len(z)  # sample size
    idz = np.random.randint(0, n, size=(B, n))  # indices to pick for all boostrap samples
    print(idz[0], len(idz[0]))
    return z[idz]

def bootstrap_mean(x, B=1000, alpha=0.05, plot=False):
    """Bootstrap standard error and (1-alpha)*100% c.i. for the population mean
    Returns bootstrapped standard error and different types of confidence intervals"""
    # Deterministic things
    n = len(x)  # sample size
    orig = x.mean()  # sample mean
    se_mean = x.std()/np.sqrt(n) # standard error of the mean
    qt = stats.t.ppf(q=1 - alpha/2, df=n - 1) # Student quantile
    # Generate boostrap distribution of sample mean
    xboot = boot_matrix(x, B=B)
    sampling_distribution = xboot.mean(axis=1)
    # Standard error and sample quantiles
    se_mean_boot = sampling_distribution.std()
    quantile_boot = np.percentile(sampling_distribution, q=(100*alpha/2, 100*(1-alpha/2)))
    # RESULTS
    print("Estimated mean:", orig)
    print("Classic standard error:", se_mean)
    print("Classic student c.i.:", orig + np.array([-qt, qt])*se_mean)
    print("\nBootstrap results:")
    print("Standard error:", se_mean_boot)
    print("t-type c.i.:", orig + np.array([-qt, qt])*se_mean_boot)
    print("Percentile c.i.:", quantile_boot)
    print("Basic c.i.:", 2*orig - quantile_boot[::-1])

    if plot:
        plt.hist(sampling_distribution, bins="fd")

[ 3 19 24  9 10 27  2  8 12  4 12 19  7 24 10 21 24  3 16 26 26 15 11 23  8
 25 27 15 29 17] 30
(100, 30)


## User variables

In [21]:
# where to find data
st = study('Olfacto')
score = 'Rec' #'Rec'
if score == 'Epi':
    path_data = path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_5s_Good_Bad_EpiScore/')
    save_path = path.join(st.path, 'feature/ERP_Good_Bad_100ms_rescale_filtered_stats_EM/')
if score == 'Rec':
    path_data =path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_Good_Bad_RecScore/')
    save_path = path.join(st.path, 'feature/ERP_Good_Bad_100ms_rescale_filtered_stats_Rec/')
# ANALYSIS PARAMETERS
low_pass_filter = 10.
sf = 512.
norm_mode = 'mean' #'ratio' 'mean' 'percent' 
baseline = [973 , 1024] #100ms before odor perception
data_to_use = [973, 1536] #1000ms after odor
time_points = data_to_use[1]-data_to_use[0]
winSample = 10 #in samples = 20ms
alpha = 0.05
n_rep = 1000
minsucc = 3
statistic = 'boot' #'perm'

-> Olfacto loaded


In [9]:
x = np.concatenate([np.random.exponential(size=200), np.random.normal(size=100)])
print(x.shape)
n = len(x)
reps = 10000
xb = np.random.choice(x, (n, reps))
print(xb.shape)
mb = xb.mean(axis=0)
mb.sort()

(300,)
(300, 10000)


## Plot ERPs for Odor groups

In [33]:
test = False

if test == True:
    n_elec = {'VACJ' :1}
    subjects = ['VACJ']
else :
    subjects = ['VACJ','SEMC','PIRJ','LEFC','MICP','CHAF'] 
    n_elec = {
    'CHAF' : 107,
    'VACJ' : 139, 
    'SEMC' : 107,
    'PIRJ' : 106,
    'LEFC' : 193,
    'MICP' : 105,
        }

for su in subjects:
    for elec in range(0, n_elec[su],1):
        #Load files
        badname = su+'_concat_odor_bad_bipo.npz'
        goodname = su+'_concat_odor_good_bipo.npz'
        data_bad = np.load(path.join(path_data, badname))
        data_good = np.load(path.join(path_data, goodname))
        data_bad, channel, label, data_good = data_bad['x'], data_bad['channel'], data_bad['label'], data_good['x']

        # Select data for one elec + name :
        data_elec_bad = data_bad[elec,:,:]
        data_elec_good = data_good[elec,:,:]
        ntrials = str(data_elec_bad.shape[1])+'/'+ str(data_elec_good.shape[1]) #to be displayed on figures
        print ('Channel : ', channel[elec], 'Label : ', label[elec], 'N_trials :', ntrials, 
               'Bad shape : ', data_elec_bad.shape, 'Good shape : ', data_elec_good.shape)

        #Filter data for one elec (all trials):
        data_elec_bad = np.array(data_elec_bad, dtype='float64')
        data_elec_good = np.array(data_elec_good, dtype='float64')
        data_bad_to_filter = np.swapaxes(data_elec_bad, 0, 1)
        data_good_to_filter = np.swapaxes(data_elec_good, 0, 1)
        filtered_data_bad = filter_data(data_bad_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        filtered_data_good = filter_data(data_good_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        print ('Size of filtered data bad :', filtered_data_bad.shape, 'filtered data good : ', filtered_data_good.shape,)

        #Normalize the non-averaged data (all trials)
        times = np.arange(filtered_data_bad.shape[1])
        print ('time points : ', times.shape)
        norm_filtered_data_bad = rescale(filtered_data_bad, times=times, baseline=baseline, mode=norm_mode)
        norm_filtered_data_good = rescale(filtered_data_good, times=times, baseline=baseline, mode=norm_mode)
        print ('Size norm & filtered data 0 : ', norm_filtered_data_bad.shape, norm_filtered_data_good.shape,)
        
        # Range of the data to compute
        data_range = range(data_to_use[0], data_to_use[1])
        # Select a time window in the data
        bad_sel = norm_filtered_data_bad[:, data_range]
        good_sel = norm_filtered_data_good[:, data_range,]
        
        # Average the signal on consecutive windows
        n_pts = bad_sel.shape[1]
        rmPoints = n_pts % winSample # Points to remove before splitting
        shapeRmPoints = np.arange(n_pts-rmPoints).astype(int) # Number of points for round division
        n_win = int(n_pts / winSample) # Number of segments

        # Split and average data (trials, n_pts)
        bad_split = np.array(np.split(bad_sel[:, shapeRmPoints], n_win, axis=1)) # n_win n_trials n_pts
        bad_split = np.mean(bad_split, axis=2).swapaxes(0,1) # n-trials n_win
        good_split = np.array(np.split(good_sel[:, shapeRmPoints], n_win, axis=1))
        good_split = np.mean(good_split, axis=2).swapaxes(0,1)

        # ========================== BALANCED CONDITIONS =====================================
        #for i in range(n_rep):
            #if bad_split.shape[0] > good_split.shape[0]:
                #bad_sel_stat = bad_split[np.random.randint(bad_split.shape[0], size=good_split.shape[0]), :] #reshape bad_data to fit good_data shape
                #good_sel_stat = good_split
            #elif bad_split.shape[0] < good_split.shape[0]:
                #bad_sel_stat = bad_split
                #good_sel_stat = good_split[np.random.randint(good_split.shape[0], size=bad_split.shape[0]), :]
            #else:
                #bad_sel_stat, good_sel_stat = bad_split, good_split
            
        # ===========================  STATISTICS  =====================================
        if statistic == 'perm':
            # Permutations and t test of the data
            bad_resample, good_resample = perm_swap(bad_sel_stat, good_sel_stat, n_perm=1000, axis=0)
            bad_resample, good_resample = np.swapaxes(bad_resample,0,1), np.swapaxes(good_resample,0,1)
            print('perm : ', bad_resample.shape, good_resample.shape)
            T0, _  = ttest_ind(bad_sel_stat, good_sel_stat, equal_var=False)
            print(T0.shape, T0.max(), T0.min())
        
        if statistic == 'boot':
            # Bootstrap and Welsh t-test
            bad_resample = boot_matrix(bad_split - bad_split.mean(), B=1000) # important centering step to get sampling distribution under the null
            good_resample = boot_matrix(good_split - good_split.mean(), B=1000)
            bad_resample, good_resample = np.swapaxes(bad_resample,0,1), np.swapaxes(good_resample,0,1)
            print('boot : ', bad_resample.shape, good_resample.shape)
            T0, _  = ttest_ind(bad_split, good_split, equal_var=False)
            print(T0.shape, T0.max(), T0.min())
            
        Trep, _ = ttest_ind(bad_resample, good_resample, equal_var=False)
        print(Trep.shape)
        thr_0_5_corr = [-perm_pvalue2level(Trep, p=0.05, maxst=True)[0],perm_pvalue2level(Trep, p=0.05, maxst=True)[0]]
        thr_0_1_corr = [-perm_pvalue2level(Trep, p=0.01, maxst=True)[0],perm_pvalue2level(Trep, p=0.01, maxst=True)[0]]
        thr_0_0_1_corr = [-perm_pvalue2level(Trep, p=0.001, maxst=True)[0],perm_pvalue2level(Trep, p=0.001, maxst=True)[0]]
        print(thr_0_5_corr,thr_0_1_corr,thr_0_0_1_corr)
        
        # Create the pvalue vector to plot
        pvals = []
        for i in range(T0.shape[0]):
            if T0[i] < thr_0_0_1_corr[0] or T0[i] > thr_0_0_1_corr[1]:
                pval = pvals.append(0.0009)
            elif T0[i] < thr_0_1_corr[0] or T0[i] > thr_0_1_corr[1]:
                pval = pvals.append(0.009)
            elif T0[i] < thr_0_5_corr[0] or T0[i] > thr_0_5_corr[1]:
                pval = pvals.append(0.04)
            else:
                pval = pvals.append(1)
            
        # ========================== PREPARE PLOTS AND SAVE STATS =========================================
        # plot and figure parameters
        xfmt = ScalarFormatter(useMathText=True)
        xfmt.set_powerlimits((0,3))
        fig = plt.figure(1,figsize=(9,6))
        title = 'ERP and Stats for '+su+' '+score+' Good/Bad '+ channel [elec] +' '+label[elec]+' ('+str(elec)+') ntrials:'+str(ntrials)
        fig.suptitle(title, fontsize=12)
        times_plot = 1000 * np.arange((baseline[0] - baseline[1]), len(shapeRmPoints)-baseline[1]+baseline[0],winSample) / sf
#         times_plot = 1000 * np.arange((baseline[0] - baseline[1]), data_to_use[1]-baseline[1]) / sf
        
        # Plot the ERPs
        plt.subplot(211)
        addLines(plt.gca(), vLines=[0], vColor=['r'], vWidth=[2], hLines=[0], 
                 hColor=['#000000'], hWidth=[2])
        BorderPlot(times_plot, Trep, kind='sd', alpha=0.2, linewidth=2, 
           ncol=1, xlabel='Time (ms)',ylabel ='t values', color='b',
          ylim=[Trep.min(),Trep.max()])
        plt.plot(times_plot, T0, '-', color='orange', linewidth=2)
        rmaxis(plt.gca(), ['right', 'top'])
        plt.legend(loc=0, handletextpad=0.1, frameon=False)
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

        # Plot DA for the ERPs
        plt.subplot(212)
        data_all = np.concatenate((bad_sel_stat,good_sel_stat), axis=0)
        label_bad = np.zeros(bad_sel_stat.shape[0], dtype='int64')
        label_good = np.ones(good_sel_stat.shape[0], dtype='int64')
        labels = np.concatenate((label_bad, label_good), axis=0)
        BorderPlot(times_plot, data_all, y=labels, kind='sem', alpha=0.2, color=['b', 'm'], 
                   linewidth=2, ncol=1, xlabel='Time (ms)',ylabel = r' $\mu$V', 
                   legend = ['bad', 'good'])
        print(pvals)
        addPval(plt.gca(), pvals, p=0.05, x=times_plot, y=2, color='orange', lw=2, minsucc=minsucc)
        addPval(plt.gca(), pvals, p=0.01, x=times_plot, y=2, color='orangered', lw=2,minsucc=minsucc)
        addPval(plt.gca(), pvals, p=0.001, x=times_plot, y=2, color='r', lw=2,minsucc=minsucc)
        addLines(plt.gca(), vLines=[0], vColor=['r'], vWidth=[2], hLines=[0], 
                 hColor=['#000000'], hWidth=[2])
        rmaxis(plt.gca(), ['right', 'top'])
        plt.legend(loc=0, handletextpad=0.1, frameon=False)
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

        # Criteria to be significant
        pvals = np.ravel(pvals)
        underp = np.where(pvals < alpha)[0]
        pvsplit = np.split(underp, np.where(np.diff(underp) != 1)[0]+1)
        signif = [True for k in pvsplit if len(k) >= minsucc ]
        #Save plots and stats
        if len(signif) >=1:
                name_tval_obs = (save_path +str(round((winSample*1000)/512))+'ms/Significant/'+su +'_tval_obs_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
                name_tvals_perm = (save_path +str(round((winSample*1000)/512))+'ms/Significant/'+su +'_tvals_perm_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
                name_pval = (save_path +str(round((winSample*1000)/512))+'ms/Significant/'+su +'_pvals_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
                plot_name = (save_path +str(round((winSample*1000)/512))+'ms/Significant/'+su +'_ERPs_'  + score +'_'+label[elec]+'_('+str(elec)+').png')
        else:
            name_tval_obs = (save_path +str(round((winSample*1000)/512))+'ms/'+su +'_tval_obs_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
            name_tvals_perm = (save_path +str(round((winSample*1000)/512))+'ms/'+su +'_tvals_perm_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
            name_pval = (save_path +str(round((winSample*1000)/512))+'ms/'+su +'_pvals_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
            plot_name = (save_path +str(round((winSample*1000)/512))+'ms/'+su +'_ERPs_'  + score +'_'+label[elec]+'_('+str(elec)+').png')
        
        np.save(name_tval_obs, T0)
        np.save(name_tvals_perm, Trep)
        np.save(name_pval, pvals)
        plt.savefig(plot_name, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close()


Channel :  b2-b1 Label :  mHC-Ent N_trials : 11/15 Bad shape :  (2560, 11) Good shape :  (2560, 15)
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Size of filtered data bad : (11, 2560) filtered data good :  (15, 2560)
time points :  (2560,)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Size norm & filtered data 0 :  (11, 2560) (15, 2560)
boot :  (11, 1000, 56) (15, 1000, 56)
(56,) 1.60764360017 -2.15681817157
(1000, 56)
[-2.3956835577613167, 2.3956835577613167] [-3.1250864928293258, 3.1250864928293258] [-4.0600124283449865, 4.0600124283449865]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]




Channel :  b3-b2 Label :  mHC-PHG&mHC-Ent N_trials : 11/15 Bad shape :  (2560, 11) Good shape :  (2560, 15)
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Size of filtered data bad : (11, 2560) filtered data good :  (15, 2560)
time points :  (2560,)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Size norm & filtered data 0 :  (11, 2560) (15, 2560)
boot :  (11, 1000, 56) (15, 1000, 56)
(56,) 2.16455192668 -2.2106178267
(1000, 56)
[-3.8748825762902879, 3.8748825762902879] [-5.0740312057248609, 5.0740312057248609] [-7.6187062979701476, 7.6187062979701476]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Channel :  b4-b3 Label :  mHC-PH

In [None]:
print(len(pvals))

In [None]:
#pvals = [1, 0.04, 1, 1, 0.04, 1, 1, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.04, 0.04, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
pvals = [1,1,1,1,1]
pvals = np.ravel(pvals)
underp = np.where(pvals < alpha)[0]
pvsplit = np.split(underp, np.where(np.diff(underp) != 1)[0]+1)
for k in pvsplit:
    if len(k) >= minsucc:
        print('daji')
else:
    print('')
succlst = [True for k in pvsplit if len(k) >= minsucc ]
print(underp, pvsplit,pvsplit, succlst, len(succlst))
# if pvsplit[0].shape[0] >= minsucc: