## Import Libraries

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator
import scipy.io as sio

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.feature import power, amplitude, sigfilt
from brainpipe.visual import *
from brainpipe.statistics import *
from scipy.stats import *

from os import path
from mne.stats import *
from mne.baseline import rescale
from mne.filter import filter_data
import time

## User variables

In [4]:
# where to find data
st = study('Olfacto')
score = 'Epi' #'Rec'
if score == 'Epi':
    path_pow = path.join(st.path, 'feature/7_Power_E1E2_Odor_Good_Bad_700_100_EpiScore/')
    save_path = path.join(st.path, 'classified/1_Classif_Power_EpiScore_all_electrodes_win700_step100/new/')

# ANALYSIS PARAMETERS
classif = 'lda'
nfreq = 5
alpha = 0.05
minsucc = 5 #nb of continuous samples to be significant
bsl = ['exp','rest']

-> Olfacto loaded


## Power Decoding - Good Bad Odors Encoding

In [5]:
test = False

if test == True:
    n_elec = {'PIRJ' :1}
    subjects = ['PIRJ']
else :
    subjects = ['MICP','VACJ','SEMC','PIRJ','LEFC','CHAF', 'FERJ'] 
    n_elec = {
    'CHAF' : 69,
    'VACJ' : 84, 
    'SEMC' : 66,
    'PIRJ' : 71,
    'LEFC' : 152,
    'MICP' : 76,
    'FERJ' : 88,
        }
