## Calculate ERPs

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from os import path
from matplotlib.ticker import ScalarFormatter, MaxNLocator
#%matplotlib notebook

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.visual import *
from brainpipe.statistics import *

from mne.baseline import rescale
from mne.filter import filter_data
from scipy.stats import *
# from mne.stats import *
import time

## User variables

In [2]:
# where to find data
st = study('Olfacto')
score = 'Rec' #'Rec'
if score == 'Epi':
    path_data = path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_5s_Good_Bad_EpiScore/')
    save_path = path.join(st.path, 'feature/ERP_Good_Bad_100ms_rescale_filtered_stats_EM/Stats_No_Bootstrap_Unpaired/')
if score == 'Rec':
    path_data =path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_Good_Bad_RecScore/')
    save_path = path.join(st.path, 'feature/ERP_Good_Bad_100ms_rescale_filtered_stats_Rec/Stats_No_Bootstrap_Unpaired/')
# ANALYSIS PARAMETERS
low_pass_filter = 10.
sf = 512.
norm_mode = 'mean' #'ratio' 'mean' 'percent' 
baseline = [973 , 1024] #100ms before odor perception
data_to_use = [973, 1536] #1000ms after odor
time_points = data_to_use[1]-data_to_use[0]
winSample = 10 #in samples = 20ms
alpha = 0.05
minsucc = 3

-> Olfacto loaded


## Plot ERPs for Odor groups

In [4]:
test = False

if test == True:
    n_elec = {'VACJ' :6}
    subjects = ['VACJ']
else :
    subjects = ['VACJ','SEMC','PIRJ','LEFC','MICP','CHAF'] 
    n_elec = {
    'CHAF' : 107,
    'VACJ' : 139, 
    'SEMC' : 107,
    'PIRJ' : 106,
    'LEFC' : 193,
    'MICP' : 105,
        }

