# EEG-ERP Preprocessing 
## Batch script

This script will run on all subjects in `rawdata`. Steps performed include:
- filtering
- epoching
- mark (but don't correct) bad segments of data (trials/channels) using AutoReject
- pass the marked data to ICA
- auto-idenitfy and remove ocular indepdent components 
- apply ICA to epoched data
- apply AutoReject to ICA-cleaned data
- rereference
- export clean epochs to `.fif` file in `derivatives/erp_preprocessing`
- save an MNE report with details/vizualizations of the above steps in `derivatives/erp_preprocessing/logs`

Most parameters that you would want to change are read from `config.yml`. 

It is recommended that you not apply any baseline correction at this stage (and the defult config.yml file reflects this). Baseline correction can be applied in subsequent scripts, but doing baseline correction here precludes later using baseline regression.

Bad channels and non-ocular independent components can be manually identified after this script is run the first time. These can be stored in additional config files saved in `rawdata`, and then this script can be re-run to have those excluded. In an ideal world with reasnably clean data this should not be necessary, but e.g., if there are broken electrodes, EGI data, data from children, etc. it may be necessary.

---
Copyright 2023 [Aaron J Newman](https://github.com/aaronjnewman), [NeuroCognitive Imaging Lab](http://ncil.science), [Dalhousie University](https://dal.ca)

Released under the [The 3-Clause BSD License](https://opensource.org/licenses/BSD-3-Clause)

---

In [None]:
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import os.path as op
from os import remove
from glob import glob
from pathlib import Path
import yaml
from yaml import CLoader as Loader
import numpy as np
import mne
from mne.preprocessing import annotate_amplitude
mne.set_log_level('error')
from mne_bids import BIDSPath, read_raw_bids
from scipy.stats import zscore
from autoreject import Ransac, get_rejection_threshold, AutoReject
from time import time

## Read Parameters from config.yml

Will import study-level parameters from `config.yml` in `bids_root`

In [None]:
# this shouldn't change if you run this script from its default location in code/import
bids_root = '../..'

cfg_file = op.join(bids_root, 'config.yml')
with open(cfg_file, 'r') as f:
    config = yaml.load(f, Loader=Loader)

study_name = config['study_name']
task = config['task']
data_type = config['data_type']
eog = config['eog']

### ADD IF STATEMENT IN CASE THIS NO WORK
if config['drop_ch'] != None:
    drop_ch = [s.strip() for s in config['drop_ch'].split(",")]  


montage_fname = config['montage_fname']

# fix per changes to config
n_jobs = config['preprocessing_settings']['n_jobs']
filt_p = {k: v for d in config['preprocessing_settings']['filter'] for k, v in d.items()}
ica_p = {k: v for d in config['preprocessing_settings']['ica'] for k, v in d.items()}
epoch_p = {k: v for d in config['preprocessing_settings']['epoch'] for k, v in d.items()}
reject = epoch_p['reject']

### Paths

In [None]:
raw_path = op.join(bids_root, 'rawdata')

derivatives_path = op.join(bids_root, 'derivatives', 'erp_preprocessing')
if Path(derivatives_path).exists() == False:
    Path(derivatives_path).mkdir(parents=True)

report_path = op.join(derivatives_path, 'logs')
if Path(report_path).exists() == False:
    Path(report_path).mkdir(parents=True)


epochs_suffix = '-epo.fif'

### Event codes - mappings between values and labels

Read `rawdata/event_code_mappings.yml` to get the labels for event codes for each experiment. This file is used to create the `event_id_map` dictionary.

The `extra_mappings` defined in the cell below provides additional code mappings codes within this script. This is intended to be used to define event types that might be created in this script. For example, if the EEG file contains codes for stimulus types, and codes on each trial indicating whether a response was correct or not, then in this script we might want to define aadditional codes so that we can create event types for correct and incorrect trials of each condition seprately.

These two dictionaries are combined to create the `event_id_map` dictionary, which is used to create the `event_id` column in the `events` dataframe.

In [None]:
with open( op.join(raw_path, 'event_code_mappings.yml'), 'r') as f:
    code_map = yaml.load(f, Loader=Loader)
    
# add an entry for 0, because it can appear in OpenVibe data
extra_mappings = {'NULL':0, 
                  'match/correct':2011, 'match/incorrect':2010, 'match/noresp':2012,
                  'mismatch/correct':2001, 'mismatch/incorrect':2000, 'mismatch/noresp':2002,
                  }

# The line below combines the codes read from the file, and codes defined here
# comment it out if you are only using one or the other
event_id_map = {**code_map, **extra_mappings}

# contingent_events could recode stimulus types ('targets') based on whether they were 
#  followed by a correct response (contingencies)
proc_contingent_events = False # change to True to use this feature
contingent_events = {'match/correct':{'target_code':event_id_map['match/noresp'], 
                                     'contingency_code':event_id_map['_corr'], 
                                     'contingency_lag':1
                                     },
                     }

def proc_contingent_events():
     for c, vals in contingent_events.items():
        events_a = np.where(events[:, 2] == vals['target_code'])[0]
        events_b_idx = events_a + vals['contingency_lag']
        events_b = events[events_b_idx, 2] == vals['contingency_code']
        events[events_a[events_b], 2] = event_id_map[c]

### Subject list

In [None]:
prefix = 'sub-'
subjects = sorted([s[-7:] for s in glob(raw_path + '/' + prefix + '*')])

In [None]:
subjects

## Read manually-marked independent components
Run this script once, then inspect ICs (in the `sub-xxx.html` file that is stored in `derivatives/erp_preprocessing/reports` folder). Based on this, make decisions about whether any additional ICs should be added, or any automatically-removed ICs should be included. Add these to the `participants_manual_ic.yml` file located in the present folder. Then, run this script again to apply the changes.

Additional ICs to remove were selected based on:
- participants for whom more than 15% of trials were removred by AutoReject after ICA correct
- participants for whom the average across all trials and all electrodes did not show a clear pattern of P1-N1-P2 components
- for such participants, the scalp map and details of each IC were visuall inspected. Components were removed if they were
    - focal at a single electrode, or a very low number of electrodes
    - focal at the edges of the electrode montage
    - present on a low number of trials
    - showed no systematic pattern across trials that was time-locked to stimulus onset
    
ICs to add (un-remove) were selected based on:
- IC shows clear and consistent temporal dependency on stimulus onset
- IC appears to contain P1-N1-P2 complex, or part thereof
- IC has broad scalp distribution across many electrodes, characteristic of ERP component

In [None]:
cfg_file = './participants_manual_ic.yml'
with open(cfg_file, 'r') as f:
    ica_manual = yaml.load(f, Loader=Loader)

## Handle study Metadata

Read behavioural log file to obtain trial metadata. This will need editing (or simply ignore) for every study

In [None]:
use_metadata = False # change to True if you have behavioural data you want to add to the EEG data output by this script

# list columns in log file that we do not want to include in the metadata
# add to this according to your study; typically there are numerous columns that are not relevant to the EEG data
beh_drop_cols = ['frameRate', 'expName', 
                        'session', 'participant'
                        ]

# read behavioural log file to obtain trial metadata
#this will need editing (or simply IGNORE) for every study
def get_metadata():
    log_path = op.join(raw_path, subject, 'ses-' + ses, 'beh')
    log_file = log_path + '/' + subject + '_ses-' + ses + '_task-' + task + '_beh.tsv'
    metadata = pd.read_csv(log_file)
    # drop practice and other non-trials
    metadata = metadata[~metadata['blockSort'].isna()]
    # combine trial counter and index columns from the two blocks
    # metadata['trial_num'] = metadata['trials.thisTrialN'].combine_first(metadata['trials_2.thisTrialN'])
    # metadata['trial_index'] = metadata['trials.thisIndex'].combine_first(metadata['trials_2.thisIndex'])
 
    # drop unneeded columns
    metadata.drop(columns=beh_drop_cols, inplace=True)

    # check if there are more events than metadata rows
    # (in which case EEG data was saved for practice trials, which we want to drop)
    if events.shape[0] > metadata.shape[0]:
        n_practice_events = events.shape[0] - metadata.shape[0]
        events = events[n_practice_events:]


## Main Loop

In [None]:
rej_log_list = []

for subject in subjects:
    for ses in sessions:

        start_time = time()
        print('\n-------------------------')
        print('-------- ' + subject + ' --------')
        print('-------------------------')

        report = mne.Report(subject=subject, 
                            title=study_name + ' preprocessing: ' + subject + ' ' + ses,
                            verbose='WARNING')

        ### subject-specific paths
        in_path = BIDSPath(root=raw_path, 
                        subject=subject[-3:],
                        session=ses,
                        datatype=data_type,
                        task=task
                        )

        ### Import data
        raw = read_raw_bids(in_path)
        if 'drop_ch' in locals():
            raw = raw.drop_channels(drop_ch)
        else:
            # foolproofing
            n_chan = len(raw.info['chs'])
            if n_chan > 40:
                print('WARNING: ' + subject + ' has ' + n_chan + ' channels, but drop_ch is not defined in config.yml')
                print('If you used 32 channels in the booth room you need to define drop_ch.')

        ### Read events
        ### THIS PROBABLY WON'T WORK OUT OF THE BOX
        ### PROVED HARD TO DEVELOP ROBUST CODE FROM KLUGEY WAYS PREVIOUS STUDIES HANDLED
        ### LEFT AS A PROBLEM TO SOLVE FOR FUTURE AARON (OR SOME OTHER BRAVE SOUL)
        events, event_dict = mne.events_from_annotations(raw)

        # remove event code(s) assocaited with '__', which marks start of recording but otherwise useless
        if '__' in event_dict.keys():
            events = np.delete(events, list(np.where(events[:, 2] == event_dict['__'])[0]), axis=0)
        
        # handle event codes from openVibe vs. other recording software
        if 'OVTK' in list(event_dict.keys())[0]:
            if list(event_dict.keys())[0].split('/')[1].split('_')[0] == 'OVTK':
                # OpenVibe exports codes in hex format with a bunch of leading text, and then MNE/BIDS 
                # maps these to a sequence of numbers starting at 10001. 
                # This will convert the events array to the original codes sent by the stimulus program.
                hex_dict = {int(str(k).split('_')[-1], 16):v for k, v in event_dict.items()}
                codes_conv = np.copy(events[:, 2])
                for k, v in hex_dict.items():
                    codes_conv[codes_conv==v] = k
                events[:, 2] = codes_conv 

        # Map labels to condition codes
        event_id = {}
        try:
            for label, code in event_id_map.items():
                event_id[label] = event_dict[label]
        except:
            event_id = event_id_map       
        
        # remove event code == 0
        events = events[np.nonzero(events[:, 2])]

        # remove duplicate events
        events = np.unique(events, axis=0)

        ## Here you could add study-specific event code processing, e.g.: recode events contingent on other events 
        if proc_contingent_events == True:
            contingent_events()
            
        ## Add events to report
        report.add_events(events, event_id=event_id, 
                            sfreq=raw.info['sfreq'],
                            title='Events'
                            )
        plt.close()

        # Mark any flat chanels/segments
        # Note - seems to replace annotations, so run after event code processing 
        # (but before filtering or segmentation)
        annotations, bads = annotate_amplitude(raw, 
                                            flat=0., 
                                            bad_percent=25,
                                            )
        raw.set_annotations(annotations)
    
        ### Filtering
        # channel selection
        picks = mne.pick_types(raw.info, 
                            eeg=True,
                            eog=True
                            )
        
        print('Filtering...')
        job_start = time()
        
        # Filter for ICA  
        raw_ica = raw.load_data().copy().filter(filt_p['l_freq_ica'], filt_p['h_freq'],
                                    picks=picks,
                                    n_jobs=n_jobs
                                )

        ## Filter for final
        raw.filter(filt_p['l_freq'], filt_p['h_freq'],
                picks=picks,
                n_jobs=n_jobs
                )
        
        print('Filtering took ' + str (time() - job_start) + ' s')
        ## Add raw to report
        report.add_raw(raw=raw, 
                psd=True, butterfly=True, 
                title='Raw data, bandpass filtered ' + str(filt_p['l_freq']) + '–' + str(filt_p['h_freq'])
                )
        # plots are generated that accumulate and hog resources. Close them to minimize this
        plt.close()
        
        ### Epoch data filtered for ICA
        epochs_ica = mne.Epochs(raw_ica,
                                events, event_id,
                                epoch_p['tmin'], epoch_p['tmax'],
                                baseline=epoch_p['baseline'], detrend=epoch_p['detrend'],
                                reject=None, 
                                flat=epoch_p['flat'],
                                preload=True
                            )

        
        # use AutoReject to remove bad epochs, repair sensors and return clean epochs.
        print('AutoReject pre-ICA...')
        job_start = time()
        ar = AutoReject(n_interpolate=[1, 2, 4, 8, 16],
                        random_state=ica_p['ica_random_state'],
                        picks=mne.pick_types(epochs_ica.info, eeg=True, eog=False),
                        n_jobs=n_jobs, 
                        verbose=False
                    )
        ar.fit(epochs_ica)
        plt.close()

        print('n_interpolate = ' +  str(ar.n_interpolate_['eeg']))
        print('AutoReject took ' + str (time() - job_start) + ' s')
        
        reject_log = ar.get_reject_log(epochs_ica)
        fig = reject_log.plot('horizontal', show=False);
        report.add_figure(fig=fig, title='AutoReject log')


        ### Fit ICA
        print('ICA...')
        job_start = time()
        ica = mne.preprocessing.ICA(method=ica_p['ica_method'],
                                    n_components=ica_p['n_components'],
                                    random_state=ica_p['ica_random_state'],
                                    max_iter='auto')
        
        ica.fit(epochs_ica[~reject_log.bad_epochs],  # added [~reject_log.bad_epochs] for AutoReject
                decim=3, 
                picks=['eeg']
                );

        # Identify ocular ICs
        # The default *z* threshold doesn't work for
        # all subjects. This routine starts with the default z (from config) and steps down
        # until at least n_max_eog EOG components are identified.
        # The limitations of this are that it assumes there will always be at least n_max_eog EOG
        # components (blinks are always present, but horizontal movements are not
        # always present), and may not work if there are > 3 components, if the
        # score of the third is > `z_step` less than the score of the second.
        # In practice, many of these components (with EGI data) may not be ocular, but are (hopefully) not EEG.
        # Be sure to check the reports and confirm no ERP components are rejected!

        ica.exclude = []
        num_excl = 0
        z_thresh = ica_p['ica_zthresh'] 
        z_step = ica_p['ica_zstep']

        while num_excl < ica_p['n_max_eog'] and z_thresh > 0:
            eog_indices, eog_scores = ica.find_bads_eog(epochs_ica, threshold=z_thresh)
            num_excl = len(eog_indices)
            z_thresh -= z_step # won't impact things if num_excl is ≥ n_max_eog 

        ica.exclude = eog_indices
        z_thresh_final = round(z_thresh + z_step, 2)

        # Manual removal/re-addition of ICs based on visual inspection
        if subject in ica_manual:
            if 'manually_excluded_ic' in ica_manual[subject][ses]:
                for ic in ica_manual[subject][ses]['manually_excluded_ic']:
                    ica.exclude.append(ic)
            if 'manually_included_ic' in ica_manual[subject][ses]:
                for ic in ica_manual[subject][ses]['manually_included_ic']:
                    ica.exclude.remove(ic)          

        # Create average of EOG events
        eog_evoked = mne.preprocessing.create_eog_epochs(raw_ica).average().apply_baseline(baseline=(None, epoch_p['tmin']))

        ## Add ICA to report
        report.add_ica(ica=ica, title='ICA', inst=epochs_ica,
                    eog_evoked=eog_evoked, 
                    eog_scores=eog_scores,
                    n_jobs=n_jobs
                    )
        print('ICA took ' + str (time() - job_start) + ' s')
        plt.close()   
            
        ### Segment filtered raw data into epochs for final analysis
        epochs = mne.Epochs(raw,
                            events, event_id,
                            epoch_p['tmin'], epoch_p['tmax'],
                            baseline=epoch_p['baseline'], detrend=epoch_p['detrend'],
                            reject=reject, 
                            flat=epoch_p['flat'],
                            preload=True
                        )

        ### ----- Combine metadata with epochs
        # Note that this uses the selection property of the epochs, which keeps track of the 
        #  original event indexes, even after some trials/epochs are dropped (e.g., if there were bad/flat channels)
        if use_metadata == True:
            epochs.metadata = metadata.iloc[epochs.selection]        
        
        ### Apply ICA correction to epochs
        ica.apply(epochs)
        
        ### Apply AutoReject to further clean epochs
        print('AutoReject post-ICA...')
        job_start = time()    
        epochs_clean, reject_log = ar.fit_transform(epochs, return_log=True)
        print('AutoReject took ' + str (time() - job_start) + ' s')
        plt.close()

        fig = reject_log.plot('horizontal', show=False)
        report.add_figure(fig=fig, title='AutoReject log')
        plt.close()

        ### Re-reference, now that channels are cleaned
        epochs_clean.set_eeg_reference(ref_channels=epoch_p['rereference']);
        
        ### Save cleaned epochs      
        out_path = BIDSPath(root=derivatives_path, 
                           subject=subject[-3:], 
                           datatype=data_type,
                           task=task
                          )    
        # remove old fif file if it exists, and update bids_path
        if str(out_path.fpath)[-len(epochs_suffix):] == epochs_suffix:
            remove(out_path.fpath)
            out_path = BIDSPath(root=derivatives_path, 
                       subject=subject[-3:], 
                       datatype=data_type,
                       task=task
                      )    
        if op.exists(str(out_path.fpath)) == False:
            Path(str(out_path.fpath)).mkdir(parents=True)

        epochs_clean.save(str(out_path.fpath) + epochs_suffix, 
                        overwrite=True)

        # add epochs to report
        report.add_epochs(epochs_clean,  
                        title='Epochs'
                        )
        ### Save plot of average across all trials
        fig = epochs_clean.copy().average().plot(spatial_colors=True, 
                                                show=False);
        report.add_figure(fig=fig, title='Grand average over all epochs')
        plt.close(fig)
        
        # Add plots of average of each condition
        for condition in event_id.values():
            fig = epochs_clean[condition].copy().average().plot(spatial_colors=True, 
                                                                show=False);
            report.add_figure(fig=fig, title=str(condition))
            plt.close(fig)
        
        proc_time = time() - start_time
        print('Total processing time: ', proc_time)

        ### Report on how much was rejected
        rm_epochs = epochs.selection.shape[0] - epochs_clean.selection.shape[0]
        pct_epochs = rm_epochs / epochs.selection.shape[0] * 100   
        rej_log_list.append(pd.DataFrame({'id':subject, 
                                        'cpu_time':proc_time,
                                        'ntrials_rej':rm_epochs,
                                        '%t_rej':round(pct_epochs, 2),
                                        'ic_rm':len(ica.exclude),
                                        'n_interp':str(ar.n_interpolate_['eeg'])
                                        }, index=[0]
                                        )
                        )

        # Save report to file
        report_name = report_path + '/' + subject + '.html'
        report.save(report_name, overwrite=True)
    
# Collate report logs
rej_log = pd.concat(rej_log_list)
rej_log.to_csv(report_path + '/rejection_log_all_Ss.csv')