# Plot scatterplots + regression line relating ERP to behavioural data

---
Copyright 2022 [Aaron J Newman](https://github.com/aaronjnewman), [NeuroCognitive Imaging Lab](http://ncil.science), [Dalhousie University](https://dal.ca)

Released under the [The 3-Clause BSD License](https://opensource.org/licenses/BSD-3-Clause)

---

## Which component to analyze

In [None]:
component = 'n170'

## Initialization

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import pyarrow.feather as feather
from scipy.stats import zscore
import os.path as op
from glob import glob
from pathlib import Path
import yaml
from yaml import CLoader as Loader
import mne
mne.set_log_level(verbose='error')

## Read Parameters from config.yml

Will import study-level parameters from `config.yml` in `bids_root`

In [None]:
# this shouldn't change if you run this script from its default location in code/import
bids_root = '../..'

cfg_file = op.join(bids_root, 'config.yml')
with open(cfg_file, 'r') as f:
    config = yaml.load(f, Loader=Loader)

study_name = config['Study']['Name']
study_name = config['Study']['TaskName']
data_type = config['data_type']
eog = {k: v for d in config['eog'] for k, v in d.items()}
montage_fname = config['montage_fname']
event_id = {k: v.pop() for d in config['events'] for k, v in d.items()}

n_jobs = config['Preprocessing']['n_jobs']

component_p = config['components'][component]
component_meas = component_p['component_meas']

## Time windows of interest

In [None]:
# Define the total length of the epoch 
# this can be less than what is in the input files; will use the crop function
t_min = float(component_p['t_min'])
t_max =  float(component_p['t_max'])
baseline = eval(component_p['baseline'])

# value obtained from butterfly plots visual inspection and added to config.yml
# peak_lat = {k: v for d in cfg['peak_lat'] for k, v in d.items()}
peak_lat = component_p['peak_lat']

# tw_width = {k: v for d in cfg['tw_width'] for k, v in d.items()}
tw_width = component_p['tw_width']

# Amount of time to shift event codes by, based on empirical testing with photocell
#  to determine lag between event code and actual stimulus appearance on screen.
#  For reasons that are unclear, we need to double this value (in sec) to get the correct shift
tshift = config['t_shift']

## Define ROIs
clusters of electrodes to average over for waveform plots

In [None]:
# convoluted unpacking from yaml
rois_all = {k: v for d in config['rois'] for k, v in d.items()}
for roi, chs in rois_all.items():
    rois_all[roi]= [c.split(', ') for c in chs][0]

rois = {roi:rois_all[roi]  for roi in component_p['rois']}

## Conditions and Contrasts of Interest

In [None]:
conditions = list(event_id.values())

contrasts = {'CS-FF':['ConsonantString', 'FalseFont'],
             'RW-FF':['RealWord', 'FalseFont'],
             'PW-FF':['PseudoWord', 'FalseFont'],
             'NW-FF':['NovelWord', 'FalseFont'],
             'RW-CS':['RealWord', 'ConsonantString'],
             'PW-CS':['PseudoWord', 'ConsonantString'],
             'NW-CS':['NovelWord', 'ConsonantString'],
             'RW-PW':['RealWord', 'PseudoWord'],
             'RW-NW':['RealWord', 'NovelWord'],
             'NW-PW':['NovelWord', 'PseudoWord']
             }
contr_order = list(contrasts.keys())

## Paths

In [None]:
raw_path = op.join(bids_root)
source_path = op.join(bids_root, 'derivatives', 'erp_preprocessing')

data_path = op.join(bids_root, 'derivatives', 'behavioral_demographic', 'data')
    
derivatives_path = op.join(bids_root, 'derivatives', 'erp_lme', component)
if Path(derivatives_path).exists() == False:
    Path(derivatives_path).mkdir(parents=True)

out_path = op.join(derivatives_path, 'data')
if Path(out_path).exists() == False:
    Path(out_path).mkdir(parents=True)

fig_path = op.join(derivatives_path, 'figures')
if Path(fig_path).exists() == False:
    Path(fig_path).mkdir(parents=True) 
    
epochs_suffix = '-epo.fif'
group_stem = op.join(out_path, 'participants_')
scatterplot_stem = op.join(fig_path, 'regr_plot_' + component + '_')
swarmplot_stem = op.join(fig_path, 'swarmplot_')

## Figure settings

In [None]:
sns.set_palette('colorblind')
sns.set_style('white')
sns.set_context('talk')

colors = {'FalseFont':sns.color_palette('colorblind')[0], 
          'ConsonantString':sns.color_palette('colorblind')[1], 
          'PseudoWord':sns.color_palette('colorblind')[2], 
          'NovelWord':sns.color_palette('colorblind')[3], 
          'RealWord':sns.color_palette('colorblind')[4]}

contr_colors = {'CS-FF':colors['ConsonantString'],
                'RW-FF':colors['RealWord'],
                'PW-FF':colors['PseudoWord'],
                'NW-FF':colors['NovelWord'],
                'RW-CS':colors['RealWord'],
                'PW-CS':colors['PseudoWord'],
                'NW-CS':colors['NovelWord'],
                'RW-PW':sns.color_palette('colorblind')[8],
                'RW-NW':sns.color_palette('colorblind')[9],
                'NW-PW':sns.color_palette('colorblind')[6]
                }

fig_format = 'pdf'

## Subject list

In [None]:
prefix = 'sub-'
subjects = sorted([s[-7:] for s in glob(source_path + '/' + prefix + '*')])

---
# Read ERP data

When we read the data, we also crop the epochs as specified above, and time-shift the event onsets to match true stimulus timing

In [None]:
epochs = {}
for subject in subjects:
    subj_path = op.join(source_path, subject, 'eeg')
    epochs[subject] = mne.read_epochs(str(subj_path + '/' + subject + '_task-' + task + epochs_suffix),
                                         verbose=None, 
                                         preload=True)
    # correct for stimulus presentation delay
    epochs[subject]._raw_times = epochs[subject]._raw_times - tshift
    epochs[subject]._times_readonly = epochs[subject]._times_readonly - tshift
    
    epochs[subject].crop(tmin=t_min, tmax=t_max).apply_baseline(baseline)
    
    epochs[subject].set_montage(montage_fname)


---
# Compute single-trial measurements



In [None]:
%%time

df_list = []

for subj in subjects:
    for cond in conditions:
        for roi, chans in rois_all.items():                
            if component_meas == 'meana':
                peak = np.array([np.nan, 
                                 np.median([component_p['tw_range'][0], 
                                            component_p['tw_range'][1]]
                                          ), 
                                 np.nan])
            else:
                # find peak amplitude in specified timewindow, among channels in ROI(s) of interest
                tmp_dat = epochs[subj][cond].average().pick_channels(chans)
                peak = tmp_dat.get_peak(tmin=component_p['tw_range'][0],
                                        tmax=component_p['tw_range'][1], 
                                        mode=component_p['component_meas'],
                                       )  

            # define time window for averaging, centred on peak
            peak_window = ((peak[1] - (component_p['tw_width'] / 2)), 
                           (peak[1] + (component_p['tw_width'] / 2)))
            idx_start, idx_stop = np.searchsorted(epochs[subj][cond].times, peak_window)

            # Get individual trial measurements centred on print tuning peak
            df_list.append(pd.concat([pd.DataFrame({'participant_id': subj, 
                                                   'Component':component,
                                                   'Trial_Time':np.repeat(epochs[subj][cond].events[:,0], len(chans)),
                                                   'Condition':cond,
                                                    'ROI':roi,
                                                    'Peak.Chan':peak[0],
                                                    'Peak.Lat':peak[1],
                                                   'Channel':np.tile(chans, epochs[subj][cond].selection.shape),
                                                   }),
                                    pd.DataFrame(epochs[subj][cond].copy().get_data(picks=chans)[:, :, idx_start:idx_stop].mean(axis=-1).flatten() * 10e5,
                                                 columns=['Amplitude']), 
                                    ], axis=1))

# concatenate list of dataframes, and add x,y,z coordinates of channels                
df = pd.concat(df_list, ignore_index=True)

## Remove Outliers

Remove individual data points <> 2 SD. Compute separately for each subject and component.

In [None]:
z_thresh = config['outlier_thresh'] # cutoff for defining outliers, in SD

# Compute standard (z) scores 
df['Peak.Ampl.z'] = df.loc[:, ['participant_id', 
                               'Component', 
                               'Amplitude']].groupby(['participant_id', 
                                                      'Component']).transform(zscore)

len_orig = len(df)

# Drop outliers based on z_thresh
df = df[(df['Peak.Ampl.z'] >= -z_thresh) & (df['Peak.Ampl.z'] <= z_thresh)]

n_dropped = len_orig - len(df)
print(str(round(((n_dropped / len_orig) * 100), 3)) + '% of data dropped as outliers based Peak.Amplitude z +/-' + str(z_thresh))

---
## Aggregate over trials and channels, within subjects and conditions
- To ensure proper CIs (between-subject variance) in plots below.
- Also select only ROIOIs

In [None]:
df_agg = df[df['ROI'].isin(rois)].groupby(['participant_id', 'Component', 'Condition', 'ROI'])['Amplitude'].mean().reset_index()

# write to file
feather.write_feather(df_agg, group_stem + 'trialavg.feather')
df_agg.sample(12)

## Read demographic & standardized test data

### Need to rename rt and acc columns

In [None]:
dfB = pd.concat([pd.read_csv(op.join(raw_path, 'participants.tsv'), sep='\t', index_col=0), 
                pd.read_csv(op.join(raw_path, 'participants_std_tests.tsv'), sep='\t', index_col=0),
                pd.read_csv(op.join(data_path, 'ldt_acc_by_subj_wide.tsv'), sep='\t', index_col=0, 
            header=0, names=['acc_CNST', 'acc_FFNT', 'acc_NVWD', 'acc_PSWD', 'acc_RLWD']),
                pd.read_csv(op.join(data_path, 'ldt_rt_trimmed_by_subj_wide.tsv'), sep='\t', index_col=0, 
            header=0, names=['rt_CNST', 'rt_FFNT', 'rt_NVWD', 'rt_PSWD', 'rt_RLWD']),
                pd.read_csv(op.join(data_path, 'ldt_sdt_data.tsv'), sep='\t', index_col=0)
                ],
                axis=1)           

dfB.head()

In [None]:
dfA = df_agg.join(dfB, on='participant_id', how='outer')

dfA.sample(12)

In [None]:
dfA.columns

## Plot relationships between ERP component amplitude and behavioural measures

### Novel Word accuracy

In [None]:
beh_var = 'Ortho_Choice'
ax = sns.lmplot(data=dfA,
              x=beh_var, y="Amplitude", 
              col='ROI', row='Condition', row_order=conditions,
              ci=95,
            scatter_kws={'s': 33}
             )

ax.savefig(scatterplot_stem + beh_var + '.' + fig_format)

plt.show()

In [None]:
sns.set_context('paper')
beh_var = 'acc_NVWD'
ax = sns.lmplot(data=dfA,
              x=beh_var, y="Amplitude", 
              col='ROI', hue='Condition', hue_order=conditions,
              ci=95,
            scatter_kws={'s': 33}
             )
# ax.set(ylim=(-17, 0))

ax.savefig(scatterplot_stem + beh_var + '_combined.' + fig_format)

plt.show()