for b in bsl:
    for su in subjects:
        #load power files (nfreq, nelec, nwin, ntrial)
        bad_data = np.load(path.join(path_pow, su+'_concat_odor_bad_bipo_new_'+b+'_power.npz'))['xpow']
        good_data = np.load(path.join(path_pow, su+'_concat_odor_good_bipo_new_'+b+'_power.npz'))['xpow']
        names = np.load(path.join(path_pow, su+'_concat_odor_bad_bipo_new_'+b+'_power.npz'))['labels']
        channels = np.load(path.join(path_pow, su+'_concat_odor_bad_bipo_new_'+b+'_power.npz'))['channels']
        freq_names = np.load(path.join(path_pow, su+'_concat_odor_bad_bipo_new_'+b+'_power.npz'))['fname']
        print (su, 'bad shape: ', bad_data.shape, 'good shape: ', good_data.shape)

    # ==========================  BALANCED CONDITIONS - Bootstrap  =====================================
        if bad_data.shape[3] > good_data.shape[3]:
            bad_stat = bad_data[:,:,:,np.random.randint(bad_data.shape[3], size=good_data.shape[3])]
            good_stat = good_data
        elif bad_data.shape[3] < good_data.shape[3]:
            bad_stat = bad_data
            good_stat = good_data[:,:,:,np.random.randint(good_data.shape[3], size=bad_data.shape[3])]
        else:
            bad_stat, good_stat = bad_data, good_data
        ntrials = bad_stat.shape[3]
        print ('balanced data : ', bad_stat.shape, good_stat.shape)

    # =========================== SELECT Power for 1 elec 1 freq =================================                 
        for elec_num in range(n_elec[su]):
            for freq in range(nfreq):
                # load power files for 1 elec // 1 freq // Bad-Good conditions
                bad_data_elec = bad_stat[freq,elec_num].swapaxes(0,1)
                good_data_elec = good_stat[freq,elec_num].swapaxes(0,1)
                print ('data elec ', bad_data_elec.shape, good_data_elec.shape)
                nwin = good_data.shape[1]
                elec, elec_label, freq_name = channels[elec_num], names[elec_num], freq_names[freq]
                print ('elec ', elec, 'elec_label ', elec_label)

    # ===========================  STATISTICS  =====================================
                # Permutations and t test of the data
                bad_perm, good_perm = perm_swap(bad_data_elec, good_data_elec, n_perm=1000, axis=0)
                bad_perm, good_perm = np.swapaxes(bad_perm,0,1), np.swapaxes(good_perm,0,1)
                print('data permuted', bad_perm.shape, good_perm.shape)
                Tperm, _ = ttest_ind(bad_perm, good_perm, equal_var=False)
                print('T perm', Tperm.shape)
                thr_0_5_stat = [-perm_pvalue2level(Tperm, p=0.05, maxst=True)[0],perm_pvalue2level(Tperm, p=0.05, maxst=True)[0]]
                thr_0_1_stat = [-perm_pvalue2level(Tperm, p=0.01, maxst=True)[0],perm_pvalue2level(Tperm, p=0.01, maxst=True)[0]]
                thr_0_0_1_stat = [-perm_pvalue2level(Tperm, p=0.001, maxst=True)[0],perm_pvalue2level(Tperm, p=0.001, maxst=True)[0]]
                print('treshold stats', thr_0_5_stat,thr_0_1_stat,thr_0_0_1_stat)
                T0, _  = ttest_ind(bad_data_elec, good_data_elec, equal_var=False)
                print('Obs stats',T0.shape, T0.max(), T0.min())

                # Create the pvalue vector to plot
                pvals = []
                for i in range(T0.shape[0]):
                    if T0[i] < thr_0_0_1_stat[0] or T0[i] > thr_0_0_1_stat[1]:
                        pval = pvals.append(0.0009)
                    elif T0[i] < thr_0_1_stat[0] or T0[i] > thr_0_1_stat[1]:
                        pval = pvals.append(0.009)
                    elif T0[i] < thr_0_5_stat[0] or T0[i] > thr_0_5_stat[1]:
                        pval = pvals.append(0.04)
                    else:
                        pval = pvals.append(1)
                print (pvals)

    # =============================  CLASSIFICATION COMPUTATION ============================================================           
                #create a data matrix, concatenate along the trial dimension
                bad_good = np.concatenate((bad_data_elec, good_data_elec), axis=0)
                print ('Size of the concatenated data: ', bad_good.shape, 'Number time windows : ', bad_good.shape[1])
                #create label vector (0 for rest and 1 for odor)
                y = [0]*bad_data_elec.shape[0] + [1]*good_data_elec.shape[0]
                print ('Size of label for classif: ', len(y))
                # Define a cross validation:
                cv = defCv(y, n_folds=10, cvtype='skfold', rep=10)
                # Define classifier technique
                clf = defClf(y=y, clf=classif)#,n_tree=200, random_state=100)
                #Classify rest and odor
                cl = classify(y, clf=clf, cvtype=cv)
                # Evaluate the classifier on data:
                da,pvalues,daperm = cl.fit(bad_good, n_perm=100,method='full_rnd', mf=False)
                #print(pvalues.shape, pvalues.min(), pvalues.max())
                th_0_05_perm = perm_pvalue2level(daperm, p=0.05, maxst=True)
                th_0_01_perm = perm_pvalue2level(daperm, p=0.01, maxst=True)
                print('th_perm : ', th_0_05_perm[0], th_0_01_perm[0])

        # ============================== PLOT POWER ANALYSIS + STATS & DECODING ACCURACY ===================================================
                # plot and figure parameters
                xfmt = ScalarFormatter(useMathText=True)
                xfmt.set_powerlimits((0,3))
                fig = plt.figure(1,figsize=(7,7))
                title = 'Power-Stats-DA for '+su+' Bad/Good '+str(elec)+' '+str(elec_label)+' ('+str(elec_num)+') ntrials:'+str(ntrials)
                fig.suptitle(title, fontsize=12)
                # Time vector to plot power
                step = 3500/ bad_data_elec.shape[1]
                times_plot = np.arange(-500, 3000, step)

                # Plot the POW + STATS
                plt.subplot(211)
                bad_good_to_plot = bad_good * 100
                BorderPlot(times_plot, bad_good_to_plot, y=y, kind='sem', alpha=0.2, color=['b','m'], 
                           linewidth=2, ncol=1, xlabel='Time (ms)',ylabel = r'Power change (%)', legend=['bad','good'])
                addLines(plt.gca(), vLines=[0], vColor=['r'], vWidth=[2], hLines=[0], 
                         hColor=['#000000'], hWidth=[2])
                addPval(plt.gca(), pvals, p=0.05, x=times_plot, y=2, color='orange', lw=2, minsucc=minsucc)
                addPval(plt.gca(), pvals, p=0.01, x=times_plot, y=2, color='r', lw=2,minsucc=minsucc)
                addPval(plt.gca(), pvals, p=0.001, x=times_plot, y=2, color='g', lw=2,minsucc=minsucc)
                rmaxis(plt.gca(), ['right', 'top'])
                plt.legend(loc=0, handletextpad=0.1, frameon=False)
                plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

                # Plot DA for the POW
                plt.subplot(212)
                BorderPlot(times_plot, da, color='b', kind='sem',xlabel='Time (ms)', 
                           ylim=[da.min()-10,da.max()+10], ylabel='Decoding accuracy (%)',
                           linewidth=2, alpha=0.3)
                rmaxis(plt.gca(), ['right', 'top'])
                addLines(plt.gca(), vLines=[0], vWidth=[2], vColor=['r'], hLines=[50], 
                         hColor=['#000000'], hWidth=[2])
                plt.legend(loc=0, handletextpad=0.1, frameon=False)   
                plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))
                plt.plot(times_plot, th_0_05_perm*np.ones(len(times_plot)), '--', color='orange', linewidth=2)
                plt.plot(times_plot, th_0_01_perm*np.ones(len(times_plot)), '--', color='r', linewidth=2)
                # Criteria to be significant
                pvals = np.ravel(pvals)
                underp = np.where(pvals < alpha)[0]
                pvsplit = np.split(underp, np.where(np.diff(underp) != 1)[0]+1)
                signif = [True for k in pvsplit if len(k) >= minsucc]

                #Save plots and stats
                name_t0 = (save_path+'All_balanced_1_100perm_DA_stats_sametrials_'+b+'/'+str(freq)+'_'+freq_name+'/stat/'+su +'_t0_' + score +'_'+str(elec_label)+'_('+str(elec_num)+').npy')
                name_tperm = (save_path+'All_balanced_1_100perm_DA_stats_sametrials_'+b+'/'+str(freq)+'_'+freq_name+'/stat/'+su +'_tperm_' + score +'_'+str(elec_label)+'_('+str(elec_num)+').npy')
                name_pval = (save_path+'All_balanced_1_100perm_DA_stats_sametrials_'+b+'/'+str(freq)+'_'+freq_name+'/stat/'+su +'_pvals_' + score +'_'+str(elec_label)+'_('+str(elec_num)+').npy')
                name_da = (save_path+'All_balanced_1_100perm_DA_stats_sametrials_'+b+'/'+str(freq)+'_'+freq_name+'/da/'+su +'_da_' + score +'_'+str(elec_label)+'_('+str(elec_num)+').npy')
                name_th_0_05_perm = (save_path+'All_balanced_1_100perm_DA_stats_sametrials_'+b+'/'+str(freq)+'_'+freq_name+'/da/'+su +'_th_0_05_perm_' + score +'_'+str(elec_label)+'_('+str(elec_num)+').npy')
                name_th_0_01_perm = (save_path+'All_balanced_1_100perm_DA_stats_sametrials_'+b+'/'+str(freq)+'_'+freq_name+'/da/'+su +'_th_0_01_perm_' + score +'_'+str(elec_label)+'_('+str(elec_num)+').npy')
                plot_name = (save_path+'All_balanced_1_100perm_DA_stats_sametrials_'+b+'/'+str(freq)+'_'+freq_name+'/fig/'+su +'_Power_'  + score +'_'+str(elec_label)+'_('+str(elec_num)+').png')            

                np.save(name_t0, T0)
                np.save(name_tperm, Tperm)
                np.save(name_pval, pvals)
                np.save(name_da, da)
                np.save(name_th_0_05_perm, th_0_05_perm[0])
                np.save(name_th_0_01_perm, th_0_01_perm[0])
                plt.savefig(plot_name, dpi=300, bbox_inches='tight')
                plt.clf()
                plt.close() 
                del bad_data_elec, good_data_elec, bad_perm, good_perm, bad_good, da, pvalues, daperm,
        del bad_data, good_data, bad_stat, good_stat

MICP bad shape:  (5, 76, 19, 46) good shape:  (5, 76, 19, 37)
balanced data :  (5, 76, 19, 37) (5, 76, 19, 37)
data elec  (37, 19) (37, 19)
elec  a'2-a'1 elec_label  Amg&Amg-pPirT
data permuted (37, 1000, 19) (37, 1000, 19)
T perm (1000, 19)
treshold stats [-1.7880919505117279, 1.7880919505117279] [-2.6663851346447576, 2.6663851346447576] [-3.6939487262696225, 3.6939487262696225]
Obs stats (19,) 2.96120125583 -3.84952486423
[1, 1, 1, 0.04, 0.04, 0.04, 0.04, 1, 1, 1, 1, 0.04, 0.009, 0.009, 0.0009, 0.009, 1, 1, 1]
Size of the concatenated data:  (74, 19) Number time windows :  19
Size of label for classif:  74
th_perm :  62.1621621622 71.6216216216


TypeError: Can't convert 'list' object to str implicitly