In [None]:
## Import Libraries

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator
import scipy.io as sio

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.feature import power, amplitude, sigfilt
from brainpipe.visual import *
from brainpipe.statistics import *
from scipy.stats import *

from os import path
from mne.stats import *
from mne.baseline import rescale
from mne.filter import filter_data
import time

## User variables

In [2]:
# where to find data
st = study('Olfacto')
score = 'Rec' #'Rec'
if score == 'Epi':
    path_data = path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_5s_Good_Bad_EpiScore/')
    save_path = path.join(st.path, 'classified/All_balanced_1_1000perm_DA_stats/')
if score == 'Rec':
    path_data = path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_Good_Bad_RecScore/')
    save_path = path.join(st.path, 'classified/0_Classif_ERP_RecScore_all_electrodes/')

# ANALYSIS PARAMETERS
low_pass_filter = 10.
sf = 512.
norm_mode = 'mean' #'ratio' 'mean' 'percent' 
baseline = [973 , 1024] #100ms before odor perception
data_to_use = [973, 1536] #1000ms after odor
time_points = data_to_use[1]-data_to_use[0]
classif = 'lda'
n_rep = 1 #bootstrap
alpha = 0.05
winSample = 10 #in samples = 20ms
minsucc = 3 #nb of continuous samples to be significant

-> Olfacto loaded


## ERPs Decoding - Good Bad Odors Encoding

In [3]:
test = True

if test == True:
    n_elec = {'PIRJ' :1}
    subjects = ['PIRJ']
else :
    subjects = ['MICP','VACJ','SEMC','PIRJ','LEFC','CHAF'] 
    n_elec = {
    'CHAF' : 107,
    'VACJ' : 139, 
    'SEMC' : 107,
    'PIRJ' : 106,
    'LEFC' : 193,
    'MICP' : 105,
        }

