In [1]:
# ICA demo, based on: https://mne.tools/stable/auto_tutorials/preprocessing/plot_40_artifact_correction_ica.html

import mne
import os
import numpy as np
from mne import io
import sys as sys
sys.path.append('..')
from estimators.linear import train_linear as estimator
import pickle
import matplotlib.pyplot as plt
%matplotlib inline
from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs,
                               corrmap)

In [2]:
# read data

#Get data path
data_path = os.getcwd()+'/../sample_data/'


# Setup for reading the raw data
raw_fname = data_path + 'B0101T.gdf'
raw = io.read_raw_gdf(raw_fname, preload=True)
raw.set_channel_types({'EOG:ch01':'eog', 'EOG:ch02':'eog', 'EOG:ch03':'eog'})

# Read montage/digitisation points
raw_fname = data_path + 'GrazIV2B_montage.elc'
montage=mne.channels.read_custom_montage(raw_fname)
raw.set_montage(montage)

#could take subset of data because ICA takes a lont time (but now we only have 3 EEG channels so it's fast anyway)
#for this use raw.crop()


Extracting EDF parameters from C:\Users\matth\OneDrive\Bureaublad\mne_examples\sample_data\B0101T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 604802  =      0.000 ...  2419.208 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


<RawGDF  |  B0101T.gdf, n_channels x n_times : 6 x 604803 (2419.2 sec), ~27.7 MB, data loaded>

In [None]:
#get a summary of how the ocular artifact manifests across each channel type
eog_evoked = create_eog_epochs(raw).average()
eog_evoked.apply_baseline(baseline=(None, -0.2))
eog_evoked.plot_joint()

EOG channel index for this subject is: [3 4 5]
Filtering the data to remove DC offset to help distinguish blinks from saccades
Setting up band-pass filter from 1 - 10 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed frequency-domain design (firwin2) method
- Hann window
- Lower passband edge: 1.00
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.75 Hz)
- Upper passband edge: 10.00 Hz
- Upper transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 10.25 Hz)
- Filter length: 4096 samples (16.384 sec)

Now detecting blinks and generating corresponding events
Found 392 significant peaks
Number of EOG events detected : 392
392 matching events found
No baseline correction applied
Not setting metadata
Loading data for 392 events and 251 original time points ...
0 bad epochs dropped
Applying baseline correction (mode: mean)


In [None]:
#Filtering to remove slow drifts
threshold = 5.
filt_raw = raw.copy()
filt_raw.load_data().filter(l_freq=threshold, h_freq=None)

In [None]:
#Fitting and plotting the ICA solution
ica = ICA(n_components=3, random_state=97)
ica.fit(filt_raw)

raw.load_data()
ica.plot_sources(raw)
ica.plot_components()

In [None]:
# plot reconstructed signal with different components excluded

ica.plot_overlay(raw, exclude=[0], picks='eeg')

ica.plot_overlay(raw, exclude=[1], picks='eeg')

ica.plot_overlay(raw, exclude=[2], picks='eeg')

In [None]:
ica.plot_properties(raw)

In [None]:
#remove component that correlates with EOG and reconstruct signal without it

#set threshold!
threshold=1.2

ica.exclude = []
# find which ICs match the EOG pattern
eog_indices, eog_scores = ica.find_bads_eog(raw,threshold=threshold)
#print(eog_indices)
ica.exclude = eog_indices

# barplot of IC component "EOG match" scores
ica.plot_scores(eog_scores)

# plot diagnostics
ica.plot_properties(raw, picks=eog_indices)

# plot ICs applied to raw data, with EOG matches highlighted
ica.plot_sources(raw)

# plot ICs applied to the averaged EOG epochs, with EOG matches highlighted
ica.plot_sources(eog_evoked)

In [None]:
#reconstruct signal
reconst_raw = raw.copy()
ica.apply(reconst_raw)

raw.plot()
reconst_raw.plot()

In [None]:
# Classifier reconstructed signal


tmin, tmax = 2, 4  # time chosen from jupyter notebook
event_id = {'left': 10, 'right': 11}

# Setup for reading the raw data

raw_data = reconst_raw.get_data()

reconst_raw.filter(6, 14, fir_design='firwin')  # extract alpha band (see jupyter notebook)
events, _ = mne.events_from_annotations(reconst_raw)

# Read epochs
epochs = mne.Epochs(reconst_raw, events, event_id, tmin, tmax-1/raw.info['sfreq'], proj=True, baseline=None,
                    preload=True, picks=[0, 2])  # use only C3 and C4, they are different
                                                 # MNE slicing for epochs INCLUDES the upper limit!!!
labels = epochs.events[:, -1]

# fit classifier
best_est = estimator(epochs, labels)
