## Import libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from os import path

from brainpipe.system import study
from brainpipe.feature import TF
from brainpipe.visual import *
from brainpipe.statistics import *

In [None]:
sf = 512
lines = [0]
f = (3,120,5,2)  # Frequency vector: (from, to, width, step) 1 to 120 Hz 4 2 #0.1, 13, 0.5, 0.1
baseline = [1178,1536] # Time in sample (rest 700ms au milieu des 1s extraite)
width, step = 80, 20 #Time vector
window_to_plot = [-200, 2000] # In seconds

### Compute TF on unbalanced data for all selected electrodes
#### No baseline correction - over 5.5s after odor is sent

In [None]:
st = study('Olfacto')
path_data = path.join (st.path, 'database/Encoding_EpiPerf_LowHigh/')
path_respi = '/media/karim/Datas4To/1_Analyses_Intra_EM_Odor/1bis_OE_BaseSam/JPlailly201306_seeg_ALS/behavior/respiration_amplitude/'
path2save = path.join (st.path, 'feature/1_TF_Encoding_EpiPerf_LowHigh/')
conds = ['low', 'high']
subjects = ['LEFC','CHAF','VACJ','SEMC','FERJ','MICP','PIRJ']
rois_to_keep = ['ACC','Amg','Amg-PirT','HC','IFG','Ins','MFG','OFC','PHG',
                'SFG','pPirT']

