## Import Libraries

In [1]:
from os import path
from itertools import product
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator
import scipy.io as sio

from brainpipe.classification import *
from brainpipe.system import study
from brainpipe.feature import power, amplitude, sigfilt, phase
from brainpipe.visual import *
from brainpipe.statistics import *
from scipy.stats import *

## User variables

In [2]:
# PATH TO DATA
st = study('Olfacto')
path_data = path.join(st.path, 'feature/7_Phase_E1E2_Odor_Good_Bad_EpiScore_Expi/')
save_path = path.join(st.path, 'classified/1_Classif_Phase_EpiScore_sel_electrodes_win700_step100_expi/')

# POWER & STATS PARAMETERS
nfreq = 4
nperm = 100

-> Olfacto loaded


## Phase Decoding - Good Bad Odors Encoding

In [3]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import StratifiedKFold as SKFold
from sklearn.metrics import roc_auc_score
from numpy.random import permutation

phases, subjects = ['odor'],['LEFC','CHAF','VACJ','SEMC','FERJ','MICP','PIRJ']
for su, phase in product(subjects,phases):
    #=========================== Load Power files (nfreq, nelec, nwin, ntrial) =================================    
    bad_data = np.load(path.join(path_data, su+'_'+phase+'_bad_bipo_sel_phase.npz'))['phase']
    good_data = np.load(path.join(path_data, su+'_'+phase+'_good_bipo_sel_phase.npz'))['phase']
    names = np.load(path.join(path_data, su+'_'+phase+'_bad_bipo_sel_phase.npz'))['labels']
    channels = np.load(path.join(path_data, su+'_'+phase+'_bad_bipo_sel_phase.npz'))['channels']
    freq_names = np.load(path.join(path_data, su+'_'+phase+'_bad_bipo_sel_phase.npz'))['fname']
    nelec = bad_data.shape[0]
    #print (su, 'bad shape: ', bad_data.shape, 'good shape: ', good_data.shape)

    # =========================== Select Power for all elec 1 freq =================================                 
    for elec in range(nelec):
        for freq in range(nfreq):
            print('computing ML on -->', su, 'elec', elec, '/', nelec, 'for freq', freq)
            # 1 freq // Bad-Good conditions before (nelec, nfreq,nwin, ntrials)(after ntrials, elec, nwins)
            bad_data_elec = bad_data[elec,freq,:,:].swapaxes(0,1)
            good_data_elec = good_data[elec,freq,:,:].swapaxes(0,1)
            print ('data elec ', bad_data_elec.shape, good_data_elec.shape)
            nwin, freq_name = good_data_elec.shape[1],freq_names[freq]
            channel, label = channels[elec], names[elec]
            print ('freq',freq_name)
    # =============================  Classification Computation ============================================================           
            # create a data matrix, concatenate along the trial dimension
            x = np.concatenate((bad_data_elec, good_data_elec), axis=0)
            #print ('Size of the concatenated data: ', x.shape, 'Number time windows : ', x.shape[2])
            #create label vector (0 for rest and 1 for odor)
            y = np.asarray([0]*bad_data_elec.shape[0] + [1]*good_data_elec.shape[0])
            #print ('Size of label for classif: ', len(y))

            auc = np.array([])
            for t in range(x.shape[1]):
                X = x[:,t]
                X = X.reshape(-1, 1)
                score_rep = []
                for i in range(10):
                    skf = SKFold(n_splits=10, random_state=i)
                    skf.get_n_splits(X, y)
                    score_cv = []
                    for train_index, test_index in skf.split(X, y):
                        clf = LDA()
                        X_train, X_test = X[train_index], X[test_index]
                        y_train, y_test = y[train_index], y[test_index]
                        clf.fit(X=X_train, y=y_train)
                        y_pred = clf.predict(X_test)
                        score_cv.append(roc_auc_score(y_test,y_pred,average='weighted'))
                    score_rep.append(np.mean(score_cv))
                score_rep = np.asarray(score_rep).reshape(1,len(score_rep))
                auc = np.vstack((auc, score_rep)) if np.size(auc) else score_rep
            auc = np.swapaxes(auc,0,1)

            perm_scores = np.array([])
            for t in range(x.shape[1]):
                X = x[:,t]
                X = X.reshape(-1, 1)
                perm_rep = []
                for perm in range(nperm):
                    y_perm = y[permutation(len(y))]
                    score_cv = []
                    for train_index, test_index in skf.split(X, y_perm):
                        clf = LDA()
                        X_train, X_test = X[train_index], X[test_index]
                        y_train, y_test = y_perm[train_index], y_perm[test_index]
                        clf.fit(X=X_train, y=y_train)
                        y_pred = clf.predict(X_test)
                        score_cv.append(roc_auc_score(y_test,y_pred,average='weighted'))
                    perm_rep.append(np.mean(score_cv))
                perm_rep = np.asarray(perm_rep).reshape(1,len(perm_rep))
                perm_scores = np.vstack((perm_scores, perm_rep)) if np.size(perm_scores) else perm_rep
                #print(perm_scores.shape)
            perm_scores = np.swapaxes(perm_scores,0,1)

            th_0_05_perm = perm_pvalue2level(perm_scores, p=0.05, maxst=True)
            th_0_01_perm = perm_pvalue2level(perm_scores, p=0.01, maxst=True)
            print('th_perm : ', th_0_05_perm[0], th_0_01_perm[0])

    # ========================== Create a pvalue vector for uac measure ========================
            auc_pvals = []
            for i in range(auc.shape[1]):
                if np.mean(auc[:,i]) > th_0_01_perm[0]:
                    auc_pvals.append(0.009)
                elif np.mean(auc[:,i]) > th_0_05_perm[0]:
                    auc_pvals.append(0.04)
                else:
                    auc_pvals.append(1)
            print (auc_pvals)

    # ============================== PLOT POWER ANALYSIS + STATS & DECODING ACCURACY ==================================================
            # plot and figure parameters
            xfmt = ScalarFormatter(useMathText=True)
            xfmt.set_powerlimits((0,3))
            fig = plt.figure(1,figsize=(7,7))
            title = 'Phase for '+su+' Bad/Good '+str(channel)+' '+str(label)+' ('+str(elec)+')'
            fig.suptitle(title, fontsize=12)
            # Time vector to plot power
            step = 5500/bad_data_elec.shape[1]
            times_plot = np.arange(0, 5500, step)
            #print('step and time',step,times_plot)

            # Plot the POW + STATS
            plt.subplot(211)
            #print(len(times_plot),X.shape)
            BorderPlot(times_plot, x, y=y, kind='sem', alpha=0.2, color=['b','m'], 
                       linewidth=2, ncol=1, xlabel='Time (ms)',ylabel = r'Phase', legend=['bad','good'])
            addLines(plt.gca(), vLines=[975,4000], vColor=['r']*2, vWidth=[2]*2, hLines=[0], 
                 hColor=['#000000'], hWidth=[2])
            rmaxis(plt.gca(), ['right', 'top'])
            addPval(plt.gca(),auc_pvals, p=0.05, x=times_plot, y=x.mean(), color='r', lw=2, minsucc=2)
            addPval(plt.gca(),auc_pvals, p=0.01, x=times_plot, y=x.mean(), color='orange', lw=2, minsucc=2)
            plt.legend(loc=0, handletextpad=0.1, frameon=False)
            plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))

            # Plot DA for the POW
            plt.subplot(212)
            BorderPlot(times_plot, auc, color='b', kind='sd',xlabel='Time (ms)', 
                       ylim=[0,1], ylabel='Decoding accuracy (%)',linewidth=2, alpha=0.3)
            rmaxis(plt.gca(), ['right', 'top'])
            addLines(plt.gca(), vLines=[975,4000], vWidth=[2]*2, vColor=['r']*2, hLines=[0.5], 
                 hColor=['#000000'], hWidth=[2])
            plt.legend(loc=0, handletextpad=0.1, frameon=False)   
            plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))
            plt.plot(times_plot, th_0_05_perm*np.ones(len(times_plot)), '--', color='r', linewidth=2)
            plt.plot(times_plot, th_0_01_perm*np.ones(len(times_plot)), '--', color='orange', linewidth=2)

        
        #Filenames to save
            name_auc = (save_path+su+'/'+str(freq)+'_'+freq_name+'/auc/'+su+'_'+freq_name+'_phase_auc_Good_Bad_elec_'+str(channel)+'_'+str(label)+'_('+str(elec)+').npy')
            name_th_0_05_perm = (save_path+su+'/'+str(freq)+'_'+freq_name+'/auc/'+su+'_'+freq_name+'_phase_th_0_05_perm_Good_Bad_elec_'+str(channel)+'_'+str(label)+'_('+str(elec)+').npy')
            name_th_0_01_perm = (save_path+su+'/'+str(freq)+'_'+freq_name+'/auc/'+su+'_'+freq_name+'_phase_th_0_01_perm_Good_Bad_elec_'+str(channel)+'_'+str(label)+'_('+str(elec)+').npy')
            plot_name = (save_path+su+'/'+str(freq)+'_'+freq_name+'/fig/'+su+'_'+freq_name+'_Phase_Good_Bad_elec_'+str(channel)+'_'+str(label)+'_('+str(elec)+').png')            
            #Save plots
            np.save(name_auc, auc)
            np.save(name_th_0_05_perm, th_0_05_perm[0])
            np.save(name_th_0_01_perm, th_0_01_perm[0])
            plt.savefig(plot_name, dpi=300, bbox_inches='tight')
            plt.clf()
            plt.close() 
            del bad_data_elec, good_data_elec, X, auc
    del bad_data, good_data

computing ML on --> LEFC elec 0 / 139 for freq 0
data elec  (23, 49) (35, 49)
freq VLFC
th_perm :  0.616666666667 0.708333333333
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
not legend
computing ML on --> LEFC elec 0 / 139 for freq 1
data elec  (23, 49) (35, 49)
freq delta
th_perm :  0.608333333333 0.716666666667
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.04, 1, 1, 1, 0.04, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
not legend
computing ML on --> LEFC elec 0 / 139 for freq 2
data elec  (23, 49) (35, 49)
freq theta
th_perm :  0.620833333333 0.729166666667
[1, 1, 1, 1, 1, 1, 1, 0.04, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
not legend
computing ML on --> LEFC elec 0 / 139 for freq 3
data elec  (23, 49) (35, 49)
freq alpha
th_perm :  0.620833333333 0.72916

KeyboardInterrupt: 