In [2]:
from ica_benchmark.io.load import Physionet_2009_Dataset, BCI_IV_Comp_Dataset, OpenBMI_Dataset
import mne
from pathlib import Path
from collections import namedtuple

mne.set_log_level(False)

physionet_dataset_folderpath = Path('/home/paulo/Documents/datasets/Physionet')
bci_dataset_folderpath = Path('/home/paulo/Documents/datasets/BCI_Comp_IV_2a/gdf/')
bci_test_dataset_folderpath = Path('/home/paulo/Documents/datasets/BCI_Comp_IV_2a/true_labels/')
openbmi_dataset_folderpath = Path('/home/paulo/Documents/datasets/OpenBMI/edf/')

physionet_dataset = Physionet_2009_Dataset(physionet_dataset_folderpath)
bci_dataset = BCI_IV_Comp_Dataset(bci_dataset_folderpath, test_folder=bci_test_dataset_folderpath)
openbmi_dataset = OpenBMI_Dataset(openbmi_dataset_folderpath)

In [3]:
def make_epochs_splits(arr, n=None, n_splits=2, sizes=None, shuffle=False, seed=1):
    if isinstance(arr, (tuple, list)):
        arr = np.array(arr)
    np.random.seed(seed)
    n = n or len(arr)
    sizes = sizes or [1 / n_splits] * n_splits

    assert np.sum(sizes) == 1.
    
    sizes = np.cumsum(
        [0] + [int(size * n) for size in sizes]
    )
    
    idx = np.arange(n)
    if shuffle:
        np.random.shuffle(idx)
    slices = [slice(start, end) for start, end in zip(sizes[:-1], sizes[1:])]
    arrs = [arr[idx[s]] for s in slices]
    return arrs


In [4]:
from copy import deepcopy

class Splitter():
    
    INTER_SESSION = True
    INTER_SUBJECT = True
    TRAIN_TEST = True
    
    UNIQUE_SESSION = False
    
    SESSION_KWARGS = dict(intra=dict(), inter=dict())
    
    def __init__(self, dataset, uids, sessions, train_folds, load_kwargs=None, fold_sizes=None):
        self.dataset = dataset
        self.uids = uids
        self.sessions = sessions
        self.train_folds = train_folds
        self.load_kwargs = load_kwargs or load_kwargs
        self.fold_sizes = fold_sizes

    def inter_subject_splitting(self, shuffle=False, sizes=[.5, .5], seed=1):
        np.random.seed(seed)
        uids = deepcopy(self.uids)
        if shuffle:
            np.random.shuffle(uids)
        splits_uids = make_epochs_splits(uids, sizes=sizes)
        # [(info, epochs), (info, epochs), ...]
        splits = [
            (
                dict(
                    sessions=self.sessions,
                    train_folds=self.train_folds,
                    uid=split_uids
                ),
                mne.concatenate_epochs(
                    [
                        self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
                        for session in self.sessions
                        for train in self.train_folds
                        for uid in split_uids
                    ]
                )
            )
            for split_uids in splits_uids
        ]
        return splits

    def inter_session_splitting(self):
        assert len(self.sessions) > 1, "You are using the inter session protocol, but only passed 1 session"
        splits = [
            (
                dict(uids=self.uids, sessions=[session], train_folds=self.train_folds),
                mne.concatenate_epochs(
                    [
                        self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
                        for train
                        in self.train_folds
                    ]
                )
            )
            for uid in self.uids
            for session in self.sessions
        ]
        return splits

    def intra_session_default_splitting(self):
        
        assert len(self.sessions) == 1, "Your are using an intra session splitting but had more than 1 session"
        splits = [
            (
                dict(uid=self.uids, sessions=[session], train_folds=[train]),
                self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
            )
            for session in self.sessions
            for train in self.train_folds
        ]
        return splits

    def intra_session_splitting(self, uid, session):
        epochs = mne.concatenate_epochs(
            [
                self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
                for train in self.train_folds
            ]
        )
        splits = make_epochs_splits(
            epochs,
            n=len(epochs.events),
            sizes=self.fold_sizes
        )
        splits = [
            (
                dict(uid=self.uids, session=session, train_folds=self.train_folds, size=size),
                split
            )
            for split, size
            in zip(splits, self.fold_sizes)
        ]
        return splits


    def make_splits(self, inter_session=True, inter_subject=False):

        assert all(np.isin(self.uids, self.dataset.list_uids()))
        session_key = "inter" if inter_session else "intra"
        subject_key = "inter" if inter_subject else "intra"
        
        splits = list()

        # Inter session for the inter subject protocol does not make sense.
        # So, if inter_subject is true, inter_session does not matter
        if inter_subject:
            splits = self.inter_subject_splitting()
            yield splits
            
        # Intra subject
        else:
            
            for uid in self.uids:
                if inter_session:
                    splits = self.inter_session_splitting()
                    yield splits

                else:
                    
                    if self.fold_sizes is None:
                        splits = self.intra_session_default_splitting()
                    else:
                        assert len(self.sessions) == 1, "You are using the intra session protocol but passed more than 1 session"
                        session = self.sessions[0]
                        splits = self.intra_session_splitting(uid, session)
                        
                    yield splits
