# ERP Component measurement 
## Mean/Peak Amplitudes & Latencies for individual trials and averages over conditions
### Exports data for LME analysis, as well as generating various plots for EDA

---
Copyright 2024 [Aaron J Newman](https://github.com/aaronjnewman), [NeuroCognitive Imaging Lab](http://ncil.science), [Dalhousie University](https://dal.ca)


Released under the [The 3-Clause BSD License](https://opensource.org/licenses/BSD-3-Clause)

---

## Which component to analyze

In [4]:
component = 'n170'

## Initialization

In [6]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import pyarrow.feather as feather
from scipy.stats import zscore
import os.path as op
from glob import glob
from pathlib import Path
import json
import mne
mne.set_log_level(verbose='error')

## Read Parameters from config.json

Will import study-level parameters from `config.json` in `bids_root`

In [13]:
# this shouldn't change if you run this script from its default location in code/import
bids_root = '../..'

config_file = op.join(bids_root, 'config.json')
config = json.load(open(config_file))

study_name = config['Study']['Name']
study_name = config['Study']['TaskName']
data_type = config['EEG']['data_type']
eog = config['EEG']['eog']
montage_fname = config['EEG']['montage']

n_jobs = config['Preprocessing']['n_jobs']

component_p = config['Analysis']['components'][component]
component_meas = component_p['component_meas']

## Time windows of interest

In [None]:
# Define the total length of the epoch 
# this can be less than what is in the input files; will use the crop function
t_min = float(component_p['t_min'])
t_max =  float(component_p['t_max'])
baseline = eval(component_p['baseline'])

# value obtained from butterfly plots visual inspection and added to config.json
# peak_lat = {k: v for d in cfg['peak_lat'] for k, v in d.items()}
peak_lat = component_p['peak_lat']

# tw_width = {k: v for d in cfg['tw_width'] for k, v in d.items()}
tw_width = component_p['tw_width']

# Amount of time to shift event codes by, based on empirical testing with photocell
#  to determine lag between event code and actual stimulus appearance on screen.
#  For reasons that are unclear, we need to double this value (in sec) to get the correct shift
tshift = config['t_shift']

## Define ROIs
clusters of electrodes to average over for waveform plots

In [None]:
montage = pd.read_csv('./9_18AverageNet128_v1.sfp', 
                      names=['Channel', 'ch_x', 'ch_y', 'ch_z'],
                      sep='\t')

# convoluted unpacking from yaml
rois_all = {k: v for d in config['rois'] for k, v in d.items()}
for roi, chs in rois_all.items():
    rois_all[roi]= [c.split(', ') for c in chs][0]

rois = {roi:rois_all[roi]  for roi in component_p['rois']}

## Conditions and Contrasts of Interest

In [None]:
conditions = list(event_id.values())

contrasts = {'CS-FF':['ConsonantString', 'FalseFont'],
             'RW-FF':['RealWord', 'FalseFont'],
             'PW-FF':['PseudoWord', 'FalseFont'],
             'NW-FF':['NovelWord', 'FalseFont'],
             'RW-CS':['RealWord', 'ConsonantString'],
             'PW-CS':['PseudoWord', 'ConsonantString'],
             'NW-CS':['NovelWord', 'ConsonantString'],
             'RW-PW':['RealWord', 'PseudoWord'],
             'RW-NW':['RealWord', 'NovelWord'],
             'NW-PW':['NovelWord', 'PseudoWord']
             }
contr_order = list(contrasts.keys())

## Paths

In [None]:
source_path = op.join(bids_root, 'derivatives', 'erp_preprocessing')

derivatives_path = op.join(bids_root, 'derivatives', 'erp_measurement', component)
if Path(derivatives_path).exists() == False:
    Path(derivatives_path).mkdir(parents=True)

out_path = op.join(derivatives_path, 'data')
if Path(out_path).exists() == False:
    Path(out_path).mkdir(parents=True)

fig_path = op.join(derivatives_path, 'figures')
if Path(fig_path).exists() == False:
    Path(fig_path).mkdir(parents=True) 
    
epochs_suffix = '-epo.fif'
group_stem = op.join(out_path, 'participants_')
pointplot_stem = op.join(fig_path, 'pointplot_')
swarmplot_stem = op.join(fig_path, 'swarmplot_')

## Figure settings

In [None]:
sns.set_palette('colorblind')
sns.set_style('white')
sns.set_context('talk')

colors = {'FalseFont':sns.color_palette('colorblind')[0], 
          'ConsonantString':sns.color_palette('colorblind')[1], 
          'PseudoWord':sns.color_palette('colorblind')[2], 
          'NovelWord':sns.color_palette('colorblind')[3], 
          'RealWord':sns.color_palette('colorblind')[4]}

contr_colors = {'CS-FF':colors['ConsonantString'],
                'RW-FF':colors['RealWord'],
                'PW-FF':colors['PseudoWord'],
                'NW-FF':colors['NovelWord'],
                'RW-CS':colors['RealWord'],
                'PW-CS':colors['PseudoWord'],
                'NW-CS':colors['NovelWord'],
                'RW-PW':sns.color_palette('colorblind')[8],
                'RW-NW':sns.color_palette('colorblind')[9],
                'NW-PW':sns.color_palette('colorblind')[6]
                }

fig_format = 'pdf'

## Subject list

In [None]:
prefix = 'sub-'
subjects = sorted([s[-7:] for s in glob(source_path + '/' + prefix + '*')])

---
# Read in the data

When we read the data, we also crop the epochs as specified above, and time-shift the event onsets to match true stimulus timing

In [None]:
epochs = {}
for subject in subjects:
    subj_path = op.join(derivatives_path, subject, data_type)
    f = op.join(out_path,  subject + '_task-' + task_name + '_desc-preproc' + epochs_suffix)
    print(f)
    epochs[subject] = mne.read_epochs(f,
                                      verbose=None, 
                                      preload=True)
    # correct for stimulus presentation delay
    epochs[subject]._raw_times = epochs[subject]._raw_times - tshift
    epochs[subject]._times_readonly = epochs[subject]._times_readonly - tshift
    
    epochs[subject].crop(tmin=t_min, tmax=t_max).apply_baseline(baseline)
    
    epochs[subject].set_montage(montage_fname)

---
# Compute single-trial measurements



In [None]:
%%time

df_list = []

for subj in subjects:
    for cond in conditions:
        for roi, chans in rois_all.items():                
            if component_meas == 'meana':
                peak = np.array([np.nan, 
                                 np.median([component_p['tw_range'][0], 
                                            component_p['tw_range'][1]]
                                          ), 
                                 np.nan])
            else:
                # find peak amplitude in specified timewindow, among channels in ROI(s) of interest
                tmp_dat = epochs[subj][cond].average().pick_channels(chans)
                try:
                    # try to find most negative/positive peak
                    peak = tmp_dat.get_peak(tmin=component_p['tw_range'][0],
                                            tmax=component_p['tw_range'][1], 
                                            mode=component_p['component_meas'],
                                        )  
                except:
                    # peak finding will fail if, eg, looking for peak negativity but all values in tw are positive
                    # in this case, take mean amplitude around group peak
                    tw_start = component_p['peak_lat'] - component_p['tw_width']
                    tw_end   = component_p['peak_lat'] + component_p['tw_width']
                    peak = np.array([np.nan, 
                                     np.median([tw_start, tw_end]), 
                                     np.nan])                   

 
            # define time window for averaging, centred on peak
            peak_window = ((peak[1] - (component_p['tw_width'] / 2)), 
                           (peak[1] + (component_p['tw_width'] / 2)))
            idx_start, idx_stop = np.searchsorted(epochs[subj][cond].times, peak_window)

            # Get individual trial measurements centred on print tuning peak
            df_list.append(pd.concat([pd.DataFrame({'participant_id': subj, 
                                                   'Component':component,
                                                   'Trial_Time':np.repeat(epochs[subj][cond].events[:,0], len(chans)),
                                                   'Condition':cond,
                                                    'ROI':roi,
                                                    'Peak.Chan':peak[0],
                                                    'Peak.Lat':peak[1],
                                                   'Channel':np.tile(chans, epochs[subj][cond].selection.shape),
                                                   }),
                                    pd.DataFrame(epochs[subj][cond].copy().get_data(picks=chans)[:, :, idx_start:idx_stop].mean(axis=-1).flatten() * 10e5,
                                                 columns=['Amplitude']), 
                                    ], axis=1))

# concatenate list of dataframes, and add x,y,z coordinates of channels                
df = pd.merge((pd.concat(df_list, ignore_index=True)), montage, how='left', on='Channel')

In [None]:
df.sample(24)

## EDA

### Boxplot of raw data values

In [None]:
ax = sns.catplot(y='Amplitude', x='Condition', hue='Condition',
                 kind='box', 
                 data=df[df['Component'] == component]
                )
ax.set_xticklabels(rotation = 20)

plt.show()

## Remove Outliers

Remove individual data points based on z threshold. Compute separately for each subject and component.

In [None]:
z_thresh = config['outlier_thresh'] # cutoff for defining outliers, in SD

# Compute standard (z) scores 
df['Peak.Ampl.z'] = df.loc[:, ['participant_id', 
                               'Component', 
                               'Amplitude']].groupby(['participant_id', 
                                                      'Component']).transform(zscore)

len_orig = len(df)

# Drop outliers based on z_thresh
df = df[(df['Peak.Ampl.z'] >= -z_thresh) & (df['Peak.Ampl.z'] <= z_thresh)]

n_dropped = len_orig - len(df)
print(str(round(((n_dropped / len_orig) * 100), 3)) + '% of data dropped as outliers based Peak.Amplitude z +/-' + str(z_thresh))

### Boxplots post-outlier removal

In [None]:
ax = sns.catplot(y='Amplitude', x='Condition', hue='Condition',
                 kind='box',  order=conditions,
                 data=df[df['Component'] == component]
                )
ax.set_xticklabels(rotation = 20)

plt.show()

### Export Trimmed Data For Analysis in R

In [None]:
# generate 1 file/subject because the aggregated file is big and creates issues eg pushing to GitHub
for subj in subjects:
    out_dir = op.join(out_path, subj)
    if Path(out_dir).exists() == False:
        Path(out_dir).mkdir(parents=True) 
    feather.write_feather(df[df['participant_id'] == subj], 
                          out_dir + '/' + subj + '_' + component + '_indiv_trials_trimmed_' + component_meas + '_' + str(round(component_p['tw_range'][0] * 1000)) + '-' + str(round(component_p['tw_range'][1] * 1000)) + '_' + subj + '.feather')

---
## Aggregate over trials and channels, within subjects and conditions
- To ensure proper CIs (between-subject variance) in plots below.
- Also select only ROIOIs

In [None]:
df_agg = df[df['ROI'].isin(rois)].groupby(['participant_id', 'Component', 'Condition', 'ROI']).mean().reset_index()
df_agg = df_agg.drop(columns=['Trial_Time', 'ch_x', 'ch_y', 'ch_z', 'Peak.Ampl.z'])
# write to file
feather.write_feather(df_agg, group_stem + 'trialavg.feather')
df_agg.sample(12)

## EDA - Aggregated Data

### Amplitude

In [None]:
ax = sns.catplot(y='Amplitude', x='ROI', hue='Condition',
                 kind='box',  #order=conditions,
                 data=df_agg[df_agg['Component'] == component],
                 aspect=1.5
                )
ax.set_xticklabels(rotation = 20)
plt.show()

In [None]:
ax = sns.catplot(y='Amplitude', x='Condition', #col='ROI',
                 kind='swarm',  hue='ROI', order=conditions, #hue_order=conditions,
                 dodge=True,
                 data=df_agg[df_agg['Component'] == component],
                 aspect=1.25
                )
ax.set_xticklabels(rotation = 30)
ax.savefig(swarmplot_stem + 'Amplitude' + '.' + fig_format)
plt.show()

In [None]:
sns.displot(kind='kde',
            x='Amplitude', col='Condition', row='ROI',
            fill=True, 
            data=df_agg
           )
plt.show()

In [None]:
sns.displot(kind='kde',
            x='Amplitude', col='Condition', row='ROI',
            fill=True, 
            data=df_agg
           )
plt.show()

#### Latency



In [None]:
ax = sns.catplot(y='Peak.Lat', x='Condition', col='ROI',
                 kind='swarm',  order=conditions, 
                 hue='participant_id', legend=False,
                 data=df_agg
                )
ax.set_xticklabels(rotation = 20)
plt.show()

In [None]:
sns.displot(kind='kde',
            x='Peak.Lat', col='Condition', row='ROI',
            fill=True, facet_kws={'sharey':False},
            rug=True, #rug_kws={'hue':'participant_id'},
            data=df_agg
           )
# sns.rugplot(data=df_agg[df_agg['Component'] == component], x='Peak_Lat', col='Condition', row='ROI',)
plt.show()

---
## Descriptives

### Amplitude

In [None]:
descriptives = df_agg.loc[:, ['Component', 'Condition',  'ROI', 'Amplitude']
                     ].groupby(['Component','Condition', 'ROI']
                              ).describe()
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
descriptives['Amplitude'][['mean', 'std']]

## Point Plots of Amplitude


In [None]:

ax = sns.catplot(kind='point',
                 data=df_agg.groupby(['participant_id', 'Condition']).mean().reset_index(),
               x='Condition', y='Amplitude', 
               join=False, dodge=True, order=conditions, hue='Condition', hue_order=conditions,
               height=4, aspect=1.5,
              )

ax.set_xticklabels(rotation = 20)

# Save images to files
ax.savefig(pointplot_stem + 'Amplitude' + '.' + fig_format)

plt.show()

In [None]:

ax = sns.catplot(kind='point',
                 data=df_agg.groupby(['participant_id', 'Condition', 'ROI']).mean().reset_index(),
               x='Condition', y='Amplitude', hue='ROI',
               join=False, dodge=.2, order=conditions, legend=False, #hue='Condition', hue_order=conditions,
               height=4, aspect=1.5,
              )

ax.set_xticklabels(rotation = 20)

# Save images to files
ax.savefig(pointplot_stem + 'Amplitude' + '.' + fig_format)

plt.show()

### Compare L-R 

In [None]:

sns.catplot(kind='point',
            data=df_agg.groupby(['participant_id', 'Condition', 'ROI']).mean().reset_index(),
            x='ROI', y='Amplitude',
            join=True, dodge=.333,  hue='Condition', hue_order=conditions,
            height=6, aspect=1.1
           )

plt.show()

### Laterality (Left-Right differences)

In [None]:
hemi_dat = df_agg.loc[:, ['Component',  'participant_id', 'Condition', 'ROI', 'Amplitude']].set_index(['Component',  'participant_id', 'Condition', 'ROI'])

lr_diff = pd.pivot_table(hemi_dat,
                       index=['Component', 'Condition',  'participant_id'],
                       columns=['ROI'], 
                       values='Amplitude').diff(axis=1).drop(columns='left').rename(mapper={'right':'L-R Diff'}, axis=1).reset_index()


In [None]:
ax = sns.catplot(kind='point',
            data=lr_diff,
            x='L-R Diff', y='Condition',  
             hue='Condition', hue_order=conditions, legend=False,
            join=False, dodge=True, order=conditions,
            height=4, aspect=2, 
      )
plt.axvline(0, color='k', linestyle='--')
plt.show()

## Latencies

### Descriptives

In [None]:
descriptives = df_agg.loc[:, ['Component', 'Condition',  'ROI', 'Peak.Lat']
                     ].groupby(['Component','Condition', 'ROI']
                              ).describe()
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
descriptives['Peak.Lat'] #[['mean', 'std']]

### Plot

In [None]:


ax = sns.catplot(kind='point',
                 data=df_agg.groupby(['participant_id', 'Condition', 'ROI']).mean().reset_index(), 
                x='Peak.Lat', y='Condition',  
             hue='Condition', hue_order=conditions, legend=False,
                 join=False,   order=conditions,
                  height=4, aspect=2,
             )

# Save images to files
ax.savefig(pointplot_stem + 'PeakLat_' + '.' + fig_format)

plt.show()

---
# Contrasts

For contrasts, we find the peak of the *difference* between conditions within our time window.

This is achived by pivoting the DataFrame to wide format (columns for each condition containing amplitude values), computing between-column differences, then stacking back to a long-format DataFrame.

In [None]:
df_contr = df[df['ROI'].isin(rois)].groupby(['participant_id', 'Component', 'Condition', 'Channel', 'ROI'])['Amplitude'].mean().reset_index()
df_contr = df_contr.pivot(index=['participant_id', 'ROI', 'Channel'], columns=['Condition'], values=['Amplitude'])
df_contr.columns = df_contr.columns.droplevel() # removes gratuitious extra level of column index
# compute contrasts
for contr, conds in contrasts.items():
    df_contr[contr] = df_contr[conds[0]] - df_contr[conds[1]]
# remove conditions, leave only contrasts
df_contr = df_contr.drop(columns=conditions)
# stack (melt) back to long format, name columns properly
df_contr = df_contr.stack().rename('Amplitude').reset_index().rename(columns={'Condition':'Contrast'})
# df_contr.sample(16)

In [None]:
df_contr.sample(16)

### Export Data For Analysis in R

In [None]:
feather.write_feather(df_contr, group_stem + 'trialavg_contr.feather')

### Descriptives

In [None]:
descriptives = df_contr.loc[:, ['Contrast', 'ROI', 'Amplitude']
                     ].groupby(['Contrast', 'ROI']
                              )['Amplitude'].describe()
descriptives

## Aggregate over channels within subject/contrast/ROI

In [None]:
df_contr_avg = df_contr.groupby(['participant_id', 'Contrast', 'ROI']).mean().reset_index()
df_contr_avg.head(12)

In [None]:
ax = sns.catplot(kind='box',
                 data=df_contr_avg.groupby(['participant_id', 'Contrast']).mean().reset_index(),
                 x='Contrast',  y='Amplitude',
                 order=contr_order, hue_order=contr_order,
                 palette=contr_colors,
                 height=6, aspect=2, 
                )
ax.set_xticklabels(rotation = 30)

plt.show()

In [None]:
ax = sns.catplot(kind='point',
            x='Contrast',  y='Amplitude', hue='Contrast', join=False,
            order=contr_order, hue_order=contr_order,
            data=df_contr_avg.groupby(['participant_id', 'Contrast']).mean().reset_index(),
            height=4, aspect=1.75, 
      )

plt.axhline(0, color='k', linestyle='--')

ax.set_xticklabels(rotation = 30)
# Save images to files
ax.savefig(pointplot_stem + 'Amplitude' + '_contr-all.' + fig_format)

plt.show()

### Laterality (Left-Right differences)

In [None]:
hemi_dat = df_contr_avg.loc[:, ['participant_id', 'Contrast', 'ROI', 'Amplitude']].set_index(['participant_id', 'Contrast', 'ROI'])

df_lr_pt = pd.pivot_table(hemi_dat,
                       index=['Contrast',  'participant_id'],
                       columns=['ROI'], 
                       values='Amplitude')
df_lr_pt['L-R'] = df_lr_pt['left'] - df_lr_pt['right']


ax = sns.catplot(kind='point',
            data=df_lr_pt.reset_index(),
            x='L-R', y='Contrast',  
             hue='Contrast', legend=False,
             order=contr_order, hue_order=contr_order,

            join=False, dodge=True, 
            height=5, aspect=2, 
      )
plt.axvline(0, color='k', linestyle='--')
plt.show()

ax.savefig(pointplot_stem + 'Amplitude' + '_contr-all_L-R.' + fig_format)
plt.show()