In [1]:
from ica_benchmark.io.load import Physionet_2009_Dataset, BCI_IV_Comp_Dataset, OpenBMI_Dataset
import mne
from pathlib import Path
from collections import namedtuple
import numpy as np
from itertools import product

mne.set_log_level(False)

physionet_dataset_folderpath = Path('/home/paulo/Documents/datasets/Physionet')
bci_dataset_folderpath = Path('/home/paulo/Documents/datasets/BCI_Comp_IV_2a/gdf/')
bci_test_dataset_folderpath = Path('/home/paulo/Documents/datasets/BCI_Comp_IV_2a/true_labels/')
openbmi_dataset_folderpath = Path('/home/paulo/Documents/datasets/OpenBMI/edf/')

physionet_dataset = Physionet_2009_Dataset(physionet_dataset_folderpath)
bci_dataset = BCI_IV_Comp_Dataset(bci_dataset_folderpath, test_folder=bci_test_dataset_folderpath)
openbmi_dataset = OpenBMI_Dataset(openbmi_dataset_folderpath)

In [2]:
from mne import Epochs

def make_epochs_splits_indexes(arr, n=None, n_splits=2, sizes=None, shuffle=False, seed=1):
    if isinstance(arr, (tuple, list)):
        arr = np.array(arr)
    np.random.seed(seed)
    if n is None:
        if isinstance(arr, Epochs):
            n = len(arr.events)
        else:
            n = len(arr)
    sizes = sizes or [1 / n_splits] * n_splits

    assert np.sum(sizes) == 1.
    
    sizes = np.cumsum(
        [0] + [int(size * n) for size in sizes]
    )
    
    idx = np.arange(n)
    if shuffle:
        np.random.shuffle(idx)
    slices = [slice(start, end) for start, end in zip(sizes[:-1], sizes[1:])]
    indexes = [idx[s] for s in slices]
    return indexes

def make_epochs_splits(arr, n=None, n_splits=2, sizes=None, shuffle=False, seed=1):
    indexes = make_epochs_splits_indexes(arr, n=n, n_splits=n_splits, sizes=sizes, shuffle=shuffle, seed=seed)
    arrs = [arr[idx] for idx in indexes]
    return arrs

make_epochs_splits(np.linspace(10, 100, 20), sizes=[.3, .7], shuffle=False)

[array([10.        , 14.73684211, 19.47368421, 24.21052632, 28.94736842,
        33.68421053]),
 array([ 38.42105263,  43.15789474,  47.89473684,  52.63157895,
         57.36842105,  62.10526316,  66.84210526,  71.57894737,
         76.31578947,  81.05263158,  85.78947368,  90.52631579,
         95.26315789, 100.        ])]

In [3]:
from copy import deepcopy