import numpy as np
kwargs = dict(
    inter_session=False,
    inter_subject=False
)
splitter = Splitter(
    openbmi_dataset,
    uids=["1", "2"],
    sessions=[1],
    train_folds=[True, False],
    load_kwargs=dict(
        reject=False
    )
#     fold_sizes=[.4, .6]
)
for j, splits in enumerate(splitter.make_splits(**kwargs)):
    print(f"FOLD {j}")
    for i, (info, split) in enumerate(splits):
        print(f"Split {i}")
        print("\tInfo", info)
        print("\tSplit", split)
        print()


NameError: name 'uid' is not defined

In [161]:
def group_iterator(*args):
    split_args_iter = [
        arg
        for arg
        in args
        if arg.to_split
    ]
    merge_args_iter = [
        arg
        for arg
        in args
        if not arg.to_split
    ]
    # Split loop
    for split_args in product(*[arg.arg_list for arg in split_args_iter]):
                
        split_arg_dict = {
            arg.name: split_args[i]
            for i, arg
            in enumerate(split_args_iter)
        }
        # Merge loop
        merge_args_list = list()
        for merge_args in product(*[arg.arg_list for arg in merge_args_iter]):
            merge_arg_dict = {
                arg.name: merge_args[i]
                for i, arg
                in enumerate(merge_args_iter)
            }
            run_kwargs = {**split_arg_dict, **merge_arg_dict}
            merge_args_list.append(run_kwargs)
        yield merge_args_list


load_kwargs=dict(
    tmin=1,
    tmax=3.5,
    reject=False
)

from copy import deepcopy
def split(dataset, uids, sessions, train_folds, fold_sizes, load_kwargs):
    kwargs_iterator = group_iterator(
        SplitArg("uid", uids, True),
        SplitArg("session", sessions, False),
        SplitArg("train", train_folds, False),
    )
    epochs_list = list()
    
    for splits_kwargs in kwargs_iterator:
        split_epochs = list()
        for split_kwargs in splits_kwargs:
            kwargs = deepcopy(split_kwargs)
            uid = kwargs.pop("uid")
            epochs, _ = dataset.load_subject(uid, **kwargs, **load_kwargs)
            split_epochs.append(epochs)
        print(splits_kwargs)
        print()
        if fold_sizes is not None:
            split_epochs = mne.concatenate_epochs(split_epochs)
            split_epochs = make_epochs_splits(split_epochs, n=len(split_epochs.events), sizes=fold_sizes)
        yield split_epochs
#         yield splits_kwargs

#         epochs_list.append(epochs)
#     epochs = mne.concatenate_epochs(epochs_list)
#     yield split
for k in split(openbmi_dataset, ["1", "2"], [1, 2], [True, False], None, load_kwargs):
#     print(k)
    pass

[{'uid': '1', 'session': 1, 'train': True}, {'uid': '1', 'session': 1, 'train': False}, {'uid': '1', 'session': 2, 'train': True}, {'uid': '1', 'session': 2, 'train': False}]

[{'uid': '2', 'session': 1, 'train': True}, {'uid': '2', 'session': 1, 'train': False}, {'uid': '2', 'session': 2, 'train': True}, {'uid': '2', 'session': 2, 'train': False}]



In [158]:
k[2]

IndexError: list index out of range

In [76]:
len(k)

2

In [23]:
from itertools import product
    
class SplitArg():
    def __init__(self, name, arg_list, to_split):
        self.to_split = to_split
        self.arg_list = arg_list
        self.name = name

DEFAULT_FOLDS = [True, False]
DEFAULT_SESSIONS = [1, 2]
DEFAULT_UIDS = ["1"]

def split(dataset, uids, sessions, train_folds, fold_sizes, load_kwargs):
    
    args = [uids, sessions, train_folds, fold_sizes]
    split_args_iter = [
        arg
        for arg
        in [uids, sessions, train_folds]
        if arg.to_split
    ]
    merge_args_iter = [
        arg
        for arg
        in [uids, sessions, train_folds]
        if not arg.to_split
    ]
    
    data = dict()
    
    # Split loop
    for split_args in product(*[arg.arg_list for arg in split_args_iter]):
                
        # Merge loop
        split_arg_dict = {
            arg.name: split_args[i]
            for i, arg
            in enumerate(split_args_iter)
        }
        epochs_list = list()
        for merge_args in product(*[arg.arg_list for arg in merge_args_iter]):
            merge_arg_dict = {
                arg.name: merge_args[i]
                for i, arg
                in enumerate(merge_args_iter)
            }
            run_kwargs = {**split_arg_dict, **merge_arg_dict, **load_kwargs}
            print(run_kwargs)
#             epochs, _ = dataset.load_subject(**run_kwargs)
#             epochs_list.append(epochs)
        
        if fold_sizes is not None:
            epochs = mne.concatenate_epochs(epochs_list)
            idx = np.arange(len(epochs))
            np.random.

        print()

