## Import Libraries

In [None]:
from os import path
from itertools import product
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator
import scipy.io as sio

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.feature import power, amplitude, sigfilt
from brainpipe.visual import *
from brainpipe.statistics import *
from scipy.stats import *

## Power Decoding - Partial//Detailed Encoding
### For ALL time points

In [None]:
time = (np.arange(0,2000))/512
print(time.shape)
print(auc.shape)
print(x.shape)
print(y.shape)

In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import StratifiedKFold as SKFold
from sklearn.metrics import roc_auc_score
from numpy.random import permutation

conds, subjects = ['odor','no_odor'],['PIRJ']#,'PIRJ','SEMC','CHAF','FERJ','VACJ','MICP']#'MICP','VACJ','CHAF','SEMC','LEFC','PIRJ'
color_codes = ['darkorange','blue']
st = study('Olfacto')
nperm = 100
step = 'Encoding'

path_data = path.join(st.path, 'database/'+step+'_No_Odor/')
path_feat = path.join(st.path, 'feature/ERP_Odor_No_Odor/'+step+'_no_ds/')
save_path = path.join(st.path, 'classified/0_Classif_ERP_'+step+'_Odor_Inspi_100perm_no_ds/')

for su in subjects:
    feat_list = []
    #=========================== Load ERPs files =================================    
    mat = np.load(path_data+su+'_odor_'+conds[1]+'_bipo_sel_phys.npz')
    labels,channels = mat['Mai_RL'],mat['channels']
    erp0 = np.load(path.join(path_feat, su+'_'+conds[0]+'_bipo_sel_phys.npy'))[:,2:-1]
    feat_list.append(erp0)
    time = (np.arange(erp0.shape[1]))
    nelecs = erp0.shape[0]
    erp1 = np.load(path.join(path_feat, su+'_'+conds[1]+'_bipo_sel_phys.npy'))[:,2:-1]
    feat_list.append(erp1)
    print (su, 'power shape: ', [feat.shape for feat in feat_list])

    # =========================== Select ERP for 1 elec =================================                 
    iterator = range(nelecs)
    for elec_num in iterator:
        elec, elec_label = channels[elec_num], labels[elec_num]
        print ('elec ', elec, 'elec_label ', elec_label)
        #Filenames to save
        name_auc = (save_path+'/auc/'+su +'_auc_'+conds[0]+'_'+conds[1]+'_'+str(elec_label)+'_('+str(elec_num)+').npy')
        name_perm = (save_path+'/auc/'+su +'_perm_'+str(elec_label)+'_('+str(elec_num)+').npy')
        plot_name = (save_path+'/fig/'+su +'_ERP_'+conds[0]+'_'+conds[1]+'_'+str(elec)+'_'+str(elec_label)+'_('+str(elec_num)+').png')            
        
#         perm_old = np.load(name_perm)
#         print(perm_old.shape)
        
