In [1]:
import numpy as np

def split_dataset(X, Y, n_splits, split_size, with_replacement=True):
    np.random.seed(0)
    n_samples = X.shape[0]

    if with_replacement:
        splits = [(X[indices], Y[indices]) for indices in [np.random.choice(n_samples, split_size, replace=True) for _ in range(n_splits)]]
    else:
        indices = np.random.permutation(n_samples)
        splits = []
        for i in range(n_splits):
            start_idx = i * split_size
            end_idx = (i + 1) * split_size if (i + 1) * split_size < n_samples else n_samples
            splits.append((X[indices[start_idx:end_idx]], Y[indices[start_idx:end_idx]]))
        # Ensure all remaining samples are included in the final split
        if n_splits * split_size < n_samples:
            splits[-1] = (np.concatenate((splits[-1][0], X[indices[n_splits * split_size:]]), axis=0),
                          np.concatenate((splits[-1][1], Y[indices[n_splits * split_size:]]), axis=0))

    return splits


In [19]:
X = np.random.rand(8,3)
Y = np.random.rand(8)
X,Y

(array([[0.43758721, 0.891773  , 0.96366276],
        [0.38344152, 0.79172504, 0.52889492],
        [0.56804456, 0.92559664, 0.07103606],
        [0.0871293 , 0.0202184 , 0.83261985],
        [0.77815675, 0.87001215, 0.97861834],
        [0.79915856, 0.46147936, 0.78052918],
        [0.11827443, 0.63992102, 0.14335329],
        [0.94466892, 0.52184832, 0.41466194]]),
 array([0.26455561, 0.77423369, 0.45615033, 0.56843395, 0.0187898 ,
        0.6176355 , 0.61209572, 0.616934  ]))

In [20]:
n_splits = 3
splits = split_dataset(X,Y,n_splits,np.ceil(len(X)/n_splits).astype(int),with_replacement=False)

for split in splits:
    print(len(split[0]))

3
3
2


In [21]:
points_per_split = 3
splits = split_dataset(X,Y,np.ceil(len(X)/points_per_split).astype(int),points_per_split,with_replacement=False)

for split in splits:
    print(len(split[0]))

3
3
2
