## Import Libraries

In [None]:
from os import path
from itertools import product
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator
import scipy.io as sio

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.feature import power, amplitude, sigfilt
from brainpipe.visual import *
from brainpipe.statistics import *
from scipy.stats import *

## User variables

In [None]:
bsl = None
# PATH TO DATA
st = study('Olfacto')
path_pow = path.join(st.path, 'feature/0_Power_Encoding_EpiPerf_4500_expi_noart/')
save_path = path.join(st.path, 'classified/0_Classif_Power_Poor_Detailed_EpiPerf_4500_expi_noart')
# POWER & STATS PARAMETERS
nfreq = 7
nperm = 100

## Power Decoding - Poor//Detailed Encoding
### For ALL time points

In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import StratifiedKFold as SKFold
from sklearn.metrics import roc_auc_score
from numpy.random import permutation

conds,phases, bsl, subjects = ['poor','detailed'],['odor'],['None'],['FERJ','MICP','VACJ','SEMC','LEFC','PIRJ','CHAF']
color_codes = ['gold','crimson']

for b, su, phase in product(bsl,subjects,phases):
    pow_list = []
    #=========================== Load Power files (nfreq, nelec, nwin, ntrial) =================================    
    poor_mat = np.load(path.join(path_pow, su+'_'+phase+'_poor_bipo_sel_phys_'+b+'_pow.npz'))
    names, channels, freq_names, time = poor_mat['labels'], poor_mat['channels'],poor_mat['fname'], poor_mat['time']
    pow_list.append(poor_mat['xpow'])
    nelecs = poor_mat['xpow'].shape[1]
    detailedname = path.join(path_pow, su+'_'+phase+'_detailed_bipo_sel_phys_'+b+'_pow.npz')
    if path.isfile(detailedname) == True:
        pow_list.append(np.load(detailedname)['xpow'])
        print (su, 'power shape: ', [pow.shape for pow in pow_list])
    
        # =========================== Select Power for 1 elec 1 freq =================================                 
        for elec_num in range(nelecs):
            for freq in range(nfreq):
                print('--» processing',su, 'elec', elec_num,'/',nelecs, 'freq',freq)
                pow_data_elec = []
                for i,power in enumerate(pow_list):
                    pow_data_elec.append(power[freq,elec_num].swapaxes(0,1))
                nwin = power.shape[1]
                elec, elec_label, freq_name = channels[elec_num], names[elec_num], freq_names[freq]
                print ('elec ', elec, 'elec_label ', elec_label)

        # =============================  Classification Computation ============================================================           
                # create a data matrix, concatenate along the trial dimension
                x = np.concatenate(pow_data_elec, axis=0)
                print ('Size of the concatenated data: ', x.shape, 'Number time windows : ', x.shape[1])
                y = np.hstack([np.array([i]*len(power)) for i, power in enumerate(pow_data_elec)])
                print ('Size of label for classif: ', len(y))

                auc = np.array([])
                for t in range(x.shape[1]):
                    X = x[:,t]
                    X = X.reshape(-1, 1)
                    score_rep = []
                    for i in range(10):
                        skf = SKFold(n_splits=5, random_state=None, shuffle=True)
                        skf.get_n_splits(X, y)
                        score_cv = []
                        for train_index, test_index in skf.split(X, y):
                            clf = LDA()
                            X_train, X_test = X[train_index], X[test_index]
                            y_train, y_test = y[train_index], y[test_index]
                            clf.fit(X=X_train, y=y_train)
                            y_pred = clf.predict(X_test)
                            score_cv.append(roc_auc_score(y_test,y_pred,average='weighted'))
                        score_rep.append(np.mean(score_cv))
                    score_rep = np.asarray(score_rep).reshape(1,len(score_rep))
                    auc = np.vstack((auc, score_rep)) if np.size(auc) else score_rep
                auc = np.swapaxes(auc,0,1)

                perm_scores = np.array([])
                for t in range(x.shape[1]):
                    X = x[:,t]
                    X = X.reshape(-1, 1)
                    perm_rep = []
                    for perm in range(nperm):
                        y_perm = y[permutation(len(y))]
                        score_cv = []
                        for train_index, test_index in skf.split(X, y_perm):
                            clf = LDA()
                            X_train, X_test = X[train_index], X[test_index]
                            y_train, y_test = y_perm[train_index], y_perm[test_index]
                            clf.fit(X=X_train, y=y_train)
                            y_pred = clf.predict(X_test)
                            score_cv.append(roc_auc_score(y_test,y_pred,average='weighted'))
                        perm_rep.append(np.mean(score_cv))
                    perm_rep = np.asarray(perm_rep).reshape(1,len(perm_rep))
                    perm_scores = np.vstack((perm_scores, perm_rep)) if np.size(perm_scores) else perm_rep
                perm_scores = np.swapaxes(perm_scores,0,1)
                th_0_05_perm = perm_pvalue2level(perm_scores, p=0.05, maxst=True)
                th_0_01_perm = perm_pvalue2level(perm_scores, p=0.01, maxst=True)
                print('th_perm : ', th_0_05_perm[0], th_0_01_perm[0], 'auc_mean', auc.mean())

        # ========================== Create a pvalue vector for uac measure ========================
                auc_pvals = []
                for i in range(auc.shape[1]):
                    if np.mean(auc[:,i]) > th_0_01_perm[0]:
                        auc_pvals.append(0.009)
                    elif np.mean(auc[:,i]) > th_0_05_perm[0]:
                        auc_pvals.append(0.04)
                    else:
                        auc_pvals.append(1)
                print (auc_pvals)

        # ============================== PLOT POWER ANALYSIS + STATS & DECODING ACCURACY ===================================================
                # plot and figure parameters
                xfmt = ScalarFormatter(useMathText=True)
                xfmt.set_powerlimits((0,3))
                fig = plt.figure(1,figsize=(7,7))
                title = 'Power-Stats-DA for '+su+' Poor vs Detailed '+str(elec)+' '+str(elec_label)+' ('+str(elec_num)+')'
                fig.suptitle(title, fontsize=12)

                # Plot the POW + STATS
                plt.subplot(211)        
                BorderPlot(time, x, y=y, kind='sem', alpha=0.2, color=color_codes,linewidth=2, 
                           ncol=1, xlabel='Time (s)',ylabel = r'Power', legend=conds)
                rmaxis(plt.gca(), ['right', 'top'])
                plt.legend(loc=0, handletextpad=0.1, frameon=False)
                plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

                # Plot DA for the POW
                plt.subplot(212)
                BorderPlot(time, auc, color='b', kind='sd',xlabel='Time (s)', ylim=[0.4,1.], ylabel='Decoding accuracy (%)',linewidth=2, alpha=0.3)
                rmaxis(plt.gca(), ['right', 'top'])
                plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))
                plt.plot(time, th_0_05_perm*np.ones(len(time)), '--', color='r', linewidth=2)
                #plt.plot(times_plot, th_0_01_perm*np.ones(len(times_plot)), '--', color='orange', linewidth=2)

                #Filenames to save
                name_auc = (save_path+'/'+str(freq)+'_'+freq_name+'/auc/'+su +'_auc_Poor_Detailed_'+str(elec_label)+'_('+str(elec_num)+').npy')
                name_th_0_05_perm = (save_path+'/'+str(freq)+'_'+freq_name+'/auc/'+su +'_th_0_05_perm_'+str(elec_label)+'_('+str(elec_num)+').npy')
                name_th_0_01_perm = (save_path+'/'+str(freq)+'_'+freq_name+'/auc/'+su +'_th_0_01_perm_'+str(elec_label)+'_('+str(elec_num)+').npy')
                plot_name = (save_path+'/'+str(freq)+'_'+freq_name+'/fig/'+su +'_Power_Poor_Detailed_'+str(elec)+'_'+str(elec_label)+'_('+str(elec_num)+').png')            

                # Criteria to be significant
                auc_pvals = np.ravel(auc_pvals)
                underp = np.where(auc_pvals < 0.05)[0]
                pvsplit = np.split(underp, np.where(np.diff(underp) != 1)[0]+1)
                signif = [True for k in pvsplit if len(k) >= 3]
                if len(signif) >=1:
                    plot_sig = (save_path+'/'+str(freq)+'_'+freq_name+'/signif/'+su +'_Power_Poor_Detailed_'+str(elec)+'_'+str(elec_label)+'_('+str(elec_num)+').png')            
                    plt.savefig(plot_sig, dpi=300, bbox_inches='tight')

                #Save plots
                np.save(name_auc, auc)
                np.save(name_th_0_05_perm, th_0_05_perm[0])
                np.save(name_th_0_01_perm, th_0_01_perm[0])
                plt.savefig(plot_name, dpi=300, bbox_inches='tight')
                plt.clf()
                plt.close() 
                del X, auc, pow_data_elec
        del pow_list
        
    else: 
        print(su,'pas de condition detailed')