class Splitter():
    
    INTER_SESSION = True
    INTER_SUBJECT = True
    TRAIN_TEST = True
    
    UNIQUE_SESSION = False
    
    SESSION_KWARGS = dict(intra=dict(), inter=dict())
    
    def __init__(self, dataset, uids, sessions, train_folds, load_kwargs=None, fold_sizes=None):
        self.dataset = dataset
        self.uids = uids
        self.sessions = sessions
        self.train_folds = train_folds
        self.load_kwargs = load_kwargs or load_kwargs
        self.fold_sizes = fold_sizes

    def inter_subject_splitting(self, shuffle=False, sizes=[.5, .5], seed=1):
        np.random.seed(seed)
        uids = deepcopy(self.uids)
        if shuffle:
            np.random.shuffle(uids)
        splits_uids = make_epochs_splits(uids, sizes=sizes)
        # [(info, epochs), (info, epochs), ...]
        splits = [
            (
                dict(
                    sessions=self.sessions,
                    train_folds=self.train_folds,
                    uid=split_uids
                ),
                mne.concatenate_epochs(
                    [
                        self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
                        for session in self.sessions
                        for train in self.train_folds
                        for uid in split_uids
                    ]
                )
            )
            for split_uids in splits_uids
        ]
        return splits

    def inter_session_splitting(self):
        assert len(self.sessions) > 1, "You are using the inter session protocol, but only passed 1 session"
        splits = [
            (
                dict(uids=self.uids, sessions=[session], train_folds=self.train_folds),
                mne.concatenate_epochs(
                    [
                        self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
                        for train
                        in self.train_folds
                    ]
                )
            )
            for uid in self.uids
            for session in self.sessions
        ]
        return splits

    def intra_session_default_splitting(self):
        
        assert len(self.sessions) == 1, "Your are using an intra session splitting but had more than 1 session"
        splits = [
            (
                dict(uid=self.uids, sessions=[session], train_folds=[train]),
                self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
            )
            for session in self.sessions
            for train in self.train_folds
        ]
        return splits

    def intra_session_splitting(self, uid, session):
        epochs = mne.concatenate_epochs(
            [
                self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
                for train in self.train_folds
            ]
        )
        splits = make_epochs_splits(
            epochs,
            n=len(epochs.events),
            sizes=self.fold_sizes
        )
        splits = [
            (
                dict(uid=self.uids, session=session, train_folds=self.train_folds, size=size),
                split
            )
            for split, size
            in zip(splits, self.fold_sizes)
        ]
        return splits


    def make_splits(self, inter_session=True, inter_subject=False):

        assert all(np.isin(self.uids, self.dataset.list_uids()))
        session_key = "inter" if inter_session else "intra"
        subject_key = "inter" if inter_subject else "intra"
        
        splits = list()

        # Inter session for the inter subject protocol does not make sense.
        # So, if inter_subject is true, inter_session does not matter
        if inter_subject:
            splits = self.inter_subject_splitting()
            yield splits
            
        # Intra subject
        else:
            
            for uid in self.uids:
                if inter_session:
                    splits = self.inter_session_splitting()
                    yield splits

                else:
                    
                    if self.fold_sizes is None:
                        splits = self.intra_session_default_splitting()
                    else:
                        assert len(self.sessions) == 1, "You are using the intra session protocol but passed more than 1 session"
                        session = self.sessions[0]
                        splits = self.intra_session_splitting(uid, session)
                        
                    yield splits
import numpy as np
kwargs = dict(
    inter_session=False,
    inter_subject=False
)
splitter = Splitter(
    openbmi_dataset,
    uids=["1", "2"],
    sessions=[1],
    train_folds=[True, False],
    load_kwargs=dict(
        reject=False
    )
#     fold_sizes=[.4, .6]
)
for j, splits in enumerate(splitter.make_splits(**kwargs)):
    print(f"FOLD {j}")
    for i, (info, split) in enumerate(splits):
        print(f"Split {i}")
        print("\tInfo", info)
        print("\tSplit", split)
        print()


NameError: name 'uid' is not defined

In [4]:
from itertools import product

class SplitArg():
    def __init__(self, name, arg_list, to_split):
        self.name = name
        self.arg_list = arg_list
        self.to_split = to_split
        

def group_iterator(*args):
    split_args_iter = [
        arg
        for arg
        in args
        if arg.to_split
    ]
    merge_args_iter = [
        arg
        for arg
        in args
        if not arg.to_split
    ]
    # Split loop
    for split_args in product(*[arg.arg_list for arg in split_args_iter]):
                
        split_arg_dict = {
            arg.name: split_args[i]
            for i, arg
            in enumerate(split_args_iter)
        }
        # Merge loop
        merge_args_list = list()
        for merge_args in product(*[arg.arg_list for arg in merge_args_iter]):
            merge_arg_dict = {
                arg.name: merge_args[i]
                for i, arg
                in enumerate(merge_args_iter)
            }
            run_kwargs = {**split_arg_dict, **merge_arg_dict}
            merge_args_list.append(run_kwargs)
        yield merge_args_list


load_kwargs=dict(
    tmin=1,
    tmax=3.5,
    reject=False
)

from copy import deepcopy
def split(dataset, uids, sessions, train_folds, fold_sizes, load_kwargs):
    kwargs_iterator = group_iterator(
        SplitArg("uid", uids, True),
        SplitArg("session", sessions, False),
        SplitArg("train", train_folds, True),
    )
    epochs_list = list()
    
    for splits_kwargs in kwargs_iterator:
        split_epochs = list()
