In [6]:
from sklearn.model_selection import StratifiedGroupKFold
import numpy as np

# Extended Example Data
X = np.array([[1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16]])
y = np.array([0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 2, 2])  # Multiclass target (0, 1, 2)
groups = np.array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8])  # Group IDs

# Create the StratifiedGroupKFold object
sgkf = StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=42)

# Perform the split
for fold, (train_idx, test_idx) in enumerate(sgkf.split(X, y, groups)):
    print(f"Fold {fold + 1}")
    print("Train indices:", train_idx, "Test indices:", test_idx)
    print(f"Train classes distribution:{np.bincount(y[train_idx])}    values {y[train_idx]}")
    print(f"Test classes distribution:{np.bincount(y[test_idx ])}    values {y[test_idx]}")
    print("Groups in test set:", np.unique(groups[test_idx]))
    print("-" * 50)

Fold 1
Train indices: [ 2  3  4  5  8  9 12 13 14 15] Test indices: [ 0  1  6  7 10 11]
Train classes distribution:[4 2 4]    values [1 1 0 0 0 0 2 2 2 2]
Test classes distribution:[2 4]    values [0 0 1 1 1 1]
Groups in test set: [1 4 6]
--------------------------------------------------
Fold 2
Train indices: [ 0  1  4  5  6  7 10 11 12 13] Test indices: [ 2  3  8  9 14 15]
Train classes distribution:[4 4 2]    values [0 0 0 0 1 1 1 1 2 2]
Test classes distribution:[2 2 2]    values [1 1 0 0 2 2]
Groups in test set: [2 5 8]
--------------------------------------------------
Fold 3
Train indices: [ 0  1  2  3  6  7  8  9 10 11 14 15] Test indices: [ 4  5 12 13]
Train classes distribution:[4 6 2]    values [0 0 1 1 1 1 0 0 1 1 2 2]
Test classes distribution:[2 0 2]    values [0 0 2 2]
Groups in test set: [3 7]
--------------------------------------------------


In [7]:
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold

# Example data
data = {
    "feature1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    "target": [1.1, 1.2, 1.8, 2.0, 2.5, 2.7, 3.0, 3.3, 3.8, 4.0, 4.5, 5.0]
}
df = pd.DataFrame(data)

# Discretize the target using pd.cut
n_bins = 3  # Number of bins for stratification
df['binned_target'] = pd.cut(df.target, bins=n_bins, labels=False)

# Create StratifiedKFold object
n_splits = 3
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Add a column for fold assignment
df['Fold'] = -1

# Perform stratified splitting
for fold_no, (train_idx, test_idx) in enumerate(skf.split(df, df['binned_target'])):
    # Assign the fold number to the test indices
    df.loc[test_idx, 'Fold'] = fold_no

    # Display details for each fold
    print(f"Fold {fold_no + 1}:")
    print(f"  Train indices: {train_idx.tolist()}")
    print(f"  Test indices: {test_idx.tolist()}")
    print(f"  Train binned target distribution: {np.bincount(df['binned_target'][train_idx])}")
    print(f"  Test binned target distribution: {np.bincount(df['binned_target'][test_idx])}")
    print("-" * 50)

# Print the final DataFrame
print("\nFinal DataFrame with Fold Assignments:")
print(df)

Fold 1:
  Train indices: [1, 3, 4, 5, 7, 8, 10, 11]
  Test indices: [0, 2, 6, 9]
  Train binned target distribution: [2 3 3]
  Test binned target distribution: [2 1 1]
--------------------------------------------------
Fold 2:
  Train indices: [0, 1, 2, 5, 6, 8, 9, 11]
  Test indices: [3, 4, 7, 10]
  Train binned target distribution: [3 2 3]
  Test binned target distribution: [1 2 1]
--------------------------------------------------
Fold 3:
  Train indices: [0, 2, 3, 4, 6, 7, 9, 10]
  Test indices: [1, 5, 8, 11]
  Train binned target distribution: [3 3 2]
  Test binned target distribution: [1 1 2]
--------------------------------------------------

Final DataFrame with Fold Assignments:
    feature1  target  binned_target  Fold
0          1     1.1              0     0
1          2     1.2              0     2
2          3     1.8              0     0
3          4     2.0              0     1
4          5     2.5              1     1
5          6     2.7              1     2
6        