# ERP Visualization 
## Butterfly plots and electrode montage

Some general-purpose plots.

---
Copyright 2023 [Aaron J Newman](https://github.com/aaronjnewman), [NeuroCognitive Imaging Lab](http://ncil.science), [Dalhousie University](https://dal.ca)

Released under the [The 3-Clause BSD License](https://opensource.org/licenses/BSD-3-Clause)

---

## Load in the necessary libraries/packages we'll need

In [None]:
import pandas as pd
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import os.path as op
from glob import glob
from pathlib import Path
import yaml
from yaml import CLoader as Loader

import mne
mne.set_log_level('error')

## Read Parameters from config.yml

Will import study-level parameters from `config.yml` in `bids_root`

In [None]:
component = 'n170' 

In [None]:
# this shouldn't change if you run this script from its default location in code/import
bids_root = '../..'

cfg_file = op.join(bids_root, 'config.yml')
with open(cfg_file, 'r') as f:
    config = yaml.load(f, Loader=Loader)

study_name = config['study_name']
task = config['task']
data_type = config['data_type']
eog = {k: v for d in config['eog'] for k, v in d.items()}
montage_fname = config['montage_fname']
event_id = {k: v.pop() for d in config['events'] for k, v in d.items()}

n_jobs = config['preprocessing_settings']['n_jobs']

component_p = config['components'][component]
component_meas = component_p['component_meas']


## Time windows of interest

In [None]:
# Define the total length of the epoch 
# this can be less than what is in the input files; will use the crop function
t_min = float(component_p['t_min'])
t_max =  float(component_p['t_max'])
baseline = eval(component_p['baseline'])

# value obtained from butterfly plots visual inspection and added to config.yml
peak_lat = component_p['peak_lat']

tw_width = component_p['tw_width']

# Amount of time to shift event codes by, based on empirical testing with photocell
#  to determine lag between event code and actual stimulus appearance on screen.
#  For reasons that are unclear, we need to double this value (in sec) to get the correct shift
tshift = config['t_shift']

## Conditions and Contrasts of Interest

In [None]:
conditions = list(event_id.values())

contrasts = {'CS-FF':['ConsonantString', 'FalseFont'],
             'RW-FF':['RealWord', 'FalseFont'],
             'PW-FF':['PseudoWord', 'FalseFont'],
             'NW-FF':['NovelWord', 'FalseFont'],
             'RW-CS':['RealWord', 'ConsonantString'],
             'PW-CS':['PseudoWord', 'ConsonantString'],
             'NW-CS':['NovelWord', 'ConsonantString'],
             'RW-PW':['RealWord', 'PseudoWord'],
             'RW-NW':['RealWord', 'NovelWord'],
             'NW-PW':['NovelWord', 'PseudoWord']
             }

## Paths

In [None]:
source_path = op.join(bids_root, 'derivatives', 'erp_preprocessing')

derivatives_path = op.join(bids_root, 'derivatives', 'erp_visualization/full_tw')
if Path(derivatives_path).exists() == False:
    Path(derivatives_path).mkdir(parents=True)

fig_path = op.join(derivatives_path, 'figures')
if Path(fig_path).exists() == False:
    Path(fig_path).mkdir(parents=True) 
    
epochs_suffix = '-epo.fif'

## Figure settings

In [None]:
colors = {'FalseFont':sns.color_palette('colorblind')[0], 
          'ConsonantString':sns.color_palette('colorblind')[1], 
          'PseudoWord':sns.color_palette('colorblind')[2], 
          'NovelWord':sns.color_palette('colorblind')[3], 
          'RealWord':sns.color_palette('colorblind')[4]}

linestyles = {'FalseFont':'-', 
              'ConsonantString': ':', 
              'PseudoWord':'-.', 
              'NovelWord':'--', 
              'RealWord':'-'}

# For big arrays of waveplots
waveplot_figsize = (18, 6)
fig_format = 'pdf'

jointplot_stem = fig_path + '/jointplot_avgref_'  
waveplot_stem = fig_path + '/' + component + '_waveforms_full_tw_'
topoplot_stem = fig_path + '/' + component + '_topoplot_full_tw_'

## Subject list

In [None]:
prefix = 'sub-'
subjects = sorted([s[-7:] for s in glob(source_path + '/' + prefix + '*')])

---
# Read in the data

When we read the data, we also crop the epochs as specified above, and time-shift the event onsets to match true stimulus timing

In [None]:
epochs = {}
for subject in subjects:
    subj_path = op.join(source_path, subject, 'eeg')
    epochs[subject] = mne.read_epochs(str(subj_path + '/' + subject + '_task-' + task + '-epo.fif'),
                                         verbose=None, 
                                         preload=True)
    # correct for stimulus presentation delay
    epochs[subject]._raw_times = epochs[subject]._raw_times - tshift
    epochs[subject]._times_readonly = epochs[subject]._times_readonly - tshift
    
    epochs[subject].crop(tmin=t_min, tmax=t_max).apply_baseline(baseline)
    
    epochs[subject].set_montage(montage_fname)
    # epochs[subject].drop_channels(drop_chs)

## Create Evokeds

In [None]:
evoked = {}
for cond in conditions:
    evoked[cond] = [epochs[subject][cond].average() for subject in subjects]

## Grand Averages

In [None]:
gavg = {}
for cond in conditions:
    gavg[cond] = mne.grand_average(evoked[cond])

### average across conditions

In [None]:
evoked_all = [epochs[subject].average() for subject in subjects]
gavg_all = mne.grand_average(evoked_all)

## Compute Contrasts
Differences between pairs of conditions

In [None]:
evoked_diff = {}

for contr, conds in contrasts.items():
    evoked_diff[contr] = [mne.combine_evoked([ c1, c2],
                                                weights=[1, -1])
                             for (c1, c2) in zip(evoked[conds[0]], evoked[conds[1]])
                            ]

## Plot montage
In this script because it's generic

In [None]:
matplotlib.rcParams['figure.dpi'] = 72
mne.viz.plot_sensors(gavg[conditions[0]].info, 
                     show_names=True,
                     pointsize=10,
                     ).savefig(fig_path + '/montage.' + fig_format)
plt.show()


In [None]:
matplotlib.rcParams['figure.dpi'] = 72
mne.viz.plot_sensors(gavg[conditions[0]].info, 
                     show_names=False,
                     pointsize=125,
                     ).savefig(fig_path + '/montage_nolabels.' + fig_format)
plt.show()


## Define ROIs
clusters of electrodes to average over for waveform plots

In [None]:
# convoluted unpacking from yaml
rois_all = {k: v for d in config['rois'] for k, v in d.items()}
for roi, chs in rois_all.items():
    rois_all[roi]= [c.split(', ') for c in chs][0]

rois = {roi:rois_all[roi]  for roi in component_p['rois']}

#### Create mask identifying ROI electrodes

In [None]:
chs = pd.Series(gavg[conditions[0]].ch_names)

roi_elec = [i for c in rois.values() for i in c ]
mask = chs.isin(roi_elec).to_numpy()
num_tp = gavg[conditions[0]].data.shape[1]
mask = np.repeat(mask[:, np.newaxis], num_tp, axis=1)


## Butterfly plots

Times for topo maps were hard-coded based on first running the plots without specifying peak times. Times from the automatic peak finding were then used in a way that is consistent across all conditions.

In [None]:
uv_range = 13
ylim = 13
plt.close()
gavg_all.plot_joint(title=('Average across all conditions'), 
                            ts_args={'hline':[0], 
                                    'ylim':{'eeg':[-ylim, ylim]}
                                    },
                            topomap_args={'sensors':False, 'contours':False, 
                                        'vmin':-uv_range, 'vmax':uv_range},
                            
                        ) #.savefig(jointplot_stem + cond + '.' + fig_format)    
plt.show() 

### Set peak times for ea condition based on above

In [None]:
peak_times = [.106, .190, .306, .428]

In [None]:
uv_range = 13
ylim = 13

for cond in conditions:
    gavg[cond].plot_joint(times=peak_times,
                             title=(cond), 
                             ts_args={'hline':[0], 
                                      'ylim':{'eeg':[-ylim, ylim]}
                                      },
                             topomap_args={'sensors':False, 'contours':False, 
                                           'vmin':-uv_range, 'vmax':uv_range}
                            ).savefig(jointplot_stem + cond + '.' + fig_format)     

## Waveforms

In [None]:
ylim = {'eeg':[-10.5, 10]}

fig, axs = plt.subplots(1, 2, figsize=waveplot_figsize)    
ax = 0
for roi, chans in rois.items():
    if ax == 0:
        show=False
    else:
        show=True

    mne.viz.plot_compare_evokeds({c:evoked[c] for c in conditions},
                                 picks=chans, combine='mean',
                                 title=(roi),
                                 colors=colors, linestyles=linestyles,
                                 ylim=ylim,
                                 show_sensors='upper right', legend='lower right', 
                                 ci=False,
                                 axes=axs[ax], show=show
                                );
    ax += 1

# Save images to files
fig.savefig(waveplot_stem + 'allcond_legend.' + fig_format)

plt.show()

---
# Visualization

## Topo map time series for each condition

In [None]:
times = np.arange(t_min, t_max, .100)

uv_range = 10


for cond in conditions:
    gavg[cond].plot_topomap(times, average=tw_width,
                            ch_type='eeg', 
                            show_names=False, sensors=False, contours=False, 
                            colorbar=True, 
                            vmin=-uv_range, vmax=uv_range,
                            title=(cond),
                            mask=mask,
                            mask_params=dict(marker='o', 
                                             markerfacecolor='w', 
                                             markeredgecolor='k',
                                             linewidth=0, 
                                             markersize=6)
                           )


## topoplots of contrasts

In [None]:
times = np.arange(.090, t_max, .100)

uv_range = 5

for contr in contrasts:
     mne.grand_average(evoked_diff[contr]).plot_topomap(times, average=tw_width,
                                                       ch_type='eeg', 
                                                       show_names=False, sensors=False, contours=False, 
                                                       colorbar=True, 
                                                       vmin=-uv_range, vmax=uv_range,
                                                        mask=mask,
                                                       title=(contr)
                                                      ).savefig(topoplot_stem + 'ts_' + contr + '.' + fig_format); 

## Plot at smaller microVolt range, to easier see effects smaller than contrasts with FF

In [None]:
times = np.arange(.090, t_max, .100)

uv_range = 2

contr_no_ff = ['RW-CS', 'PW-CS', 'NW-CS', 'RW-PW', 'RW-NW', 'NW-PW']

for contr in contr_no_ff:
     mne.grand_average(evoked_diff[contr]).plot_topomap(times, average=tw_width,
                                                       ch_type='eeg', 
                                                       show_names=False, sensors=False, contours=False, 
                                                       colorbar=True, 
                                                       vmin=-uv_range, vmax=uv_range,
                                                        mask=mask,
                                                       title=(contr)
                                                      ).savefig(topoplot_stem + 'ts_' + contr + '.' + fig_format); 