#         for split_kwargs in splits_kwargs:
#             kwargs = deepcopy(split_kwargs)
#             uid = kwargs.pop("uid")
#             epochs, _ = dataset.load_subject(uid, **kwargs, **load_kwargs)
#             split_epochs.append(epochs)
        print(splits_kwargs)
        print()
#         if fold_sizes is not None:
#             split_epochs = mne.concatenate_epochs(split_epochs)
#             split_epochs = make_epochs_splits(split_epochs, n=len(split_epochs.events), sizes=fold_sizes)
#         yield split_epochs
        yield splits_kwargs

#         epochs_list.append(epochs)
#     epochs = mne.concatenate_epochs(epochs_list)
#     yield split
for k in split(openbmi_dataset, ["1", "2"], [1, 2], [True, False], None, load_kwargs):
#     print(k)
    pass

[{'uid': '1', 'train': True, 'session': 1}, {'uid': '1', 'train': True, 'session': 2}]

[{'uid': '1', 'train': False, 'session': 1}, {'uid': '1', 'train': False, 'session': 2}]

[{'uid': '2', 'train': True, 'session': 1}, {'uid': '2', 'train': True, 'session': 2}]

[{'uid': '2', 'train': False, 'session': 1}, {'uid': '2', 'train': False, 'session': 2}]



In [5]:
iters = [
    [1, 2, 3],
    ["a", "b", "c", "d"]
]
(1, "a"), (1, "b"), (1, "c"), (1, "d")

iters = [
    [1, 2, 3],
    [
        ["a", "b"], ["A", "B"]
    ],
    [True, False]
]
(1, "a"), (1, "b"), (1, "c"), (1, "d")


iters = [
    [1, 2, 3],
    [
        ["a", "b"], ["A", "B"]
    ]
]
(1, "a", "A"), (1, "a", "B"), (1, "b", "A"), (1, "b", "B")

def unpack_product(iters):
    iters = product(*iters)
    for iter_list in iters:
        print(iter_list)
        is_list = False
        for i, iter_value in enumerate(iter_list):
            is_list = is_list or isinstance(iter_value, list)
        for i, iter_value in enumerate(iter_list):
            iter_list[i] = [iter_list[i]] if is_list else iter_list[i]
    return iter_list
unpack_product(iters)

(1, ['a', 'b'])


TypeError: 'tuple' object does not support item assignment

In [158]:
from sklearn.model_selection import KFold, LeaveOneOut

cv = LeaveOneOut()
for a, b in cv.split(np.arange(100)):
    print(a)
    print(b)
    print()

[ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
 97 98 99]
[0]

[ 0  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
 97 98 99]
[1]

[ 0  1  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
 97 98 99]
[2]

