In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio, display

import mne

# Dataset

[Link to OpenNeuro](https://openneuro.org/datasets/ds002725/versions/1.0.0)  
[Link to Paper](https://www.nature.com/articles/s41598-019-45105-2)  

**Recording**  

EEG was recorded via a 32 channel (31 channel EEG and 1 channel electrocardiogram) MRI-compatible BrainAmp MR and BrainCap MR EEG system (Brain Products Inc., Germany). EEG was recorded `at 5,000 Hz, without filtering` (an analogous approach to [60]), and with an amplitude resolution of 0.5uV. The reference electrode was placed at FCz. All electrodes were placed according to `the International 10/20 system`. Impedances were kept below 15kΩ throughout the experiments.

**Preprocessing**  

The imaging artefact was first attenuated using the Average Artefact Subtraction (AAS) method62, as implemented in Vision Analyzer software (BrainProducts). The ballisto-cardiogram artefact was also removed from the EEG via the AAS method. The cleaned EEG was then visually checked to confirm successful attenuation of the artefacts.

In [None]:
SUB_ID = '01'
GEN_ID = '01'
EXT_ROOT = ''
SAVE_PATH = 'preprocessing/data/'

In [None]:
# FILE_PATH = 'data/openneuro/sub-01/eeg/sub-01_task-genMusic01_eeg.edf'
ROOT_PATH = f'data/openneuro/sub-{SUB_ID}/eeg'
RAW_FILE = f'sub-{SUB_ID}_task-genMusic{GEN_ID}_eeg.edf'
INFO_FILE = f'sub-{SUB_ID}_task-genMusic{GEN_ID}_channels.tsv'
raw = mne.io.read_raw_edf(os.path.join(EXT_ROOT, ROOT_PATH, RAW_FILE))

In [None]:
raw.info

# Channels Info

In [None]:
ch_info = pd.read_csv(os.path.join(ROOT_PATH, INFO_FILE), sep='\t')
ch_info

In [None]:
ch_mapping = {
    'EEG': 'eeg',
    'ECG': 'ecg',
    'response': 'resp',
    'reponse': 'resp', # typo
    'stimuli': 'stim'
}
ch_types = dict(zip(raw.ch_names, ch_info['type'].apply(lambda x: ch_mapping[x])))
raw.set_channel_types(ch_types)

# Basic Plotting

In [None]:
%matplotlib inline
_ = raw.plot(n_channels=31, scalings='1e-4')

In [None]:
montage_1020 = mne.channels.make_standard_montage('standard_1020')
raw.set_montage(montage_1020, on_missing='ignore')  # ignore other channels

In [None]:
%matplotlib inline
_ = raw.plot_sensors(show_names=True)

# PSD

[Link to MNE](https://mne.tools/stable/auto_tutorials/raw/40_visualize_raw.html)

In [None]:
data_channels = raw.ch_names[:31]

## Power line noise

In [None]:
%matplotlib inline
def add_arrows(axes):
    # add some arrows at 60 Hz and its harmonics
    for ax in axes:
        freqs = ax.lines[-1].get_xdata()
        psds = ax.lines[-1].get_ydata()
        for freq in (60, 120, 180, 240):
            idx = np.searchsorted(freqs, freq)
            # get ymax of a small region around the freq. of interest
            y = psds[(idx - 4) : (idx + 5)].max()
            ax.arrow(
                x=freqs[idx],
                y=y + 18,
                dx=0,
                dy=-12,
                color="red",
                width=0.1,
                head_width=3,
                length_includes_head=True,
            )

psd = raw.compute_psd(fmax=250).plot(average=True, picks="data", exclude="bads")
add_arrows(psd.axes[:2])

### TODO: Notch Filter

If figure above have peaks at power line's frequencies

## By Brain wavelength

In [None]:
# %matplotlib tk
# psd = raw.compute_psd(fmax=35, picks=data_channels)
# _ = psd.plot(picks='data')

In [None]:
%matplotlib inline
psd = raw.compute_psd(fmin=0.5, fmax=45, picks=data_channels)
_ = psd.plot(picks='data', average=True)

In [None]:
%matplotlib inline
_ = psd.plot_topomap()

# Preprocessing

[See Best Practices](https://mne.tools/stable/auto_tutorials/preprocessing/30_filtering_resampling.html#best-practices)

## Remove Bad Channels

In [None]:
assert len(raw.info['bads']) == 0
raw.info['bads']

## Setting

In [None]:
current_sfreq = raw.info["sfreq"]
desired_sfreq = 512  # Hz
decim = np.round(current_sfreq / desired_sfreq).astype(int)
obtained_sfreq = current_sfreq / decim
lowpass_freq = obtained_sfreq / 3.0

## Low-pass filtering

In [None]:
raw_filtered = raw.load_data().copy().filter(l_freq=None, h_freq=lowpass_freq)

## Re-compute Music events

In [None]:
ch_music = raw_filtered.ch_names.index('music')
ch_info.iloc[ch_music].status_description

In [None]:
raw_filtered.apply_function(lambda x: np.round(x * 20 * 1e6).astype(int), picks='music')

In [None]:
events = mne.find_events(raw_filtered, stim_channel='music', consecutive=True)

In [None]:
music_ids = list(set(events[:, 2]))
music_ids

In [None]:
music_root = 'data/openneuro/stimuli/generated'
music_files = list(map(lambda x: f'{x//100}-{x//10%10}_{x%10}.wav', music_ids))

In [None]:
%matplotlib inline
_ = mne.viz.plot_events(events, event_id=dict(zip(music_files, music_ids)))

## Set Annotations

In [None]:
raw_filtered.annotations

In [None]:
annot_from_events = mne.annotations_from_events(
    events=events,
    event_desc=dict(zip(music_ids, music_files)),
    sfreq=raw_filtered.info["sfreq"],
    orig_time=raw_filtered.info["meas_date"],
)
raw_filtered.set_annotations(annot_from_events)

In [None]:
%matplotlib inline
_ = raw_filtered.plot(start=190, duration=60, n_channels=31, scalings='1e-4')

## Down-sampling

[About Sample Rate](https://www.researchgate.net/post/What_is_the_advantage_of_very_high_sampling_rates_in_EEG_systems)

In [None]:
raw_filtered.resample(desired_sfreq)
raw_filtered.set_eeg_reference()

## Band Pass Filter

In [None]:
raw_filtered.filter(l_freq=0.5, h_freq=100) # 100 hz for ICA first

In [None]:
%matplotlib inline
_ = raw_filtered.plot(duration=40, n_channels=35, scalings='5e-4')

## Remove Artifact with ICA

In [None]:
ica = mne.preprocessing.ICA(
    n_components=15,
    max_iter="auto",
    method="infomax",
    random_state=42,
    fit_params=dict(extended=True),
)

In [None]:
ica.fit(raw_filtered, picks=data_channels)

In [None]:
%matplotlib inline
_ = ica.plot_components(colorbar=True)

In [None]:
%matplotlib inline
_ = ica.plot_sources(raw_filtered)

### Manual Selection [🔗](https://labeling.ucsd.edu/tutorial/labels)

In [None]:
# %matplotlib inline
# _ = ica.plot_properties(raw_filtered, picks=[1], )

### ICA Classification [🔗](https://mne.tools/mne-icalabel/stable/auto_examples/iclabel_automatic_artifact_correction_ica.html)

In [None]:
from mne_icalabel import label_components

In [None]:
cls_components = label_components(raw_filtered, ica, method="iclabel")

#### Label Components

In [None]:
ICA_THRESHOLD = 0.9
REMOVE_BRAIN = False
REMOVE_MUSCLE = True
REMOVE_EOG = True
REMOVE_ECG = False
REMOVE_NOISE = True
REMOVE_OTHER = False
ica.exclude = []

##### Brain

In [None]:
for idx in ica.labels_['brain']:
    if REMOVE_BRAIN and cls_components['y_pred_proba'][idx] > ICA_THRESHOLD:
        ica.exclude.append(idx)

In [None]:
%matplotlib inline
for pick in ica.labels_['brain']:
    if pick in ica.exclude:
        _ = ica.plot_properties(raw_filtered, picks=[pick])
        print(f"ICA{pick:03d}, Predict: {cls_components['labels'][pick]} ({cls_components['y_pred_proba'][pick]*100:.2f}%)")

##### Muscle

In [None]:
for idx in ica.labels_['muscle']:
    if REMOVE_MUSCLE and cls_components['y_pred_proba'][idx] > ICA_THRESHOLD:
        ica.exclude.append(idx)

In [None]:
%matplotlib inline
for pick in ica.labels_['muscle']:
    if pick in ica.exclude:
        _ = ica.plot_properties(raw_filtered, picks=[pick])
        print(f"ICA{pick:03d}, Predict: {cls_components['labels'][pick]} ({cls_components['y_pred_proba'][pick]*100:.2f}%)")

##### EOG (Eyes)

In [None]:
for idx in ica.labels_['eog']:
    if REMOVE_EOG and cls_components['y_pred_proba'][idx] > ICA_THRESHOLD:
        ica.exclude.append(idx)

In [None]:
%matplotlib inline
for pick in ica.labels_['eog']:
    if pick in ica.exclude:
        _ = ica.plot_properties(raw_filtered, picks=[pick])
        print(f"ICA{pick:03d}, Predict: {cls_components['labels'][pick]} ({cls_components['y_pred_proba'][pick]*100:.2f}%)")

##### ECG

In [None]:
for idx in ica.labels_['ecg']:
    if REMOVE_ECG and cls_components['y_pred_proba'][idx] > ICA_THRESHOLD:
        ica.exclude.append(idx)

In [None]:
%matplotlib inline
for pick in ica.labels_['ecg']:
    if pick in ica.exclude:
        _ = ica.plot_properties(raw_filtered, picks=[pick])
        print(f"ICA{pick:03d}, Predict: {cls_components['labels'][pick]} ({cls_components['y_pred_proba'][pick]*100:.2f}%)")

##### Line Noise

In [None]:
for idx in ica.labels_['line_noise']:
    if REMOVE_NOISE and cls_components['y_pred_proba'][idx] > ICA_THRESHOLD:
        ica.exclude.append(idx)

In [None]:
%matplotlib inline
for pick in ica.labels_['line_noise']:
    if pick in ica.exclude:
        _ = ica.plot_properties(raw_filtered, picks=[pick])
        print(f"ICA{pick:03d}, Predict: {cls_components['labels'][pick]} ({cls_components['y_pred_proba'][pick]*100:.2f}%)")

##### Channel Noise

In [None]:
for idx in ica.labels_['ch_noise']:
    if REMOVE_NOISE and cls_components['y_pred_proba'][idx] > ICA_THRESHOLD:
        ica.exclude.append(idx)

In [None]:
%matplotlib inline
for pick in ica.labels_['ch_noise']:
    if pick in ica.exclude:
        _ = ica.plot_properties(raw_filtered, picks=[pick])
        print(f"ICA{pick:03d}, Predict: {cls_components['labels'][pick]} ({cls_components['y_pred_proba'][pick]*100:.2f}%)")

##### Other

In [None]:
for idx in ica.labels_['other']:
    if REMOVE_OTHER and cls_components['y_pred_proba'][idx] > ICA_THRESHOLD:
        ica.exclude.append(idx)

In [None]:
%matplotlib inline
for pick in ica.labels_['other']:
    if pick in ica.exclude:
        _ = ica.plot_properties(raw_filtered, picks=[pick])
        print(f"ICA{pick:03d}, Predict: {cls_components['labels'][pick]} ({cls_components['y_pred_proba'][pick]*100:.2f}%)")

#### Remove Noise Channels

In [None]:
reconst_raw = raw_filtered.copy()
ica.apply(reconst_raw)
reconst_raw.filter(l_freq=0.5, h_freq=45)

In [None]:
%matplotlib inline
_ = raw_filtered.plot(title='original', scalings='1e-3', duration=600)
_ = reconst_raw.plot(title='ICA', scalings='1e-3', duration=600)

### By EOG/ECG Channels [🔗](https://mne.tools/stable/auto_tutorials/preprocessing/40_artifact_correction_ica.html#using-an-eog-channel-to-select-ica-components) (Not working, No EOG, Bad ECG)

##### EOG  
Since no EOG channel in the dataset, we skip this parts.

In [None]:
# eog_indices, eog_scores = ica.find_bads_eog(raw)
# ica.exclude += eog_indices

# # barplot of ICA component "EOG match" scores
# ica.plot_scores(eog_scores)

# # plot diagnostics
# ica.plot_properties(raw, picks=eog_indices)

# # plot ICs applied to raw data, with EOG matches highlighted
# ica.plot_sources(raw, show_scrollbars=False)

# # plot ICs applied to the averaged EOG epochs, with EOG matches highlighted
# ica.plot_sources(eog_evoked)

##### ECG 🚩

In [None]:
# %matplotlib inline
# _ = raw_filtered.copy().pick('ecg').plot(n_channels=1, duration=60, start=120, scalings='5e-3')

Creating ECG epochs [🔗](https://mne.tools/stable/auto_tutorials/preprocessing/40_artifact_correction_ica.html#visualizing-the-artifacts)

In [None]:
# %matplotlib inline
# ecg_epochs = mne.preprocessing.create_ecg_epochs(raw_filtered)
# ecg_epochs.plot_image(combine="mean")

In [None]:
# avg_ecg_epochs = ecg_epochs.copy().average().apply_baseline((None, 0))

In [None]:
# %matplotlib inline
# avg_ecg_epochs.plot_topomap(times=np.linspace(-0.05, 0.05, 11))

In [None]:
# ecg_evoked = mne.preprocessing.create_ecg_epochs(raw_filtered).average()
# ecg_evoked.apply_baseline(baseline=(None, -0.2))
# ecg_evoked.plot_joint()

In [None]:
# # 'correlation' method uses ECG channel as reference.
# ecg_indices, ecg_scores = ica.find_bads_ecg(raw_filtered, method="correlation", threshold="auto")

In [None]:
# ica.exclude += ecg_indices

# # barplot of ICA component "ECG match" scores
# ica.plot_scores(ecg_scores)

# # plot diagnostics
# if len(ecg_indices) > 0:
#     ica.plot_properties(raw, picks=ecg_indices)
# else:
#     print('Not matching with ECG.')

# # plot ICs applied to raw data, with ECG matches highlighted
# ica.plot_sources(raw, show_scrollbars=False)

# # plot ICs applied to the averaged ECG epochs, with ECG matches highlighted
# ica.plot_sources(ecg_evoked)

# Save Result

In [None]:
reconst_raw.save(os.path.join(SAVE_PATH, f'ica_sub-{SUB_ID}_task-genMusic{GEN_ID}_eeg.fif'), overwrite=True)
reconst_raw