In [12]:
import numpy as np
import matplotlib.pyplot as plt
import scikits.bootstrap as boot
from os import path
from matplotlib.ticker import ScalarFormatter, MaxNLocator

from brainpipe.system import study
from brainpipe.visual import *
from brainpipe.statistics import *
from mne.baseline import rescale
from mne.filter import filter_data
from scipy.stats import *

In [24]:
# ANALYSIS PARAMETERS
low_pass_filter = 10.
sf = 512.
norm_mode = 'zscore' #'ratio' 'mean' 'percent' 
bsl = (922,1024) #-200 to 0ms before odor perception
data_to_use = [922, 2560] #-200 to 2000ms after odor
time_points = data_to_use[1]-data_to_use[0]
n_rep = 100 #bootstrap
alpha = 0.05

In [26]:
st = study('Olfacto')
data_path = path.join(st.path, 'database/Encoding_EpiPerf_LowHigh/')
dataname = path.join(data_path, '{}_odor_{}_common_renamed_OFC_HC.npz')
savename = path.join(data_path, 'ERPs_0.001/ERP_{}_{}_elec({})_sig0.001_boot100_3s.png')
savename2 = path.join(data_path, 'ERPs_0.001/ERP_{}_{}_elec({})_sig0.001_boot100_3s.pdf')

dict_E = {'LEFC':["o4-o3","b2-b1","b4-b3","d4-d3","d5-d4"],'FERJ':["b'6-b'5","j'3-j'2"],
          'PIRJ':["o7-o6","b'2-b'1"],'SEMC':["b2-b1"],'VACJ':["d'2-d'1"]}
dict_R = {'FERJ':["j'3-j'2"],'PIRJ':["o12-o11"],
          'SEMC':["o11-o10"],'VACJ':["o'12-o'11","d'3-d'2"]}
conds, exp = ['low','high'], 'E'
dict_ = dict_E if exp == 'E' else dict_R

for su in dict_:
    mat0 = np.load(dataname.format(su,conds[0]),allow_pickle=True)
    mat1 = np.load(dataname.format(su,conds[1]),allow_pickle=True)
    # 3584 points >> 7s from -3 to +4s
    data0, data1 = mat0['x'], mat1['x']
    labels, channels = mat0['labels'], mat0['channels']
    idx = [i for i, chan in enumerate(channels) if chan in dict_[su]]
    xyz = mat0['xyz']

    for elec in idx:
        x0, x1, lab = data0[elec].swapaxes(0,1), data1[elec].swapaxes(0,1), labels[elec]
        xyz_elec = xyz[elec]
        print(su, elec, xyz_elec, lab)
        print(x0.shape, x1.shape)
        
        time = np.arange(x0.shape[1])
        norm_x0 = rescale(x0, times=time, baseline=bsl, mode=norm_mode)
        norm_x1 = rescale(x1, times=time, baseline=bsl, mode=norm_mode)
        print(norm_x0.shape, norm_x1.shape)
        
        data0_rep, data1_rep = np.array([]), np.array([])
        for i in range(n_rep):
            if norm_x0.shape[0] > norm_x1.shape[0]:
                data0_sel = np.mean(norm_x0[np.random.randint(norm_x0.shape[0],
                                            size=norm_x1.shape[0]), :], axis=0)
                data1_sel = norm_x1
            if norm_x0.shape[0] < norm_x1.shape[0]:
                data0_sel = norm_x0
                data1_sel = np.mean(norm_x1[np.random.randint(norm_x1.shape[0], 
                                            size=norm_x0.shape[0]), :], axis=0)
            elif norm_x0.shape[0] == norm_x1.shape[0]:
                data0_sel, data1_sel = norm_x0, norm_x1
            data0_rep = np.vstack((data0_rep, data0_sel)) if np.size(data0_rep) else data0_sel
            data1_rep = np.vstack((data1_rep, data1_sel)) if np.size(data1_rep) else data1_sel
        print('data bad, good',data0_rep.shape, data1_rep.shape)
        
        mean_x0, mean_x1 = np.mean(data0_rep,axis=0),np.mean(data1_rep,axis=0)
        print(mean_x0.shape, mean_x1.shape)
        mean_x0, mean_x1 = np.array(mean_x0, dtype='float64'), np.array(mean_x1, dtype='float64')
        
        filt_x0 = filter_data(mean_x0, sfreq=512, l_freq=None, h_freq=low_pass_filter, 
                                                  method='fir', phase='zero-double')
        filt_x1 = filter_data(mean_x1, sfreq=512, l_freq=None, h_freq=low_pass_filter, 
                                                  method='fir', phase='zero-double')
        final_x0 = filt_x0[data_to_use[0]:data_to_use[1]][np.newaxis]
        final_x1 = filt_x1[data_to_use[0]:data_to_use[1]][np.newaxis]
        print(su, lab, 'low norm shape', final_x0.shape, 'high norm shape', 
              final_x1.shape)
        
        # ========================== PREPARE PLOTS AND SAVE STATS =========================================
        # plot and figure parameters
        xfmt = ScalarFormatter(useMathText=True)
        xfmt.set_powerlimits((0,3))
        fig = plt.figure(1,figsize=(9,6))
        title = 'ERP and Stats for {} High/Low in {}'.format(su,lab)
        fig.suptitle(title, fontsize=12)
        times_plot = 1000 * np.arange((bsl[0] - bsl[1]), data_to_use[1]-bsl[1]) / sf
        print('time plots', times_plot.shape)
        # Plot the ERPs and the stats
        data_all = np.concatenate((final_x0, final_x1), axis=0)
        y = np.array([0]+[1])
        BorderPlot(times_plot, data_all, y=y, kind='sd', alpha=0.2, color=['b','m'], 
                   linewidth=2, ncol=1, xlabel='Time (ms)',ylabel = r' $\mu$V', 
                   legend = ['bad', 'good'])
        addLines(plt.gca(), vLines=[0], vColor=['r'], vWidth=[2], hLines=[0], 
                 hColor=['#000000'], hWidth=[2])
        rmaxis(plt.gca(), ['right', 'top'])
        plt.legend(loc=0, handletextpad=0.1, frameon=False)
        plt.gca().yaxis.set_major_locator(MaxNLocator(3,integer=True))
        plt.savefig(savename.format(su,lab,elec), dpi=300, bbox_inches='tight')
        plt.savefig(savename2.format(su,lab,elec), dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close()
        

-> Olfacto loaded
LEFC 3 [ 30.65 -20.5  -10.75] aHC
(34, 3584) (21, 3584)
Applying baseline correction (mode: zscore)
Applying baseline correction (mode: zscore)
(34, 3584) (21, 3584)
data bad, good (100, 3584) (2100, 3584)
(3584,) (3584,)
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 677 samples (1.322 sec) selected
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 677 samples (1.322 sec) selected
LEFC aHC low norm shape (1, 1638) high norm shape (1, 1638)
time plots (1638,)
LEFC 5 [ 38.45 -20.1  -10.95] aHC
(34, 3584) (21, 3584)
Applying baseline correction (mode: zscore)
Applying baseline correction (mode: zscore)
(34, 3584) (21, 3584)
data bad, good (100, 3584) (2100, 3584)
(3584,) (3584,)
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length of 677 samples (1.322 sec) selected
Setting up low-pass filter at 10 Hz
h_trans_bandwidth chosen to be 2.5 Hz
Filter length

In [None]:
ci.shape
plt.plot(times_plot, np.mean(norm_filt_x0,axis=0),'b-')
plt.plot(times_plot, np.mean(norm_filt_x1,axis=0),'m-')
plt.show()