#         if perm_old.shape[0] == 200:
        if path.exists(name_auc):
            print(su,elec_num,'already computed')
        else:
            print('--» processing',su, 'elec', elec_num,'/',nelecs)
            data_elec = []
            for i,erp in enumerate(feat_list):
                data_elec.append(erp[elec_num].swapaxes(0,1))

    # =============================  Classification Computation ============================================================           
            # create a data matrix, concatenate along the trial dimension
            x = np.concatenate(data_elec, axis=0)
            print ('Size of the concatenated data: ', x.shape, 'Number time windows : ', x.shape[1])
            y = np.hstack([np.array([i]*len(erp)) for i, erp in enumerate(data_elec)])
            print ('Size of label for classif: ', len(y))

            auc = np.array([])
            for t in range(x.shape[1]):
                X = x[:,t]
                X = X.reshape(-1, 1)
                score_rep = []
                for i in range(10):
                    k = 10
                    skf = SKFold(n_splits=k, random_state=None, shuffle=True)
                    skf.get_n_splits(X, y)
                    score_cv = []
                    for train_index, test_index in skf.split(X, y):
                        clf = LDA()
                        X_train, X_test = X[train_index], X[test_index]
                        y_train, y_test = y[train_index], y[test_index]
                        clf.fit(X=X_train, y=y_train)
                        y_pred = clf.predict(X_test)
                        score_cv.append(roc_auc_score(y_test,y_pred,average='weighted'))
                    score_rep.append(np.mean(score_cv))
                score_rep = np.asarray(score_rep).reshape(1,len(score_rep))
                auc = np.vstack((auc, score_rep)) if np.size(auc) else score_rep
            auc = np.swapaxes(auc,0,1)

            perm_scores = np.array([])
            for t in range(x.shape[1]):
                X = x[:,t]
                X = X.reshape(-1, 1)
                perm_rep = []
                for perm in range(nperm):
                    y_perm = y[permutation(len(y))]
                    score_cv = []
                    for train_index, test_index in skf.split(X, y_perm):
                        clf = LDA()
                        X_train, X_test = X[train_index], X[test_index]
                        y_train, y_test = y_perm[train_index], y_perm[test_index]
                        clf.fit(X=X_train, y=y_train)
                        y_pred = clf.predict(X_test)
                        score_cv.append(roc_auc_score(y_test,y_pred,average='weighted'))
                    perm_rep.append(np.mean(score_cv))
                perm_rep = np.asarray(perm_rep).reshape(1,len(perm_rep))
                perm_scores = np.vstack((perm_scores, perm_rep)) if np.size(perm_scores) else perm_rep
            perm_scores = np.swapaxes(perm_scores,0,1)           
            th_0_01_perm = perm_pvalue2level(perm_scores, p=0.01, maxst=True)
            th_0_001_perm = perm_pvalue2level(perm_scores, p=0.001, maxst=True)
            print('th_perm 01: ', th_0_01_perm[0], '001',th_0_001_perm[0], 'auc_max', np.max(auc))

    # ============================== PLOT POWER ANALYSIS + STATS & DECODING ACCURACY ===================================================
            # plot and figure parameters
            xfmt = ScalarFormatter(useMathText=True)
            xfmt.set_powerlimits((0,3))
            fig = plt.figure(1,figsize=(7,7))
            title = 'ERP-Stats-DA for '+su+' '+conds[0]+' vs '+conds[1]+' '+str(elec)+' '+str(elec_label)+' ('+str(elec_num)+')'
            fig.suptitle(title, fontsize=12)

            # Plot the POW + STATS
            plt.subplot(211)        
            BorderPlot(time, x, y=y, kind='sem', alpha=0.2, color=color_codes,linewidth=2, 
                       ncol=1, xlabel='Time (s)',ylabel = r'Power', legend=['No Odor','Odor'])
            rmaxis(plt.gca(), ['right', 'top'])
            addLines(plt.gca(), vLines=[0], vColor=['darkgray'], vWidth=[2])
            plt.legend(loc=0, handletextpad=0.1, frameon=False)
            plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

            # Plot DA for the POW
            plt.subplot(212)
            BorderPlot(time, auc, color='b', kind='sd',xlabel='Time (s)', ylim=[0.4,1.], ylabel='Decoding accuracy (%)',linewidth=2, alpha=0.3)
            rmaxis(plt.gca(), ['right', 'top'])
            addLines(plt.gca(), vLines=[0], vColor=['darkgray'], vWidth=[2])
            plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))
            plt.plot(time, th_0_01_perm*np.ones(len(time)), '--', color='r', linewidth=2)
            #plt.plot(times_plot, th_0_01_perm*np.ones(len(times_plot)), '--', color='orange', linewidth=2)

            #Save plots
            np.save(name_auc, auc)
            np.save(name_perm, perm_scores)
            plt.savefig(plot_name, dpi=300, bbox_inches='tight')
            plt.clf()
            plt.close() 
            del X, auc, 
    del feat_list