for su in subjects:
    for elec in range(0, n_elec[su],1):
        tic = time.clock()
        #Load files
        badname = su+'_concat_odor_bad_bipo.npz'
        goodname = su+'_concat_odor_good_bipo.npz'
        data_bad = np.load(path.join(path_data, badname))
        data_good = np.load(path.join(path_data, goodname))
        data_bad, channel, label, data_good = data_bad['x'], data_bad['channel'], data_bad['label'], data_good['x']

        # Select data for one elec + name :
        data_elec_bad = data_bad[elec,:,:]
        data_elec_good = data_good[elec,:,:]
        ntrials = str(data_elec_bad.shape[1])+'/'+ str(data_elec_good.shape[1]) #to be displayed on figures
        print ('Channel : ', channel[elec], 'Label : ', label[elec], 'N_trials :', ntrials, 
               'Bad shape : ', data_elec_bad.shape, 'Good shape : ', data_elec_good.shape)

        #Filter data for one elec (all trials):
        data_elec_bad = np.array(data_elec_bad, dtype='float64')
        data_elec_good = np.array(data_elec_good, dtype='float64')
        data_bad_to_filter = np.swapaxes(data_elec_bad, 0, 1)
        data_good_to_filter = np.swapaxes(data_elec_good, 0, 1)
        filtered_data_bad = filter_data(data_bad_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        filtered_data_good = filter_data(data_good_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        print ('Size of filtered data bad :', filtered_data_bad.shape, 'filtered data good : ', filtered_data_good.shape,)

        #Normalize the non-averaged data (all trials)
        times = np.arange(filtered_data_bad.shape[1])
        print ('time points : ', times.shape)
        norm_filtered_data_bad = rescale(filtered_data_bad, times=times, baseline=baseline, mode=norm_mode)
        norm_filtered_data_good = rescale(filtered_data_good, times=times, baseline=baseline, mode=norm_mode)
        print ('Size norm & filtered data 0 : ', norm_filtered_data_bad.shape, norm_filtered_data_good.shape,)
        
        # Range of the data to compute
        data_range = range(data_to_use[0], data_to_use[1])
        # Select a time window in the data
        bad_sel = norm_filtered_data_bad[:, data_range]
        good_sel = norm_filtered_data_good[:, data_range,]
        
        # Average the signal on consecutive windows
        n_pts = bad_sel.shape[1]
        rmPoints = n_pts % winSample # Points to remove before splitting
        shapeRmPoints = np.arange(n_pts-rmPoints).astype(int) # Number of points for round division
        n_win = int(n_pts / winSample) # Number of segments

        # Split and average data (trials, n_pts)
        bad_split = np.array(np.split(bad_sel[:, shapeRmPoints], n_win, axis=1)) # n_win n_trials n_pts
        bad_split = np.mean(bad_split, axis=2).swapaxes(0,1) # n-trials n_win
        good_split = np.array(np.split(good_sel[:, shapeRmPoints], n_win, axis=1))
        good_split = np.mean(good_split, axis=2).swapaxes(0,1)

        # ===========================  STATISTICS  =====================================
        # Permutations and t test of the data
        bad_perm, good_perm = perm_swap(bad_split, good_split, n_perm=1000, axis=0)
        bad_perm, good_perm = np.swapaxes(bad_perm,0,1), np.swapaxes(good_perm,0,1)
        print(bad_perm.shape, good_perm.shape)
        Tperm, _ = ttest_ind(bad_perm, good_perm, equal_var=False)
        print(Tperm.shape)
        thr_0_5_corr = [-perm_pvalue2level(Tperm, p=0.05, maxst=True)[0],perm_pvalue2level(Tperm, p=0.05, maxst=True)[0]]
        thr_0_1_corr = [-perm_pvalue2level(Tperm, p=0.01, maxst=True)[0],perm_pvalue2level(Tperm, p=0.01, maxst=True)[0]]
        thr_0_0_1_corr = [-perm_pvalue2level(Tperm, p=0.001, maxst=True)[0],perm_pvalue2level(Tperm, p=0.001, maxst=True)[0]]
        print(thr_0_5_corr,thr_0_1_corr,thr_0_0_1_corr)
        T0, _  = ttest_ind(bad_split, good_split, equal_var=False)
        print(T0.shape, T0.max(), T0.min())
        
        # Create the pvalue vector to plot
        pvals = []
        for i in range(T0.shape[0]):
            if T0[i] < thr_0_0_1_corr[0] or T0[i] > thr_0_0_1_corr[1]:
                pval = pvals.append(0.0009)
            elif T0[i] < thr_0_1_corr[0] or T0[i] > thr_0_1_corr[1]:
                pval = pvals.append(0.009)
            elif T0[i] < thr_0_5_corr[0] or T0[i] > thr_0_5_corr[1]:
                pval = pvals.append(0.04)
            else:
                pval = pvals.append(1)
            
                
        # ========================== PREPARE PLOTS AND SAVE STATS =========================================
        # plot and figure parameters
        xfmt = ScalarFormatter(useMathText=True)
        xfmt.set_powerlimits((0,3))
        fig = plt.figure(1,figsize=(9,6))
        title = 'ERP and Stats for '+su+' '+score+' Good/Bad '+ channel [elec] +' '+label[elec]+' ('+str(elec)+') ntrials:'+str(ntrials)
        fig.suptitle(title, fontsize=12)
        times_plot = 1000 * np.arange((baseline[0] - baseline[1]), len(shapeRmPoints)-baseline[1]+baseline[0],winSample) / sf
#         times_plot = 1000 * np.arange((baseline[0] - baseline[1]), data_to_use[1]-baseline[1]) / sf
        
        # Plot the STATS
        plt.subplot(211)
        addLines(plt.gca(), vLines=[0], vColor=['r'], vWidth=[2], hLines=[0], 
                 hColor=['#000000'], hWidth=[2])
        BorderPlot(times_plot, Tperm, kind='sd', alpha=0.2, linewidth=2, 
           ncol=1, xlabel='Time (ms)',ylabel ='t values', color='b',
          ylim=[Tperm.min(),Tperm.max()])
        plt.plot(times_plot, T0, '-', color='orange', linewidth=2)
        rmaxis(plt.gca(), ['right', 'top'])
        plt.legend(loc=0, handletextpad=0.1, frameon=False)
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

        # Plot the ERPs
        plt.subplot(212)
        data_all = np.concatenate((bad_split,good_split), axis=0)
        label_bad = np.zeros(bad_split.shape[0], dtype='int64')
        label_good = np.ones(good_split.shape[0], dtype='int64')
        labels = np.concatenate((label_bad, label_good), axis=0)
        BorderPlot(times_plot, data_all, y=labels, kind='sem', alpha=0.2, color=['b', 'm'], 
                   linewidth=2, ncol=1, xlabel='Time (ms)',ylabel = r' $\mu$V', 
                   legend = ['bad', 'good'])
        print(pvals)
        addPval(plt.gca(), pvals, p=0.05, x=times_plot, y=2, color='orange', lw=2, minsucc=minsucc)
        addPval(plt.gca(), pvals, p=0.01, x=times_plot, y=2, color='orangered', lw=2,minsucc=minsucc)
        addPval(plt.gca(), pvals, p=0.001, x=times_plot, y=2, color='r', lw=2,minsucc=minsucc)
        addLines(plt.gca(), vLines=[0], vColor=['r'], vWidth=[2], hLines=[0], 
                 hColor=['#000000'], hWidth=[2])
        rmaxis(plt.gca(), ['right', 'top'])
        plt.legend(loc=0, handletextpad=0.1, frameon=False)
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

        # Criteria to be significant
        pvals = np.ravel(pvals)
        underp = np.where(pvals < alpha)[0]
        pvsplit = np.split(underp, np.where(np.diff(underp) != 1)[0]+1)
        signif = [True for k in pvsplit if len(k) >= minsucc ]
        #Save plots and stats
        if len(signif) >=1:
                name_tval_obs = (save_path +str(round((winSample*1000)/512))+'ms/Significant/'+su +'_tval_obs_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
                name_tvals_perm = (save_path +str(round((winSample*1000)/512))+'ms/Significant/'+su +'_tvals_perm_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
                name_pval = (save_path +str(round((winSample*1000)/512))+'ms/Significant/'+su +'_pvals_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
                plot_name = (save_path +str(round((winSample*1000)/512))+'ms/Significant/'+su +'_ERPs_'  + score +'_'+label[elec]+'_('+str(elec)+').png')
        else:
            name_tval_obs = (save_path +str(round((winSample*1000)/512))+'ms/'+su +'_tval_obs_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
            name_tvals_perm = (save_path +str(round((winSample*1000)/512))+'ms/'+su +'_tvals_perm_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
            name_pval = (save_path +str(round((winSample*1000)/512))+'ms/'+su +'_pvals_' + score +'_'+label[elec]+'_('+str(elec)+').npy')
            plot_name = (save_path +str(round((winSample*1000)/512))+'ms/'+su +'_ERPs_'  + score +'_'+label[elec]+'_('+str(elec)+').png')
        
        np.save(name_tval_obs, T0)
        np.save(name_tvals_perm, Tperm)
        np.save(name_pval, pvals)
        plt.savefig(plot_name, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close()


Channel :  b2-b1 Label :  mHC-Ent N_trials : 11/15 Bad shape :  (2560, 11) Good shape :  (2560, 15)
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Size of filtered data bad : (11, 2560) filtered data good :  (15, 2560)
time points :  (2560,)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Size norm & filtered data 0 :  (11, 2560) (15, 2560)
(11, 1000, 56) (15, 1000, 56)
(1000, 56)
[-1.8641699015968691, 1.8641699015968691] [-2.8870793003197006, 2.8870793003197006] [-3.8545404768497566, 3.8545404768497566]
(56,) 1.60764360017 -2.15681817157
[1, 1, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]




Channel :  b3-b2 Label :  mHC-PHG&mHC-Ent N_trials : 11/15 Bad shape :  (2560, 11) Good shape :  (2560, 15)
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Size of filtered data bad : (11, 2560) filtered data good :  (15, 2560)
time points :  (2560,)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Size norm & filtered data 0 :  (11, 2560) (15, 2560)
(11, 1000, 56) (15, 1000, 56)
(1000, 56)
[-1.8807035479710639, 1.8807035479710639] [-2.8461735193995481, 2.8461735193995481] [-4.5328758563691869, 4.5328758563691869]
(56,) 2.16455192668 -2.2106178267
[1, 1, 1, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.04, 0.04, 0.04, 1, 1, 1, 1, 1, 0.04, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Channel :  b4-b3 Label

In [None]:
print(len(pvals))

In [None]:
#pvals = [1, 0.04, 1, 1, 0.04, 1, 1, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.04, 0.04, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
pvals = [1,1,1,1,1]
pvals = np.ravel(pvals)
underp = np.where(pvals < alpha)[0]
pvsplit = np.split(underp, np.where(np.diff(underp) != 1)[0]+1)
for k in pvsplit:
    if len(k) >= minsucc:
        print('daji')
else:
    print('')
succlst = [True for k in pvsplit if len(k) >= minsucc ]
print(underp, pvsplit,pvsplit, succlst, len(succlst))
# if pvsplit[0].shape[0] >= minsucc: