## Calculate ERPs

In [None]:
#Importing files and modules

import numpy as np
import matplotlib.pyplot as plt
from os import path
%matplotlib notebook
from brainpipe.system import study
from brainpipe.visual import *
from brainpipe.statistics import *
from mne.baseline import rescale
from mne.filter import filter_data
import pandas as pd
from pandas import ExcelWriter

from detect_peaks import detect_peaks

## User variables

In [None]:
low_pass_filter = 10.
sf = 512.
norm_mode = 'mean' #'ratio' 'mean' 'percent' 
baseline = [640 , 768] #-250ms à 0ms
data_to_use = [768, 1536]
n_perm = 200

# Compute ERPs and Stats

In [None]:
st = study('Olfacto')
path_data = path.join (st.path, 'database/TS_E_all_cond_by_block_trigs_th40_art400_30_250_5s_concatOK/')

subjects = ['CHAF','VACJ','SEMC','PIRJ','LEFC','MICP',]

conds = ['all']

n_elec = {
    'CHAF' : 107,
    'VACJ' : 139, 
    'SEMC' : 107,
    'PIRJ' : 106,
    'LEFC' : 193,
    'MICP' : 105,
}

for su in subjects:
    for elec in range(0, n_elec[su],1):
        for cond in conds:
            filename = su+'_E1E2_concat_all_bipo_new.npz'
            print (filename)
            data_all = np.load(path.join(path_data, filename))
            data, channel, label = data_all['x'], data_all['channel'], data_all['label']
            #data, channel = data_all['x'], [data_all['channel'][i][0] for i in range(len(data_all['channel']))]
            #label = [data_all['label'][i][0] for i in range(len(data_all['label']))]
            
            # Select data for one elec + name :
            data_elec = data[elec,:,:]
            ntrials = len(data_elec[2])
            print ('\nOriginal data : ', data.shape, 'Channel : ', channel[elec], 'Label : ', label[elec], 'N_trials :', ntrials, 'One elec shape : ', data_elec.shape)

            #Filter data for one elec (all trials):
            data = np.array(data_elec, dtype='float64')
            data_to_filter = np.swapaxes(data, 0, 1)
            filtered_data = filter_data(data_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
            filtered_data = np.swapaxes(filtered_data, 0, 1)
            print ('Size of filtered data:', filtered_data.shape,)

            #Normalize the non-averaged data (all trials)
            times = np.arange(filtered_data.shape[0])
            print ('time points : ', times.shape)
            filtered_data_to_norm = np.swapaxes(filtered_data, 0, 1)
            norm_filtered_data = rescale(filtered_data_to_norm, times=times, baseline=baseline, mode=norm_mode)
            norm_filtered_data = np.swapaxes(norm_filtered_data, 0, 1)
            print ('Size norm & filtered data : ', norm_filtered_data.shape,)

# =======================================SELECT DATA FOR STATISTICS=====================================
            # Define a range vector for the baseline and data :
            baseline_range = range(baseline[0], baseline[1])
            data_range = range(data_to_use[0], data_to_use[1])
            
            #Get the baseline and data from the FILTERED data :
            baseline_tr = filtered_data[baseline_range, :]
            data_elec_tr = filtered_data[data_range, :]

            # Mean the baseline across time (increase consistency) :
            baseline_tr = baseline_tr.mean(0)[np.newaxis, ...]
            print('-> Shape of the baseline : ', baseline_tr.shape,' and the selected data :', data_elec_tr.shape)
            
            # Repeat the baseline by the nb of temporal time points (for swap function)
            baseline_tr_rep = np.tile(baseline_tr, (data_elec_tr.shape[0], 1))
            print('-> Shape of repeated baseline :', baseline_tr_rep.shape)
            
            # Swap RAW data and baseline across trials (dim 1) :
            perm_data = perm_swap(baseline_tr_rep, data_elec_tr, axis=1, n_perm=n_perm)[0]
            print('-> Shape of permuted data / baseline : ', perm_data.shape)

            # Take the mean across time :
            perm_data_mean = perm_data.mean(2)
            print('-> Shape of meaned permuted data / baseline : ', perm_data_mean.shape)

            # Get p-values from the permuted data :
            p_vals = perm_2pvalue(data_elec_tr.mean(1), perm_data_mean, n_perm=n_perm, threshold=None, tail=2)
            print('-> Shape of non-corrected p-values : ', p_vals.shape)

            # Correct across time :
            perm_corr = maxstat(perm_data, axis=2)[..., 0]
            print('-> Shape of maxstat perm : ', perm_corr.shape)

            # Get p-values from the permuted data :
            p_vals_corr = perm_2pvalue(data_elec_tr.mean(1), perm_corr, n_perm=n_perm, threshold=None, tail=2)
            print('-> Shape of corrected p-values : ', p_vals_corr.shape)

            # Test if there's significant p-values after multiplt comparison :
            print('-> Significant p-values after mutiple comparison? ', p_vals_corr.min() <= 0.05)

# =======================PREPARE DATA TO PLOT AND PLOT THE ERPs=====================================
            # Data to plot :
            #data_to_plot = norm_data[range(baseline[0], data_to_use[1])]
            #print('-> Shape of data to plot : ', data_to_plot.shape)
            filtered_data_to_plot = norm_filtered_data[range(baseline[0], data_to_use[1])]
            print('-> Shape of filtered data to plot : ', filtered_data_to_plot.shape)

            # Time vector :
            times_plot = 1000 * np.arange((baseline[0] - baseline[1]), filtered_data_to_plot.shape[0]-baseline[1] + baseline[0],) / sf
            #times_plot = 1000 * np.arange(-baseline[1], filtered_data_to_plot.shape[0]-baseline[1],) / sf
            print('-> Shape of time vector : ', times_plot.shape)

            # P-values to plot :
            p_vals_to_plot = np.insert(p_vals, 0, 10 * np.ones((data_to_use[0] - baseline[0],)))
            p_vals_corr_to_plot = np.insert(p_vals_corr, 0, 10 * np.ones((data_to_use[0] - baseline[0],)))
            print('-> Shape of p-values to plot :', p_vals_to_plot.shape, p_vals_corr_to_plot.shape)

            #Prepare the plot
            fig = plt.figure(0, figsize=(12, 7))
            ax = fig.add_subplot(111)
            fig.subplots_adjust(top=0.85)
            ax.set_xlabel('Times (ms)', fontsize=12)
            ax.set_ylabel('Potential', fontsize=12)

            #Plot the Data
            #plt.plot(times_plot, data_to_plot, '#808080', linewidth=1, label='data')
            #plt.plot(times_plot, filtered_data_to_plot, 'g-', linewidth=2, label='filtered data < '+str(low_pass_filter)+'Hz')
            BorderPlot(times_plot, filtered_data_to_plot, kind='sem', color='', alpha=0.2, linewidth=2, ncol=1, 
                      title=su+'_ERP_Odor_bipo_'+norm_mode+'_'+channel[elec]+'_'+label[elec]+' elec_num: '+str(elec)+'_ntrials:'+str(ntrials),) #legend= roi+'_filter < '+str(low_pass_filter)+'Hz'
            plt.gca()
            lines = [0] #time vector is in ms
            addPval(plt.gca(), p_vals_to_plot, p=0.05, x=times_plot, y=0.5, color='darkred', lw=3)
            addPval(plt.gca(), p_vals_to_plot, p=0.01, x=times_plot, y=1, color='darkblue', lw=4)
            addPval(plt.gca(), p_vals_corr_to_plot, p=0.05, x=times_plot, y=2, color='red', lw=3)
            addPval(plt.gca(), p_vals_corr_to_plot, p=0.01, x=times_plot, y=3, color='dodgerblue', lw=4)
            addLines(plt.gca(), vLines=lines, vColor=['firebrick']*2, vWidth=[2]*2, hLines=[0], hColor=['#000000'], hWidth=[2])
            plt.grid()
            #plt.legend(fontsize='small')
            #plt.show()         
                 
# =========================SAVE PLOTS of ERPs=================================================
            rep = path.join(st.path, 'feature/ERP_Encoding_all_bipo_250ms_mean_thr40_art400_30_250/',su)
            fname = (rep + '_E1E2_ERP_concat_all_bipo_' + channel [elec] +'_'+str(elec)+'_'+label[elec]+'.png')
            print (fname)
            plt.savefig(fname, dpi=300, bbox_inches='tight')
            plt.close()
del x, channel, n_elec, n_trials, label