## Import Libraries

In [1]:
from os import path, makedirs
from itertools import product
import numpy as np
from brainpipe.system import study

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import StratifiedKFold as SKFold
from sklearn.model_selection import permutation_test_score
from sklearn.metrics import roc_auc_score

In [2]:
st = study('Ripples')
pow_path = path.join(st.path,'feature/')
reps = ['Encoding/', 'Retrieval_new_odors/', 'Retrieval_new_rec/']
conds = ['low','high']
freqs = ['ripple', 'low_freq','HFA']
subjects = ['LEFC']
nperm = 100

for rep in reps:
    pow_name = path.join(pow_path+rep, '{}_cond={}_bipo_feat_norm.npz')

    for su in subjects:
        print(rep)
        mat0 = np.load(pow_name.format(su,conds[0]), allow_pickle=True)
        mat1 = np.load(pow_name.format(su,conds[1]), allow_pickle=True)
        print(mat0['time'])
        print(mat1['time'])

-> Ripples loaded
Encoding/
[0.05 0.1  0.15 0.2  0.25 0.3  0.35 0.4  0.45 0.5  0.55 0.6  0.65 0.7
 0.75 0.8  0.85 0.9  0.95 1.   1.05 1.1  1.15 1.2  1.25 1.3  1.35 1.4
 1.45 1.5  1.55 1.6  1.65 1.7  1.75 1.8  1.85 1.9  1.95 2.   2.05 2.1
 2.15 2.2  2.25 2.3  2.35 2.4  2.45 2.5  2.55 2.6  2.65 2.7  2.75 2.8
 2.85 2.9  2.95 3.   3.05 3.1  3.15 3.2  3.25 3.3  3.35 3.4  3.45 3.5
 3.55 3.6  3.65 3.7  3.75 3.8  3.85 3.9  3.95 4.   4.05 4.1  4.15 4.2
 4.25 4.3  4.35 4.4  4.45 4.5  4.55 4.6  4.65 4.7  4.75 4.8  4.85 4.9
 4.95 5.   5.05 5.1  5.15 5.2  5.25 5.3  5.35 5.4  5.45 5.5  5.55 5.6
 5.65 5.7  5.75 5.8  5.85 5.9  5.95 6.   6.05 6.1  6.15 6.2  6.25 6.3
 6.35 6.4  6.45 6.5  6.55 6.6  6.65 6.7  6.75 6.8  6.85 6.9  6.95]
[0.05 0.1  0.15 0.2  0.25 0.3  0.35 0.4  0.45 0.5  0.55 0.6  0.65 0.7
 0.75 0.8  0.85 0.9  0.95 1.   1.05 1.1  1.15 1.2  1.25 1.3  1.35 1.4
 1.45 1.5  1.55 1.6  1.65 1.7  1.75 1.8  1.85 1.9  1.95 2.   2.05 2.1
 2.15 2.2  2.25 2.3  2.35 2.4  2.45 2.5  2.55 2.6  2.65 2.7  2.75

In [None]:
"""
Compute classification power in time for allfreqs and electrodes
"""

st = study('Ripples')
pow_path = path.join(st.path,'feature/')
reps = ['Encoding/', 'Retrieval_new_odors/', 'Retrieval_new_rec/']
conds = ['low','high']
freqs = ['ripple', 'low_freq','HFA']
subjects = ['LEFC','CHAF','VACJ','SEMC','FERJ','PIRJ']
nperm = 100

for rep in reps:
    pow_name = path.join(pow_path+rep, '{}_cond={}_bipo_feat_norm.npz')
    path_save = path.join(st.path, 'classified/'+rep)
    if not path.exists(path_save):
        makedirs(path_save)
    clf_name = path.join(path_save, '{}_LDA_clf_{}_{}_{}.npz')    

    for su in subjects:
        mat0 = np.load(pow_name.format(su,conds[0]), allow_pickle=True)
        mat1 = np.load(pow_name.format(su,conds[1]), allow_pickle=True)
        if rep != 'Retrieval_new_rec/':
            to_take, time = [54, 100], mat0['time']-3
        else:
            to_take, time = [39, 120], mat0['time']-5

        for i,freq in enumerate(freqs):
            id_f = [i for i,f in enumerate(mat1['fname']) if f==freq][0]
            pow_list = []
            #=========================== Load Power files (nfreq, nelec, nwin, ntrial) =================================    
            pow_list.append(mat0['xpow'][id_f,:,to_take[0]:to_take[1],:])
            nelecs, npts, _ = mat0['xpow'][id_f,:,to_take[0]:to_take[1],:].shape
            pow_list.append(mat1['xpow'][id_f,:,to_take[0]:to_take[1],:])
            print (su,mat0.files, 'xpow shape: ', [pow.shape for pow in pow_list])
        
            #=========================== Create dict for all results =================================    
            name_classif = clf_name.format(su,freq,conds[0],conds[1])
            if not path.exists(name_classif):
                kwargs = {}
                kwargs['names'], kwargs['channels'] = mat0['new_labels'], mat0['channels']
                kwargs['xyz'], kwargs['time'] = mat0['xyz'], time

                # =========================== Select Power for 1 elec 1 freq =================================                 
                permut,auc = np.zeros((nelecs,npts,nperm)),np.zeros((nelecs,npts))
                for elec_num in range(nelecs):
                    print('--» processing',rep,su, 'elec', elec_num,'/',nelecs, 'freq',freq)
                    pow_data_elec = [power[elec_num].swapaxes(0,1) for power in pow_list]

                    # create a data matrix, concatenate along the trial dimension
                    x = np.concatenate(pow_data_elec, axis=0)
                    print ('Size of the concatenated data: ', x.shape)
                    y = np.hstack([np.array([i]*len(power)) for i, power in enumerate(pow_data_elec)])
                    print ('Size of label for classif: ', len(y))

                    for t in range(npts):
                        X_t = x[:,t][:,np.newaxis]
                        clf = LDA()
                        cv = SKFold(5)
                        score, permutation_scores, pvalue = permutation_test_score(
                        clf, X_t, y, scoring="roc_auc", cv=cv, n_permutations=nperm, n_jobs=-1)
                        permut[elec_num,t] += permutation_scores
                        auc[elec_num,t] += score
    #                     print("clf score for %t %s (pvalue : %s)" % (t, score, pvalue))
                kwargs['perm'], kwargs['auc'] = permut, auc
                #Save plots
                np.savez(name_classif, **kwargs)
                del x, y, pow_data_elec, permutation_scores, score
            else:
                print(name_classif, 'already computed')
        del pow_list