In [474]:
%pip install -e git+https://github.com/UN-GCPDS/python-gcpds.MI_prediction.git#egg=MI_prediction

Obtaining MI_prediction from git+https://github.com/UN-GCPDS/python-gcpds.MI_prediction.git#egg=MI_prediction
  Updating ./src/mi-prediction clone
  Running command git fetch -q --tags
  Running command git reset --hard -q 02f0cf71c60b1ca389163f6a338a97bf25a445a6
  Preparing metadata (setup.py) ... [?25ldone
Installing collected packages: MI_prediction
  Attempting uninstall: MI_prediction
    Found existing installation: MI-prediction 0.1
    Uninstalling MI-prediction-0.1:
      Successfully uninstalled MI-prediction-0.1
  Running setup.py develop for MI_prediction
Successfully installed MI_prediction-0.1
Note: you may need to restart the kernel to use updated packages.


# Create Cho2017 resting state class

In [164]:
import logging

import numpy as np
from mne import create_info
from mne.channels import make_standard_montage
from mne.io import RawArray
from scipy.io import loadmat
from moabb.datasets import download as dl
from moabb.datasets.base import BaseDataset

from braindecode.preprocessing.preprocess import exponential_moving_standardize, preprocess, Preprocessor, scale
from braindecode.datasets import create_from_mne_raw
import pandas as pd
from braindecode.datasets import BaseDataset as BD
from braindecode.datasets import BaseConcatDataset
from braindecode.preprocessing.windowers import create_windows_from_events,_create_fixed_length_windows

In [165]:
log = logging.getLogger(__name__)
GIGA_URL = "ftp://parrot.genomics.cn/gigadb/pub/10.5524/100001_101000/100295/mat_data/"

In [166]:
class Cho2017_Rest(BaseDataset):
    def __init__(self):
        super().__init__(
            subjects=list(range(1, 53)),
            sessions_per_subject=1,
            events=dict(rest=1),
            code="Cho2017_Rest",
            interval=[0, 60],  # full trial is 0-3s, but edge effects
            paradigm="imagery",
            doi="10.5524/100295",
        )

        for ii in [32, 46, 49]:
            self.subject_list.remove(ii)

    def _get_single_subject_data(self, subject):
        """return data for a single subject"""
        fname = self.data_path(subject)

        data = loadmat(
            fname,
            squeeze_me=True,
            struct_as_record=False,
            verify_compressed_data_integrity=False,
        )["eeg"]

        # fmt: off
        eeg_ch_names = [
            "Fp1", "AF7", "AF3", "F1", "F3", "F5", "F7", "FT7", "FC5", "FC3", "FC1",
            "C1", "C3", "C5", "T7", "TP7", "CP5", "CP3", "CP1", "P1", "P3", "P5", "P7",
            "P9", "PO7", "PO3", "O1", "Iz", "Oz", "POz", "Pz", "CPz", "Fpz", "Fp2",
            "AF8", "AF4", "AFz", "Fz", "F2", "F4", "F6", "F8", "FT8", "FC6", "FC4",
            "FC2", "FCz", "Cz", "C2", "C4", "C6", "T8", "TP8", "CP6", "CP4", "CP2",
            "P2", "P4", "P6", "P8", "P10", "PO8", "PO4", "O2",
        ]
        # fmt: on
        emg_ch_names = ["EMG1", "EMG2", "EMG3", "EMG4"]
        ch_names = eeg_ch_names + emg_ch_names 
        ch_types = ["eeg"] * 64 + ["emg"] * 4 
        montage = make_standard_montage("standard_1005")
        resting = data.rest - data.rest.mean(axis=1, keepdims=True)
        
        eeg_rest = resting * 1e-6

        info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=data.srate)
        raw = RawArray(data=eeg_rest, info=info, verbose=False)
        raw.set_montage(montage)

        return {"session_0": {"run_0": raw}}

    def data_path(
        self, subject, path=None, force_update=False, update_path=None, verbose=None
    ):
        if subject not in self.subject_list:
            raise (ValueError("Invalid subject number"))

        url = "{:s}s{:02d}.mat".format(GIGA_URL, subject)
        return dl.data_dl(url, "GIGADB", path, force_update, verbose)

In [167]:
ds_r = Cho2017_Rest()

In [168]:
s1 = ds_r.get_data([1])

In [169]:
s1

{1: {'session_0': {'run_0': <RawArray | 68 x 34048 (66.5 s), ~17.8 MB, data loaded>}}}

In [170]:
def _fetch_and_unpack_moabb_data(dataset, subject_ids):
    data = dataset.get_data(subject_ids)
    raws, subject_ids, session_ids, run_ids = [], [], [], []
    for subj_id, subj_data in data.items():
        for sess_id, sess_data in subj_data.items():
            for run_id, raw in sess_data.items():
                raws.append(raw)
                subject_ids.append(subj_id)
                session_ids.append(sess_id)
                run_ids.append(run_id)
    description = pd.DataFrame({
        'subject': subject_ids,
        'session': session_ids,
        'run': run_ids
    })
    return raws, description

In [171]:
raws,description = _fetch_and_unpack_moabb_data(ds_r, [1])

In [None]:
raws

[<RawArray | 68 x 34048 (66.5 s), ~17.8 MB, data loaded>]

In [173]:
for i in description.iterrows():
    print(i)

(0, subject            1
session    session_0
run            run_0
Name: 0, dtype: object)


In [174]:
all_base_ds = [BD(raw, row)
                       for raw, (_, row) in zip(raws, description.iterrows())]

In [175]:
for raw, (_, row) in zip(raws, description.iterrows()):
    db = BD(raw, row)

