In [None]:
## Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator
import scipy.io as sio

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.feature import power, amplitude, sigfilt
from brainpipe.visual import *
from brainpipe.statistics import *

from os import path
from mne.stats import *
from mne.baseline import rescale
from mne.filter import filter_data
import time

## User variables

In [None]:
# where to find data
st = study('Olfacto')
score = 'Rec' #'Rec'
if score == 'Epi':
    path_data = path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_5s_Good_Bad_EpiScore/')
    save_path = path.join(st.path, 'classified/0_Classif_ERP_EpiScore_all_electrodes/')
if score == 'Rec':
    path_data = path.join (st.path, 'database/TS_E_all_by_odor_th40_art400_30_250_Good_Bad_RecScore/')
    save_path = path.join(st.path, 'classified/0_Classif_ERP_RecScore_all_electrodes/')

# ANALYSIS PARAMETERS
low_pass_filter = 10.
sf = 512.
norm_mode = 'mean' #'ratio' 'mean' 'percent' 
baseline = [973 , 1024] #100ms before odor perception
data_to_use = [973, 1536] #1000ms after odor
time_points = data_to_use[1]-data_to_use[0]
classif = 'lda'
n_rep = 1 #bootstrap
alpha = 0.05
winSample = 10 #in samples = 20ms

## ERPs Decoding - Good Bad Odors Encoding

In [None]:
test = False

if test == True:
    n_elec = {'PIRJ' :1}
    
    subjects = ['PIRJ']
    
else :
    subjects = ['LEFC','CHAF'] #'MICP','VACJ','SEMC','PIRJ',
    n_elec = {
    'CHAF' : 107,
    'VACJ' : 139, 
    'SEMC' : 107,
    'PIRJ' : 106,
    'LEFC' : 193,
    'MICP' : 105,
        }

