## Note for writing up
Check which epochs are used in the end!!! I think they are filtered between .5 and 30 Hz in the end rather than what I do here!

Run this notebook to

- Load bdf files
- Set channel locations
- Apply low-pass filter
- Apply high-pass filter
- Downsample
- Manually select very bad segments
- Apply mastoid reference
- Interpolate noisy channels
- Compute ICA to remove eye movement artifcats
- Save cleaned data
                                                                                                                     

## Import stuff and create directories

In [None]:
import os
import numpy as np
import mne
from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs,corrmap)
import matplotlib.pyplot as plt
import seaborn as sns
import re
import sklearn 

%matplotlib qt

In [None]:
def mkdir(p):
    sp = re.split('/|\\\\', p)
    bp = ''
    for pp in sp:
        bp = os.path.join(bp, pp)
        if not os.path.exists(bp):
            os.mkdir(bp)
            print( '%s created.' % bp)

            
mkdir('EEGdata')
mkdir('annotations')
mkdir('EEGdata/cleaned_mastoid_reference')


## Load the original bdf file

In [None]:
session = 2

In [None]:
subject = 21

In [None]:
data_dir = 'EEGdata'

if session == 1:
    filepath = os.path.join(data_dir, '%i_1.bdf' % subject) # session 1
if session == 2:
    filepath = os.path.join(data_dir, '%i_2.bdf' % subject) # session 2

In [None]:
raw = mne.io.read_raw_bdf(filepath) 

## Channel locations

In [None]:
raw = raw.drop_channels(['EXG7', 'EXG8'])

eogs = ['ER', 'EL', 'ELA', 'ELB']
emgs = ['ML', 'MR']

mne.rename_channels(info = raw.info, mapping = dict(zip(raw.ch_names[64:72], eogs + emgs)))

In [None]:
print(raw)

In [None]:
print(raw.info)

In [None]:
raw = raw.load_data()

## Low-pass filter to < 50 Hz

In [None]:
raw = raw.filter(l_freq=None, h_freq=50)

## High-pass filter to > .1 Hz

In [None]:
raw = raw.filter(l_freq=.1, h_freq=None)

## Downsample to 250 Hz

In [None]:
raw = raw.resample(250)

## Manually select very bad segments

In [None]:
# Check if any annotations have been saved for this participant?

if session == 1:
    annot_file = '%s/%i-annot.fif' % ('annotations', subject)
if session == 2:
    annot_file = '%s/%i_2-annot.fif' % ('annotations', subject)


if os.path.exists(annot_file):
    txt = 'Annotations file found at %s.\nDo you want to load it?\ny/n...'
    resp = input(txt)
    if resp.lower() == 'y':
        old_annotations = mne.read_annotations(annot_file)
        raw = raw.set_annotations(old_annotations)


In [None]:
# Plot raw data and manually select bad segments
raw.plot(n_channels=68, highpass=None, lowpass=None).canvas.key_press_event('a')



In [None]:
# Save annotations in annotations file for participant 

if raw.annotations is not None:
    txt = 'Save annotations to file?\ny/n...'
    resp = input(txt)
    if resp.lower() == 'y':
        raw.annotations.save(annot_file, overwrite=True)

## Apply mastoid reference

In [None]:
if session == 1:
    if subject == 3:
        raw = raw.set_eeg_reference(ref_channels=['MR']) # ML fell of 
else:     
    raw = raw.set_eeg_reference(ref_channels=['ML', 'MR'])
    
    

In [None]:
raw = raw.drop_channels(['ML', 'MR'])

d = {}
d.update( dict(zip(eogs, ['eog']*len(eogs))))
d.update( dict(zip(raw.ch_names[:64], ['eeg']*64)))

raw.set_channel_types(d)

raw = raw.set_montage('biosemi64')

## Interpolate noisy channels if necessary

In [None]:
# Plot raw data to check for noisy channels
raw.plot(n_channels=68, highpass=None, lowpass=None).canvas.key_press_event('a')


In [None]:
# Plot PSD (check for bad channels)

raw.plot_psd(reject_by_annotation=False)

In [None]:
# Plot channel covariance matrix (check for bad channels)

X = raw._data[:64]
cov = np.cov(X)
plt.figure(figsize=(8,6))
sns.heatmap(cov, cmap='seismic', center=0);

In [None]:
# Plot channel correlation matrix (check for bad channels)

X = raw._data[:64]
cor = np.corrcoef(X)
plt.figure(figsize=(8,6))
sns.heatmap(cor, cmap='seismic', center=0);

In [None]:
# Use this to figure out which number refers to which channel 

print(raw.ch_names[6])


In [None]:
# Add entries here for channels to interpolate

if session == 1:
    participant_bads = {
        1: ['PO4'],
        10: ['P2'],
        13: ['PO4'],
        14: ['PO4'],
        16: ['P2'],
        20: ['P2'],
        21: ['POz']
    }

    
if session == 2:    
    participant_bads = {
        3: ['P2'],
        5: ['P2'],
        10: ['P2'],
        11: ['P2'],
        12: ['PO4'],
        13: ['PO4'],
        16: ['P2'],
        21: ['P2']
    }


    