split(
    None,
    SplitArg("uid", ["A", "B", "C"], False),
    SplitArg("session", [1, 2, 3], True),
    SplitArg("train", [True, False], False),
    [.3, .2, .5],
    dict()
)

{'session': 1, 'uid': 'A', 'train': True}
{'session': 1, 'uid': 'A', 'train': False}
{'session': 1, 'uid': 'B', 'train': True}
{'session': 1, 'uid': 'B', 'train': False}
{'session': 1, 'uid': 'C', 'train': True}
{'session': 1, 'uid': 'C', 'train': False}

{'session': 2, 'uid': 'A', 'train': True}
{'session': 2, 'uid': 'A', 'train': False}
{'session': 2, 'uid': 'B', 'train': True}
{'session': 2, 'uid': 'B', 'train': False}
{'session': 2, 'uid': 'C', 'train': True}
{'session': 2, 'uid': 'C', 'train': False}

{'session': 3, 'uid': 'A', 'train': True}
{'session': 3, 'uid': 'A', 'train': False}
{'session': 3, 'uid': 'B', 'train': True}
{'session': 3, 'uid': 'B', 'train': False}
{'session': 3, 'uid': 'C', 'train': True}
{'session': 3, 'uid': 'C', 'train': False}



In [77]:
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline
from sklearn.base import BaseEstimator
from sklearn.metrics import balanced_accuracy_score

class GetData(BaseEstimator):
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X, y=None):
        filter_kwargs = dict(
            method="iir",
            iir_params=dict(
                order=5,
                ftype="butter"
            )
        )
        return X.load_data().filter(8, 30, **filter_kwargs).get_data()


# kwargs = dict(
#     uids=["1", "2", "3", "4"],
#     inter_session=False,
#     inter_subject=True,
#     sessions=[2],
#     train_folds=[True, False],
#     fold_sizes=[.5, .5],
#     load_kwargs=dict(
#         reject=False
#     )
# )
kwargs = dict(
    uids=["1", "2", "3", "4"],
    inter_session=False,
    inter_subject=False,
    sessions=[1],
    train_folds=[True, False],
    fold_sizes=[.25, .75],
    load_kwargs=dict(
        tmin=1,
        tmax=3.5,
        reject=False
    )
)
for j, splits in enumerate(OpenBMI_Splitter(openbmi_dataset).make_splits(**kwargs)):
    infos = [info for info, split in splits]
    splits = [split for info, split in splits]
    print(infos)
    clf = make_pipeline(
        GetData(),
        CSP(5, log=True),
        LDA()
    )
    train_epochs = splits[0]
    train_labels = train_epochs.events[:, 2]
    
    test_epochs = splits[1]
    test_labels = test_epochs.events[:, 2]
    
    clf.fit(train_epochs, train_labels)
    pred = clf.predict(test_epochs)
    acc = balanced_accuracy_score(test_labels, pred)
    print(acc)

  epochs = mne.concatenate_epochs(


[{'uid': '1', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '1', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.5903271692745377


  epochs = mne.concatenate_epochs(


[{'uid': '2', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '2', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.8466749866286325


  epochs = mne.concatenate_epochs(


[{'uid': '3', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '3', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.8866999465145302


  epochs = mne.concatenate_epochs(


[{'uid': '4', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '4', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.47759601706970123


In [71]:
train_labels

array([0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,
       1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1,
       0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
       1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0])

In [46]:
[f"{i} {j}" for i in range(3) for j in range(3)]

['0 0', '0 1', '0 2', '1 0', '1 1', '1 2', '2 0', '2 1', '2 2']

In [20]:
def split_array(arr, n_splits=2, sizes=None):
    n = len(arr)
    sizes = sizes or [1 / n] * n_splits
    assert np.sum(sizes) == 1.
    sizes = [int(size * n) for size in sizes]
    sizes = [0] + sizes
    sizes = np.cumsum(sizes)
    slices = [slice(start, end) for start, end in zip(sizes[:-1], sizes[1:])]
    arrs = [arr[s] for s in slices]
    return arrs

arrs = split_array(np.random.rand(100), sizes=[.2, .3, .1, .4])
list(map(len, arrs))

[20, 30, 10, 40]

In [44]:
split_array(openbmi_dataset.list_uids(), sizes=[.2, .3, .5])

[array(['25', '15', '41', '12', '37', '2', '42', '6', '52', '18'],
       dtype=object),
 array(['38', '30', '34', '5', '11', '44', '29', '8', '17', '33', '46',
        '23', '19', '4', '13', '22'], dtype=object),
 array(['35', '45', '36', '39', '16', '50', '10', '53', '7', '28', '27',
        '14', '32', '20', '31', '51', '24', '9', '21', '40', '47', '26',
        '49', '54', '43', '1', '3'], dtype=object)]

In [41]:
e = bci_dataset.load_subject("1")[0]
split_array(e, n=len(e.events), sizes=[.2, .8])

  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  next(self.gen)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  next(self.gen)


[<Epochs |  57 events (good & bad), -0.3 - 0.7 sec, baseline off, ~26 kB, data not loaded>,
 <Epochs |  230 events (good & bad), -0.3 - 0.7 sec, baseline off, ~26 kB, data not loaded>]