## Temporal generalization mne

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from brainpipe.system import study
from brainpipe.visual import *
from brainpipe.statistics import *
from os import path
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold as SKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
import mne
from mne.decoding import (SlidingEstimator, GeneralizingEstimator,
                          cross_val_multiscore, LinearModel, get_coef)

In [None]:
st = study('Olfacto')
path_pow = path.join(st.path, 'feature/7_Power_E1E2_Odor_Good_Bad_EpiScore_Expi/')
path_save = path.join(st.path, 'classified/1_Classif_Power_EpiScore_sel_electrodes_win700_step100_expi/TG/')
bsl = 'None'
su = 'LEFC'

file_good = su+'_odor_good_bipo_sel_'+bsl+'_pow_TG.npz'
file_bad = su+'_odor_bad_bipo_sel_'+bsl+'_pow_TG.npz'
# Load power (nfreq,nelec,nwin,ntrials)
pow_good,pow_bad = np.load(path_pow+file_good)['xpow'], np.load(path_pow+file_bad)['xpow']
fnames = np.load(path_pow+file_good)['fname']
channels = np.load(path_pow+file_good)['channels']
labels = np.load(path_pow+file_good)['labels']
print(pow_good.shape,pow_bad.shape)
# Generalization parameters
y = np.array([1]*pow_good.shape[3] + [2]*pow_bad.shape[3])
pow_all = np.concatenate((pow_good,pow_bad), axis=3)
# y = np.random.randint(2, size=pow_all.shape[3])
print(y)
nelecs, time = pow_all.shape[1], np.arange(pow_all.shape[2])
for i,freq in enumerate(fnames):
    fname = str(i)+'_'+freq
    for elec in range(nelecs):
        x = pow_all[i,elec,:,:].T
        x = x[:,np.newaxis,:]
        print(x.shape, y.shape)

        # Compute decoding :
        print('-> Generalization')
        clf = LDA()
        print('time', time.shape, 'y', y.shape,'x',x.shape)
        # Define the Temporal Generalization object
        skf = SKFold(n_splits=10, random_state=1)
        time_gen = GeneralizingEstimator(clf, scoring='roc_auc', n_jobs=6)
        scores = cross_val_multiscore(time_gen, x, y, cv=skf, n_jobs=6)
        # Mean scores across cross-validation splits
        scores = np.mean(scores, axis=0)

        #get_ipython().magic('matplotlib inline')
        print('-> Stats')
        da = scores*100
        pval = bino_da2p(y, da) * len(time)**2
        da[np.diag_indices(da.shape[0])] = 0
        times_plot = np.arange(0, 5500, 5500/int(time.shape[0]))
        fig = plt.figure(figsize=(8.5, 7))
        p = tilerplot()
        p.plot2D(fig, scores, xvec=times_plot, yvec=times_plot, vmin=0.2, vmax=1,
                 xlabel='Generalization time', ylabel='Training time',
                 xticks=[0,700,4000], yticks=[0,700,4000], cblabel='Auc Score',
                 contour={'data': pval, 'level': [0.001], 'label': ['p<0.001'],
                          'colors': ['white']}, dpax=['left', 'bottom'],
                 rmax=['top', 'right'], cmap='viridis')

        ax = plt.gca()
        addLines(ax, vLines=[0,1300,4000], vColor=['#c47d7a']*3, vWidth=[2]*3, vShape=['-']*3,
                 hLines=[0,1300,4000], hColor=['#c47d7a']*3, hWidth=[2]*3, hShape=['-']*3,)
        ax.text(10, -500, 'Pre', rotation=90, color='#c47d7a', size=15, weight='bold')
        ax.text(1300, -500, 'Odor', rotation=90, color='#c47d7a', size=15, weight='bold')
        ax.text(4000, -500, 'Post', rotation=90, color='#c47d7a', size=15, weight='bold')

        fig.savefig(path_save+fname+'/TG_'+su+'_'+fname+'_'+channels[elec]+'_'+labels[elec]+'_('+str(elec)+')_lda.png', dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close() 
        del x, da, pval

### TG with all elecs at once

In [None]:
st = study('Olfacto')
path_pow = path.join(st.path, 'feature/7_Power_E1E2_Odor_Good_Bad_EpiScore_Expi/')
path_save = path.join(st.path, 'classified/1_Classif_Power_EpiScore_sel_electrodes_win700_step100_expi/TG/')
bsl = 'None'
subjects = ['LEFC','CHAF','VACJ','SEMC','FERJ','MICP','PIRJ']

for su in subjects:
    file_good = su+'_odor_good_bipo_sel_'+bsl+'_pow_TG.npz'
    file_bad = su+'_odor_bad_bipo_sel_'+bsl+'_pow_TG.npz'
    # Load power (nfreq,nelec,nwin,ntrials)
    pow_good,pow_bad = np.load(path_pow+file_good)['xpow'], np.load(path_pow+file_bad)['xpow']
    fnames = np.load(path_pow+file_good)['fname']
    channels = np.load(path_pow+file_good)['channels']
    labels = np.load(path_pow+file_good)['labels']
    print(pow_good.shape,pow_bad.shape)
    # Generalization parameters
    y = np.array([1]*pow_good.shape[3] + [2]*pow_bad.shape[3])
    pow_all = np.concatenate((pow_good,pow_bad), axis=3)
    nelecs, time = pow_all.shape[1], np.arange(pow_all.shape[2])
    #sel_elecs = [25,65,5,16,17,18,19,23,7,11,12,26,28,29,30,34]
    for i,freq in enumerate(fnames):
        fname = str(i)+'_'+freq
        #for elec in range(nelecs):
        x = pow_all[i,:,:,:].swapaxes(0,2).swapaxes(1,2)
        print(x.shape, y.shape)

        # Compute decoding :
        print('-> Generalization')
        clf = LDA()
        print('time', time.shape, 'y', y.shape,'x',x.shape)
        # Define the Temporal Generalization object
        skf = SKFold(n_splits=10, random_state=1)
        time_gen = GeneralizingEstimator(clf, scoring='roc_auc', n_jobs=6)
        scores = cross_val_multiscore(time_gen, x, y, cv=skf, n_jobs=6)
        # Mean scores across cross-validation splits
        scores = np.mean(scores, axis=0)
        print(scores.shape,len(time))

        #get_ipython().magic('matplotlib inline')
        print('-> Stats')
        da = scores*100
        pval = bino_da2p(y, da) * len(time)**2
        da[np.diag_indices(da.shape[0])] = 0
        times_plot = np.arange(0, 5500, 5500/int(time.shape[0]))
        fig = plt.figure(figsize=(8.5, 7))
        p = tilerplot()
        p.plot2D(fig, scores, xvec=times_plot, yvec=times_plot, vmin=0.2, vmax=0.8,
                 xlabel='Generalization time', ylabel='Training time',
                 xticks=[0,975,4000], yticks=[0,975,4000], cblabel='Auc Score',
                 contour={'data': pval, 'level': [0.001], 'label': ['p<0.001'],
                          'colors': ['white']}, dpax=['left', 'bottom'],
                 rmax=['top', 'right'], cmap='viridis')

        ax = plt.gca()
        addLines(ax, vLines=[0,975,4000], vColor=['#c47d7a']*3, vWidth=[2]*3, vShape=['-']*3,
                 hLines=[0,975,4000], hColor=['#c47d7a']*3, hWidth=[2]*3, hShape=['-']*3,)
        ax.text(-10, -500, 'Pre', rotation=90, color='#c47d7a', size=15, weight='bold')
        ax.text(965, -500, 'Odor', rotation=90, color='#c47d7a', size=15, weight='bold')
        ax.text(3990, -500, 'Post', rotation=90, color='#c47d7a', size=15, weight='bold')

        fig.savefig(path_save+'/TG_MF_sel_elec_'+su+'_'+fname+'_lda.png', dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close() 
        del x, da, pval