# Pipeline for BIAPT lab EEG Preprocessing: 
#### inspired by: https://github.com/hoechenberger/pybrain_mne/
#### adapted by: Beatrice PDK, Victoria Sus and Charlotte Maschke, 
#### This pipeline uses MNE Python to preprocess EEG data: Plese go here: 
####                                https://mne.tools/stable/overview/index.html
####  for more documentation on MNE Python

## Setup and import

In [None]:
import matplotlib
#import mne_bids
import pathlib
import mne
import os
import os.path as op
from mne import viz
import numpy as np
import PyQt5
# interactive plotting functions.


from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs,corrmap)
#import openneuro

#from mne_bids import BIDSPath, read_raw_bids, print_dir_tree, make_report

# Ensure Matplotlib uses the Qt5Agg backend, 
# which is the best choice for MNE-Python's 
# interactive plotting functions.
matplotlib.use('Qt5Agg')
%matplotlib qt
import matplotlib.pyplot as plt

### Enter the recording information you want preprocess

In [None]:
ID = "030MW"
task = "sedoff" # "sedon1", "sedoff", "sedon2" 
save_task_as = "sedoff" # "sedon1", "turn01", "sedoff", "turn02", "sedon2"

location = 'input location'
output = 'output location'
extension = 'mff'

In [None]:
raw_path = f"{location}/sub-{ID}/eeg/sub-{ID}_task-{task}_eeg.{extension}"
raw_path

## Load the raw data!

In [None]:
if extension == 'mff':
    raw = mne.io.read_raw_egi(raw_path)
if extension == 'set':
    raw = mne.io.read_raw_eeglab(raw_path)
raw

# To crop (skip if no crop needed)

In [None]:
eeg_cropped = raw.crop(tmin = 690.1, tmax = 1364) # in seconds 
eeg_cropped

# Run if no crop needed

In [None]:
eeg_cropped = raw.copy()

## Resample the data to 250

In [None]:
if eeg_cropped.info['sfreq'] != 250:
    resampled = eeg_cropped.resample(250)
else:
    resampled = eeg_cropped.copy()

### Keep the EEG only

In [None]:
# this is to load EEG. If you want to load other stuff please refer to the website documentation
eeg = resampled.pick_types(eeg = True)
print('Number of channels in EEG:')
len(eeg.ch_names)
#eeg.ch_names

In [None]:
# change the individual reference channel names to 'VREF' for all subjects (to avoid having different names for reference)
if 'E129' in eeg.ch_names:
    print('REF E129 replaced')
    mne.rename_channels(eeg.info,{'E129':'VREF'})
if 'E1001' in eeg.ch_names:
    print('REF E1001 replaced')
    mne.rename_channels(eeg.info,{'E1001':'VREF'})
if 'Vertex Reference' in eeg.ch_names:
    print('Vertex Reference replaced')
    mne.rename_channels(eeg.info,{'Vertex Reference':'VREF'})

## Apply filtering

In [None]:
# load actual data into system (before it was only metadata)
eeg.load_data()
#eeg.load_data()

# filter the data between 0.1 to 50 Hz
eeg_filtered = eeg.filter(l_freq=0.1, h_freq = 50) # 0.1, 50

# notch filter the data for freq =60
eeg_notch = eeg_filtered.copy().notch_filter(freqs= 60)
#eeg_notch = eeg_filtered.copy().notch_filter(freqs= 60, notch_widths = 2.25) # Default notch was not long enough for 027MW

In [None]:
%matplotlib qt
viz.plot_raw_psd(eeg_notch, exclude = ['VREF'], fmax = 70)
plt.title(ID)
if not os.path.exists('{}/Preprocessing_info/sub-{}'.format(output,ID,save_task_as)) :
    os.makedirs('{}/Preprocessing_info/sub-{}'.format(output,ID,save_task_as))
plt.savefig('{}/Preprocessing_info/sub-{}/sub-{}_task-{}_PSD_raw_filtered.png'.
            format(output,ID,ID,save_task_as))

# Skip if not needed

In [None]:
#eeg_notch = eeg_notch.copy().notch_filter(freqs= 47) 
#eeg_notch = eeg_notch.copy().notch_filter(freqs= 41.5)

In [None]:
%matplotlib qt
viz.plot_raw_psd(eeg_notch, exclude = ['VREF'], fmax = 70)
plt.title(ID)
if not os.path.exists('{}/Preprocessing_info/sub-{}'.format(output,ID,save_task_as)) :
    os.makedirs('{}/Preprocessing_info/sub-{}'.format(output,ID,save_task_as))
plt.savefig('{}/Preprocessing_info/sub-{}/sub-{}_task-{}_PSD_raw_filtered2.png'.
            format(output,ID,ID,save_task_as))

## Visualize raw data to identify bad channels

In [None]:
eeg_notch.plot(n_channels=60, duration=20)

Verify if labelled correctly

In [None]:
marked_bad = eeg_notch.info['bads']
marked_bad

