# EEG-ERP Preprocessing 
## Batch script

This script will run on all subjects in `rawdata`. Steps performed include:
- filtering
- epoching
- mark bad segments of data (trials/channels) using AutoReject
- pass the marked data to ICA
- auto-idenitfy and remove ocular indepdent components 
- apply ICA to epoched data
- apply AutoReject to ICA-cleaned data
- rereference
- export clean epochs to `.fif` file in `derivatives/erp_preprocessing`
- save an MNE report with details/vizualizations of the above steps in `derivatives/erp_preprocessing/logs`

Most parameters that you would want to change are read from `config.yml`. 

It is recommended that you not apply any baseline correction at this stage (and the defult config.yml file reflects this). Baseline correction can be applied in subsequent scripts, but doing baseline correction here precludes later using baseline regression.

Bad channels and non-ocular independent components can be manually identified after this script is run the first time. These can be stored in additional config files saved in `rawdata`, and then this script can be re-run to have those excluded. In an ideal world with reasnably clean data this should not be necessary, but e.g., if there are broken electrodes, EGI data, data from children, etc. it may be necessary.

---
Copyright 2023 [Aaron J Newman](https://github.com/aaronjnewman), [NeuroCognitive Imaging Lab](http://ncil.science), [Dalhousie University](https://dal.ca)

Released under the [The 3-Clause BSD License](https://opensource.org/licenses/BSD-3-Clause)

---

In [None]:
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import os.path as op
from os import remove
from glob import glob
from pathlib import Path
import yaml
from yaml import CLoader as Loader
import mne
mne.set_log_level('error')
from mne_bids import BIDSPath, read_raw_bids
from scipy.stats import zscore
from autoreject import Ransac, get_rejection_threshold, AutoReject
from time import time

## Read Parameters from config.yml

Will import study-level parameters from `config.yml` in `bids_root`

In [None]:
# this shouldn't change if you run this script from its default location in code/import
bids_root = '../..'

cfg_file = op.join(bids_root, 'config.yml')
with open(cfg_file, 'r') as f:
    config = yaml.load(f, Loader=Loader)

study_name = config['study_name']
task = config['task']
data_type = config['data_type']
eog = {k: v for d in config['eog'] for k, v in d.items()}
drop_ch = config['drop_ch']            
montage_fname = config['montage_fname']
event_id = {k: v.pop() for d in config['events'] for k, v in d.items()}

# fix per changes to config
n_jobs = config['preprocessing_settings']['n_jobs']
filt_p = {k: v for d in config['preprocessing_settings']['filter'] for k, v in d.items()}
ica_p = {k: v for d in config['preprocessing_settings']['ica'] for k, v in d.items()}
epoch_p = {k: v for d in config['preprocessing_settings']['epoch'] for k, v in d.items()}
reject = {k:eval(v)  for i in epoch_p['reject'] for k, v in i.items()}

### Paths

In [None]:
raw_path = op.join(bids_root, 'rawdata')

derivatives_path = op.join(bids_root, 'derivatives', 'erp_preprocessing_rej')
if Path(derivatives_path).exists() == False:
    Path(derivatives_path).mkdir(parents=True)

report_path = op.join(derivatives_path, 'logs')
if Path(report_path).exists() == False:
    Path(report_path).mkdir(parents=True)


epochs_suffix = '-epo.fif'

### Subject list

In [None]:
prefix = 'sub-'
subjects = sorted([s[-7:] for s in glob(raw_path + '/' + prefix + '*')])

## Read manually-marked independent components
Run this script once, inspect ICs, make decisions about whether any additional ICs should be added, or any automatically-removed ICs should be included.

Additional ICs to remove were selected based on:
- participants for whom more than 15% of trials were removred by AutoReject after ICA correct
- participants for whom the average across all trials and all electrodes did not show a clear pattern of P1-N1-P2 components
- for such participants, the scalp map and details of each IC were visuall inspected. Components were removed if they were
    - focal at a single electrode, or a very low number of electrodes
    - focal at the edges of the electrode montage
    - present on a low number of trials
    - showed no systematic pattern across trials that was time-locked to stimulus onset
    
ICs to add (un-remove) were selected based on:
- IC shows clear and consistent temporal dependency on stimulus onset
- IC appears to contain P1-N1-P2 complex, or part thereof
- IC has broad scalp distribution across many electrodes, characteristic of ERP component

In [None]:
cfg_file = op.join(raw_path + '/participants_manual_IC.yml')
with open(cfg_file, 'r') as f:
    ica_manual = yaml.load(f, Loader=Loader)

## Main Loop

In [None]:
rej_log_list = []

for subject in subjects:
    start_time = time()
    print('\n-------------------------')
    print('-------- ' + subject + ' --------')
    print('-------------------------')

    report = mne.Report(subject=subject, 
                        title=study_name + ' preprocessing: ' + subject,
                        verbose='WARNING')

    ### subject-specific paths
    in_path = BIDSPath(root=raw_path, 
                       subject=subject[-3:], 
                       datatype=data_type,
                       task=task
                      )    

    ### Import data
    raw = read_raw_bids(in_path)
    
    # manually-flagged bad channels
    # raw.info['bads'] = bad_ch[subject]
    
    # drop useless channels that mess with topomaps
    raw.drop_channels(drop_ch)

    # Create bipolar EOG channels
    raw = mne.set_bipolar_reference(raw.load_data(), 
                                    anode=[e[0] for e in eog.values()],
                                    cathode=[e[1] for e in eog.values()],
                                    ch_name=list(eog.keys()),
                                    ch_info=[{'kind':202} for i in range(len(eog.keys()))]
                                   )    
    
    ### Filtering
    # channel selection
    picks = mne.pick_types(raw.info, 
                           eeg=True,
                           eog=True
                          )
    
    ## Filter for ICA  
    raw_ica = raw.copy().filter(filt_p['l_freq_ica'], filt_p['h_freq'],
                                picks=picks,
                                n_jobs=n_jobs
                               )

    ## Filter for final
    raw.filter(filt_p['l_freq'], filt_p['h_freq'],
               picks=picks,
               n_jobs=n_jobs
              )
    
    ## Add raw to report
    report.add_raw(raw=raw, 
               psd=True, butterfly=True, 
               title='Raw data, bandpass filtered ' + str(filt_p['l_freq']) + '–' + str(filt_p['h_freq'])
              )

    ### Read events
    events, event_dict = mne.events_from_annotations(raw)

    event_dict_new = {}
    for key, value in event_dict.items():
        if key in event_id.keys():
            # rename events of experimental interest
            event_dict_new[event_id[key]] = value
        # else:
        #     # keep other events to create more epochs for ICA to fit on
        #     event_dict_new[key] = value                       
    
    ## Add events to report
    report.add_events(events, event_id=event_dict_new, 
                      sfreq=raw.info['sfreq'],
                      title='Events'
                     )
    
    ### Epoch data filtered for ICA
    epochs_ica = mne.Epochs(raw_ica,
                            events, event_dict_new,
                            epoch_p['tmin'], epoch_p['tmax'],
                            baseline=epoch_p['baseline'], detrend=epoch_p['detrend'],
                            reject=None, 
                            flat=epoch_p['flat'],
                            preload=True
                           )

    
    # use AutoReject to remove bad epochs, repair sensors and return clean epochs.
    ar = AutoReject(n_interpolate=[8, 16, 32, 64, 96],
                    random_state=ica_p['ica_random_state'],
                    picks=mne.pick_types(epochs_ica.info, eeg=True, eog=False),
                    n_jobs=n_jobs, 
                    verbose=False
                   )
    ar.fit(epochs_ica)
    # if subject == 'sub-045':
    #     ar.transform(epochs_ica)
    print('n_interpolate = ' +  str(ar.n_interpolate_['eeg']))
    reject_log = ar.get_reject_log(epochs_ica)
    fig = reject_log.plot('horizontal', show=False);
    report.add_figure(fig=fig, title='AutoReject log')

    ### Fit ICA
    ica = mne.preprocessing.ICA(method='fastica',
                                n_components=ica_p['n_components'],
                                random_state=ica_p['ica_random_state'],
                                max_iter='auto')
    
    ica.fit(epochs_ica[~reject_log.bad_epochs],  # added [~reject_log.bad_epochs] for AutoReject
            decim=3, 
            picks=['eeg']
            );

    # Identify ocular ICs
    # The default *z* threshold doesn't work for
    # all subjects. This routine starts with the default z (from config) and steps down
    # until at least n_max_eog EOG components are identified.
    # The limitations of this are that it assumes there will always be at least n_max_eog EOG
    # components (blinks are always present, but horizontal movements are not
    # always present), and may not work if there are > 3 components, if the
    # score of the third is > `z_step` less than the score of the second.
    # In practice, many of these components (with EGI data) may not be ocular, but are (hopefully) not EEG.
    # Be sure to check the reports and confirm no ERP components are rejected!

    ica.exclude = []
    num_excl = 0
    z_thresh = ica_p['ica_zthresh'] 
    z_step = ica_p['ica_zstep']

    while num_excl < ica_p['n_max_eog']:
        eog_indices, eog_scores = ica.find_bads_eog(epochs_ica, threshold=z_thresh)
        num_excl = len(eog_indices)
        z_thresh -= z_step # won't impact things if num_excl is ≥ n_max_eog 

    ica.exclude = eog_indices
    z_thresh_final = round(z_thresh + z_step, 2)

    # Manual removal/re-addition of ICs based on visual inspection
    if subject in ica_manual:
        if 'add_ics' in ica_manual[subject]:
            for ic in ica_manual[subject]['add_ics']:
                ica.exclude.append(ic)
        if 'rm_ics' in ica_manual[subject]:
            for ic in ica_manual[subject]['rm_ics']:
                ica.exclude.remove(ic)         

    # Create average of EOG events
    eog_evoked = mne.preprocessing.create_eog_epochs(raw_ica).average().apply_baseline(baseline=(None, epoch_p['tmin']))

    ## Add ICA to report
    report.add_ica(ica=ica, title='ICA', inst=epochs_ica,
                   eog_evoked=eog_evoked, 
                   eog_scores=eog_scores,
                   n_jobs=n_jobs
                  )
    
    ### Segment filtered raw data into epochs for final analysis
    epochs = mne.Epochs(raw,
                        events, event_dict_new,
                        epoch_p['tmin'], epoch_p['tmax'],
                        baseline=epoch_p['baseline'], detrend=epoch_p['detrend'],
                        reject=reject, 
                        flat=epoch_p['flat'],
                        preload=True
                       )

    ### Apply ICA correction to epochs
    ica.apply(epochs)
    

    ### Apply AutoReject to further clean epochs
    ar = AutoReject(n_interpolate=[8, 16, 32, 64, 96],
                    random_state=ica_p['ica_random_state'],
                    picks=mne.pick_types(epochs_ica.info, eeg=True, eog=False),
                    n_jobs=n_jobs, 
                    verbose=False
                   )
    epochs_clean, reject_log = ar.fit_transform(epochs, return_log=True)
    fig = reject_log.plot('horizontal', show=False)
    report.add_figure(fig=fig, title='AutoReject log')

    ### Re-reference, now that channels are cleaned
    epochs_clean.set_eeg_reference(ref_channels='average');
    
    ### Save cleaned epochs
    out_path = BIDSPath(root=derivatives_path, 
                       subject=subject[-3:], 
                       datatype=data_type,
                       task=task
                      )    
    # remove old fif file if it exists, and update bids_path
    if str(out_path.fpath)[-len(epochs_suffix):] == epochs_suffix:
        remove(out_path.fpath)
        out_path = BIDSPath(root=derivatives_path, 
                   subject=subject[-3:], 
                   datatype=data_type,
                   task=task
                  )    
    # save the file
    epochs_clean.save(str(out_path.fpath) + epochs_suffix, 
                      overwrite=True)

    # add epochs to report
    report.add_epochs(epochs_clean,  
                      title='Epochs'
                     )
    ### Save plot of average across all trials
    fig = epochs_clean.copy().average().plot(spatial_colors=True, 
                                      show=False);
    
    # add figures to report
    report.add_figure(fig=fig, title='Grand average over all epochs')
    plt.close(fig)
    
    # Add plots of average of each condition
    for condition in event_id.values():
        fig = epochs_clean[condition].copy().average().plot(spatial_colors=True, 
                                                            show=False);
        report.add_figure(fig=fig, title=condition)
        plt.close(fig)
    
    proc_time = time() - start_time
    print('Total processing time: ', proc_time)

    ### Report on how much was rejected
    rm_epochs = epochs.selection.shape[0] - epochs_clean.selection.shape[0]
    pct_epochs = rm_epochs / epochs.selection.shape[0] * 100   
    rej_log_list.append(pd.DataFrame({'id':subject, 
                                      'cpu_time':proc_time,
                                      'ntrials_rej':rm_epochs,
                                      '%t_rej':round(pct_epochs, 2),
                                      'ic_rm':len(ica.exclude),
                                      'n_interp':str(ar.n_interpolate_['eeg'])
                                     }, index=[0]
                                    )
                       )

    # Save report to file
    report_name = report_path + '/' + subject + '.html'
    report.save(report_name, overwrite=True)
    
# Collate report logs
rej_log = pd.concat(rej_log_list)
rej_log.to_csv(report_path + '/rejection_log_all_Ss.csv')