## Import Libraries

In [1]:
from os import path
from itertools import product
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator
import scipy.io as sio

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.feature import power, amplitude, sigfilt
from brainpipe.visual import *
from brainpipe.statistics import *
from scipy.stats import *

In [2]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import StratifiedKFold as SKFold
from sklearn.metrics import roc_auc_score
from numpy.random import permutation

conds,phases, subjects = ['low','high'],['odor'],['FERJ','LEFC']#,'MICP','VACJ','SEMC','LEFC','PIRJ']
color_codes = ['darkorange','blue']
#bsls = ['bsl1','bsl2']
bsl = ''
st = study('Olfacto')
freqs = 5
nperm = 1000

for su,phase in product(subjects,phases):
    path_pow = path.join(st.path, 'feature/0_Power_Encoding_EpiPerf_LowHigh/')
    save_path = path.join(st.path, 'classified/LDA_Power_E_EpiPerf_LowHigh'+bsl+'/')

    pow_list = []
    #=========================== Load Power files (nfreq, nelec, nwin, ntrial) =================================    
    mat0 = np.load(path.join(path_pow, su+'_'+phase+'_'+conds[0]+'_bipo_sel_physFT_pow'+bsl+'.npz'),
                  allow_pickle=True)
    names, channels, freq_names, time = mat0['Mai_RL'], mat0['channels'],mat0['fname'], mat0['time']
    time = time[20:47]-3
    #print(mat0['xpow'].shape,time.shape,time[17:42])
    pow_list.append(mat0['xpow'][:,:,20:47,:]) #17:52
    nelecs = mat0['xpow'].shape[1]
    mat1 = np.load(path.join(path_pow, su+'_'+phase+'_'+conds[1]+'_bipo_sel_physFT_pow'+bsl+'.npz'),
                   allow_pickle=True)
    pow_list.append(mat1['xpow'][:,:,20:47,:]) #17:52
    print (su, 'power shape: ', [pow.shape for pow in pow_list])
    # =========================== Select Power for 1 elec 1 freq =================================                 
    iterator = range(nelecs)
    for elec_num in iterator:#iterator
        for freq in range(4,freqs):#range(2,freqs):
            elec, elec_label, freq_name = channels[elec_num], names[elec_num], freq_names[freq]
            print ('elec ', elec, 'elec_label ', elec_label)
            #Filenames to save
            name_auc = (save_path+str(freq)+'_'+freq_name+'/auc/'+su +'_auc_'+conds[0]+'_'+conds[1]+'_'+str(elec_label)+'_('+str(elec_num)+').npy')
            name_perm = (save_path+str(freq)+'_'+freq_name+'/auc/'+su +'_perm_'+str(elec_label)+'_('+str(elec_num)+').npy')
            plot_name = (save_path+str(freq)+'_'+freq_name+'/fig/'+su +'_Power_'+conds[0]+'_'+conds[1]+'_'+str(elec)+'_'+str(elec_label)+'_('+str(elec_num)+').png')            
            
            if path.exists(name_auc):
                print(su,bsl,phase,elec_num,freq,'already computed')
            else:
                print('--» processing',su, 'elec', elec_num,'/',nelecs, 'freq',freq)
                pow_data_elec = []
                for i,power in enumerate(pow_list):
                    pow_data_elec.append(power[freq,elec_num].swapaxes(0,1))
                nwin = power.shape[1]

        # =============================  Classification Computation ============================================================           
                # create a data matrix, concatenate along the trial dimension
                x = np.concatenate(pow_data_elec, axis=0)
                print ('Size of the concatenated data: ', x.shape, 'Number time windows : ', x.shape[1])
                y = np.hstack([np.array([i]*len(power)) for i, power in enumerate(pow_data_elec)])
                print ('Size of label for classif: ', len(y))

                auc = np.array([])
                for t in range(x.shape[1]):
                    X = x[:,t]
                    X = X.reshape(-1, 1)
                    score_rep = []
                    for i in range(10):
                        k = 5
                        skf = SKFold(n_splits=k, random_state=None, shuffle=True)
                        skf.get_n_splits(X, y)
                        score_cv = []
                        for train_index, test_index in skf.split(X, y):
                            clf = LDA()
                            X_train, X_test = X[train_index], X[test_index]
                            y_train, y_test = y[train_index], y[test_index]
                            clf.fit(X=X_train, y=y_train)
                            y_pred = clf.predict(X_test)
                            score_cv.append(roc_auc_score(y_test,y_pred,average='weighted'))
                        score_rep.append(np.mean(score_cv))
                    score_rep = np.asarray(score_rep).reshape(1,len(score_rep))
                    auc = np.vstack((auc, score_rep)) if np.size(auc) else score_rep
                auc = np.swapaxes(auc,0,1)

                perm_scores = np.array([])
                for t in range(x.shape[1]):
                    X = x[:,t]
                    X = X.reshape(-1, 1)
                    perm_rep = []
                    for perm in range(nperm):
                        y_perm = y[permutation(len(y))]
                        score_cv = []
                        for train_index, test_index in skf.split(X, y_perm):
                            clf = LDA()
                            X_train, X_test = X[train_index], X[test_index]
                            y_train, y_test = y_perm[train_index], y_perm[test_index]
                            clf.fit(X=X_train, y=y_train)
                            y_pred = clf.predict(X_test)
                            score_cv.append(roc_auc_score(y_test,y_pred,average='weighted'))
                        perm_rep.append(np.mean(score_cv))
                    perm_rep = np.asarray(perm_rep).reshape(1,len(perm_rep))
                    perm_scores = np.vstack((perm_scores, perm_rep)) if np.size(perm_scores) else perm_rep
                perm_scores = np.swapaxes(perm_scores,0,1)           
                th_0_05_perm = perm_pvalue2level(perm_scores, p=0.005, maxst=True)
                th_0_01_perm = perm_pvalue2level(perm_scores, p=0.001, maxst=True)
                print('th_perm 005: ', th_0_05_perm[0], '001',th_0_01_perm[0], 'auc_max', np.max(auc))

        # ============================== PLOT POWER ANALYSIS + STATS & DECODING ACCURACY ===================================================
                #plot and figure parameters
                xfmt = ScalarFormatter(useMathText=True)
                xfmt.set_powerlimits((0,3))
                fig = plt.figure(1,figsize=(7,7))
                title = 'Power-Stats-DA for '+su+' '+conds[0]+' vs '+conds[1]+' '+str(elec)+' '+str(elec_label)+' ('+str(elec_num)+')'
                fig.suptitle(title, fontsize=12)

                # Plot the POW + STATS
                plt.subplot(211)        
                BorderPlot(time, x, y=y, kind='sem', alpha=0.2, color=color_codes,linewidth=2, 
                          ncol=1, xlabel='Time (s)',ylabel = r'Power', legend=conds)
                rmaxis(plt.gca(), ['right', 'top'])
                addLines(plt.gca(), vLines=[0], vColor=['darkgray'], vWidth=[2])
                plt.legend(loc=0, handletextpad=0.1, frameon=False)
                plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

                # Plot DA for the POW
                plt.subplot(212)
                BorderPlot(time, auc, color='b', kind='sd',xlabel='Time (s)', ylim=[0.4,1.], ylabel='Decoding accuracy (%)',linewidth=2, alpha=0.3)
                rmaxis(plt.gca(), ['right', 'top'])
                addLines(plt.gca(), vLines=[0], vColor=['darkgray'], vWidth=[2])
                plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))
                plt.plot(time, th_0_05_perm*np.ones(len(time)), '--', color='r', linewidth=2)
                plt.plot(time, th_0_01_perm*np.ones(len(time)), '--', color='orange', linewidth=2)

                #Save plots
                np.save(name_auc, auc)
                np.save(name_perm, perm_scores)
                plt.savefig(plot_name, dpi=300, bbox_inches='tight')
                plt.clf()
                plt.close() 
                del X, auc, pow_data_elec
    del pow_list

