In [12]:
from ica_benchmark.io.load import Physionet_2009_Dataset, BCI_IV_Comp_Dataset, OpenBMI_Dataset
import mne
from pathlib import Path
from collections import namedtuple
import numpy as np
from itertools import product
from sklearn.model_selection import KFold, LeaveOneOut
import pandas as pd

mne.set_log_level(False)

physionet_dataset_folderpath = Path('/home/paulo/Documents/datasets/Physionet')
bci_dataset_folderpath = Path('/home/paulo/Documents/datasets/BCI_Comp_IV_2a/gdf/')
bci_test_dataset_folderpath = Path('/home/paulo/Documents/datasets/BCI_Comp_IV_2a/true_labels/')
openbmi_dataset_folderpath = Path('/home/paulo/Documents/datasets/OpenBMI/edf/')

physionet_dataset = Physionet_2009_Dataset(physionet_dataset_folderpath)
bci_dataset = BCI_IV_Comp_Dataset(bci_dataset_folderpath, test_folder=bci_test_dataset_folderpath)
openbmi_dataset = OpenBMI_Dataset(openbmi_dataset_folderpath)

In [517]:
from collections.abc import Iterable

def flatten_deepest(lst):
    res = []
    for item in lst:
        if isinstance(item, Iterable) and not isinstance(item, str):
            if any(isinstance(i, Iterable) and not isinstance(i, str) for i in item):
                res.append(flatten_deepest(item))
            else:
                res.append(flatten_list(item))
        else:
            res.append(item)
    return res

list_of_lists = [
    [
        [1, 2],
        [[4], 5]
    ],
    [5]
]
    
print(list_of_lists)
print(flatten_deepest(list_of_lists))

[[[1, 2], [[4], 5]], [5]]
[[[1, 2], [[4], 5]], [5]]


In [514]:
isinstance("12", Iterable)

True

In [476]:
from mne import Epochs

def make_epochs_splits_indexes(arr, n=None, n_splits=2, sizes=None, shuffle=False, seed=1):
    if isinstance(arr, (tuple, list)):
        arr = np.array(arr)
    np.random.seed(seed)
    if n is None:
        if isinstance(arr, Epochs):
            n = len(arr.events)
        else:
            n = len(arr)
    sizes = sizes or [1 / n_splits] * n_splits

    assert np.sum(sizes) == 1.
    
    sizes = np.cumsum(
        [0] + [int(size * n) for size in sizes]
    )
    
    idx = np.arange(n)
    if shuffle:
        np.random.shuffle(idx)
    slices = [slice(start, end) for start, end in zip(sizes[:-1], sizes[1:])]
    indexes = [idx[s] for s in slices]
    return indexes

def make_epochs_splits(arr, n=None, n_splits=2, sizes=None, shuffle=False, seed=1):
    indexes = make_epochs_splits_indexes(arr, n=n, n_splits=n_splits, sizes=sizes, shuffle=shuffle, seed=seed)
    arrs = [arr[idx] for idx in indexes]
    return arrs

make_epochs_splits(np.linspace(10, 100, 20), sizes=[1.], shuffle=False)

[array([ 10.        ,  14.73684211,  19.47368421,  24.21052632,
         28.94736842,  33.68421053,  38.42105263,  43.15789474,
         47.89473684,  52.63157895,  57.36842105,  62.10526316,
         66.84210526,  71.57894737,  76.31578947,  81.05263158,
         85.78947368,  90.52631579,  95.26315789, 100.        ])]

In [477]:
np.linspace(10, 100, 20)

array([ 10.        ,  14.73684211,  19.47368421,  24.21052632,
        28.94736842,  33.68421053,  38.42105263,  43.15789474,
        47.89473684,  52.63157895,  57.36842105,  62.10526316,
        66.84210526,  71.57894737,  76.31578947,  81.05263158,
        85.78947368,  90.52631579,  95.26315789, 100.        ])

In [4]:
from copy import deepcopy

class Splitter():
    
    INTER_SESSION = True
    INTER_SUBJECT = True
    TRAIN_TEST = True
    
    UNIQUE_SESSION = False
    
    SESSION_KWARGS = dict(intra=dict(), inter=dict())
    
    def __init__(self, dataset, uids, sessions, train_folds, load_kwargs=None, fold_sizes=None):
        self.dataset = dataset
        self.uids = uids
        self.sessions = sessions
        self.train_folds = train_folds
        self.load_kwargs = load_kwargs or load_kwargs
        self.fold_sizes = fold_sizes

    def inter_subject_splitting(self, shuffle=False, sizes=[.5, .5], seed=1):
        np.random.seed(seed)
        uids = deepcopy(self.uids)
        if shuffle:
            np.random.shuffle(uids)
        splits_uids = make_epochs_splits(uids, sizes=sizes)
        # [(info, epochs), (info, epochs), ...]
        splits = [
            (
                dict(
                    sessions=self.sessions,
                    train_folds=self.train_folds,
                    uid=split_uids
                ),
                mne.concatenate_epochs(
                    [
                        self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
                        for session in self.sessions
                        for train in self.train_folds
                        for uid in split_uids
                    ]
                )
            )
            for split_uids in splits_uids
        ]
        return splits

    def inter_session_splitting(self):
        assert len(self.sessions) > 1, "You are using the inter session protocol, but only passed 1 session"
        splits = [
            (
                dict(uids=self.uids, sessions=[session], train_folds=self.train_folds),
                mne.concatenate_epochs(
                    [
                        self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
                        for train
                        in self.train_folds
                    ]
                )
            )
            for uid in self.uids
            for session in self.sessions
        ]
        return splits

    def intra_session_default_splitting(self):
        
        assert len(self.sessions) == 1, "Your are using an intra session splitting but had more than 1 session"
        splits = [
            (
                dict(uid=self.uids, sessions=[session], train_folds=[train]),
                self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
            )
            for session in self.sessions
            for train in self.train_folds
        ]
        return splits

    def intra_session_splitting(self, uid, session):
        epochs = mne.concatenate_epochs(
            [
                self.dataset.load_subject(uid, session=session, train=train, **self.load_kwargs)[0]
                for train in self.train_folds
            ]
        )
        splits = make_epochs_splits(
            epochs,
            n=len(epochs.events),
            sizes=self.fold_sizes
        )
        splits = [
            (
                dict(uid=self.uids, session=session, train_folds=self.train_folds, size=size),
                split
            )
            for split, size
            in zip(splits, self.fold_sizes)
        ]
        return splits


    def make_splits(self, inter_session=True, inter_subject=False):

        assert all(np.isin(self.uids, self.dataset.list_uids()))
        session_key = "inter" if inter_session else "intra"
        subject_key = "inter" if inter_subject else "intra"
        
        splits = list()

        # Inter session for the inter subject protocol does not make sense.
        # So, if inter_subject is true, inter_session does not matter
        if inter_subject:
            splits = self.inter_subject_splitting()
            yield splits
            
        # Intra subject
        else:
            
            for uid in self.uids:
                if inter_session:
                    splits = self.inter_session_splitting()
                    yield splits

                else:
                    
                    if self.fold_sizes is None:
                        splits = self.intra_session_default_splitting()
                    else:
                        assert len(self.sessions) == 1, "You are using the intra session protocol but passed more than 1 session"
                        session = self.sessions[0]
                        splits = self.intra_session_splitting(uid, session)
                        
                    yield splits
import numpy as np
kwargs = dict(
    inter_session=False,
    inter_subject=False
)
splitter = Splitter(
    openbmi_dataset,
    uids=["1", "2"],
    sessions=[1],
    train_folds=[True, False],
    load_kwargs=dict(
        reject=False
    )
#     fold_sizes=[.4, .6]
)
for j, splits in enumerate(splitter.make_splits(**kwargs)):
    print(f"FOLD {j}")
    for i, (info, split) in enumerate(splits):
        print(f"Split {i}")
        print("\tInfo", info)
        print("\tSplit", split)
        print()


NameError: name 'uid' is not defined

In [4]:
from itertools import product

class SplitArg():
    def __init__(self, name, arg_list, to_split):
        self.name = name
        self.arg_list = arg_list
        self.to_split = to_split
        

def group_iterator(*args):
    split_args_iter = [
        arg
        for arg
        in args
        if arg.to_split
    ]
    merge_args_iter = [
        arg
        for arg
        in args
        if not arg.to_split
    ]
    # Split loop
    for split_args in product(*[arg.arg_list for arg in split_args_iter]):
                
        split_arg_dict = {
            arg.name: split_args[i]
            for i, arg
            in enumerate(split_args_iter)
        }
        # Merge loop
        merge_args_list = list()
        for merge_args in product(*[arg.arg_list for arg in merge_args_iter]):
            merge_arg_dict = {
                arg.name: merge_args[i]
                for i, arg
                in enumerate(merge_args_iter)
            }
            run_kwargs = {**split_arg_dict, **merge_arg_dict}
            merge_args_list.append(run_kwargs)
        yield merge_args_list


load_kwargs=dict(
    tmin=1,
    tmax=3.5,
    reject=False
)

from copy import deepcopy
def split(dataset, uids, sessions, train_folds, fold_sizes, load_kwargs):
    kwargs_iterator = group_iterator(
        SplitArg("uid", uids, True),
        SplitArg("session", sessions, False),
        SplitArg("train", train_folds, True),
    )
    epochs_list = list()
    
    for splits_kwargs in kwargs_iterator:
        split_epochs = list()
#         for split_kwargs in splits_kwargs:
#             kwargs = deepcopy(split_kwargs)
#             uid = kwargs.pop("uid")
#             epochs, _ = dataset.load_subject(uid, **kwargs, **load_kwargs)
#             split_epochs.append(epochs)
        print(splits_kwargs)
        print()
#         if fold_sizes is not None:
#             split_epochs = mne.concatenate_epochs(split_epochs)
#             split_epochs = make_epochs_splits(split_epochs, n=len(split_epochs.events), sizes=fold_sizes)
#         yield split_epochs
        yield splits_kwargs

#         epochs_list.append(epochs)
#     epochs = mne.concatenate_epochs(epochs_list)
#     yield split
for k in split(openbmi_dataset, ["1", "2"], [1, 2], [True, False], None, load_kwargs):
#     print(k)
    pass

[{'uid': '1', 'train': True, 'session': 1}, {'uid': '1', 'train': True, 'session': 2}]

[{'uid': '1', 'train': False, 'session': 1}, {'uid': '1', 'train': False, 'session': 2}]

[{'uid': '2', 'train': True, 'session': 1}, {'uid': '2', 'train': True, 'session': 2}]

[{'uid': '2', 'train': False, 'session': 1}, {'uid': '2', 'train': False, 'session': 2}]



In [5]:
iters = [
    [1, 2, 3],
    ["a", "b", "c", "d"]
]
(1, "a"), (1, "b"), (1, "c"), (1, "d")