[ 0  1  2  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25

In [197]:
a = dict(a=1, b=2)
a.pop('a')

1

In [20]:
import pandas as pd
from sklearn.model_selection import KFold, LeaveOneOut
from itertools import tee

def remove_key(d, k):
    return {
        key: value
        for key, value in d.items()
        if key != k
    }

class Split():
    
    def __init__(self, kwarg_dict_list, fold_sizes=None):
        self.kwargs_list = kwarg_dict_list
        self.fold_sizes = fold_sizes
    
    def to_dataframe():
        return pd.DataFrame.from_records(self.kwargs_list)

    def __repr__(self):
        return str(self.kwargs_list)
    
    def load_epochs(self, dataset, **load_kwargs):
        epochs = mne.concatenate_epochs(
            [
                dataset.load_subject(kwargs["uid"], **remove_key(kwargs, "uid"), **load_kwargs)[0]
                for kwargs in self.kwargs_list
            ]
        )
        if self.fold_sizes is None:
            return epochs
        else:
            return make_epochs_splits(epochs, sizes=self.fold_sizes)
        


df = pd.DataFrame(product(openbmi_dataset.list_uids(), ["1", "2"], [True, False]), columns=["uid", "session", "train"])

k = ["uid", "session"]
# df.sort_values(by=k).set_index(k)

uids = openbmi_dataset.list_uids()[:3]
sessions = [1, 2]
train_folds = [True, False]


# for uid, session in product(uids, sessions):
#     a = df.query("uid == @uid").query("session == @session")
#     display(a)
    
def inter_subject(splitter):
    for uid_splits_idxs in splitter.split(uids):
        splits_uids = [uids[idx] for idx in uid_splits_idxs]
        yield [
            Split(
                [
                    dict(
                        uid=uid,
                        session=session,
                        train=train_fold
                    )
                    for uid in split_uids
                    for session in sessions
                    for train_fold in train_folds
                ]
            )
            for split_uids in splits_uids
        ]
    
def inter_session():
    for uid in uids:
        yield [
            Split(
                [
                    dict(
                        uid=uid,
                        session=session,
                        train=train_fold
                    )
                    for train_fold in train_folds
                ]
            )
            for session in sessions
        ]
        
def intra_session(fold_sizes=None):
    for uid in uids:
        for session in sessions:
            if fold_sizes is None:
                yield [
                    Split(
                        [
                            dict(
                                uid=uid,
                                session=session,
                                train=train_fold
                            )
                        ],
                        None
                    )
                    for train_fold in train_folds
                ]
            else:
                yield [
                    Split(
                        [
                            dict(
                                uid=uid,
                                session=session,
                                train=train_fold
                            )
                            for train_fold in train_folds
                        ],
                        fold_sizes
                    )
                ]
                
load_kwargs=dict(
    reject=False
)
fold_sizes = [.4, .6]
split_fn = inter_session
for splits in inter_session():
    if (split_fn is inter_session) and (fold_sizes is not None):
        splits_folds_epochs = list()
        for split in splits:
            epochs_list = make_epochs_splits(
                split.load_epochs(
                    openbmi_dataset,
                    **load_kwargs
                ),
                sizes=fold_sizes
            )
            splits_folds_epochs.append(
                epochs_list
            )
        folds_epochs_list = list(zip(*splits_folds_epochs))
        break
        print(folds_epochs_list)
        split_epochs = [mne.concatenate_epochs(epochs_list) for epochs_list in folds_epochs_list]
        display(split_epochs)
    else:
        for split in splits:
            epochs = split.load_epochs(openbmi_dataset, **load_kwargs)
            display(epochs)
    print("NEXT")

  epochs = mne.concatenate_epochs(
  epochs = mne.concatenate_epochs(


In [23]:
folds_epochs_list[1]

(<Epochs |  120 events (all good), -0.3 - 0.7 sec, baseline off, ~56.9 MB, data loaded,
  '0': 63
  '1': 57>,
 <Epochs |  120 events (all good), -0.3 - 0.7 sec, baseline off, ~56.9 MB, data loaded,
  '0': 63
  '1': 57>)

In [12]:
filepaths_df = openbmi_dataset.list_subject_filepaths()
uid = '25'
session = 1
train = True
(
    filepaths_df.query("uid == @uid")
    .query("train == @train")
    .query("session == @session")
    .path
)

108    /home/paulo/Documents/datasets/OpenBMI/edf/ses...
Name: path, dtype: object

In [13]:
filepaths_df.query("uid == @uid").query("train == @train").query("session == @session")

Unnamed: 0,path,train,session,uid
108,/home/paulo/Documents/datasets/OpenBMI/edf/ses...,True,1,25


In [195]:
split.kwargs_list

[{'uid': '25', 'session': '1', 'train': True},
 {'uid': '25', 'session': '1', 'train': False}]

In [120]:
uid

('48',)

In [76]:
from itertools import product

class SplitArg():
    def __init__(self, name, arg_list, level=0):
        self.name = name
        self.arg_list = arg_list
        self.level = level

    def __repr__(self):
        return "|<{}> {}: {}|".format(self.level, self.name, self.arg_list)

def group_iterator(*args):
    
    levels = sorted(set([arg.level for arg in args]))
    
    levels_args_iters = [
        [
            arg
            for arg in args
            if arg.level == level
        ]
        for level in levels
    ]
    print(levels_args_iters)
    # Split loop
    level_iterables = [
        product(*[arg.arg_list for arg in level_args_iter])
        for level_args_iter in levels_args_iters
    ]
    for all_level_args_lists in product(*level_iterables):
        split_arg_dict = {
#             arg.name: levels_args_iters[level].arg_list[i]
            levels_args_iters[level][i].name: arg_value
            for level, level_args_iter in enumerate(all_level_args_lists)
            for i, arg_value in enumerate(level_args_iter)
        }
        print(split_arg_dict)
        
#         merge_args_list = list()
#         for merge_args in product(*[arg.arg_list for arg in merge_args_iter]):
#             merge_arg_dict = {
#                 arg.name: merge_args[i]
#                 for i, arg
#                 in enumerate(merge_args_iter)
#             }
#             run_kwargs = {**split_arg_dict, **merge_arg_dict}
#             merge_args_list.append(run_kwargs)
        yield split_arg_dict


load_kwargs=dict(
    tmin=1,
    tmax=3.5,
    reject=False
)

from copy import deepcopy
def split(dataset, uids, sessions, train_folds, fold_sizes, load_kwargs):
    kwargs_iterator = group_iterator(
        SplitArg("uid", uids, 0),
        SplitArg("session", sessions, 0),
        SplitArg("train", train_folds, 0),
    )
    epochs_list = list()
    
    for splits_kwargs in kwargs_iterator:
        split_epochs = list()
#         for split_kwargs in splits_kwargs:
#             kwargs = deepcopy(split_kwargs)
#             uid = kwargs.pop("uid")
#             epochs, _ = dataset.load_subject(uid, **kwargs, **load_kwargs)
#             split_epochs.append(epochs)
#         print(splits_kwargs)
#         print()
#         if fold_sizes is not None:
#             split_epochs = mne.concatenate_epochs(split_epochs)
#             split_epochs = make_epochs_splits(split_epochs, n=len(split_epochs.events), sizes=fold_sizes)
        yield splits_kwargs
#         yield splits_kwargs

#         epochs_list.append(epochs)
#     epochs = mne.concatenate_epochs(epochs_list)
#     yield split
for k in split(openbmi_dataset, ["1", "2"], [1, 2], [True, False], None, load_kwargs):
#     print(k)
    pass

[[|<0> uid: ['1', '2']|, |<0> session: [1, 2]|, |<0> train: [True, False]|]]
{'uid': '1', 'session': 1, 'train': True}
{'uid': '1', 'session': 1, 'train': False}
{'uid': '1', 'session': 2, 'train': True}
{'uid': '1', 'session': 2, 'train': False}
{'uid': '2', 'session': 1, 'train': True}
{'uid': '2', 'session': 1, 'train': False}
{'uid': '2', 'session': 2, 'train': True}
{'uid': '2', 'session': 2, 'train': False}


In [None]:
from itertools import product

class SplitArg():
    def __init__(self, name, arg_list, level=0):
        self.name = name
        self.arg_list = arg_list
        self.level = level

    def __repr__(self):
        return "|<{}> {}: {}|".format(self.level, self.name, self.arg_list)

def group_iterator(*args):
    
    levels = sorted([arg.level for arg in args])
    
    levels_args_iters = [
        [
            arg
            for arg in args
            if arg.level == level
        ]
        for level in levels
    ]
    # Split loop
    print([level_args for level_args in levels_args_iters])
    return
    level_iterables = [
        product(*[arg.arg_list for arg in level_args_iter])
        for level_args_iter in levels_args_iters
    ]
    print([list(a) for a in level_iterables])
    return
    for all_level_args_lists in product(*level_iterables):
                
        split_arg_dict = {
            arg.name: level_args_iter[i]
            for level_args_iter in all_level_args_lists
            for i, arg in enumerate(level_args_iter)
        }
        yield arg_dict
        # Merge loop
        
#         merge_args_list = list()
#         for merge_args in product(*[arg.arg_list for arg in merge_args_iter]):
#             merge_arg_dict = {
#                 arg.name: merge_args[i]
#                 for i, arg
#                 in enumerate(merge_args_iter)
#             }
#             run_kwargs = {**split_arg_dict, **merge_arg_dict}
#             merge_args_list.append(run_kwargs)
        yield merge_args_list


load_kwargs=dict(
    tmin=1,
    tmax=3.5,
    reject=False
)

from copy import deepcopy
def split(dataset, uids, sessions, train_folds, fold_sizes, load_kwargs):
    kwargs_iterator = group_iterator(
        SplitArg("uid", uids, True),
        SplitArg("session", sessions, False),
        SplitArg("train", train_folds, True),
    )
    epochs_list = list()
    
    for splits_kwargs in kwargs_iterator:
        split_epochs = list()
#         for split_kwargs in splits_kwargs:
#             kwargs = deepcopy(split_kwargs)
#             uid = kwargs.pop("uid")
#             epochs, _ = dataset.load_subject(uid, **kwargs, **load_kwargs)
#             split_epochs.append(epochs)
#         print(splits_kwargs)
#         print()
#         if fold_sizes is not None:
#             split_epochs = mne.concatenate_epochs(split_epochs)
#             split_epochs = make_epochs_splits(split_epochs, n=len(split_epochs.events), sizes=fold_sizes)
        yield splits_kwargs
#         yield splits_kwargs

#         epochs_list.append(epochs)
#     epochs = mne.concatenate_epochs(epochs_list)
#     yield split
for k in split(openbmi_dataset, ["1", "2"], [1, 2], [True, False], None, load_kwargs):
    print(k)
    pass

In [76]:
len(k)

2

In [23]:
from itertools import product
    
class SplitArg():
    def __init__(self, name, arg_list, to_split):
        self.to_split = to_split
        self.arg_list = arg_list
        self.name = name

DEFAULT_FOLDS = [True, False]
DEFAULT_SESSIONS = [1, 2]
DEFAULT_UIDS = ["1"]

def split(dataset, uids, sessions, train_folds, fold_sizes, load_kwargs):
    
    args = [uids, sessions, train_folds, fold_sizes]
    split_args_iter = [
        arg
        for arg
        in [uids, sessions, train_folds]
        if arg.to_split
    ]
    merge_args_iter = [
        arg
        for arg
        in [uids, sessions, train_folds]
        if not arg.to_split
    ]
    
    data = dict()
    
    # Split loop
    for split_args in product(*[arg.arg_list for arg in split_args_iter]):
                
        # Merge loop
        split_arg_dict = {
            arg.name: split_args[i]
            for i, arg
            in enumerate(split_args_iter)
        }
        epochs_list = list()
        for merge_args in product(*[arg.arg_list for arg in merge_args_iter]):
            merge_arg_dict = {
                arg.name: merge_args[i]
                for i, arg
                in enumerate(merge_args_iter)
            }
            run_kwargs = {**split_arg_dict, **merge_arg_dict, **load_kwargs}
            print(run_kwargs)
#             epochs, _ = dataset.load_subject(**run_kwargs)
#             epochs_list.append(epochs)
        
        if fold_sizes is not None:
            epochs = mne.concatenate_epochs(epochs_list)
            idx = np.arange(len(epochs))
            np.random.

        print()

split(
    None,
    SplitArg("uid", ["A", "B", "C"], False),
    SplitArg("session", [1, 2, 3], True),
    SplitArg("train", [True, False], False),
    [.3, .2, .5],
    dict()
)

{'session': 1, 'uid': 'A', 'train': True}
{'session': 1, 'uid': 'A', 'train': False}
{'session': 1, 'uid': 'B', 'train': True}
{'session': 1, 'uid': 'B', 'train': False}
{'session': 1, 'uid': 'C', 'train': True}
{'session': 1, 'uid': 'C', 'train': False}

{'session': 2, 'uid': 'A', 'train': True}
{'session': 2, 'uid': 'A', 'train': False}
{'session': 2, 'uid': 'B', 'train': True}
{'session': 2, 'uid': 'B', 'train': False}
{'session': 2, 'uid': 'C', 'train': True}
{'session': 2, 'uid': 'C', 'train': False}

{'session': 3, 'uid': 'A', 'train': True}
{'session': 3, 'uid': 'A', 'train': False}
{'session': 3, 'uid': 'B', 'train': True}
{'session': 3, 'uid': 'B', 'train': False}
{'session': 3, 'uid': 'C', 'train': True}
{'session': 3, 'uid': 'C', 'train': False}



In [77]:
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline
from sklearn.base import BaseEstimator
from sklearn.metrics import balanced_accuracy_score

class GetData(BaseEstimator):
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X, y=None):
        filter_kwargs = dict(
            method="iir",
            iir_params=dict(
                order=5,
                ftype="butter"
            )
        )
        return X.load_data().filter(8, 30, **filter_kwargs).get_data()


# kwargs = dict(
#     uids=["1", "2", "3", "4"],
#     inter_session=False,
#     inter_subject=True,
#     sessions=[2],
#     train_folds=[True, False],
#     fold_sizes=[.5, .5],
#     load_kwargs=dict(
#         reject=False
#     )
# )
kwargs = dict(
    uids=["1", "2", "3", "4"],
    inter_session=False,
    inter_subject=False,
    sessions=[1],
    train_folds=[True, False],
    fold_sizes=[.25, .75],
    load_kwargs=dict(
        tmin=1,
        tmax=3.5,
        reject=False
    )
)
for j, splits in enumerate(OpenBMI_Splitter(openbmi_dataset).make_splits(**kwargs)):
    infos = [info for info, split in splits]
    splits = [split for info, split in splits]
    print(infos)
    clf = make_pipeline(
        GetData(),
        CSP(5, log=True),
        LDA()
    )
    train_epochs = splits[0]
    train_labels = train_epochs.events[:, 2]
    
    test_epochs = splits[1]
    test_labels = test_epochs.events[:, 2]
    
    clf.fit(train_epochs, train_labels)
    pred = clf.predict(test_epochs)
    acc = balanced_accuracy_score(test_labels, pred)
    print(acc)

  epochs = mne.concatenate_epochs(


[{'uid': '1', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '1', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.5903271692745377


  epochs = mne.concatenate_epochs(


[{'uid': '2', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '2', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.8466749866286325


  epochs = mne.concatenate_epochs(


[{'uid': '3', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '3', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.8866999465145302


  epochs = mne.concatenate_epochs(


[{'uid': '4', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '4', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.47759601706970123


In [71]:
train_labels

array([0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,
       1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1,
       0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
       1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0])

In [46]:
[f"{i} {j}" for i in range(3) for j in range(3)]

['0 0', '0 1', '0 2', '1 0', '1 1', '1 2', '2 0', '2 1', '2 2']

In [20]:
def split_array(arr, n_splits=2, sizes=None):
    n = len(arr)
    sizes = sizes or [1 / n] * n_splits
    assert np.sum(sizes) == 1.
    sizes = [int(size * n) for size in sizes]
    sizes = [0] + sizes
    sizes = np.cumsum(sizes)
    slices = [slice(start, end) for start, end in zip(sizes[:-1], sizes[1:])]
    arrs = [arr[s] for s in slices]
    return arrs

arrs = split_array(np.random.rand(100), sizes=[.2, .3, .1, .4])
list(map(len, arrs))

[20, 30, 10, 40]

In [44]:
split_array(openbmi_dataset.list_uids(), sizes=[.2, .3, .5])

[array(['25', '15', '41', '12', '37', '2', '42', '6', '52', '18'],
       dtype=object),
 array(['38', '30', '34', '5', '11', '44', '29', '8', '17', '33', '46',
        '23', '19', '4', '13', '22'], dtype=object),
 array(['35', '45', '36', '39', '16', '50', '10', '53', '7', '28', '27',
        '14', '32', '20', '31', '51', '24', '9', '21', '40', '47', '26',
        '49', '54', '43', '1', '3'], dtype=object)]

In [41]:
e = bci_dataset.load_subject("1")[0]
split_array(e, n=len(e.events), sizes=[.2, .8])

  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  next(self.gen)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  next(self.gen)


[<Epochs |  57 events (good & bad), -0.3 - 0.7 sec, baseline off, ~26 kB, data not loaded>,
 <Epochs |  230 events (good & bad), -0.3 - 0.7 sec, baseline off, ~26 kB, data not loaded>]