for su in subjects:
    #Load files
    data_bad = np.load(path.join(path_data, su+'_concat_odor_bad_bipo.npz'))
    data_good = np.load(path.join(path_data, su+'_concat_odor_good_bipo.npz'))
    data_bad, channels, names, data_good = data_bad['x'], data_bad['channel'], data_bad['label'], data_good['x']

    for elec in range(0,n_elec[su]):
    #for elec in n_elec[su]:
        tic = time.clock()
        # Select data for one elec + name :
        data_elec_bad = data_bad[elec,:,:]
        data_elec_good = data_good[elec,:,:]
        ntrials = str(data_elec_bad.shape[1])+'/'+ str(data_elec_good.shape[1])
        channel, name = channels[elec], names[elec]
        print (su, 'Channel : ', channel, 'Label : ', name,'Bad shape : ', 
               data_elec_bad.shape, 'Good shape : ', data_elec_good.shape)

        #Filter data for one elec (all trials):
        data_elec_bad = np.array(data_elec_bad, dtype='float64')
        data_elec_good = np.array(data_elec_good, dtype='float64')
        data_bad_to_filter = np.swapaxes(data_elec_bad, 0, 1)
        data_good_to_filter = np.swapaxes(data_elec_good, 0, 1)
        filtered_data_bad = filter_data(data_bad_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        filtered_data_good = filter_data(data_good_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
        print ('Size of filtered data bad :', filtered_data_bad.shape, 'filtered data good : ', filtered_data_good.shape,)

        #Normalize the non-averaged data (all trials)
        times = np.arange(filtered_data_bad.shape[1])
        print ('time points : ', times.shape)
        norm_filtered_data_bad = rescale(filtered_data_bad, times=times, baseline=baseline, mode=norm_mode)
        norm_filtered_data_good = rescale(filtered_data_good, times=times, baseline=baseline, mode=norm_mode)
        print ('Size norm & filtered data 0 : ', norm_filtered_data_bad.shape, norm_filtered_data_good.shape,)

        # Range of the data to compute
        data_range = range(data_to_use[0], data_to_use[1])
        bad_sel = norm_filtered_data_bad[:, data_range]
        good_sel = norm_filtered_data_good[:, data_range,]
        print ('-> Shape of bad data', bad_sel.shape, 'good data', good_sel.shape)
        
        # Average the signal on consecutive windows
        n_pts = bad_sel.shape[1]
        rmPoints = n_pts % winSample # Points to remove before splitting
        shapeRmPoints = np.arange(n_pts-rmPoints).astype(int) # Number of points for round division
        n_win = int(n_pts / winSample) # Number of segments

        # Split and average data (trials, n_pts)
        bad_split = np.array(np.split(bad_sel[:, shapeRmPoints], n_win, axis=1)) # n_win n_trials n_pts
        bad_split = np.mean(bad_split, axis=2).swapaxes(0,1) # n-trials n_win
        good_split = np.array(np.split(good_sel[:, shapeRmPoints], n_win, axis=1))
        good_split = np.mean(good_split, axis=2).swapaxes(0,1)

# ==========================  BALANCED CONDITIONS - Bootstrap  =====================================
        da_rep, daperm_rep = np.array([]), np.array([])
        bad_rep, good_rep = np.array([]), np.array([])
        for i in range(n_rep):
            if bad_split.shape[0] > good_split.shape[0]:
                bad_sel_stat = bad_split[np.random.randint(bad_split.shape[0], size=good_split.shape[0]), :] #reshape bad_data to fit good_data shape
                good_sel_stat = good_split
            elif bad_split.shape[0] < good_split.shape[0]:
                bad_sel_stat = bad_split
                good_sel_stat = good_split[np.random.randint(good_split.shape[0], size=bad_split.shape[0]), :]
            else:
                bad_sel_stat, good_sel_stat = bad_split, good_split
            print ('balanced data : ', bad_sel_stat.shape, good_sel_stat.shape)
                
# =============================  CLASSIFICATION COMPUTATION ============================================================           
            #create a data matrix, concatenate along the trial dimension
            bad_good = np.concatenate((bad_sel_stat, good_sel_stat), axis=0)
            #print ('Size of the concatenated data: ', bad_good.shape, 'Number time windows : ', bad_good.shape[1])

            #create label vector (0 for rest and 1 for odor)
            y = [0]*bad_sel_stat.shape[0] + [1]*good_sel_stat.shape[0]
            #print ('Size of label for classif: ', len(y))

            # Define a cross validation:
            cv = defCv(y, n_folds=10, cvtype='skfold', rep=10)
            # Define classifier technique
            clf = defClf(y=y, clf=classif)#,n_tree=200, random_state=100)
            #Classify rest and odor
            cl = classify(y, clf=clf, cvtype=cv)
            # Evaluate the classifier on data:
            da,pvalue,daperm = cl.fit(bad_good, n_perm=1000,method='full_rnd',mf=False)
            # Save da, daperm and bootstrapped data
            da_rep = np.vstack((da_rep,da)) if np.size(da_rep) else da
            daperm_rep = np.vstack((daperm_rep,daperm)) if np.size(daperm_rep) else daperm
            bad_rep = np.vstack((bad_rep,bad_sel_stat)) if np.size(bad_rep) else bad_sel_stat
            good_rep = np.vstack((good_rep,good_sel_stat)) if np.size(good_rep) else good_sel_stat
        print ('Bootstrap da&data : ', 'da_rep',da_rep.shape, 'daperm_rep',daperm_rep.shape,
                  'bad rep', bad_rep.shape, 'good_rep', good_rep.shape)
        
        #Save all bootstraps for good and bad conditions
        level_0_5 = perm_pvalue2level(daperm_rep, p=0.05, maxst=False)
        level_0_1 = perm_pvalue2level(daperm_rep, p=0.01, maxst=False)
        level_0_0_1 = perm_pvalue2level(daperm_rep, p=0.001, maxst=False)
        th_0_05 = level_0_5.max()
        th_0_01 = level_0_1.max()
        th_0_001 = level_0_0_1.max()
        print('levels', th_0_05, th_0_01, th_0_001)
        
# ============================== PLOT ERPs ANALYSIS + STATS & DECODING ACCURACY ===================================================
        # data to plot
        bad_good_plot = np.concatenate((bad_rep, good_rep), axis=0)
        y_plot = [0]*bad_rep.shape[0] + [1]*good_rep.shape[0]

        # plot and figure parameters
        xfmt = ScalarFormatter(useMathText=True)
        xfmt.set_powerlimits((0,3))
        fig = plt.figure(1,figsize=(7,7))
        title = 'ERP and DA for '+su+' Good/Bad '+str(channel)+' '+str(name)+' ('+str(elec)+') ntrials:'+str(ntrials)
        fig.suptitle(title, fontsize=12)
        times_plot = 1000 * np.arange((baseline[0] - baseline[1]), len(shapeRmPoints)-baseline[1]+baseline[0],winSample) / sf

        # Plot the ERPs
        plt.subplot(211)
        BorderPlot(times_plot, bad_good_plot, y=y_plot, kind='sem', alpha=0.2, color=['b','m'], 
                   linewidth=2, ncol=1, xlabel='Time (ms)',ylabel = r' $\mu$V', legend=['bad','good'])
        addLines(plt.gca(), vLines=[0], vColor=['r'], vWidth=[2], hLines=[0], 
                 hColor=['#000000'], hWidth=[2])
        rmaxis(plt.gca(), ['right', 'top'])
        plt.legend(loc=0, handletextpad=0.1, frameon=False)
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

        # Plot DA for the ERPs
        plt.subplot(212)
        BorderPlot(times_plot, da_rep, color='b', kind='sem',xlabel='Time (ms)', 
                   ylim=[da_rep.min()-10,da_rep.max()+10], ylabel='Decoding accuracy (%)',
                   linewidth=2, alpha=0.3)
        rmaxis(plt.gca(), ['right', 'top'])
        addLines(plt.gca(), vLines=[0], vWidth=[2], vColor=['r'], hLines=[50], 
                 hColor=['#000000'], hWidth=[2])
        plt.legend(loc=0, handletextpad=0.1, frameon=False)   
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))
        plt.plot(times_plot, th_0_05*np.ones(len(times_plot)), '--', color='orange', 
                  linewidth=2)
        plt.plot(times_plot, th_0_01*np.ones(len(times_plot)), '--', color='orangered', 
                  linewidth=2)
        plt.plot(times_plot, th_0_001*np.ones(len(times_plot)), '--', color='r', 
                      linewidth=2)
        
        #Save plots
        if (da_rep.mean(axis=0)).max() >= level_0_5.max():
            np.save(save_path+str(round((winSample*1000)/512))+'ms/Significant/'+su+'_da_Bad_vs_Good_ERP_'+classif+'_'+str(name)+'_('+str(elec)+')',da_rep)
            np.save(save_path+str(round((winSample*1000)/512))+'ms/Significant/'+su+'_daperm_Bad_vs_Good_ERP_'+classif+'_'+str(name)+'_('+str(elec)+')',daperm_rep)
            fname = save_path+str(round((winSample*1000)/512))+'ms/Significant/'+su+'_da_Bad_vs_Good_ERP_'+classif+'_'+str(name)+'_('+str(elec)+').png'
        else:
            np.save(save_path+str(round((winSample*1000)/512))+'ms/'+su+'_da_Bad_vs_Good_ERP_'+classif+'_'+str(name)+'_('+str(elec)+')',da_rep)
            np.save(save_path+str(round((winSample*1000)/512))+'ms/'+su+'_daperm_Bad_vs_Good_ERP_'+classif+'_'+str(name)+'_('+str(elec)+')',daperm_rep)
            fname = save_path+str(round((winSample*1000)/512))+'ms/'+su+'_da_Bad_vs_Good_ERP_'+classif+'_'+str(name)+'_('+str(elec)+').png'
        fig.savefig(fname, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close()    
        toc = time.clock()
        print(round(toc-tic,2))
        del bad_sel, good_sel, good_sel_stat, bad_sel_stat
    del data_bad, data_good, channels, names