for su in subjects:
    #Load files
    data_bad = np.load(path.join(path_data, su+'_concat_odor_bad_bipo.npz'))
    data_good = np.load(path.join(path_data, su+'_concat_odor_good_bipo.npz'))
    data_bad, channels, names, data_good = data_bad['x'], data_bad['channel'], data_bad['label'], data_good['x']

    for elec in range(0,n_elec[su]):
    #for elec in n_elec[su]:
        # Select data for one elec + name :
        data_elec_bad = data_bad[elec,:,:]
        data_elec_good = data_good[elec,:,:]
        ntrials = str(data_elec_bad.shape[1])+'/'+ str(data_elec_good.shape[1])
        channel, name = channels[elec], names[elec]
        print (su, 'Channel : ', channel, 'Label : ', name,'Bad shape : ', 
               data_elec_bad.shape, 'Good shape : ', data_elec_good.shape)

        #Filter data for one elec (all trials):
        data_elec_bad = np.array(data_elec_bad, dtype='float64')
        data_elec_good = np.array(data_elec_good, dtype='float64')
        data_bad_to_filter = np.swapaxes(data_elec_bad, 0, 1)
        data_good_to_filter = np.swapaxes(data_elec_good, 0, 1)
        filtered_data_bad = filter_data(data_bad_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        filtered_data_good = filter_data(data_good_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        print ('Size of filtered data bad :', filtered_data_bad.shape, 'filtered data good : ', filtered_data_good.shape,)

        #Normalize the non-averaged data (all trials)
        times = np.arange(filtered_data_bad.shape[1])
        print ('time points : ', times.shape)
        norm_filtered_data_bad = rescale(filtered_data_bad, times=times, baseline=baseline, mode=norm_mode)
        norm_filtered_data_good = rescale(filtered_data_good, times=times, baseline=baseline, mode=norm_mode)
        print ('Size norm & filtered data 0 : ', norm_filtered_data_bad.shape, norm_filtered_data_good.shape,)

        # Range of the data to compute
        data_range = range(data_to_use[0], data_to_use[1])
        bad_sel = norm_filtered_data_bad[:, data_range]
        good_sel = norm_filtered_data_good[:, data_range,]
        print ('-> Shape of bad data', bad_sel.shape, 'good data', good_sel.shape)
        
        # Average the signal on consecutive windows
        n_pts = bad_sel.shape[1]
        rmPoints = n_pts % winSample # Points to remove before splitting
        shapeRmPoints = np.arange(n_pts-rmPoints).astype(int) # Number of points for round division
        n_win = int(n_pts / winSample) # Number of segments

        # Split and average data (trials, n_pts)
        bad_split = np.array(np.split(bad_sel[:, shapeRmPoints], n_win, axis=1)) # n_win n_trials n_pts
        bad_split = np.mean(bad_split, axis=2).swapaxes(0,1) # n-trials n_win
        good_split = np.array(np.split(good_sel[:, shapeRmPoints], n_win, axis=1))
        good_split = np.mean(good_split, axis=2).swapaxes(0,1)

# ==========================  BALANCED CONDITIONS - Bootstrap  =====================================
        if bad_split.shape[0] > good_split.shape[0]:
            bad_sel_stat = bad_split[np.random.randint(bad_split.shape[0], size=good_split.shape[0]), :] #reshape bad_data to fit good_data shape
            good_sel_stat = good_split
        elif bad_split.shape[0] < good_split.shape[0]:
            bad_sel_stat = bad_split
            good_sel_stat = good_split[np.random.randint(good_split.shape[0], size=bad_split.shape[0]), :]
        else:
            bad_sel_stat, good_sel_stat = bad_split, good_split
        print ('balanced data : ', bad_sel_stat.shape, good_sel_stat.shape)
# ===========================  STATISTICS  =====================================
        # Permutations and t test of the data
        bad_perm, good_perm = perm_swap(bad_sel_stat, good_sel_stat, n_perm=1000, axis=0)
        bad_perm, good_perm = np.swapaxes(bad_perm,0,1), np.swapaxes(good_perm,0,1)
        print(bad_perm.shape, good_perm.shape)
        Tperm, _ = ttest_ind(bad_perm, good_perm, equal_var=False)
        print(Tperm.shape)
        thr_0_5_stat = [-perm_pvalue2level(Tperm, p=0.05, maxst=True)[0],perm_pvalue2level(Tperm, p=0.05, maxst=True)[0]]
        thr_0_1_stat = [-perm_pvalue2level(Tperm, p=0.01, maxst=True)[0],perm_pvalue2level(Tperm, p=0.01, maxst=True)[0]]
        thr_0_0_1_stat = [-perm_pvalue2level(Tperm, p=0.001, maxst=True)[0],perm_pvalue2level(Tperm, p=0.001, maxst=True)[0]]
        print(thr_0_5_stat,thr_0_1_stat,thr_0_0_1_stat)
        T0, _  = ttest_ind(bad_split, good_split, equal_var=False)
        print(T0.shape, T0.max(), T0.min())
        
        # Create the pvalue vector to plot
        pvals = []
        for i in range(T0.shape[0]):
            if T0[i] < thr_0_0_1_stat[0] or T0[i] > thr_0_0_1_stat[1]:
                pval = pvals.append(0.0009)
            elif T0[i] < thr_0_1_stat[0] or T0[i] > thr_0_1_stat[1]:
                pval = pvals.append(0.009)
            elif T0[i] < thr_0_5_stat[0] or T0[i] > thr_0_5_stat[1]:
                pval = pvals.append(0.04)
            else:
                pval = pvals.append(1)
                
# =============================  CLASSIFICATION COMPUTATION ============================================================           
        #create a data matrix, concatenate along the trial dimension
        bad_good = np.concatenate((bad_sel_stat, good_sel_stat), axis=0)
        print ('Size of the concatenated data: ', bad_good.shape, 'Number time windows : ', bad_good.shape[1])
        #create label vector (0 for rest and 1 for odor)
        y = [0]*bad_sel_stat.shape[0] + [1]*good_sel_stat.shape[0]
        print ('Size of label for classif: ', len(y))
        # Define a cross validation:
        cv = defCv(y, n_folds=10, cvtype='skfold', rep=10)
        # Define classifier technique
        clf = defClf(y=y, clf=classif)#,n_tree=200, random_state=100)
        #Classify rest and odor
        cl = classify(y, clf=clf, cvtype=cv)
        # Evaluate the classifier on data:
        da,pvalues,daperm = cl.fit(bad_good, n_perm=1000,method='full_rnd', mf=False)
        #print(pvalues.shape, pvalues.min(), pvalues.max())
        th_0_05_perm = perm_pvalue2level(daperm, p=0.05, maxst=True)
        th_0_01_perm = perm_pvalue2level(daperm, p=0.01, maxst=True)
        th_0_001_perm = perm_pvalue2level(daperm, p=0.001, maxst=True)
        print('th_perm : ', th_0_05_perm[0], th_0_01_perm[0], th_0_001_perm[0])

# ============================== PLOT ERPs ANALYSIS + STATS & DECODING ACCURACY ===================================================
        # plot and figure parameters
        xfmt = ScalarFormatter(useMathText=True)
        xfmt.set_powerlimits((0,3))
        fig = plt.figure(1,figsize=(7,7))
        title = 'ERPs-Stats-DA for '+su+' Good/Bad '+str(channel)+' '+str(name)+' ('+str(elec)+') ntrials:'+str(ntrials)
        fig.suptitle(title, fontsize=12)
        times_plot = 1000 * np.arange((baseline[0] - baseline[1]), len(shapeRmPoints)-baseline[1]+baseline[0],winSample) / sf

        # Plot the ERPs + STATS
        plt.subplot(211)
        BorderPlot(times_plot, bad_good, y=y, kind='sem', alpha=0.2, color=['b','m'], 
                   linewidth=2, ncol=1, xlabel='Time (ms)',ylabel = r' $\mu$V', legend=['bad','good'])
        addLines(plt.gca(), vLines=[0], vColor=['r'], vWidth=[2], hLines=[0], 
                 hColor=['#000000'], hWidth=[2])
        addPval(plt.gca(), pvals, p=0.05, x=times_plot, y=2, color='orange', lw=2, minsucc=minsucc)
        addPval(plt.gca(), pvals, p=0.01, x=times_plot, y=2, color='orangered', lw=2,minsucc=minsucc)
        addPval(plt.gca(), pvals, p=0.001, x=times_plot, y=2, color='r', lw=2,minsucc=minsucc)
        rmaxis(plt.gca(), ['right', 'top'])
        plt.legend(loc=0, handletextpad=0.1, frameon=False)
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

        # Plot DA for the ERPs
        plt.subplot(212)
        BorderPlot(times_plot, da, color='b', kind='sem',xlabel='Time (ms)', 
                   ylim=[da.min()-10,da.max()+10], ylabel='Decoding accuracy (%)',
                   linewidth=2, alpha=0.3)
        rmaxis(plt.gca(), ['right', 'top'])
        addLines(plt.gca(), vLines=[0], vWidth=[2], vColor=['r'], hLines=[50], 
                 hColor=['#000000'], hWidth=[2])
        plt.legend(loc=0, handletextpad=0.1, frameon=False)   
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))
        plt.plot(times_plot, th_0_05_perm*np.ones(len(times_plot)), '--', color='orange', linewidth=2)
        plt.plot(times_plot, th_0_01_perm*np.ones(len(times_plot)), '--', color='orangered', linewidth=2)
        plt.plot(times_plot, th_0_001_perm*np.ones(len(times_plot)), '--', color='r', linewidth=2)
        
        # Criteria to be significant
        pvals = np.ravel(pvals)
        underp = np.where(pvals < alpha)[0]
        pvsplit = np.split(underp, np.where(np.diff(underp) != 1)[0]+1)
        signif = [True for k in pvsplit if len(k) >= minsucc]
        
        #Save plots and stats
        if len(signif) >=1:
            name_t0 = (save_path +str(round((winSample*1000)/512))+'ms/Significant/stat/'+su +'_t0_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            name_tperm = (save_path +str(round((winSample*1000)/512))+'ms/Significant/stat/'+su +'_tperm_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            name_pval = (save_path +str(round((winSample*1000)/512))+'ms/Significant/stat/'+su +'_pvals_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            name_da = (save_path +str(round((winSample*1000)/512))+'ms/Significant/da/'+su +'_da_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            name_daperm = (save_path +str(round((winSample*1000)/512))+'ms/Significant/da/'+su +'_daperm_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            plot_name = (save_path +str(round((winSample*1000)/512))+'ms/Significant/fig'+su +'_ERPs_'  + score +'_'+str(name)+'_('+str(elec)+').png')
        else:
            name_t0 = (save_path +str(round((winSample*1000)/512))+'ms/Not_Significant/stat/'+su +'_t0_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            name_tperm = (save_path +str(round((winSample*1000)/512))+'ms/Not_Significant/stat/'+su +'_tperm_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            name_pval = (save_path +str(round((winSample*1000)/512))+'ms/Not_Significant/stat/'+su +'_pvals_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            name_da = (save_path +str(round((winSample*1000)/512))+'ms/Not_Significant/da/'+su +'_da_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            name_daperm = (save_path +str(round((winSample*1000)/512))+'ms/Not_Significant/da/'+su +'_daperm_' + score +'_'+str(name)+'_('+str(elec)+').npy')
            plot_name = (save_path +str(round((winSample*1000)/512))+'ms/Not_Significant/fig'+su +'_ERPs_'  + score +'_'+str(name)+'_('+str(elec)+').png')
        
        np.save(name_t0, T0)
        np.save(name_tperm, Tperm)
        np.save(name_pval, pvals)
        np.save(name_da, da)
        np.save(name_daperm, daperm)
        plt.savefig(plot_name, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close() 
        del bad_sel, good_sel, good_sel_stat, bad_sel_stat, bad_split, good_split
    del data_bad, data_good, channels, names

PIRJ Channel :  b2-b1 Label :  aHC Bad shape :  (2560, 13) Good shape :  (2560, 15)
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 1352 samples (2.641 sec) selected
Size of filtered data bad : (13, 2560) filtered data good :  (15, 2560)
time points :  (2560,)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Size norm & filtered data 0 :  (13, 2560) (15, 2560)
-> Shape of bad data (13, 563) good data (15, 563)
balanced data :  (13, 56) (13, 56)
(13, 1000, 56) (13, 1000, 56)
(1000, 56)
[-1.8404192093221812, 1.8404192093221812] [-2.6882365729993345, 2.6882365729993345] [-4.2575867002976704, 4.2575867002976704]
(56,) 1.32703039944 -2.19189729898
Size of the concatenated data:  (26, 56) Number time windows :  56
Size of label for classif:  26
th_perm :  69.2307692308 76.9230769231 92.307692



NameError: name 'label' is not defined