-> Olfacto loaded
FERJ power shape:  [(5, 85, 27, 17), (5, 85, 27, 14)]
elec  a2-a1 elec_label  Amg
--» processing FERJ elec 0 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.775 001 0.85 auc_max 0.8666666666666668
elec  a3-a2 elec_label  Amg
--» processing FERJ elec 1 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.775 001 0.8666666666666666 auc_max 0.9666666666666666
elec  a8-a7 elec_label  MTG
--» processing FERJ elec 2 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.775 001 0.8666666666666666 auc_max 0.8916666666666668
elec  a9-a8 elec_label  MTG
--» processing FERJ elec 3 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.775 001 0.875 auc_max 0.7499999999999999
elec  a10-a9 el

th_perm 005:  0.7750000000000001 001 0.875 auc_max 0.875
elec  d'7-d'6 elec_label  FuG
--» processing FERJ elec 35 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.7666666666666667 001 0.8916666666666666 auc_max 0.675
elec  d'8-d'7 elec_label  FuG
--» processing FERJ elec 36 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.7833333333333334 001 0.85 auc_max 0.8416666666666668
elec  d'9-d'8 elec_label  FuG
--» processing FERJ elec 37 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.7833333333333334 001 0.9 auc_max 0.85
elec  d'10-d'9 elec_label  ITG
--» processing FERJ elec 38 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.775 001 0.8416666666666668 auc_max 0.8083333333333332
elec  d

