## Calculate ERPs

In [None]:
#Importing files and modules

import numpy as np
import matplotlib.pyplot as plt
from os import path
%matplotlib notebook
from brainpipe.system import study
from brainpipe.visual import *
from brainpipe.statistics import *
from mne.baseline import rescale
from mne.filter import filter_data

## Plot ERP for one elec one cond corrected

In [None]:
st = study('Olfacto')
path_data = path.join (st.path, 'database/TS_E_all_cond_by_block_trigs/')

#Data file for 1 subject
filename = 'MICP_E1E2_concat_allfilter1.npz'
data_all = np.load(path.join(path_data, filename))
low_pass_filter = 10.
norm_mode = 'mean' #'ratio' 'mean' 'percent' 

#Data for one elec + name
elec = 37
data, channel = data_all['x'], [data_all['channel'][i][0] for i in range(len(data_all['channel']))]
print (data.shape)
data_elec = data[elec,:,:]
ntrials = len(data_elec[2])
print (channel[0], 'n_trials :', ntrials)
print (data_elec.shape)

#Average the data for one electrode
mean_data_elec = np.mean(data_elec, axis=1)
print ('size of the data:',mean_data_elec.shape)

#Normalize the ERP
baseline = [973 , 1024]   #matrix are in sample
times = np.arange(mean_data_elec.shape[0])
norm_data = rescale(mean_data_elec, times=times, baseline=baseline, mode=norm_mode,)
print (norm_data.shape, norm_data.dtype)

## Compute stats

In [None]:
#generate the permutations between baseline and signal
baseline = norm_data [973:1024]
print (baseline.shape)
evoked_signal = norm_data[1024:1792]
perm_data = perm_swap(baseline, evoked_signal, n_perm=1000, axis=0, rndstate=0) #axis=-1, the shape of a and b could be diffrent
print(perm_data[0].shape, perm_data[1].shape)

#Compute the statistics
p_vals = perm_2pvalue(evoked_signal, perm_data[1], n_perm=1000, threshold=None, tail=2)
add_p_to_plot = np.ones(int(baseline.shape[0]), dtype = None, order = 'C')
p_vals_to_plot = np.insert(p_vals, 0, add_p_to_plot,)
print (p_vals_to_plot.shape)


## Plots stats and ERPs