In [176]:
db

<braindecode.datasets.base.BaseDataset at 0x299101880>

In [444]:
all_base_ds

[<braindecode.datasets.base.BaseDataset at 0x299e19970>]

In [182]:
class MOABBDataset_Rest(BaseConcatDataset):
    def __init__(self, dataset, subject_ids, dataset_kwargs=None):
        raws, description = _fetch_and_unpack_moabb_data(dataset, subject_ids)
        all_base_ds = [BD(raw, row)
                       for raw, (_, row) in zip(raws, description.iterrows())]
        super().__init__(all_base_ds)

In [472]:
dataset = MOABBDataset_Rest(dataset=Cho2017_Rest(), subject_ids=[1])

In [473]:
dataset.datasets[0]

<braindecode.datasets.base.BaseDataset at 0x2b7f79d90>

In [331]:
from mne import Epochs

In [332]:
sfreq = dataset.datasets[0].raw.info["sfreq"]

In [333]:
dataset.datasets[0].raw.annotations

<Annotations | 0 segments>

In [457]:
from braindecode.preprocessing.windowers import _compute_window_inds, _check_windowing_arguments, WindowsDataset
from joblib import Parallel, delayed

In [372]:
ds = dataset.datasets[0]

In [437]:
def _create_windows_from_events(
        ds, infer_window_size_stride,
        trial_start_offset_samples, trial_stop_offset_samples,
        window_size_samples=None, window_stride_samples=None,
        drop_last_window=False, preload=False,
        drop_bad_windows=True, picks=None, reject=None, flat=None,
        on_missing='error', accepted_bads_ratio=0.0):

    #trial_start_offset_samples = 0 #
    #trial_stop_offset_samples = 0 #
    #drop_last_window = False #
    #accepted_bads_ratio=0.0 #
    events_id = None #
    #drop_bad_windows = False #

    duration = int(ds.raw.n_times/ds.raw.info["sfreq"])
    onsets = np.array([0])
    stops = onsets+np.array([int(duration*ds.raw.info["sfreq"])])

    last_samp = ds.raw.first_samp + ds.raw.n_times
    if stops[-1] + trial_stop_offset_samples > last_samp:
        raise ValueError(
            '"trial_stop_offset_samples" too large. Stop of last trial '
            f'({stops[-1]}) + "trial_stop_offset_samples" '
            f'({trial_stop_offset_samples}) must be smaller than length of'
            f' recording ({len(ds)}).')

    window_size_samples = stops[0] + trial_stop_offset_samples - (onsets[0] + trial_start_offset_samples)
    window_stride_samples = window_size_samples

    i_trials, i_window_in_trials, starts, stops = _compute_window_inds(onsets, stops, trial_start_offset_samples,
        trial_stop_offset_samples, window_size_samples, window_stride_samples, drop_last_window,
        accepted_bads_ratio)
        
    description = -1
    events = [[start, window_size_samples, description]
                for i_start, start in enumerate(starts)]

    events = np.array(events)

    description = events[:, -1]

    metadata = pd.DataFrame({
        'i_window_in_trial': i_window_in_trials,
        'i_start_in_trial': starts,
        'i_stop_in_trial': stops,
        'target': description})

    mne_epochs = mne.Epochs(
        dataset.datasets[0].raw, events, events_id, baseline=None, tmin=0,
        tmax=(window_size_samples - 1) / ds.raw.info["sfreq"],
        metadata=metadata,)

    if drop_bad_windows:
            mne_epochs.drop_bad()

    windows_ds = WindowsDataset(mne_epochs, ds.description)
    return windows_ds

In [460]:
def create_windows_from_events(
        concat_ds, trial_start_offset_samples=0, trial_stop_offset_samples=0,
        window_size_samples=None, window_stride_samples=None,
        drop_last_window=False, preload=False,
        drop_bad_windows=True, picks=None, reject=None, flat=None,
        on_missing='error', accepted_bads_ratio=0.0, n_jobs=1):

    _check_windowing_arguments(
        trial_start_offset_samples, trial_stop_offset_samples,
        window_size_samples, window_stride_samples)

    infer_window_size_stride = window_size_samples is None

    list_of_windows_ds = Parallel(n_jobs=n_jobs)(
        delayed(_create_windows_from_events)(
            ds, infer_window_size_stride,
            trial_start_offset_samples, trial_stop_offset_samples,
            window_size_samples, window_stride_samples, drop_last_window,
            preload, drop_bad_windows, picks, reject, flat,
            on_missing, accepted_bads_ratio) for ds in concat_ds.datasets)

    return BaseConcatDataset(list_of_windows_ds)

In [461]:
trials = create_windows_from_events(dataset, trial_start_offset_samples=0, trial_stop_offset_samples=0)

Adding metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1 events and 33792 original time points ...
0 bad epochs dropped


# Test Cho2017 resting state class

In [4]:
from MI_prediction.Datasets.Moabb import MOABBDataset_Rest
from MI_prediction.Datasets import Cho2017_Rest
from MI_prediction.Utils.Windowers import create_windows_from_events

In [3]:
dataset = MOABBDataset_Rest(dataset=Cho2017_Rest(), subject_ids=[1])

In [5]:
sfreq = dataset.datasets[0].raw.info["sfreq"]
trials= create_windows_from_events(dataset,trial_start_offset_samples=int(0*sfreq),
                    trial_stop_offset_samples=int(0*sfreq), preload=True)

Adding metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1 events and 33792 original time points ...
0 bad epochs dropped


In [9]:
trials[0][0].shape

(68, 33792)