if subject in participant_bads.keys():
    bads = participant_bads[subject]
    raw.info['bads'] = bads
raw = raw.interpolate_bads()


## Independent component analysis (ICA)

### Fit ICA and plot components

In [None]:
# filter the data to remove low-frequency drifts

filt_raw = raw.copy()
filt_raw.load_data().filter(l_freq=1., h_freq=None)

In [None]:
# fit the ICA (15 components - check with Nick?)

ica = ICA(n_components=15, method='fastica', max_iter='auto', random_state=97)
ica.fit(filt_raw)

In [None]:
ica.plot_sources(raw)

In [None]:
ica.plot_components()

In [None]:
# pick components to exclude

exclude_list = [0, 3]

In [None]:
# plot an overlay of the original signal against the reconstructed signal with the artifactual ICs excluded

ica.plot_overlay(raw, exclude=exclude_list, picks='eeg')

In [None]:
# further check the components you want to exclude

ica.plot_properties(raw, picks=exclude_list, psd_args={'fmax': 35.})

In [None]:
# enter components to exclude (most likely the ones in exclude_list)

if session == 1:
    excludes = {
        1: [0, 3],
        2: [1, 2],
        3: [1, 3],
        4: [0, 6],
        5: [0, 3],
        6: [0, 3],
        7: [0, 6],
        8: [0, 2],
        9: [0, 2],
        10: [0, 5],
        11: [0, 4],
        12: [0, 6],
        13: [0, 1],
        15: [0, 2],
        16: [0, 3],
        17: [0, 4],
        18: [0, 4],
        19: [0, 5],
        20: [0, 1],
        21: [0, 2]
    }


if session == 2:    
    excludes = {
        1: [0, 6],
        2: [0, 4],
        3: [0, 3],
        4: [1, 8],
        5: [0, 1],
        6: [0, 3],
        7: [0, 5],
        8: [0, 4],
        9: [0, 2],
        10: [0, 9],
        11: [0, 4],
        12: [0, 4],
        13: [0, 3],
        15: [0, 4],
        16: [1, 6],
        17: [0, 4],
        18: [0, 4],
        19: [1, 6],
        20: [1, 9],
        21: [0, 3]
    }

In [None]:
#ica.exclude = excludes['%i' % subject] #check with this does not work
ica.exclude = exclude_list

In [None]:
raw = ica.apply(raw)

## Save to EEGdata/converted/sXXX-raw.fif

In [None]:
if session == 1:
    raw.save('EEGdata/cleaned_mastoid_reference/s%i-raw.fif' % subject, overwrite=True)

if session == 2:
    raw.save('EEGdata/cleaned_mastoid_reference/s%i_2-raw.fif' % subject, overwrite=True)

## 

## Inspect ERPs

In [None]:
# # strategic condition

# event_dict = {'exp_start': 1, 
#               'block_start': 2, 
#               'fixation_cross': 3,
#               'stimulus_left_correct': 4, 
#               'stimulus_right_correct': 5, 
#               'response_left': 6,
#               'response_right': 7,
#               'highlight_box': 8,
#               'confidence_rating': 9,
#               'partner_marker_left': 10,
#               'partner_marker_right': 11,
#               'higher_conf_box': 12,
#               'feedback_correct': 13,
#               'feedback_incorrect': 14,
#               'exp_end': 15
#              }



# non-strategic condition

event_dict = {'exp_start': 1, 
              'block_start': 2, 
              'fixation_cross': 3,
              'stimulus_left_correct': 4, 
              'stimulus_right_correct': 5, 
              'response_left': 6,
              'response_right': 7,
              'highlight_box': 8,
              'confidence_rating': 9,
              'partner_marker_left': 10,
              'partner_marker_right': 11,
              #'exp_end': 15
             }

In [None]:
# identify stimulus events
    
events = mne.find_events(raw, stim_channel='Status')

In [None]:
# show events timecourse

fig = mne.viz.plot_events(events, event_id=event_dict, sfreq=raw.info['sfreq'],
                          first_samp=raw.first_samp)

In [None]:
epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-0.2, tmax=0.5, preload=True)

### Feedback Related Negativity

In [None]:
feedback_correct_epochs = epochs['feedback_correct']
feedback_incorrect_epochs = epochs['feedback_incorrect']

In [None]:
feedback_correct_evoked = feedback_correct_epochs.average()
feedback_incorrect_evoked = feedback_incorrect_epochs.average()

In [None]:
evokeds = dict(correct_feedback=list(epochs['feedback_correct'].iter_evoked()),
               incorrect_feedbackl=list(epochs['feedback_incorrect'].iter_evoked()))

mne.viz.plot_compare_evokeds(evokeds, combine='mean', picks=['Cz'], invert_y=True)

### Stimulus-locked P300

In [None]:
stimulus_epochs = epochs['stimulus_left_correct', 'stimulus_right_correct']

In [None]:
stimulus_epochs.plot_image(picks=['Pz'])

### Visual Potential P100

In [None]:
stimulus_epochs.plot_image(picks=['PO8'])

### Frontocentral N100

In [None]:
stimulus_epochs.plot_image(picks=['FCz'])