In [None]:
#Filter the data 10Hz (just for illustration)
data_to_filter = norm_data[np.newaxis]
data_to_filter = np.array(norm_data, dtype='float64')
print (data_to_filter.shape, data_to_filter.dtype)
filtered_data = filter_data(data_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
print (filtered_data.shape,)

#Data to plot
norm_data_to_plot = norm_data[973:1792]
print ('size window to plot', norm_data_to_plot.shape)
filtered_data_to_plot = filtered_data[973:1792]
times_plot = 1000 * np.arange(-51, norm_data_to_plot.shape[0]-51,) / 512
print (times_plot.shape)

#Plot the ERP data and filtered data
fig = plt.figure()
ax = fig.add_subplot(111)
fig.subplots_adjust(top=0.85)
ax.set_title('ERP_'+norm_mode, fontsize=14, fontweight='bold')
ax.set_xlabel('Times (ms)', fontsize=12)
ax.set_ylabel('Potential', fontsize=12)
plt.plot(times_plot, norm_data_to_plot, '#808080', linewidth=1, label='data')
plt.plot(times_plot, filtered_data_to_plot, 'g-', linewidth=2, label='filtered data < '+str(low_pass_filter)+'Hz')
lines = (-100, 0) #time vector is in ms
addPval(plt.gca(), p_vals_to_plot, p=0.05, x=times_plot, y=0.5, color='darkred', lw=3)
addPval(plt.gca(), p_vals_to_plot, p=0.01, x=times_plot, y=0.2, color='darkblue', lw=4)
addLines(plt.gca(), vLines=lines, vColor=['firebrick']*2, vWidth=[2]*2, hLines=[0], hColor=['#000000'], hWidth=[2])
plt.grid()
plt.legend(fontsize='small')
plt.show()

## Plot ERP for all eclectrodes all subjects

In [None]:
#Import files & parameters
st = study('Olfacto')
path_data = path.join (st.path, 'database/TS_E_all_cond_by_block_trigs_60th_200art/')
subjects = ['CHAF','VACJ','SEMC','FERJ','PIRJ','LEFC','MICP',]
#subjects = ['MICP',]
conds = ['all']
norm_mode = 'mean'
low_pass_filter = 10.
n_elec = {
    'CHAF' : 51,
    'VACJ' : 152, 
    'SEMC' : 118,
    'FERJ' : 126,
    'PIRJ' : 117,
    'LEFC' : 210,
    'MICP' : 122,
}

for su in subjects:
    for elec in range(0, n_elec[su],1):
        for cond in conds:
            filename = su+'_E1E2_concat_allfilter1.npz'
            print (filename)
            data_all = np.load(path.join(path_data, filename))
            data, channel = data_all['x'], [data_all['channel'][i][0] for i in range(len(data_all['channel']))]
            label = [data_all['label'][i][0] for i in range(len(data_all['label']))]
#             print (data.shape, channel.shape)

            #Select data 
            data_elec = data[elec,:,:]
            n_trials = len(data_elec[2])
            print ('nb of trials = ', n_trials)
            #print (channel[0])
            print (data_elec.shape)
            
            #filter data
            data = np.array(data_elec, dtype='float64')
            data_to_filter = np.swapaxes(data, 0, 1)
            filtered_data = filter_data(data_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')
            print (data_to_filter.shape,)
            
            #Average data for one electrode
            mean_filtered_data = np.mean(filtered_data, axis=0)
            mean_data_elec = np.mean(data_elec, axis=1)
            print (mean_data_elec.shape, mean_data_elec.shape[0])

            #Normalize data
            baseline = [717 , 768]   #matrix are in sample de -100ms à 0
            times = np.arange(mean_data_elec.shape[0])
            norm_data = rescale(mean_data_elec, times=times, baseline=baseline, mode=norm_mode, copy=True, verbose=None)
            norm_filtered_data = rescale(mean_filtered_data, times=times, baseline=baseline, mode=norm_mode,copy=True, verbose=None)
            print (norm_data.shape, norm_filtered_data.shape)
            
            #ON RAW DATA
            #generate the permutations between baseline and signal
            #baseline_raw = norm_data [973:1024]
            #evoked_signal_raw = norm_data[1024:1792]
            #perm_data_raw = perm_swap(baseline_raw, evoked_signal_raw, n_perm=1000, axis=0, rndstate=0) #axis=-1, the shape of a and b could be diffrent

            #Compute the statistics
            #p_vals_raw = perm_2pvalue(evoked_signal_raw, perm_data_raw[1], n_perm=1000, threshold=None, tail=2)
            #add_p = np.ones(int(baseline_raw.shape[0]), dtype = None, order = 'C')
            #p_vals_raw_to_plot = np.insert(p_vals_raw, 0, add_p,)
            
            #Correct p val over time
            #perm_raw_corr = maxstat(perm_data_raw[1], axis=1) #-1, corrected permutations across all dimensions
            #print (perm_raw_corr.shape)
            #p_vals_raw_corr = perm_2pvalue(evoked_signal_raw, perm_raw_corr, n_perm=1000, threshold=None, tail=2)
            #p_vals_raw_corr_to_plot = np.insert(p_vals_raw_corr, 0, add_p,)
            #print (p_vals_raw_corr_to_plot)

            #ON FILTERED DATA
            #Filter the data 10Hz
            #data_to_filter = norm_data[np.newaxis]
            #data_to_filter = np.array(norm_data, dtype='float64')
            #filtered_data = filter_data(data_to_filter, sfreq=512, l_freq=None, h_freq=low_pass_filter, method='fir', phase='zero-double')

            #generate the permutations between baseline and signal
            #baseline_filt = filtered_data [973:1024]
            #evoked_signal_filt = filtered_data [1024:1792]
            #perm_data_filt = perm_swap(baseline_filt, evoked_signal_filt, n_perm=1000, axis=0, rndstate=0) #axis=-1, the shape of a and b could be diffrent

            #Compute the statistics
            #p_vals_filt = perm_2pvalue(evoked_signal_filt, perm_data_filt[1], n_perm=1000, threshold=None, tail=2)
            #add_p_filt = np.ones(int(baseline_filt.shape[0]), dtype = None, order = 'C')
            #p_vals_filt_to_plot = np.insert(p_vals_filt, 0, add_p_filt,)
            
            #Correct p val over time
            #perm_filt_corr = maxstat(perm_data_filt[1], axis=1) #-1, corrected permutations across all dimensions
            #print (perm_filt_corr.shape)
            #p_vals_filt_corr = perm_2pvalue(evoked_signal_filt, perm_filt_corr, n_perm=1000, threshold=None, tail=2)
            #p_vals_filt_corr_to_plot = np.insert(p_vals_filt_corr, 0, add_p,)
            #print (p_vals_filt_corr_to_plot)

            #Data to plot
            norm_data_to_plot = norm_data[717:1536] #de -100ms à + 1500ms
            print ('size window to plot', norm_data_to_plot.shape)
            filtered_data_to_plot = norm_filtered_data[717:1536]
            times_plot = 1000 * np.arange(-51, norm_data_to_plot.shape[0]-51,) / 512
            print (times_plot.shape)

            #Plot the ERP data and filtered data
            fig = plt.figure()
            ax = fig.add_subplot(111)
            fig.subplots_adjust(top=0.85)
            ax.set_title(su+'_ERP_Odor_'+norm_mode+'_'+channel[elec]+'_'+label[elec]+' elec_num: '+str(elec)+'_ntrials:'+str(n_trials), fontsize=14, fontweight='bold')
            ax.set_xlabel('Times (ms)', fontsize=12)
            ax.set_ylabel('Potential', fontsize=12)
            #BorderPlot(times_plot, filtered_data_to_plot, kind='sem', color='', alpha=0.2, linewidth=2, ncol=1, legend='filtered data < '+str(low_pass_filter)+'Hz')
            plt.plot(times_plot, norm_data_to_plot, '#808080', linewidth=1, label='data')
            plt.plot(times_plot, filtered_data_to_plot, 'g-', linewidth=2, label='filtered data < '+str(low_pass_filter)+'Hz')
            lines = [0] #time vector is in ms
            #addPval(plt.gca(), p_vals_filt_to_plot, p=0.05, x=times_plot, y=0.5, color='darkred', lw=3)
            #addPval(plt.gca(), p_vals_filt_to_plot, p=0.01, x=times_plot, y=1, color='darkblue', lw=4)
            #addPval(plt.gca(), p_vals_filt_corr_to_plot, p=0.05, x=times_plot, y=2, color='darkorange', lw=3)
            #addPval(plt.gca(), p_vals_filt_corr_to_plot, p=0.01, x=times_plot, y=3, color='darkgrey', lw=4)
            addLines(plt.gca(), vLines=lines, vColor=['firebrick'], vWidth=[2], hLines=[0], hColor=['#000000'], hWidth=[2])
            plt.grid()
            plt.legend(fontsize='small')
#             plt.show()
                      
            #Save all your plots
            rep = path.join(st.path, 'feature/ERP_Encoding_all_mono_100ms_mean_thr30_art400/',su)
            fname = (rep + '_E1E2_ERP_concat_all_mono_' + channel [elec] +'_'+str(elec)+'_'+label[elec]+'.png')
            print (fname)
            plt.savefig(fname, dpi=300, bbox_inches='tight')
            plt.close()
del x, channel, n_elec, n_trials