for su in subjects:
    data0 = np.load(path.join(path_data, su+'_odor_'+conds[0]+'_bipo_sel_physFT.npz'))['x']
    data1 = np.load(path.join (path_data, su+'_odor_'+conds[1]+'_bipo_sel_physFT.npz'))['x']
    names = np.load(path.join(path_data, su+'_odor_'+conds[0]+'_bipo_sel_physFT.npz'),
                   allow_pickle=True)['Mai_RL']
    channels = np.load(path.join(path_data, su+'_odor_'+conds[0]+'_bipo_sel_physFT.npz'))['channels']
    id_rois = np.where([roi in rois_to_keep for roi in names])[0]
    data0, data1 = data0[id_rois,...], data1[id_rois,...]
    names, channels = names[id_rois], channels[id_rois]
    print (su, 'low shape: ', data0.shape, 'high shape: ', data1.shape)
    
    # ========================= COMPUTE TF FOR 1 ELEC =============================================
    for elec in range(data0.shape[0]):
        data0_elec, data1_elec = data0[elec,:,:][np.newaxis], data1[elec,:,:][np.newaxis]
        channel, label, sf, npts = channels[elec], names[elec], 512, data0_elec.shape[1]
        print (su, channel, label, data0_elec.shape, 'nb points', npts)

        time = 1000 * np.arange(-1536,data0_elec.shape[1]-1536) / sf #to set the 0 in your TF
        print ('Time points: ', len(time),min(time), max(time))
        trials_high, trials_low = data1_elec.shape[2], data0_elec.shape[2]
        trials_sel = trials_high if trials_high<trials_low else trials_low
        data0_elec, data1_elec = data0_elec[:,:,:trials_sel], data1_elec[:,:,:trials_sel]
        print('trials',trials_sel)
        tfObj = TF(sf, npts, f=f, time=time, width=width, step=step,baseline=baseline,norm=3)
        xtf_low, _ = tfObj.get(data0_elec)
        xtf_high, _ = tfObj.get(data1_elec)
        xtf_low,xtf_high = 100*np.swapaxes(xtf_low,0,1), 100*np.swapaxes(xtf_high,0,1)
        xtf_diff = xtf_high - xtf_low
            
        #Plot all TF 
        xtf_all = np.concatenate((xtf_low, xtf_high,xtf_diff), axis=0)
        print(xtf_all.shape)
        timebin = np.array(tfObj.xvec)
        sl = slice(np.argmin(np.abs(timebin-window_to_plot[0])), np.argmin(np.abs(timebin-window_to_plot[1])))
        title = su+' Elec('+str(elec)+') '+channel+' '+label
        fig = plt.figure(elec, figsize=(15, 5))
        fig, allax = tfObj.plot2D(fig, xtf_all[:,:, sl], xvec=tfObj.xvec[sl], cmap='viridis',
                 yvec=tfObj.yvec, xlabel='Time (ms)', vmin=-150, vmax=150, ycb=-20,
                 ylabel='Frequency (hz)', figtitle=title, title=['Low','High','High - Low'],
                 cblabel='Power modulations (%)', pltype='imshow', resample=(0.5, 0.1),
                sharex=False, sharey=False, subdim=(1,3), subspace={'top':0.8})
        for k in allax:
            addLines(k, vLines=lines, vColor=['firebrick']*3, vWidth=[2]*3)
        #Save all your plots
        fname = (path2save+su+'_Elec('+str(elec)+')_'+channel+'_'+label+'_Low_High.pdf')
        fname2 = (path2save+su+'_Elec('+str(elec)+')_'+channel+'_'+label+'_Low_High.png')
        fig.savefig(fname, dpi=300, bbox_inches='tight')
        fig.savefig(fname2, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close()
    del xtf_all, data0_elec,data1_elec, sf, channel
    


In [None]:
feat, phase, win, th = 'pow', 'Encoding', 1.0, '0.01'
###############################################################################
st = study('Olfacto')
path_data = path.join (st.path, 'database/'+phase+'_EpiPerf_LowHigh/')
npz_form = path.join(path_data, '{}_odor_{}_bipo_sel_physFT.npz')
path_classif = path.join(st.path, 'figure/0_Classif_Power_'+phase[0]+'_EpiPerf_LowHigh_1000perm_BBG/')
classif_npz = path.join(path_classif, '{}_sources_{}_{}_low_high_sel_physFT.npz')
masks_form = path.join(path_classif, 'masks_stat/All_subjects_mask_stat_{}_minwin{}_th{}.npy') 
path2save = path.join(st.path, 'feature/1_TF_'+phase+'_EpiPerf_LowHigh/TF_roi/')
f_form_save = path.join(path2save, 'TF_{}_{}_Low_High_{}.png')
###############################################################################
conds = ['low', 'high']
sf=512
lines = [0]
f = (3,120,4,2)  # Frequency vector: (from, to, width, step) 1 to 120 Hz 4 2 #0.1, 13, 0.5, 0.1
baseline = [1178,1536] # Time in sample (rest 700ms au milieu des 1s extraite) 1178 - 1536
width, step = 80, 20 #Time vector
window_to_plot = [-700, 2000] # In seconds
###############################################################################
rois_list = {'Frontal':['IFG','MFG','SFG'],
            'Olf':['pPirT','Amg-PirT','Ins','OFC'],
            'MTL':['HC']}
freqs = '3_gamma'
subjs = ['CHAF','LEFC','FERJ','MICP','SEMC','VACJ','PIRJ']
rois_to_keep = ['ACC','Amg','Amg-PirT','HC','IFG','Ins','MFG','OFC','PHG',
                'SFG','pPirT']
###############################################################################
region = 'MTL'

mat = np.load(classif_npz.format('All_subjects',freq, 'odor'))
id_rois = np.where([roi in rois_to_keep for roi in mat['s_labels']])
subjects = mat['su_codes'][id_rois]
mask = np.load(masks_form.format(freq,str(win),th))
mask = np.logical_not(mask) #inverse of visbrain 
data_low, data_high = np.array([]), np.array([])
for s,su in enumerate(sorted(subjs)):
    mask_su = mask[np.where(subjects=='S'+str(s))]
    npz_low = np.load(npz_form.format(su,conds[0]))
    npz_high = np.load(npz_form.format(su,conds[1]))
    id_roi_su = np.where([roi in rois_to_keep for roi in npz_low['Mai_RL']])
    Mai_RL = npz_low['Mai_RL']
    channels = npz_low['channels']
    print('Mai_RL before sel', len(Mai_RL))
    Mai_RL= npz_low['Mai_RL'][id_roi_su][mask_su]
    su_low = npz_low['x'][id_roi_su][mask_su][np.where([x in rois_list[region] for x in Mai_RL])]
    su_high = npz_high['x'][id_roi_su][mask_su][np.where([x in rois_list[region] for x in Mai_RL])]
    channels = channels[id_roi_su][mask_su][np.where([x in rois_list[region] for x in Mai_RL])]
    print(region, su, Mai_RL,len(Mai_RL), su_low.shape,su_high.shape, channels)

    su_low = su_low.reshape((su_low.shape[1], -1))[np.newaxis]
    su_high = su_high.reshape((su_high.shape[1], -1))[np.newaxis]
    su_low, su_high = np.squeeze(su_low), np.squeeze(su_high)
    
    if np.size(channels):
        npts = su_low.shape[0]
        time = 1000 * np.arange(-1536,npts-1536) / sf #to set the 0 in your TF
        print ('Time points: ', len(time),min(time), max(time))
        tfObj = TF(sf, npts, f=f, time=time, width=width, step=step, 
                   baseline=baseline,norm=3)
        xtf_low, _ = tfObj.get(su_low)
        xtf_high, _ = tfObj.get(su_high)
        xtf_low,xtf_high = 100*np.swapaxes(xtf_low,0,1), 100*np.swapaxes(xtf_high,0,1)

        #Plot all TF 
        xtf_all = np.concatenate((xtf_low, xtf_high), axis=0)
        print(xtf_all.shape)
        timebin = np.array(tfObj.xvec)
        sl = slice(np.argmin(np.abs(timebin-window_to_plot[0])), np.argmin(np.abs(timebin-window_to_plot[1])))
        title = 'TF sig elecs in '+region+' for '+freq[2:]
        fig = plt.figure(figsize=(10, 5))
        fig, allax = tfObj.plot2D(fig, xtf_all[:,:, sl], xvec=tfObj.xvec[sl],cmap='viridis',
                 yvec=tfObj.yvec, xlabel='Time (ms)', vmin=-200, vmax=200, ycb=-20,
                 ylabel='Frequency (hz)', figtitle=title, title=['Low', 'High'],
                 cblabel='Power modulations (%)', pltype='imshow', resample=(0.01, 0.01),
                sharex=False, sharey=False, subdim=(1,2), subspace={'top':0.8})
        for k in allax:
            addLines(k, vLines=lines, vColor=['firebrick']*2, vWidth=[2]*2)
        #Save all your plots
        fname = f_form_save.format(region,freq,su)
        fig.savefig(fname, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close()