iters = [
    [1, 2, 3],
    [
        ["a", "b"], ["A", "B"]
    ],
    [True, False]
]
(1, "a"), (1, "b"), (1, "c"), (1, "d")


iters = [
    [1, 2, 3],
    [
        ["a", "b"], ["A", "B"]
    ]
]
(1, "a", "A"), (1, "a", "B"), (1, "b", "A"), (1, "b", "B")

def unpack_product(iters):
    iters = product(*iters)
    for iter_list in iters:
        print(iter_list)
        is_list = False
        for i, iter_value in enumerate(iter_list):
            is_list = is_list or isinstance(iter_value, list)
        for i, iter_value in enumerate(iter_list):
            iter_list[i] = [iter_list[i]] if is_list else iter_list[i]
    return iter_list
unpack_product(iters)

(1, ['a', 'b'])


TypeError: 'tuple' object does not support item assignment

In [158]:
from sklearn.model_selection import KFold, LeaveOneOut

cv = LeaveOneOut()
for a, b in cv.split(np.arange(100)):
    print(a)
    print(b)
    print()

[ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
 97 98 99]
[0]

[ 0  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
 97 98 99]
[1]

[ 0  1  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
 97 98 99]
[2]

[ 0  1  2  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25

In [266]:
import pandas as pd
from sklearn.model_selection import KFold, LeaveOneOut
from itertools import tee
from warnings import warn

def remove_key(d, k):
    return {
        key: value
        for key, value in d.items()
        if key != k
    }

class Split():
    
    def __init__(self, kwarg_dict_list, fold_sizes=None):
        self.kwargs_list = kwarg_dict_list
        self.fold_sizes = fold_sizes
    
    def to_dataframe():
        return pd.DataFrame.from_records(self.kwargs_list)

    def __repr__(self):
        dict_reps = [str(d) for d in self.kwargs_list]
        return "Split({})".format(",".join(dict_reps))
    
    def load_epochs(self, dataset, **load_kwargs):
        epochs = mne.concatenate_epochs(
            [
                dataset.load_subject(kwargs["uid"], **remove_key(kwargs, "uid"), **load_kwargs)[0]
                for kwargs in self.kwargs_list
            ]
        )
        return epochs


# df = pd.DataFrame(product(openbmi_dataset.list_uids(), ["1", "2"], [True, False]), columns=["uid", "session", "train"])

# k = ["uid", "session"]

uids = openbmi_dataset.list_uids()[:3]
sessions = [1, 2]
runs = [1, 2]


# for uid, session in product(uids, sessions):
#     a = df.query("uid == @uid").query("session == @session")
#     display(a)
                
class Splitter():
    
    INTER_SESSION = True
    INTER_SUBJECT = True
    TRAIN_TEST = True
    
    UNIQUE_SESSION = False
    
    SESSION_KWARGS = dict(intra=dict(), inter=dict())
    
    def default_cv_splitter(self):
        return KFold(4)
    
    def __init__(self, dataset, uids, sessions, runs, load_kwargs=None, cv_splitter=None, unsafe=False, intra_session_shuffle=False):
        self.dataset = dataset
        self.uids = uids
        self.sessions = sessions
        self.runs = runs
        self.load_kwargs = load_kwargs or load_kwargs
        self.cv_splitter = cv_splitter or self.default_cv_splitter()
        self.intra_session_shuffle = intra_session_shuffle
        
    def validate_config(self, mode, fold_sizes=None):
        valid_modes = [
            "inter_subject",
            "inter_session",
            "intra_session_intra_run",
            "intra_session_intra_run_merge",
            "intra_session_inter_run"
        ]
        
        assert mode in valid_modes, "Please choose one mode among the following: {}".format(", ".join(valid_modes))
        if mode == "inter_subject":
            if fold_sizes is not None:
                warn("You are using the inter_subject mode, so the fold_sizes argument will not be used")
#         if mode == "intra_session":
#             if len(self.runs) > 
        elif mode == "inter_session":
            if len(self.runs) > 1:
                warn("You are using inter session protocol with more than one run. All runs from each session will be concatenated and yielded in different steps.")

        elif (mode in ("intra_session_intra_run", "intra_session_intra_run_merge")):
            if fold_sizes is None:
                warn("You are using intra session protocol with no fold sizes. The splitter will only yield one epoch at time")

        elif (mode == "intra_session_inter_run"):
            if (len(self.runs) == 1):
                warn("You are using an intra session protocol, splitting by run, but only passed one run. The splitter can only yield one epoch at time (from the only run you passed as argument)")

    def inter_subject(self):
        for uid_splits_idxs in self.cv_splitter.split(self.uids):
            splits_uids = [uids[idx] for idx in uid_splits_idxs]
            yield [
                Split(
                    [
                        dict(
                            uid=uid,
                            session=session,
                            run=run
                        )
                        for uid in split_uids
                        for session in self.sessions
                        for run in self.runs
                    ]
                )
                for split_uids in splits_uids
            ]

    def inter_session(self):
        for uid in self.uids:
            yield [
                Split(
                    [
                        dict(
                            uid=uid,
                            session=session,
                            run=run
                        )
                        for run in self.runs
                    ]
                )
                for session in self.sessions
            ]

    def intra_session_inter_run(self):
        for uid in self.uids:
            for session in self.sessions:
                yield [
                    Split(
                        [
                            dict(
                                uid=uid,
                                session=session,
                                run=run
                            )
                        ]
                    )
                    for run in self.runs
                ]
    
    def intra_session_intra_run(self):
        for uid in self.uids:
            for session in self.sessions:
                for run in self.runs:
                    yield [
                        Split(
                            [
                                dict(
                                    uid=uid,
                                    session=session,
                                    run=run
                                )
                            ]
                        )
                    ]

    def intra_session_intra_run_merge(self):
        for uid in self.uids:
            for session in self.sessions:
                yield [
                    Split(
                        [
                            dict(
                                uid=uid,
                                session=session,
                                run=run
                            )
                            for run in self.runs
                        ]
                    )
                ]
                    
    def yield_splits_epochs(self, mode, fold_sizes=None):
        self.validate_config(mode, fold_sizes)

        split_fn_dict = dict(
            # Intra subject, inter session
            inter_session=self.inter_session,
            # Inter subject, will concatenate all sessions and runs
            inter_subject=self.inter_subject,
            # Intra subject, intra_session, inter run (will split runs)
            intra_session_inter_run=self.intra_session_inter_run,
            # Intra subject, intra_session, intra run (will split using fold sizes)
            intra_session_intra_run=self.intra_session_intra_run,
            # Intra subject, intra_session, intra run (will merge all runs and split using fold sizes)
            intra_session_intra_run_merge=self.intra_session_intra_run_merge,
        )

        split_fn = split_fn_dict[mode]
        for splits in split_fn():
            splits_epochs = [
                split.load_epochs(openbmi_dataset, **load_kwargs)
                for split in splits
            ]
            assert len(splits) == len(splits_epochs)
            if mode in ("intra_session_intra_run", "intra_session_intra_run_merge") and (fold_sizes is not None):
                assert len(splits_epochs) == 1
                epochs = splits_epochs[0]
                splits_epochs = make_epochs_splits(epochs, sizes=fold_sizes, shuffle=self.intra_session_shuffle)
            yield splits, splits_epochs
                    
splitter = Splitter(
    openbmi_dataset,
    uids=openbmi_dataset.list_uids()[:1],
    sessions=openbmi_dataset.SESSIONS,
#     runs=openbmi_dataset.RUNS,
    runs=[1, 2],
    load_kwargs=dict(
        reject=False
    ),
#     fold_sizes=None,
    cv_splitter=None,
    intra_session_shuffle=False
)
for splits, epochs in splitter.yield_splits_epochs(mode="intra_session_intra_run_merge", fold_sizes=[.3, .7]):
    print(splits)
    print(epochs)

  epochs = mne.concatenate_epochs(


[Split({'uid': '25', 'session': 1, 'run': 1},{'uid': '25', 'session': 1, 'run': 2})]
[<Epochs |  60 events (all good), 1 - 3.5 sec, baseline off, ~71.0 MB, data loaded,
 '0': 24
 '1': 36>, <Epochs |  140 events (all good), 1 - 3.5 sec, baseline off, ~165.7 MB, data loaded,
 '0': 76
 '1': 64>]


  epochs = mne.concatenate_epochs(


[Split({'uid': '25', 'session': 2, 'run': 1},{'uid': '25', 'session': 2, 'run': 2})]
[<Epochs |  60 events (all good), 1 - 3.5 sec, baseline off, ~71.0 MB, data loaded,
 '0': 24
 '1': 36>, <Epochs |  140 events (all good), 1 - 3.5 sec, baseline off, ~165.7 MB, data loaded,
 '0': 76
 '1': 64>]


In [143]:
def my_product(*inp):
    return (dict(zip(inp.keys(), values)) for values in product(*inp.values()))


list(product_dict(l1=l1, l2=l2))

[{'l1': 1, 'l2': 'a'},
 {'l1': 1, 'l2': 'b'},
 {'l1': 1, 'l2': 'c'},
 {'l1': 2, 'l2': 'a'},
 {'l1': 2, 'l2': 'b'},
 {'l1': 2, 'l2': 'c'},
 {'l1': 3, 'l2': 'a'},
 {'l1': 3, 'l2': 'b'},
 {'l1': 3, 'l2': 'c'}]

In [357]:
l1 = [1, 2]
l2 = [101, 102]
l3 = ["a", "b"]
l4 = ["A", "B"]

def product_dict(**kwargs):
    keys = kwargs.keys()
    vals = kwargs.values()
    for instance in product(*vals):
        yield dict(zip(keys, instance))

def _split_group_iterator(outer_split_kwargs=None, inner_split_kwargs=None, merge_kwargs=None):
    outer_split_kwargs = outer_split_kwargs or dict()
    inner_split_kwargs = inner_split_kwargs or dict()
    merge_kwargs = merge_kwargs or dict()

    for outside_kwargs in product_dict(**outer_split_kwargs):
        splits = [
            [
                dict(
                    **outside_kwargs,
                    **inside_kwargs,
                    **merge_kwargs
                )
                for merge_kwargs in product_dict(**merge_kwargs)
            ]
            for inside_kwargs in product_dict(**inner_split_kwargs)
        ]
        yield splits

def split_group_iterator(outer_split_kwargs=None, inner_split_kwargs=None, merge_kwargs=None):

    for iteration_kwargs_list in _split_group_iterator(outer_split_kwargs, inner_split_kwargs, merge_kwargs):
        splits = [
            Split(
                [
                    dict(
                        **kwargs
                    )
                    for kwargs in kwargs_list
                ]
            )
            for kwargs_list in iteration_kwargs_list
        ]
        yield splits
        
# def split_group_iterator(outer_kwargs_list_dict=None, inner_kwargs_list_dict=None):
#     outer_kwargs_list_dict = outer_kwargs_list_dict or dict()
#     inner_kwargs_list_dict = inner_kwargs_list_dict or dict()    

#     for kwargs in group_iterator(outer_kwargs_list_dict, inner_kwargs_list_dict):
# #         yield = [
# #             Split(kwargs)
# #         ]
#         yield kwargs
        
# a = group_iterator(dict(l1=l1, l2=l2), dict(l3=l3))
a = split_group_iterator(dict(l=l1), dict(l3=l3), dict())
a = list(a)
for x in a:
    print(x)

[Split({'l1': 1, 'l3': 'a'}), Split({'l1': 1, 'l3': 'b'})]
[Split({'l1': 2, 'l3': 'a'}), Split({'l1': 2, 'l3': 'b'})]


In [327]:
# function Recurse (y, number) 
#    if (number > 1)
#       Recurse ( y, number - 1 )
#    else
#       for x in range (y)
#           whatever()

l1 = [1, 2]
l2 = [101, 102]
l3 = ["a", "b"]
l4 = ["A", "B"]

def group_iterator(split_kwargs_dicts, d=None, level=None):
    
    level = level or 0
#     level = level or len(split_kwargs_dicts) - 1
    d = d or dict()


    for inner_d in product_dict(**split_kwargs_dicts[level]): 

#         if level == 0:
        if level == (len(split_kwargs_dicts) - 1):

            yield {**d, **inner_d}
        else:

#             yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level-1)
            yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level + 1)



In [328]:
# Two level iterator, first level is cross product of l1 and l2, inside level is the possibilities of l3
a = group_iterator([dict(session=[1, 2]), dict(uid=[1, 2, 3, 4])])
a = list(a)
b = [list(x) for x in a]
b

[[{'session': 1, 'uid': 1},
  {'session': 1, 'uid': 2},
  {'session': 1, 'uid': 3},
  {'session': 1, 'uid': 4}],
 [{'session': 2, 'uid': 1},
  {'session': 2, 'uid': 2},
  {'session': 2, 'uid': 3},
  {'session': 2, 'uid': 4}]]

In [8]:
[[{'l1': 1, 'l2': 1},
  {'l1': 1, 'l2': 2},
  {'l1': 1, 'l2': 3},
  {'l1': 1, 'l2': 4}],
 [{'l1': 2, 'l2': 1},
  {'l1': 2, 'l2': 2},
  {'l1': 2, 'l2': 3},
  {'l1': 2, 'l2': 4}]]

[[{'l1': 1, 'l2': 1},
  {'l1': 1, 'l2': 2},
  {'l1': 1, 'l2': 3},
  {'l1': 1, 'l2': 4}],
 [{'l1': 2, 'l2': 1},
  {'l1': 2, 'l2': 2},
  {'l1': 2, 'l2': 3},
  {'l1': 2, 'l2': 4}]]

In [337]:
# Two level iterator, first level varies possibilities for l1, inner level is the cross product of l2 and l3
a = group_iterator([dict(l1=l1), dict(l2=l2, l3=l3)])
a = list(a)
b = [list(x) for x in a]
b

[[{'l1': 1, 'l2': 101, 'l3': 'a'},
  {'l1': 1, 'l2': 101, 'l3': 'b'},
  {'l1': 1, 'l2': 102, 'l3': 'a'},
  {'l1': 1, 'l2': 102, 'l3': 'b'}],
 [{'l1': 2, 'l2': 101, 'l3': 'a'},
  {'l1': 2, 'l2': 101, 'l3': 'b'},
  {'l1': 2, 'l2': 102, 'l3': 'a'},
  {'l1': 2, 'l2': 102, 'l3': 'b'}]]

In [349]:
# 3 level iterator
a = group_iterator(
    [
        dict(l1=l1),
        dict(l2=l2),
        dict(l3=l3)
    ]
)
for level_1_iter in a:
    for level_2_iter in level_1_iter:
        for level_3_iter in level_2_iter:
            print(level_3_iter)
# b = [list(x) for x in a]
# c = [list(x) for x in b]
# d = [list(x) for x in c]


{'l1': 1, 'l2': 101, 'l3': 'a'}
{'l1': 1, 'l2': 101, 'l3': 'b'}
{'l1': 1, 'l2': 102, 'l3': 'a'}
{'l1': 1, 'l2': 102, 'l3': 'b'}
{'l1': 2, 'l2': 101, 'l3': 'a'}
{'l1': 2, 'l2': 101, 'l3': 'b'}
{'l1': 2, 'l2': 102, 'l3': 'a'}
{'l1': 2, 'l2': 102, 'l3': 'b'}


In [351]:
# 3 level iterator
a = group_iterator(
    [
        dict(l1=l1),
        dict(l21=l2, l22=l2),
        dict(l3=l3)
    ]
)
for level_1_iter in a:
    for level_2_iter in level_1_iter:
        for level_3_iter in level_2_iter:
            print(level_3_iter)
# b = [list(x) for x in a]
# c = [list(x) for x in b]
# d = [list(x) for x in c]


{'l1': 1, 'l21': 101, 'l22': 101, 'l3': 'a'}
{'l1': 1, 'l21': 101, 'l22': 101, 'l3': 'b'}
{'l1': 1, 'l21': 101, 'l22': 102, 'l3': 'a'}
{'l1': 1, 'l21': 101, 'l22': 102, 'l3': 'b'}
{'l1': 1, 'l21': 102, 'l22': 101, 'l3': 'a'}
{'l1': 1, 'l21': 102, 'l22': 101, 'l3': 'b'}
{'l1': 1, 'l21': 102, 'l22': 102, 'l3': 'a'}
{'l1': 1, 'l21': 102, 'l22': 102, 'l3': 'b'}
{'l1': 2, 'l21': 101, 'l22': 101, 'l3': 'a'}
{'l1': 2, 'l21': 101, 'l22': 101, 'l3': 'b'}
{'l1': 2, 'l21': 101, 'l22': 102, 'l3': 'a'}
{'l1': 2, 'l21': 101, 'l22': 102, 'l3': 'b'}
{'l1': 2, 'l21': 102, 'l22': 101, 'l3': 'a'}
{'l1': 2, 'l21': 102, 'l22': 101, 'l3': 'b'}
{'l1': 2, 'l21': 102, 'l22': 102, 'l3': 'a'}
{'l1': 2, 'l21': 102, 'l22': 102, 'l3': 'b'}


In [350]:
# Indefinitely deep iterator
a = group_iterator(
    [
        dict(l1=[1, 2]),
        dict(l2=[1, 2]),
        dict(l3=[1, 2]),
        dict(l4=[1, 2]),
        dict(l5=[1, 2]),
        dict(l6=[1, 2]),        
    ]
)


for level_1_iter in a:
    for level_2_iter in level_1_iter:
        for level_3_iter in level_2_iter:
            for level_4_iter in level_3_iter:
                for level_5_iter in level_4_iter:
                    for level_6_iter in level_5_iter:
                        print(level_6_iter)


{'l1': 1, 'l2': 1, 'l3': 1, 'l4': 1, 'l5': 1, 'l6': 1}
{'l1': 1, 'l2': 1, 'l3': 1, 'l4': 1, 'l5': 1, 'l6': 2}
{'l1': 1, 'l2': 1, 'l3': 1, 'l4': 1, 'l5': 2, 'l6': 1}
{'l1': 1, 'l2': 1, 'l3': 1, 'l4': 1, 'l5': 2, 'l6': 2}
{'l1': 1, 'l2': 1, 'l3': 1, 'l4': 2, 'l5': 1, 'l6': 1}
{'l1': 1, 'l2': 1, 'l3': 1, 'l4': 2, 'l5': 1, 'l6': 2}
{'l1': 1, 'l2': 1, 'l3': 1, 'l4': 2, 'l5': 2, 'l6': 1}
{'l1': 1, 'l2': 1, 'l3': 1, 'l4': 2, 'l5': 2, 'l6': 2}
{'l1': 1, 'l2': 1, 'l3': 2, 'l4': 1, 'l5': 1, 'l6': 1}
{'l1': 1, 'l2': 1, 'l3': 2, 'l4': 1, 'l5': 1, 'l6': 2}
{'l1': 1, 'l2': 1, 'l3': 2, 'l4': 1, 'l5': 2, 'l6': 1}
{'l1': 1, 'l2': 1, 'l3': 2, 'l4': 1, 'l5': 2, 'l6': 2}
{'l1': 1, 'l2': 1, 'l3': 2, 'l4': 2, 'l5': 1, 'l6': 1}
{'l1': 1, 'l2': 1, 'l3': 2, 'l4': 2, 'l5': 1, 'l6': 2}
{'l1': 1, 'l2': 1, 'l3': 2, 'l4': 2, 'l5': 2, 'l6': 1}
{'l1': 1, 'l2': 1, 'l3': 2, 'l4': 2, 'l5': 2, 'l6': 2}
{'l1': 1, 'l2': 2, 'l3': 1, 'l4': 1, 'l5': 1, 'l6': 1}
{'l1': 1, 'l2': 2, 'l3': 1, 'l4': 1, 'l5': 1, 'l6': 2}
{'l1': 1, 

In [541]:
iterable = group_iterator(
    [
        dict(uid=["1", "2"]),
        dict(session=[1, 2]),
        dict(run=[1, 2])
    ]
)
# a = list(iterable)
# c = [[list(x) for x in b] for b in a]
print("Iter, Split, Merge_idx")
for iter_n, iteration in enumerate(iterable):
    for splits_n, splits in enumerate(iteration):
        for fold_n, fold in enumerate(splits):
            print(iter_n, splits_n, fold_n, fold)
    

Iter, Split, Merge_idx
0 0 0 {'uid': '1', 'session': 1, 'run': 1}
0 0 1 {'uid': '1', 'session': 1, 'run': 2}
0 1 0 {'uid': '1', 'session': 2, 'run': 1}
0 1 1 {'uid': '1', 'session': 2, 'run': 2}
1 0 0 {'uid': '2', 'session': 1, 'run': 1}
1 0 1 {'uid': '2', 'session': 1, 'run': 2}
1 1 0 {'uid': '2', 'session': 2, 'run': 1}
1 1 1 {'uid': '2', 'session': 2, 'run': 2}


In [None]:
Iter, Split, Merge_idx
0 0 0 {'uid': '1', 'session': 1, 'run': 1}
0 1 0 {'uid': '2', 'session': 1, 'run': 1}
1 0 0 {'uid': '3', 'session': 1, 'run': 1}
1 1 0 {'uid': '4', 'session': 1, 'run': 1}


In [None]:
iterable = group_iterator(
    [
        dict(uid=["1", "2"]),
        dict(session=[1, 2]),
        dict(run=[1, 2])
    ]
)
# a = list(iterable)
# c = [[list(x) for x in b] for b in a]
for iter_n, iteration in enumerate(iterable):
    for splits_n, splits in enumerate(iteration):
        for fold_n, fold in enumerate(splits):
            print(iter_n, splits_n, fold_n, fold)
    

In [542]:
list(product_dict(uid=[1, 2], session=[1, 2]))

[{'uid': 1, 'session': 1},
 {'uid': 1, 'session': 2},
 {'uid': 2, 'session': 1},
 {'uid': 2, 'session': 2}]

In [350]:
from types import GeneratorType
from itertools import chain

# def cv_product_dict(**kwargs):
#     keys = kwargs.keys()
#     vals = kwargs.values()
#     for instance in product(*vals):
#         yield dict(zip(keys, instance))

# def group_iterator(split_kwargs_dicts, d=None, level=None):
    
#     level = level or 0
# #     level = level or len(split_kwargs_dicts) - 1
#     d = d or dict()


#     for inner_d in cv_product_dict(**split_kwargs_dicts[level]): 

# #         if level == 0:
#         if level == (len(split_kwargs_dicts) - 1):

#             yield {**d, **inner_d}
#         else:

# #             yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level-1)
#             yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level + 1)


ITERABLES_TYPES = (list, tuple, GeneratorType, product)

def unpack_deep_iterable(deep_iterable):
    # Keep levels as a nested list
    if isinstance(deep_iterable, (GeneratorType, tuple, list)):
        # If deep_iterable is iterable, just make sure that if it is a generator that it is iterated
        deep_iterable = list(deep_iterable)
        return [
            unpack_deep_iterable(shallow_iterable)
            for shallow_iterable in deep_iterable
        ]
    else:
        return deep_iterable

def flatten_deep_iterable(deep_iterable):
    # Returns a flat iterator of all items that are not in ITERABLES_TYPES inside deep_iterable
    for item in deep_iterable:
        if isinstance(item, ITERABLES_TYPES):
            for nested_item in flatten_deep_iterable(item):
                yield nested_item
        else:
            yield item

# def flatten_deep_iterable(container):
#     return list(_flatten_deep_iterable(container))
            

    
iterable = group_iterator(
    [
        dict(uid=["1", "2"]),
        dict(session=[1, 2]),
        dict(run=[1, 2]),
        
    ]
)
list(flatten_deep_iterable(iterable))
# count_deep_iterable_levels(iterable, return_levels=True)
# count_deep_iterable_levels([[1], [[2,], [[3, ], [4, 5, 6]]]], return_levels=True)


[{'uid': '1', 'session': 1, 'run': 1},
 {'uid': '1', 'session': 1, 'run': 2},
 {'uid': '1', 'session': 2, 'run': 1},
 {'uid': '1', 'session': 2, 'run': 2},
 {'uid': '2', 'session': 1, 'run': 1},
 {'uid': '2', 'session': 1, 'run': 2},
 {'uid': '2', 'session': 2, 'run': 1},
 {'uid': '2', 'session': 2, 'run': 2}]

In [56]:
def count_deep_iterable_levels(deep_iterable, level=0, return_levels=False):
    # If return_levels, returns a nested list of the deep_iterable level
    # Else returns the max level of the iterable
    
    if isinstance(deep_iterable, (GeneratorType, list, tuple)):
        deep_iterable = list(deep_iterable)
        levels = [count_deep_iterable_levels(shallow_iterable, level + 1, return_levels=return_levels) for shallow_iterable in deep_iterable]
        if return_levels:
            return levels
        else:
            if len(levels) == 0:
                return level
            return max(levels)
    else:
        return level
    
iterable = [
    [1],
    [
        [1], [1], []
    ]
]
count_deep_iterable_levels(
    iterable, return_levels=False
)

3

In [358]:

def check_constraints(constraints, idx, kwargs):
    constraints = {**constraints[idx]}
    for kwarg in constraints:
        print(kwarg)
        if kwargs[kwarg] not in constraints[kwarg]:
            return False
    return True

def constrained_group_iterator(split_kwargs_dicts, d=None, level=None, constraining_function=None, level_idx_dict=None):
    constraining_function = constraining_function or (lambda l, i, kwargs: (kwargs, True))
    level = level or 0
    level_idx_dict = level_idx_dict or dict()
#     level = level or len(split_kwargs_dicts) - 1
    d = d or dict()

    for idx, inner_d in enumerate(product_dict(**split_kwargs_dicts[level])):
        level_idx_dict[level] = idx

#         if level == 0:
        kwargs = {**d, **inner_d}
        kwargs, valid = constraining_function(level, level_idx_dict, kwargs)
        if not valid:
            continue

        if level == (len(split_kwargs_dicts) - 1):
            yield kwargs
        else:
#             yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level-1)
            yield constrained_group_iterator(split_kwargs_dicts, d=kwargs, level=level + 1, constraining_function=constraining_function, level_idx_dict=level_idx_dict)

def constrained_split_group_iterator(split_kwargs_dicts):

    for iteration_splits_kwargs in group_iterator(split_kwargs_dicts):
        yield [
            Split(
                [
                    dict(
                        **split_kwargs
                    )
                    for split_kwargs in splits_kwargs_list
                ]
            )
            for splits_kwargs_list in iteration_splits_kwargs
        ]

def create_splitter_constraint_fn(splitter, uids):
    split_df = pd.DataFrame()
    for fold, (train_idx, test_idx) in enumerate(splitter.split(uids)):
        train_uids, test_uids = uids[train_idx], uids[test_idx]
        
        train_split_df = pd.DataFrame()
        train_split_df["uid"] = train_uids
        train_split_df["fold"] = fold
        train_split_df["group"] = "train"
        test_split_df = pd.DataFrame()
        test_split_df["uid"] = test_uids
        test_split_df["fold"] = fold
        test_split_df["group"] = "test"
        split_df = pd.concat(
            [
                split_df,
                pd.concat([train_split_df, test_split_df], axis=0),
            ],
            axis=0
        )
        
    def my_constraint_fn(level, level_idx_dict, kwargs):
        kwargs = {**kwargs}

        if (not "group" in kwargs) or (not "uid" in kwargs):
            return kwargs, True

        fold = kwargs["fold"]
        group = kwargs["group"]
        uid = kwargs["uid"]
#         r = uid in split_df.query("group == @group").query("fold == @fold").uid.to_numpy()
        r = uid in split_df[(split_df.group == group) & (split_df.fold == fold)].uid.to_numpy()
        kwargs.pop("group")
        kwargs.pop("fold")

        return kwargs, r
    
    return my_constraint_fn

uids = np.arange(8)
k = 4
kfold_iterable = constrained_group_iterator(
    [
        dict(fold=np.arange(k)),
        dict(group=["train", "test"]),
        dict(uid=uids),
    ],
    constraining_function=create_splitter_constraint_fn(KFold(k), uids)
)
list(unpack_deep_iterable(kfold_iterable))


[[[{'uid': 2}, {'uid': 3}, {'uid': 4}, {'uid': 5}, {'uid': 6}, {'uid': 7}],
  [{'uid': 0}, {'uid': 1}]],
 [[{'uid': 0}, {'uid': 1}, {'uid': 4}, {'uid': 5}, {'uid': 6}, {'uid': 7}],
  [{'uid': 2}, {'uid': 3}]],
 [[{'uid': 0}, {'uid': 1}, {'uid': 2}, {'uid': 3}, {'uid': 6}, {'uid': 7}],
  [{'uid': 4}, {'uid': 5}]],
 [[{'uid': 0}, {'uid': 1}, {'uid': 2}, {'uid': 3}, {'uid': 4}, {'uid': 5}],
  [{'uid': 6}, {'uid': 7}]]]

In [409]:
def kfold_split_group_iterator(splitter, uids=None, k=5):
    uids = np.arange(9)
    kfold_iterable = constrained_group_iterator(
        [
            dict(fold=np.arange(k)),
            dict(group=["train", "test"]),
            dict(uid=uids),
        ],
        constraining_function=create_splitter_constraint_fn(splitter, uids)
    )
    for iteration_splits_kwargs in kfold_iterable:
        yield [
            Split(
                [
                    dict(
                        **split_kwargs
                    )
                    for split_kwargs in splits_kwargs_list
                ]
            )
            for splits_kwargs_list in iteration_splits_kwargs
        ]
print(
    json.dumps(
        list(unpack_deep_iterable(kfold_split_group_iterator(KFold(4)))),
        default=str,
        indent=4
    ),
)

[
    [
        "Split({'uid': 3},{'uid': 4},{'uid': 5},{'uid': 6},{'uid': 7},{'uid': 8})",
        "Split({'uid': 0},{'uid': 1},{'uid': 2})"
    ],
    [
        "Split({'uid': 0},{'uid': 1},{'uid': 2},{'uid': 5},{'uid': 6},{'uid': 7},{'uid': 8})",
        "Split({'uid': 3},{'uid': 4})"
    ],
    [
        "Split({'uid': 0},{'uid': 1},{'uid': 2},{'uid': 3},{'uid': 4},{'uid': 7},{'uid': 8})",
        "Split({'uid': 5},{'uid': 6})"
    ],
    [
        "Split({'uid': 0},{'uid': 1},{'uid': 2},{'uid': 3},{'uid': 4},{'uid': 5},{'uid': 6})",
        "Split({'uid': 7},{'uid': 8})"
    ],
    [
        "Split()",
        "Split()"
    ]
]


In [372]:
for train, test in KFold(5).split(np.arange(8)):
    print(train)
    print(test)
    print()

[2 3 4 5 6 7]
[0 1]

[0 1 4 5 6 7]
[2 3]

[0 1 2 3 6 7]
[4 5]

[0 1 2 3 4 5 7]
[6]

[0 1 2 3 4 5 6]
[7]



In [195]:
def optimized_constrained_group_iterator(split_kwargs_dicts, d, level, constraining_function, level_idx_dict):

    for idx, inner_d in enumerate(product_dict(**split_kwargs_dicts[level])):
        level_idx_dict[level] = idx

#         if level == 0:
        kwargs = {**d, **inner_d}
        kwargs, valid = constraining_function(level, level_idx_dict, kwargs)
        if not valid:
            continue

        if level == (len(split_kwargs_dicts) - 1):
            yield kwargs
        else:
            
#             yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level-1)
            yield constrained_group_iterator(split_kwargs_dicts, d=kwargs, level=level + 1, constraining_function=constraining_function, level_idx_dict=level_idx_dict)


In [196]:
k = 5
uids = np.arange(50)

In [197]:
%%timeit
unpack_deep_iterable(
    constrained_group_iterator(
        [
            dict(fold=np.arange(k)),
            dict(group=["train", "test"]),
            dict(uid=uids),
        ],
        constraining_function=create_splitter_constraint_fn(KFold(k), uids)
    )
)

327 ms ± 17.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [198]:
%%timeit
unpack_deep_iterable(
    optimized_constrained_group_iterator(
        [
            dict(fold=np.arange(k)),
            dict(group=["train", "test"]),
            dict(uid=uids),
        ],
        dict(),
        0,
        create_splitter_constraint_fn(KFold(k), uids),
        dict()
    )
)

308 ms ± 17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [200]:
from copy import deepcopy, copy

uids = np.arange(8)
k = 4
kfold_iterable = constrained_group_iterator(
    [
        dict(fold=np.arange(k + 1)),
        dict(group=["train", "test"]),
        dict(uid=uids),
    ],
    constraining_function=create_splitter_constraint_fn(KFold(k), uids)
)
print(list(unpack_deep_iterable(kfold_iterable)))

def is_deeply_empty(deep_list):
    if not isinstance(deep_list, (GeneratorType, tuple, list)):
        return False
    
    # If deep list is a generator, iterate and unpack it
    deep_list = list(deep_list)
    
    # If it is empty, it is obviously deeply_empty
    if len(deep_list) == 0:
        return True
            
    # If not, it can only be deeply empty if all inner iterables are deeply_empty
    is_empty_list = [is_deeply_empty(shallow_list) for shallow_list in deep_list]
    
    return all(is_empty_list)

def prune_nested_list(deep_list):
    # Keep levels as a nested list
    if isinstance(deep_list, (GeneratorType, tuple, list)):
        # If deep_iterable is iterable, just make sure that if it is a generator that it is iterated
        deep_list = list(deep_list)
        lists =  [
            prune_nested_list(shallow_list)
            for shallow_list in deep_list
            if not is_deeply_empty(shallow_list)
        ]
        return lists
    else:
        return deep_list
    
kfold_iterable = constrained_group_iterator(
    [
        dict(fold=np.arange(k + 1)),
        dict(group=["train", "test"]),
        dict(uid=uids),
    ],
    constraining_function=create_splitter_constraint_fn(KFold(k), uids)
)
# 
print(prune_nested_list(unpack_deep_iterable(kfold_iterable)))


[[[{'uid': 2}, {'uid': 3}, {'uid': 4}, {'uid': 5}, {'uid': 6}, {'uid': 7}], [{'uid': 0}, {'uid': 1}]], [[{'uid': 0}, {'uid': 1}, {'uid': 4}, {'uid': 5}, {'uid': 6}, {'uid': 7}], [{'uid': 2}, {'uid': 3}]], [[{'uid': 0}, {'uid': 1}, {'uid': 2}, {'uid': 3}, {'uid': 6}, {'uid': 7}], [{'uid': 4}, {'uid': 5}]], [[{'uid': 0}, {'uid': 1}, {'uid': 2}, {'uid': 3}, {'uid': 4}, {'uid': 5}], [{'uid': 6}, {'uid': 7}]], [[], []]]
[[[{'uid': 2}, {'uid': 3}, {'uid': 4}, {'uid': 5}, {'uid': 6}, {'uid': 7}], [{'uid': 0}, {'uid': 1}]], [[{'uid': 0}, {'uid': 1}, {'uid': 4}, {'uid': 5}, {'uid': 6}, {'uid': 7}], [{'uid': 2}, {'uid': 3}]], [[{'uid': 0}, {'uid': 1}, {'uid': 2}, {'uid': 3}, {'uid': 6}, {'uid': 7}], [{'uid': 4}, {'uid': 5}]], [[{'uid': 0}, {'uid': 1}, {'uid': 2}, {'uid': 3}, {'uid': 4}, {'uid': 5}], [{'uid': 6}, {'uid': 7}]]]


In [17]:
uids = np.arange(8)
k = 4
kfold_iterable = constrained_group_iterator(
    [
        dict(fold=np.arange(9)),
        dict(group=["train", "test"]),
        dict(uid=uids),
    ],
    constraining_function=create_kfold_constraint_fn(LeaveOneOut(), uids)
)
list(unpack_deep_iterable(kfold_iterable))

[[[{'uid': 1},
   {'uid': 2},
   {'uid': 3},
   {'uid': 4},
   {'uid': 5},
   {'uid': 6},
   {'uid': 7}],
  [{'uid': 0}]],
 [[{'uid': 0},
   {'uid': 2},
   {'uid': 3},
   {'uid': 4},
   {'uid': 5},
   {'uid': 6},
   {'uid': 7}],
  [{'uid': 1}]],
 [[{'uid': 0},
   {'uid': 1},
   {'uid': 3},
   {'uid': 4},
   {'uid': 5},
   {'uid': 6},
   {'uid': 7}],
  [{'uid': 2}]],
 [[{'uid': 0},
   {'uid': 1},
   {'uid': 2},
   {'uid': 4},
   {'uid': 5},
   {'uid': 6},
   {'uid': 7}],
  [{'uid': 3}]],
 [[{'uid': 0},
   {'uid': 1},
   {'uid': 2},
   {'uid': 3},
   {'uid': 5},
   {'uid': 6},
   {'uid': 7}],
  [{'uid': 4}]],
 [[{'uid': 0},
   {'uid': 1},
   {'uid': 2},
   {'uid': 3},
   {'uid': 4},
   {'uid': 6},
   {'uid': 7}],
  [{'uid': 5}]],
 [[{'uid': 0},
   {'uid': 1},
   {'uid': 2},
   {'uid': 3},
   {'uid': 4},
   {'uid': 5},
   {'uid': 7}],
  [{'uid': 6}]],
 [[{'uid': 0},
   {'uid': 1},
   {'uid': 2},
   {'uid': 3},
   {'uid': 4},
   {'uid': 5},
   {'uid': 6}],
  [{'uid': 7}]],
 [[], []]]

In [408]:
def split_group_iterator(split_kwargs_dicts):

    for iteration_splits_kwargs in group_iterator(split_kwargs_dicts):
        yield [
            Split(
                [
                    dict(
                        **split_kwargs
                    )
                    for split_kwargs in splits_kwargs_list
                ]
            )
            for splits_kwargs_list in iteration_splits_kwargs
        ]
        
def create_split_group_iterator(outer_split_kwargs=None, inner_split_kwargs=None, merge_kwargs=None):
    outer_split_kwargs = outer_split_kwargs or dict()
    inner_split_kwargs = inner_split_kwargs or dict()
    merge_kwargs = merge_kwargs or dict()
    
    return split_group_iterator(
        [
            outer_split_kwargs,
            inner_split_kwargs,
            merge_kwargs ,        
        ]
    ) 
inter_run_iterator = create_split_group_iterator(
    dict(uid=[1,], session=[1, 2]),
    dict(),
    dict(run=[1, 2]),        
)
intra_run_iterator = create_split_group_iterator(
    dict(uid=[1,], session=[1, 2], run=[1, 2]),
    dict(),
    dict(),
)
inter_session_iterator = create_split_group_iterator(
    dict(uid=[1,]),
    dict(session=[1, 2]),
    dict(run=[1, 2]),
)
inter_session_inter_run_iterator = create_split_group_iterator(
    dict(uid=[1,], run=[1, 2]),
    dict(session=[1, 2]),
    dict(),
)
intra_session_iterator = create_split_group_iterator(
    dict(uid=[1, 2], session=[1, 2]),
    dict(),
    dict(run=[1, 2]),
)
intra_session_inter_run_iterator = create_split_group_iterator(
    dict(uid=[1, 2], session=[1, 2]),
    dict(run=[1, 2]),
    dict(),
)
inter_subject = create_split_group_iterator(
    dict(),
    dict(uid=[1, 2, 3]),
    dict(run=[1, 2], session=[1, 2]),
)


list(unpack_deep_iterable(inter_run_iterator))
list(unpack_deep_iterable(intra_run_iterator))
list(unpack_deep_iterable(inter_session_iterator))
# list(unpack_deep_iterable(inter_session_inter_run_iterator))
list(unpack_deep_iterable(intra_session_iterator))
# list(unpack_deep_iterable(inter_run_iterator))
list(unpack_deep_iterable(intra_session_inter_run_iterator))


[[Split({'uid': 1, 'session': 1, 'run': 1}),
  Split({'uid': 1, 'session': 1, 'run': 2})],
 [Split({'uid': 1, 'session': 2, 'run': 1}),
  Split({'uid': 1, 'session': 2, 'run': 2})],
 [Split({'uid': 2, 'session': 1, 'run': 1}),
  Split({'uid': 2, 'session': 1, 'run': 2})],
 [Split({'uid': 2, 'session': 2, 'run': 1}),
  Split({'uid': 2, 'session': 2, 'run': 2})]]

In [464]:
remove_key(dict(uid=1, session=1, run=1), "uid")

{'session': 1, 'run': 1}

In [473]:
class Split():
    
    def __init__(self, kwarg_dict_list, fold_sizes=None):
        self.kwargs_list = kwarg_dict_list
        self.fold_sizes = fold_sizes
    
    def to_dataframe():
        return pd.DataFrame.from_records(self.kwargs_list)

    def __repr__(self):
        dict_reps = [str(d) for d in self.kwargs_list]
        return "Split({})".format(",".join(dict_reps))
    
    def load_epochs(self, dataset, **load_kwargs):
        epochs = mne.concatenate_epochs(
            [
                dataset.load_subject(kwargs["uid"], **remove_key(kwargs, "uid"), **load_kwargs)[0]
                for kwargs in self.kwargs_list
            ]
        )
        return epochs


    
inter_session_iterator = create_split_group_iterator(
    dict(uid=['1', '2']),
    dict(session=[1, 2]),
    dict(run=[1, 2],),
)
for fold_splits in inter_session_iterator:
    train_split, test_split = fold_splits
    print(train_split.load_epochs(openbmi_dataset, reject=False))

  epochs = mne.concatenate_epochs(


<Epochs |  200 events (all good), -0.3 - 0.7 sec, baseline off, ~94.8 MB, data loaded,
 '0': 100
 '1': 100>


  epochs = mne.concatenate_epochs(


<Epochs |  200 events (all good), -0.3 - 0.7 sec, baseline off, ~94.8 MB, data loaded,
 '0': 100
 '1': 100>


In [470]:
uid = 1
session = 1
run = 1
filepaths_df = openbmi_dataset.list_subject_filepaths()
filepath = (
    filepaths_df.query("uid == @uid")
    .query("run == @run")
    .query("session == @session")
    .path
)
filepath

Series([], Name: path, dtype: object)

In [503]:
def transformer_from_processor(processor_fn):
    
    def transformer(*args, **kwargs):
        item = processor_fn(*args, **kwargs)
        if len(item)

def myfn(item):
    if isinstance(item, ITERABLES_TYPES):
        if count_deep_iterable_levels(item) == 2:
            return item[0]
    return item

def apply_to_iterator(iterator, fn, level=2):
    for item in iterator:
        if isinstance(item, ITERABLES_TYPES):
            if count_deep_iterable_levels(item) == level:
                
                yield fn(item)
            else:
                yield apply_to_iterator(item, fn, level=level)
        else:
            yield fn(item)
intra_session_iterator = create_split_group_iterator(
    dict(uid=[1, 2], session=[1, 2]),
    dict(),
    dict(run=[1, 2]),
)
# count_deep_iterable_levels(intra_session_iterator)
list(
    unpack_deep_iterable(
        apply_to_iterator(
            intra_session_iterator,
            myfn,
#             lambda x: x,
            level=2
        )
    )
)


[[Split({'uid': 1, 'session': 1, 'run': 1},{'uid': 1, 'session': 1, 'run': 2})],
 [Split({'uid': 1, 'session': 2, 'run': 1},{'uid': 1, 'session': 2, 'run': 2})],
 [Split({'uid': 2, 'session': 1, 'run': 1},{'uid': 2, 'session': 1, 'run': 2})],
 [Split({'uid': 2, 'session': 2, 'run': 1},{'uid': 2, 'session': 2, 'run': 2})]]

In [420]:
import time

def apply_to_iterator(iterator, fn):
    for item in iterator:
        if isinstance(item, ITERABLES_TYPES):
            yield apply_to_iterator(item, fn)
        else:
            yield fn(item)

iterator = [
    [
        1, 2, 3,
        [
            3, 7
        ]
    ],
    [
        3,
        [
            3, 4
        ]
    ]
]

print(
    list(
        unpack_deep_iterable(
            apply_to_iterator(iterator, lambda x: x ** 2)
        )
    )
)

[[1, 4, 9, [9, 49]], [9, [9, 16]]]


In [440]:
a = [[1]]
unpack_deep_iterable(chain(*[[1], [2, 3], [[2, 3]]]))
# count_deep_iterable_levels(a)

[1, 2, 3, [2, 3]]

In [281]:

iterator_1 = group_iterator(
    [
        dict(uid=[1,]),
        dict(session=[1, 2]),
        dict(run=[1, 2]),
    ]
)

iterator_2 = group_iterator(
    [
        dict(uid=[1,]),
        dict(session=[1, 2]),
        dict(run=[1, 2]),
    ]
)

list(unpack_deep_iterable(iterator_1))

TypeError: __main__.product_dict() argument after ** must be a mapping, not list

In [218]:
from itertools import permutations

list(permutations([1, 2, 3], 3))

[(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]

In [252]:
# def product_dict(**kwargs):
#     keys = kwargs.keys()
#     vals = kwargs.values()
#     for instance in product(*vals):
#         yield dict(zip(keys, instance))

def combinatorial_product_dict(**kwargs):
    keys = kwargs.keys()
    vals = kwargs.values()
    vals_combinations = [permutations(val) for val in vals]
    for vals_combination in product(*vals_combinations):
        for instance in vals_combination:
            yield dict(zip(keys, instance))

list(
    product_dict(
        a=permutations([1, 2]),
        b=permutations([3, 4])
    )
)

[{'a': (1, 2), 'b': (3, 4)},
 {'a': (1, 2), 'b': (4, 3)},
 {'a': (2, 1), 'b': (3, 4)},
 {'a': (2, 1), 'b': (4, 3)}]

In [258]:
from itertools import chain

def concatenate_deep_iterators(*iterators):
    iterators = [item if isinstance(item, ITERABLES_TYPES) else [item] for item in iterators]
    
    chained_iterators = chain(*iterators)
    return chained_iterators
list(concatenate_deep_iterators([1, [2]], [[4], [5]], 3))


[1, [2], 4, [5], 3]

In [381]:
import json
from copy import copy, deepcopy


#Needs fix
# def group_iterator(split_kwargs_dicts, d=None, level=None):
    
#     level = level or 0
# #     level = level or len(split_kwargs_dicts) - 1
#     d = d or dict()


#     for inner_d in product_dict(**split_kwargs_dicts[level]): 

# #         if level == 0:
#         if level == (len(split_kwargs_dicts) - 1):

#             yield {**d, **inner_d}
#         else:

# #             yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level-1)
#             yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level + 1)


def combinatorial_group_iterator(split_kwargs_dicts, make_combinations=None, d=None, level=None):
    
    level = level or 0
#     level = level or len(split_kwargs_dicts) - 1
    d = d or dict()
    make_combinations = make_combinations or [dict() for level in split_kwargs_dicts]

    for inner_d in product_dict(**split_kwargs_dicts[level]): 

#         if level == 0:
        if level == (len(split_kwargs_dicts) - 1):

            yield {**d, **inner_d}
        else:
            
            if level < (len(split_kwargs_dicts) - 1):
                new_make_combinations = copy(make_combinations)
                new_make_combinations[level + 1] = False

                permutation_dict = {
                    k: permutations(v) if make_combinations[level][k] else [v]
                    for k, v
                    in split_kwargs_dicts[level + 1].items()
                }

                iterators = list()
                for permuted_d in product_dict(**permutation_dict):
                    new_split_kwargs_dicts = deepcopy(split_kwargs_dicts)
                    new_split_kwargs_dicts[level + 1] = permuted_d
                    iterator = combinatorial_group_iterator(new_split_kwargs_dicts, new_make_combinations, d={**d, **inner_d}, level=level + 1)
                    iterators.append(iterator)
        #             return iterator
                yield list(chain(*iterators))
            else:

#             yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level-1)
                yield combinatorial_group_iterator(split_kwargs_dicts, make_combinations, d={**d, **inner_d}, level=level + 1)


print(
    "Normal",
    json.dumps(
        list(
            unpack_deep_iterable(
                group_iterator(
                    [
                        dict(uid=[1,]),
                        dict(session=[1, 2]),
                        dict(run=[1, 2]),
                    ]
                )
            )
        ),
        indent=4
    )
)
        
print(
    "Combinatorial",
    json.dumps(
        list(
            unpack_deep_iterable(
                combinatorial_group_iterator(
                    [
                        dict(uid=[1,]),
                        dict(session=[1, 2]),
                        dict(run=[1, 2]),
                    ],
                    make_combinations=[
                        dict(uid=False),
                        dict(session=True),
                        dict(run=False)
                    ]
                ),
            )
        ),
        indent=4
    )
)

Normal [
    [
        [
            {
                "uid": 1,
                "session": 1,
                "run": 1
            },
            {
                "uid": 1,
                "session": 1,
                "run": 2
            }
        ],
        [
            {
                "uid": 1,
                "session": 2,
                "run": 1
            },
            {
                "uid": 1,
                "session": 2,
                "run": 2
            }
        ]
    ]
]


KeyError: 'session'

In [223]:
list(
    combinatorial_product_dict(
        a=[1, 2, 3],
        b=[3, 4]
    )
)

[{'a': 1, 'b': 3},
 {'a': 1, 'b': 4},
 {'a': 2, 'b': 3},
 {'a': 2, 'b': 4},
 {'a': 3, 'b': 3},
 {'a': 3, 'b': 4},
 {'a': 3, 'b': 1},
 {'a': 3, 'b': 2},
 {'a': 3, 'b': 3},
 {'a': 4, 'b': 1},
 {'a': 4, 'b': 2},
 {'a': 4, 'b': 3}]

In [23]:
folds_epochs_list[1]

(<Epochs |  120 events (all good), -0.3 - 0.7 sec, baseline off, ~56.9 MB, data loaded,
  '0': 63
  '1': 57>,
 <Epochs |  120 events (all good), -0.3 - 0.7 sec, baseline off, ~56.9 MB, data loaded,
  '0': 63
  '1': 57>)

In [12]:
filepaths_df = openbmi_dataset.list_subject_filepaths()
uid = '25'
session = 1
train = True
(
    filepaths_df.query("uid == @uid")
    .query("train == @train")
    .query("session == @session")
    .path
)

108    /home/paulo/Documents/datasets/OpenBMI/edf/ses...
Name: path, dtype: object

In [13]:
filepaths_df.query("uid == @uid").query("train == @train").query("session == @session")

Unnamed: 0,path,train,session,uid
108,/home/paulo/Documents/datasets/OpenBMI/edf/ses...,True,1,25


In [77]:
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline
from sklearn.base import BaseEstimator
from sklearn.metrics import balanced_accuracy_score

class GetData(BaseEstimator):
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X, y=None):
        filter_kwargs = dict(
            method="iir",
            iir_params=dict(
                order=5,
                ftype="butter"
            )
        )
        return X.load_data().filter(8, 30, **filter_kwargs).get_data()


# kwargs = dict(
#     uids=["1", "2", "3", "4"],
#     inter_session=False,
#     inter_subject=True,
#     sessions=[2],
#     train_folds=[True, False],
#     fold_sizes=[.5, .5],
#     load_kwargs=dict(
#         reject=False
#     )
# )
kwargs = dict(
    uids=["1", "2", "3", "4"],
    inter_session=False,
    inter_subject=False,
    sessions=[1],
    train_folds=[True, False],
    fold_sizes=[.25, .75],
    load_kwargs=dict(
        tmin=1,
        tmax=3.5,
        reject=False
    )
)
for j, splits in enumerate(OpenBMI_Splitter(openbmi_dataset).make_splits(**kwargs)):
    infos = [info for info, split in splits]
    splits = [split for info, split in splits]
    print(infos)
    clf = make_pipeline(
        GetData(),
        CSP(5, log=True),
        LDA()
    )
    train_epochs = splits[0]
    train_labels = train_epochs.events[:, 2]
    
    test_epochs = splits[1]
    test_labels = test_epochs.events[:, 2]
    
    clf.fit(train_epochs, train_labels)
    pred = clf.predict(test_epochs)
    acc = balanced_accuracy_score(test_labels, pred)
    print(acc)

  epochs = mne.concatenate_epochs(


[{'uid': '1', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '1', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.5903271692745377


  epochs = mne.concatenate_epochs(


[{'uid': '2', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '2', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.8466749866286325


  epochs = mne.concatenate_epochs(


[{'uid': '3', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '3', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.8866999465145302


  epochs = mne.concatenate_epochs(


[{'uid': '4', 'session': 1, 'train_folds': [True, False], 'size': 0.25}, {'uid': '4', 'session': 1, 'train_folds': [True, False], 'size': 0.75}]
0.47759601706970123


In [71]:
train_labels

array([0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,
       1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1,
       0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
       1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0])

In [46]:
[f"{i} {j}" for i in range(3) for j in range(3)]

['0 0', '0 1', '0 2', '1 0', '1 1', '1 2', '2 0', '2 1', '2 2']

In [20]:
def split_array(arr, n_splits=2, sizes=None):
    n = len(arr)
    sizes = sizes or [1 / n] * n_splits
    assert np.sum(sizes) == 1.
    sizes = [int(size * n) for size in sizes]
    sizes = [0] + sizes
    sizes = np.cumsum(sizes)
    slices = [slice(start, end) for start, end in zip(sizes[:-1], sizes[1:])]
    arrs = [arr[s] for s in slices]
    return arrs

arrs = split_array(np.random.rand(100), sizes=[.2, .3, .1, .4])
list(map(len, arrs))

[20, 30, 10, 40]

In [44]:
split_array(openbmi_dataset.list_uids(), sizes=[.2, .3, .5])

[array(['25', '15', '41', '12', '37', '2', '42', '6', '52', '18'],
       dtype=object),
 array(['38', '30', '34', '5', '11', '44', '29', '8', '17', '33', '46',
        '23', '19', '4', '13', '22'], dtype=object),
 array(['35', '45', '36', '39', '16', '50', '10', '53', '7', '28', '27',
        '14', '32', '20', '31', '51', '24', '9', '21', '40', '47', '26',
        '49', '54', '43', '1', '3'], dtype=object)]

In [556]:
import pandas as pd
from sklearn.model_selection import KFold
from warnings import warn
from itertools import product, chain
import numpy as np
from mne import Epochs
import mne
from types import GeneratorType
from ica_benchmark.io.load import OpenBMI_Dataset
from collections.abc import Iterable
from pathlib import Path

ITERABLES_TYPES = (list, tuple, GeneratorType, product, chain)


def apply_to_iterator(iterator, fn):
    for item in iterator:
        if isinstance(item, ITERABLES_TYPES):
            yield apply_to_iterator(item, fn)
        else:
            yield fn(item)


def unpack_deep_iterable(deep_iterable):
    # Keep levels as a nested list
    if isinstance(deep_iterable, (GeneratorType, tuple, list)):
        # If deep_iterable is iterable, just make sure that if it is a generator that it is iterated
        deep_iterable = list(deep_iterable)
        return [
            unpack_deep_iterable(shallow_iterable)
            for shallow_iterable in deep_iterable
        ]
    else:
        return deep_iterable


def flatten_deep_iterable(deep_iterable):
    # Returns a flat iterator of all items that are not in ITERABLES_TYPES inside deep_iterable
    for item in deep_iterable:
        if isinstance(item, ITERABLES_TYPES):
            for nested_item in flatten_deep_iterable(item):
                yield nested_item
        else:
            yield item

def apply_to_iterator(iterator, fn):
    for item in iterator:
        if isinstance(item, ITERABLES_TYPES):
            yield apply_to_iterator(item, fn)
        else:
            yield fn(item)


def make_epochs_splits_indexes(arr, n=None, n_splits=2, sizes=None, shuffle=False, seed=1):
    if not isinstance(arr, (Epochs,)):
        arr = np.array(arr)

    np.random.seed(seed)
    if n is None:
        if isinstance(arr, Epochs):
            n = len(arr.events)
        else:
            n = len(arr)

    sizes = sizes or [1 / n_splits] * n_splits

    assert np.sum(sizes) == 1.

    sizes = np.cumsum(
        [0] + [int(size * n) for size in sizes]
    )

    idx = np.arange(n)
    if shuffle:
        np.random.shuffle(idx)
    slices = [slice(start, end) for start, end in zip(sizes[:-1], sizes[1:])]
    indexes = [idx[s] for s in slices]
    return indexes


def make_epochs_splits(arr, n=None, n_splits=2, sizes=None, shuffle=False, seed=1):
    indexes = make_epochs_splits_indexes(arr, n=n, n_splits=n_splits, sizes=sizes, shuffle=shuffle, seed=seed)
    arrs = [arr[idx] for idx in indexes]
    return arrs


def remove_key(d, k):
    return {
        key: value
        for key, value in d.items()
        if key != k
    }


def product_dict(**kwargs):
    keys = kwargs.keys()
    vals = kwargs.values()
    for instance in product(*vals):
        yield dict(zip(keys, instance))


def insideout_group_iterator(split_kwargs_dicts, d=None, level=None):

    level = level or len(split_kwargs_dicts) - 1

    for inner_d in product_dict(**split_kwargs_dicts[level]):
        if level == 0:
            yield {**d, **inner_d}
        else:
            yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level - 1)


def group_iterator(split_kwargs_dicts, d=None, level=None):

    level = level or 0
    d = d or dict()

    for inner_d in product_dict(**split_kwargs_dicts[level]): 
        if level == (len(split_kwargs_dicts) - 1):
            yield {**d, **inner_d}
        else:
            yield group_iterator(split_kwargs_dicts, d={**d, **inner_d}, level=level + 1)


def split_group_iterator(split_kwargs_dicts):

    for iteration_splits_kwargs in group_iterator(split_kwargs_dicts):
        yield [
            Split(
                [
                    dict(
                        **split_kwargs
                    )
                    for split_kwargs in splits_kwargs_list
                ]
            )
            for splits_kwargs_list in iteration_splits_kwargs
        ]


def create_split_group_iterator(outer_split_kwargs=None, inner_split_kwargs=None, merge_kwargs=None):
    outer_split_kwargs = outer_split_kwargs or dict()
    inner_split_kwargs = inner_split_kwargs or dict()
    merge_kwargs = merge_kwargs or dict()

    return split_group_iterator(
        [
            outer_split_kwargs,
            inner_split_kwargs,
            merge_kwargs,
        ]
    )


class Split():

    def __init__(self, kwarg_dict_list):
        self.kwargs_list = kwarg_dict_list

    def to_dataframe(self):
        return pd.DataFrame.from_records(self.kwargs_list)

    def __repr__(self):
        dict_reps = [str(d) for d in self.kwargs_list]
        return "Split({})".format(",".join(dict_reps))

    def load_epochs(self, dataset, **load_kwargs):
        epochs = mne.concatenate_epochs(
            [
                dataset.load_subject(kwargs["uid"], **remove_key(kwargs, "uid"), **load_kwargs)[0]
                for kwargs in self.kwargs_list
            ]
        )
        return epochs


def constrained_group_iterator(split_kwargs_dicts, d=None, level=None, constraining_function=None, level_idx_dict=None):
    constraining_function = constraining_function or (lambda l, i, kwargs: (kwargs, True))
    level = level or 0
    level_idx_dict = level_idx_dict or dict()
    d = d or dict()

    for idx, inner_d in enumerate(product_dict(**split_kwargs_dicts[level])):
        level_idx_dict[level] = idx

        kwargs = {**d, **inner_d}
        kwargs, valid = constraining_function(level, level_idx_dict, kwargs)
        if not valid:
            continue

        if level == (len(split_kwargs_dicts) - 1):
            yield kwargs
        else:
            yield constrained_group_iterator(
                split_kwargs_dicts,
                d=kwargs,
                level=level + 1,
                constraining_function=constraining_function,
                level_idx_dict=level_idx_dict
            )


def constrained_split_group_iterator(split_kwargs_dicts):

    for iteration_splits_kwargs in group_iterator(split_kwargs_dicts):
        yield [
            Split(
                [
                    dict(
                        **split_kwargs
                    )
                    for split_kwargs in splits_kwargs_list
                ]
            )
            for splits_kwargs_list in iteration_splits_kwargs
        ]


def create_splitter_constraint_fn(splitter, uids):

    # Creates a dataframe with columns fold, uid and group (train or test)
    # It is only used to later check if volunteer with uid is in train or test for each fold.
    splits_dfs = list()
    for fold, splits_idx in enumerate(splitter.split(uids)):
        for group_i, group_uid_idx in enumerate(splits_idx):
            split_uids = uids[group_uid_idx]
            split_df = pd.DataFrame()
            split_df["uid"] = split_uids
            split_df["fold"] = fold
            split_df["group"] = group_i
            splits_dfs.append(split_df)

    split_df = pd.concat(splits_dfs, axis=0)

    def my_constraint_fn(level, level_idx_dict, kwargs):
        kwargs = {**kwargs}

        if ("group" not in kwargs) or ("uid" not in kwargs):
            return kwargs, True

        r = (
            kwargs["uid"] in 
            split_df[(split_df.group == kwargs["group"]) & (split_df.fold == kwargs["fold"])].uid.to_numpy()
        )
        kwargs.pop("group")
        kwargs.pop("fold")

        return kwargs, r

    return my_constraint_fn


def kfold_split_group_iterator(splitter, uids, n_groups=2):
    
    kfold_iterable = constrained_group_iterator(
        [
            dict(fold=np.arange(splitter.get_n_splits())),
            dict(group=np.arange(n_groups)),
            dict(uid=uids),
        ],
        constraining_function=create_splitter_constraint_fn(splitter, uids)
    )
    for iteration_splits_kwargs in kfold_iterable:
        yield [
            Split(
                [
                    dict(
                        **split_kwargs
                    )
                    for split_kwargs in splits_kwargs_list
                ]
            )
            for splits_kwargs_list in iteration_splits_kwargs
        ]


class Splitter():

    SESSION_KWARGS = dict(intra=dict(), inter=dict())

    def default_splitter(self):
        splitter = KFold(4)
        warn("Using default splitter: " + str(splitter))
        return splitter

    def __init__(self, dataset, uids, sessions, runs, load_kwargs=None, splitter=None, unsafe=False, intra_session_shuffle=False, fold_sizes=None):
        self.dataset = dataset
        self.uids = uids
        self.sessions = sessions
        self.runs = runs
        self.load_kwargs = load_kwargs or load_kwargs
        self.splitter = splitter or self.default_splitter()
        self.intra_session_shuffle = intra_session_shuffle
        self.fold_sizes = fold_sizes

    def validate_config(self, mode):
        valid_modes = [
            "inter_subject",
            "inter_session",
            "intra_session_intra_run",
            "intra_session_inter_run",
            "intra_session_intra_run_merged"
        ]
        fold_sizes = self.fold_sizes
        assert mode in valid_modes, "Please choose one mode among the following: {}".format(", ".join(valid_modes))
        if mode == "inter_subject":
            if fold_sizes is not None:
                warn("You are using the inter_subject mode, so the fold_sizes argument will not be used")
        elif mode == "inter_session":
            if len(self.runs) > 1:
                warn("You are using inter session protocol with more than one run. All runs from each session will be concatenated and yielded in different steps.")
        elif (mode == "intra_session_inter_run"):
            if (len(self.runs) == 1):
                warn("You are using an intra session protocol, splitting by run, but only passed one run. The splitter can only yield one epoch at time (from the only run you passed as argument)")
        elif mode in ("intra_session_intra_run", "intra_session_intra_run_merged"):
            if fold_sizes is None:
                warn("You are using intra session intra run protocol with no fold sizes. The splitter will only yield one epoch at time")

    def kfold_split_group_iterator(self, n_groups=2):
    
        kfold_iterable = constrained_group_iterator(
            [
                dict(
                    fold=np.arange(
                        self.splitter.get_n_splits()
                    )
                ),
                dict(group=np.arange(n_groups)),
                dict(
                    uid=self.uids,
                    session=self.sessions,
                    run=self.runs
                ),
            ],
            constraining_function=create_splitter_constraint_fn(self.splitter, uids)
        )
        for iteration_splits_kwargs in kfold_iterable:
            yield [
                Split(
                    [
                        dict(
                            **split_kwargs
                        )
                        for split_kwargs in splits_kwargs_list
                    ]
                )
                for splits_kwargs_list in iteration_splits_kwargs
            ]
                
    def inter_subject(self, splitter=None):
        return self.kfold_split_group_iterator()

    def inter_session(self):
        inter_session_iterator = create_split_group_iterator(
            dict(uid=self.uids),
            dict(session=self.sessions),
            dict(run=self.runs),
        )
        return inter_session_iterator

    def intra_session_inter_run(self):
        intra_session_inter_run_iterator = create_split_group_iterator(
            dict(uid=self.uids, session=self.sessions),
            dict(run=self.runs),
            dict(),
        )
        return intra_session_inter_run_iterator

    def intra_session_intra_run(self, merge=False):
        # intra_session_intra_run_merge
        # Duas runs mergidas, precisa separar por pct
        if merge:
            return self.intra_session_intra_run_merged()
        else:
            # Cada run em seu experiment, mas ainda precisa separar por pct
            intra_run_iterator = create_split_group_iterator(
                dict(uid=self.uids, session=self.sessions, run=self.runs),
                dict(),
                dict(),
            )
            return intra_run_iterator

    def intra_session_intra_run_merged(self, merge=False):
        intra_session_iterator = create_split_group_iterator(
            dict(uid=self.uids, session=self.sessions),
            dict(),
            dict(run=self.runs),
        )
        return intra_session_iterator

    def yield_splits_epochs(self, mode):

        split_fn_dict = dict(
            # Intra subject, inter session
            inter_session=self.inter_session,
            # Inter subject, will concatenate all sessions and runs
            inter_subject=self.inter_subject,
            # Intra subject, intra_session, inter run (will split runs)
            intra_session_inter_run=self.intra_session_inter_run,
            # Intra subject, intra_session, intra run (will split using fold sizes)
            intra_session_intra_run=self.intra_session_intra_run,
            # Intra subject, intra_session, intra run (will merge all runs and split using fold sizes)
            intra_session_intra_run_merged=self.intra_session_intra_run_merged,
        )

        split_fn = split_fn_dict[mode]
        for fold_splits in split_fn():

            # if (len(fold_splits) == 1) and (fold_sizes is None):
            #     warn("This splitter return only one split and you passed no fold sizes for intra splitting. Is this what you want?")
                
            yield fold_splits

    def load_from_split(self, splits, fold_sizes=None):
        fold_sizes = fold_sizes or self.fold_sizes
        splits_epochs = [
            split.load_epochs(self.dataset, **self.load_kwargs)
            for split in splits
        ]
        print(fold_sizes)
        if fold_sizes is not None:
            assert len(splits_epochs) == 1, "You passed fold_sizes={} but there in more than one split".format(fold_sizes)
            splits_epochs = make_epochs_splits(
                splits_epochs[0],
                sizes=fold_sizes,
                shuffle=self.intra_session_shuffle
            )
        return splits_epochs




In [568]:
s = KFold(4)
kfold_iterable = constrained_group_iterator(
    [
        dict(
            fold=np.arange(
                s.get_n_splits()
            )
        ),
        dict(group=np.arange(2)),
        dict(
            uid=np.array(["1", "2", "3", "4"]),
            session=[1, 2],
            run=[1, 2]
        ),
    ],
    constraining_function=create_splitter_constraint_fn(s, uids)
)
list(unpack_deep_iterable(kfold_iterable))

  kwargs["uid"] in


[[[], []], [[], []], [[], []], [[], []]]

In [570]:
def create_splitter_constraint_fn(splitter, uids):

    # Creates a dataframe with columns fold, uid and group (train or test)
    # It is only used to later check if volunteer with uid is in train or test for each fold.
    splits_dfs = list()
    for fold, splits_idx in enumerate(splitter.split(uids)):
        for group_i, group_uid_idx in enumerate(splits_idx):
            split_uids = uids[group_uid_idx]
            split_df = pd.DataFrame()
            split_df["uid"] = split_uids
            split_df["fold"] = fold
            split_df["group"] = group_i
            splits_dfs.append(split_df)

    split_df = pd.concat(splits_dfs, axis=0)

    def my_constraint_fn(level, level_idx_dict, kwargs):
        kwargs = {**kwargs}

        if ("group" not in kwargs) or ("uid" not in kwargs):
            return kwargs, True

        r = (
            kwargs["uid"] in 
            split_df[(split_df.group == kwargs["group"]) & (split_df.fold == kwargs["fold"])].uid.to_numpy()
        )
        kwargs.pop("group")
        kwargs.pop("fold")

        return kwargs, r

    return my_constraint_fn
s = KFold(4)
f = create_splitter_constraint_fn(s, np.array(np.array(["1", "2", "3", "4"])))
uids = np.array(["1", "2", "3", "4"])
kfold_iterable = constrained_group_iterator(
    [
        dict(
            fold=np.arange(
                s.get_n_splits()
            )
        ),
        dict(group=np.arange(2)),
        dict(
            uid=uids,
            session=[1, 2],
            run=[1, 2]
        ),
    ],
    constraining_function=create_splitter_constraint_fn(s, uids)
)
list(unpack_deep_iterable(kfold_iterable))

[[[{'uid': '2', 'session': 1, 'run': 1},
   {'uid': '2', 'session': 1, 'run': 2},
   {'uid': '2', 'session': 2, 'run': 1},
   {'uid': '2', 'session': 2, 'run': 2},
   {'uid': '3', 'session': 1, 'run': 1},
   {'uid': '3', 'session': 1, 'run': 2},
   {'uid': '3', 'session': 2, 'run': 1},
   {'uid': '3', 'session': 2, 'run': 2},
   {'uid': '4', 'session': 1, 'run': 1},
   {'uid': '4', 'session': 1, 'run': 2},
   {'uid': '4', 'session': 2, 'run': 1},
   {'uid': '4', 'session': 2, 'run': 2}],
  [{'uid': '1', 'session': 1, 'run': 1},
   {'uid': '1', 'session': 1, 'run': 2},
   {'uid': '1', 'session': 2, 'run': 1},
   {'uid': '1', 'session': 2, 'run': 2}]],
 [[{'uid': '1', 'session': 1, 'run': 1},
   {'uid': '1', 'session': 1, 'run': 2},
   {'uid': '1', 'session': 2, 'run': 1},
   {'uid': '1', 'session': 2, 'run': 2},
   {'uid': '3', 'session': 1, 'run': 1},
   {'uid': '3', 'session': 1, 'run': 2},
   {'uid': '3', 'session': 2, 'run': 1},
   {'uid': '3', 'session': 2, 'run': 2},
   {'uid': '4

In [557]:

openbmi_dataset_folderpath = Path('/home/paulo/Documents/datasets/OpenBMI/edf/')
dataset = OpenBMI_Dataset(openbmi_dataset_folderpath)
fold_sizes = None
splitter = Splitter(
    dataset,
    uids=dataset.list_uids()[:4],
    sessions=dataset.SESSIONS,
    runs=dataset.RUNS,
    load_kwargs=dict(
        reject=False
    ),
    splitter=KFold(4),
    intra_session_shuffle=False,
    fold_sizes=fold_sizes
)

splits_iterable = splitter.yield_splits_epochs(mode="inter_subject")
for i, fold_splits in enumerate(splits_iterable):
    print(f"Fold {i}")
    print(f"\tSplits {fold_splits}")
#     epochs = splitter.load_from_split(fold_splits, fold_sizes=fold_sizes)
#     print(f"\tEpochs {epochs}")
#     print()
#     del epochs

Fold 0
	Splits [Split(), Split()]
Fold 1
	Splits [Split(), Split()]
Fold 2
	Splits [Split(), Split()]
Fold 3
	Splits [Split(), Split()]


  kwargs["uid"] in


In [560]:
splitter.sessions

[1, 2]

In [548]:
def create_splitter_constraint_fn(splitter, uids):
    
    # Creates a dataframe with columns fold, uid and group (train or test)
    # It is only used to later check if volunteer with uid is in train or test for each fold.
    splits_dfs = list()
    for fold, splits_idx in enumerate(splitter.split(uids)):
        for group_i, group_uid_idx in enumerate(splits_idx):
            split_uids = uids[group_uid_idx]
            split_df = pd.DataFrame()
            split_df["uid"] = split_uids
            split_df["fold"] = fold
            split_df["group"] = group_i
            splits_dfs.append(split_df)
            
    split_df = pd.concat(splits_dfs, axis=0)
    def my_constraint_fn(level, level_idx_dict, kwargs):
        kwargs = {**kwargs}

        if ("group" not in kwargs) or ("uid" not in kwargs):
            return kwargs, True

        r = (
            kwargs["uid"] in 
            split_df[(split_df.group == kwargs["group"]) & (split_df.fold == kwargs["fold"])].uid.to_numpy()
        )
        kwargs.pop("group")
        kwargs.pop("fold")

        return kwargs, r

    return my_constraint_fn


def kfold_split_group_iterator(splitter, uids, n_groups=2):
    
    kfold_iterable = constrained_group_iterator(
        [
            dict(fold=np.arange(splitter.get_n_splits())),
            dict(group=np.arange(n_groups)),
            dict(uid=uids),
        ],
        constraining_function=create_splitter_constraint_fn(splitter, uids)
    )
    for iteration_splits_kwargs in kfold_iterable:
        yield [
            Split(
                [
                    dict(
                        **split_kwargs
                    )
                    for split_kwargs in splits_kwargs_list
                ]
            )
            for splits_kwargs_list in iteration_splits_kwargs
        ]

list(unpack_deep_iterable(kfold_split_group_iterator(KFold(4), np.array([1, 2, 3, 4]))))

[[Split({'uid': 2},{'uid': 3},{'uid': 4}), Split({'uid': 1})],
 [Split({'uid': 1},{'uid': 3},{'uid': 4}), Split({'uid': 2})],
 [Split({'uid': 1},{'uid': 2},{'uid': 4}), Split({'uid': 3})],
 [Split({'uid': 1},{'uid': 2},{'uid': 3}), Split({'uid': 4})]]