In [1]:
import numpy as np
from op_op.grouptype_crossval import GroupTypeKFold

## Group Type K-Fold Cross Validation

- Groups ~ Trials
- Group Types ~ Trial types.

- Goal: Split data such that
    - Data from each group is present only in one of the training or test sets.
    - Data from each group type is present in equal proportions in the training and test sets.

In [2]:
def get_proportion_group_type(group_types, train_idx, test_idx):
    """Utility function, ignore"""
    _, train_group_type_counts = np.unique(group_types[train_idx], return_counts=True)
    _, test_group_type_counts = np.unique(group_types[test_idx], return_counts=True)

    train_group_type_props = train_group_type_counts / len(train_idx)
    test_group_type_props = test_group_type_counts / len(test_idx)

    return train_group_type_props, test_group_type_props

In [3]:
# simulate data
X = np.random.rand(500, 2)
groups = np.random.randint(0, 10, 500)
group_types = np.random.choice(["A", "B", "C"], 500)

# define cross-validation
n_splits = 5
gt_cv = GroupTypeKFold(n_splits=n_splits)

# run cross-validation
fold_results = []
for fold, (train_idx, test_idx) in enumerate(
    gt_cv.split(X, groups=groups, group_types=group_types)
):
    train_group_type_props, test_group_type_props = get_proportion_group_type(
        group_types, train_idx, test_idx
    )

    fold_results.append(
        {
            "fold": fold,
            "n_train": len(train_idx),
            "n_test": len(test_idx),
            "train_groups": np.unique(groups[train_idx]),
            "test_groups": np.unique(groups[test_idx]),
            "test_group_types": np.unique(group_types[test_idx]),
            "train_group_types": np.unique(group_types[train_idx]),
            "train_group_type_props": train_group_type_props,
            "test_group_type_props": test_group_type_props,
        }
    )


fold_results

[{'fold': 0,
  'n_train': 412,
  'n_test': 88,
  'train_groups': array([1, 2, 3, 4, 6, 7, 8, 9]),
  'test_groups': array([0, 5]),
  'test_group_types': array(['A', 'B', 'C'], dtype='<U1'),
  'train_group_types': array(['A', 'B', 'C'], dtype='<U1'),
  'train_group_type_props': array([0.38349515, 0.30097087, 0.31553398]),
  'test_group_type_props': array([0.26136364, 0.36363636, 0.375     ])},
 {'fold': 1,
  'n_train': 401,
  'n_test': 99,
  'train_groups': array([0, 2, 3, 4, 5, 7, 8, 9]),
  'test_groups': array([1, 6]),
  'test_group_types': array(['A', 'B', 'C'], dtype='<U1'),
  'train_group_types': array(['A', 'B', 'C'], dtype='<U1'),
  'train_group_type_props': array([0.33665835, 0.32917706, 0.33416459]),
  'test_group_type_props': array([0.46464646, 0.24242424, 0.29292929])},
 {'fold': 2,
  'n_train': 407,
  'n_test': 93,
  'train_groups': array([0, 1, 3, 4, 5, 6, 8, 9]),
  'test_groups': array([2, 7]),
  'test_group_types': array(['A', 'B', 'C'], dtype='<U1'),
  'train_group_types'