In [1063]:
import mne
import matplotlib.pyplot as plt
import PyQt5
import seaborn as sns
import numpy as np
import pandas as pd
import glob
from scipy import io, stats

In [1064]:
plt.switch_backend('QtAgg')

In [1065]:
def load_data(cnt_data, mrk_data, session = 1):
    return { 
        "clab": [_[0] for _ in cnt_data["cnt"][0,session-1]["clab"][0,0][0]],
        "fs": cnt_data["cnt"][0,session-1]["fs"][0,0][0,0],
        "x": cnt_data["cnt"][0,session-1]["x"][0,0] * 1e-6,
        "time": mrk_data["mrk"][0,session-1]["time"][0,0][0] * 1e-3,    # conversion from 'ms' to 's'
        "y": mrk_data["mrk"][0,session-1]["y"][0,0],
        "event": [_[0] for _ in mrk_data["mrk"][0,session-1]["event"][0,0][0,0][0]],
        "className": [_[0] for _ in mrk_data["mrk"][0,session-1]["className"][0,0][0]]
    }

In [1066]:
dataset = 'A'

subject = '29'              # subject id to determine file path

tmin, tmax = -5.0, 20.0     # epoch start/end relative to event marker (seconds)
baseline = (None, -2.0)     # baseline correction  

In [1067]:
%%capture

raws = []

path = glob.glob(f'../dataset\\EEG_[0-2][0-9]-[0-2][0-9]\\subject {subject}\\with occular artifact\\')[0]

montage = mne.channels.make_standard_montage("standard_1005")   # international 10-5 system

cnt = io.loadmat(path + 'cnt.mat')
mrk = io.loadmat(path + 'mrk.mat')

# session 1,3,5 for Dataset A (lhmi/rhmi)
# session 2,4,6 for Dateset B (ma/baseline)
sessions = [1,3,5] if dataset == 'A' else [2,4,6]

for session in sessions:
    task = load_data(cnt, mrk, session)

    sfreq = task["fs"]  # 200Hz
    ch_names = task["clab"]
    ch_types = list(map(lambda c: "eog" if c.endswith("EOG") else "eeg", ch_names))

    data = task["x"].transpose() # ~ 120,000 samples and ~ 600 seconds
    
    onset = task["time"]
    duration = [10] * len(task["time"])
    description = list(map(lambda y: "cond1" if y else "cond2", task["y"][0]))

    # For Dataset A => cond1 = lhmi and cond2 = rhmi
    # For Dataset B => cond1 = ma and cond2 = baseline

    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
    annotations = mne.Annotations(onset=onset, duration=duration, description=description)

    raw = mne.io.RawArray(data=data, info=info)
    raw.set_montage(montage)

    raw.set_annotations(annotations)

    raws.append(raw)

raw = mne.concatenate_raws(raws)

In [1068]:
# raw.plot_sensors(kind='3d', ch_type='eeg', block=True)

In [1069]:
raw.plot(block=True)

Channels marked as bad:
['P8']


<MNEBrowseFigure size 1121x800 with 4 Axes>

In [1070]:
%%capture
raw.interpolate_bads()

In [1071]:
raw.info["bads"] = []

In [1072]:
%%capture

raw.set_eeg_reference("average")    # common average reference

In [1073]:
%%capture
raw_filtered = raw.copy().filter(l_freq=0.5, h_freq=50)     # band pass filters

In [1074]:
%%capture

raw_ica = raw.copy().filter(l_freq=1.0, h_freq=None)

ica = mne.preprocessing.ICA(
    n_components=0.999,
    max_iter=500,
    method='picard',
    fit_params=dict(ortho=True, extended=True),
    random_state=42)

ica.fit(raw_ica)

In [1075]:
%%capture

eog_epochs = mne.preprocessing.create_eog_epochs(raw_ica, reject=None, baseline=(None, 0), tmin=-0.5, tmax=0.5)
eog_evoked = eog_epochs.average()
eog_inds, eog_scores = ica.find_bads_eog(eog_epochs)

ica.exclude = eog_inds

In [1076]:
ica.exclude

[]

In [1077]:
ica.plot_sources(raw_ica, block=True)

Creating RawArray with float64 data, n_channels=26, n_times=360903
    Range : 0 ... 360902 =      0.000 ...  1804.510 secs
Ready.


<MNEBrowseFigure size 1121x800 with 4 Axes>

In [1078]:
ica.exclude

[0]

In [1079]:
%%capture

ica.apply(raw_filtered)  # Apply ICA

In [1080]:
raw_filtered.plot(block=True)

Channels marked as bad:
none


<MNEBrowseFigure size 1121x800 with 4 Axes>

In [1081]:
%%capture
raw_filtered.interpolate_bads()

In [1082]:
raw_filtered.info["bads"] = []

In [1083]:
%%capture

# raw_unfiltered = raw.copy()

raw_filtered.filter(picks="eeg", l_freq=8, h_freq=30, method="iir" , iir_params=None)     # band pass filters

# fig, ax = plt.subplots(2)

# raw_unfiltered.plot_psd(ax=ax[0], show=False)
# raw.plot_psd(ax=ax[1], show=False)

# ax[0].set_title('PSD before filtering')
# ax[1].set_title('PSD after filtering')
# ax[1].set_xlabel('Frequency (Hz)')

# fig.set_tight_layout(True)

# plt.show()  # block execution and analyze plots

In [1084]:
events, event_id = mne.events_from_annotations(raw_filtered)

Used Annotations descriptions: ['cond1', 'cond2']


In [1085]:
epochs = mne.Epochs(
    raw_filtered, 
    events, 
    event_id=event_id, 
    tmin=tmin, 
    tmax=tmax, 
    baseline=baseline,
    reject_by_annotation=True)

Not setting metadata
60 matching events found
Setting baseline interval to [-5.0, -2.0] s
Applying baseline correction (mode: mean)
0 projection items activated


In [1086]:
epochs.save(f"epochs_eeg/Dataset_{dataset}/subject_{subject}_epo.fif", overwrite=True)

Overwriting existing file.
Using data from preloaded Raw for 60 events and 5001 original time points ...
0 bad epochs dropped
Using data from preloaded Raw for 1 events and 5001 original time points ...
Overwriting existing file.
Using data from preloaded Raw for 60 events and 5001 original time points ...