th_perm 005:  0.775 001 0.875 auc_max 0.8666666666666666
elec  o3-o2 elec_label  OFC
--» processing FERJ elec 70 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.775 001 0.875 auc_max 0.875
elec  o4-o3 elec_label  OFC
--» processing FERJ elec 71 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.7833333333333334 001 0.9 auc_max 0.875
elec  o5-o4 elec_label  OFC
--» processing FERJ elec 72 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.775 001 0.8499999999999999 auc_max 0.9333333333333332
elec  o6-o5 elec_label  OFC
--» processing FERJ elec 73 / 85 freq 4
Size of the concatenated data:  (31, 27) Number time windows :  27
Size of label for classif:  31
th_perm 005:  0.7833333333333333 001 0.9 auc_max 0.8416666666666666
elec  o7-o6 elec_label  OFC
--» processing

elec  d8-d7 elec_label  FuG
--» processing LEFC elec 19 / 60 freq 4
Size of the concatenated data:  (55, 27) Number time windows :  27
Size of label for classif:  55
th_perm 005:  0.6778571428571429 001 0.7654761904761905 auc_max 0.5571428571428572
elec  d9-d8 elec_label  ITG
--» processing LEFC elec 20 / 60 freq 4
Size of the concatenated data:  (55, 27) Number time windows :  27
Size of label for classif:  55
th_perm 005:  0.6804761904761905 001 0.7561904761904763 auc_max 0.7054761904761906
elec  d10-d9 elec_label  ITG
--» processing LEFC elec 21 / 60 freq 4
Size of the concatenated data:  (55, 27) Number time windows :  27
Size of label for classif:  55
th_perm 005:  0.6935714285714286 001 0.7611904761904762 auc_max 0.6757142857142857
elec  d11-d10 elec_label  ITG
--» processing LEFC elec 22 / 60 freq 4
Size of the concatenated data:  (55, 27) Number time windows :  27
Size of label for classif:  55
th_perm 005:  0.6876190476190477 001 0.7578571428571428 auc_max 0.7769047619047621
e

th_perm 005:  0.6778571428571428 001 0.8028571428571428 auc_max 0.6771428571428572
elec  s4-s3 elec_label  SFG
--» processing LEFC elec 53 / 60 freq 4
Size of the concatenated data:  (55, 27) Number time windows :  27
Size of label for classif:  55
th_perm 005:  0.6754761904761903 001 0.7697619047619046 auc_max 0.6721428571428572
elec  s5-s4 elec_label  SFG
--» processing LEFC elec 54 / 60 freq 4
Size of the concatenated data:  (55, 27) Number time windows :  27
Size of label for classif:  55
th_perm 005:  0.6778571428571428 001 0.7871428571428571 auc_max 0.7421428571428572
elec  s6-s5 elec_label  SFG
--» processing LEFC elec 55 / 60 freq 4
Size of the concatenated data:  (55, 27) Number time windows :  27
Size of label for classif:  55
th_perm 005:  0.6735714285714286 001 0.7521428571428572 auc_max 0.6935714285714286
elec  s7-s6 elec_label  SFG
--» processing LEFC elec 56 / 60 freq 4
Size of the concatenated data:  (55, 27) Number time windows :  27
Size of label for classif:  55
th_p