In [None]:
# save in a txt
if not os.path.exists('{}/Preprocessing_info/sub-{}'.format(output,ID)) :
    os.makedirs('{}/Preprocessing_info/sub-{}'.format(output,ID))
with open('{}/Preprocessing_info/sub-{}/sub-{}_task-{}_marked_bads.txt'.format(output,ID,ID,save_task_as), 'w') as outfile:
    outfile.write("\n".join(marked_bad))


In [None]:
# just for plotting reasons, we add VREF here. This does not add VREF to the bad channels
bads_tmp = eeg_notch.info['bads'].copy()
bads_tmp.append('VREF')

In [None]:
%matplotlib qt
viz.plot_raw_psd(eeg_notch, exclude = bads_tmp, fmax = 70)

## Remove bad channels

In [None]:
eeg_badremoved = eeg_notch.copy().drop_channels(marked_bad)

## Segment into 10-sec epochs

In [None]:
epochs = mne.make_fixed_length_epochs(eeg_badremoved, duration = 10, overlap=0)

In [None]:
epochs  #verify initial number

## Average Reference the data

In [None]:
# use the average of all channels as reference
eeg_avg_ref = epochs.load_data().set_eeg_reference(ref_channels='average')

## Remove Non-Brain Electrodes 

In [None]:
non_brain_el = ['E127', 'E126', 'E17', 'E21', 'E14', 'E25', 'E8', 'E128', 'E125', 'E43', 'E120', 'E48', 
                'E119', 'E49', 'E113', 'E81', 'E73', 'E88', 'E68', 'E94', 'E63', 'E99', 'E56', 'E107' ]

#only add non-brain channels if not already part of noisy channels
for e in non_brain_el: 
    if e in eeg_avg_ref.info['ch_names']:
        if e not in marked_bad :
            eeg_avg_ref.info['bads'].append(e)
    


In [None]:
print(eeg_avg_ref.info['bads'])

In [None]:
# remove non-brain channels
eeg_brainonly = eeg_avg_ref.copy().drop_channels(eeg_avg_ref.info['bads'])

### Reject epochs with amplitude bigger than 2000 µVolt 

Peak to peak amplitude on brain scalp > 2000 µVolt are epochs not linked with physiological causes, physiological amplitude accepted < 800 µVolt

In [None]:
epochs_clean = eeg_brainonly.copy().load_data()
#epochs_clean.drop_bad({'eeg':600*1e-6})
#epochs_clean.plot_drop_log()

In [None]:
epochs_clean.plot(title='bad_epochs_remaining', n_epochs=3, n_channels=120, scalings=20e-6)

## Check final data

In [None]:
#eeg_brainonly.plot(n_epochs=3, n_channels=100, scalings=20e-6)

In [None]:
epochs_clean

In [None]:
# marked_bad = epochs_clean.info['bads']
# marked_bad

In [None]:
# epochs_clean = epochs_clean.copy().drop_channels(marked_bad)

In [None]:
%matplotlib qt
epochs_clean.plot_psd(fmax=70, exclude = ['VREF'])
plt.title(ID)
if not os.path.exists('{}/Preprocessing_info/sub-{}'.format(output,ID,save_task_as)) :
    os.makedirs('{}/Preprocessing_info/sub-{}'.format(output,ID,save_task_as))
plt.savefig('{}/Preprocessing_info/sub-{}/sub-{}_task-{}_PSD_final.png'.
            format(output,ID,ID,save_task_as))

In [None]:
# save a list of dropped epochs
dropped_epochs =  [n for n, dl in enumerate(epochs_clean.drop_log) if len(dl)]
dropped_epochs = [str(i) for i in dropped_epochs] # need to convert to str to save
# save in a txt
if not os.path.exists('{}/Preprocessing_info/sub-{}'.format(output,ID)) :
    os.makedirs('{}/Preprocessing_info/sub-{}'.format(output,ID))
with open('{}/Preprocessing_info/sub-{}/sub-{}_task-{}_dropped_epochs.txt'.format(output,ID,ID,save_task_as), 'w') as outfile:
    outfile.write("\n".join(dropped_epochs))

## Save final brain data

In [None]:
# get some summary info
output_summary = {
    "nr_epochs_final": len(epochs_clean),
    "nr_epochs_dropped": len(epochs)-len(epochs_clean),
    "nr_channels_final": len(epochs_clean.info['ch_names']),
    "nr_channels_dropped": 129 - len(non_brain_el)-len(epochs_clean.info['ch_names'])}

with open('{}/Preprocessing_info/sub-{}/sub-{}_task-{}_SUMMARY.txt'.format(output,ID,ID,save_task_as), 'w') as file:
    print(output_summary, file=file)

In [None]:
if not os.path.exists(f'{output}/sub-{ID}/eeg/') :
    os.makedirs(f'/{output}/sub-{ID}/eeg/')

epochs_clean.save(f"/{output}/sub-{ID}/eeg/sub-{ID}_task-{save_task_as}_epoch_eeg.fif", overwrite=True)

In [None]:
print('End :)')