In [8]:
import torch

def k_fold_cross_validation(n_samples:int, k: int = 5, shuffle:bool = True) -> list:
    # Create Indices 
    indices = torch.arange(n_samples)

    # → torch.arange(10) = tensor([0, 1, 2, ..., 9])

    # Shuffle
    if shuffle:
        indices = indices[torch.randperm(n_samples)]
        # → torch.randperm gives random permutation of indices

    # Calculate fold sozes
    base_size = n_samples// k
    remainder = n_samples % k

    #Create folds
    folds = []
    start_idx =0

    for i in range(k):
        fold_size = base_size + (1 if i <remainder else 0)
        end_idx = start_idx + fold_size
        folds.append(indices[start_idx:end_idx])
        start_idx = end_idx

    # Generate splits
    splits = []
    for i in range(k):
        test_indices = folds[i]
        train_indices = torch.cat([folds[j] for j in range(k) if j != i])

        splits.append((train_indices.tolist(), test_indices.tolist()))

In [15]:
import numpy as np
np.random.seed(42)
print(k_fold_cross_validation(n_samples=10, k=5, shuffle=False))

[([2, 3, 4, 5, 6, 7, 8, 9], [0, 1]), ([0, 1, 4, 5, 6, 7, 8, 9], [2, 3]), ([0, 1, 2, 3, 6, 7, 8, 9], [4, 5]), ([0, 1, 2, 3, 4, 5, 8, 9], [6, 7]), ([0, 1, 2, 3, 4, 5, 6, 7], [8, 9])]


In [14]:
import numpy as np

def k_fold_cross_validation(n_samples: int, k: int = 5, shuffle: bool = True) -> list:
    """Generate K-Fold cross-validation train/test splits"""
    
    # Create indices
    indices = np.arange(n_samples)
    
    # Shuffle if requested
    if shuffle:
        np.random.shuffle(indices)
    
    # Calculate fold sizes
    base_size = n_samples // k
    remainder = n_samples % k
    
    # Create folds
    folds = []
    start_idx = 0
    
    for i in range(k):
        fold_size = base_size + (1 if i < remainder else 0)
        end_idx = start_idx + fold_size
        folds.append(indices[start_idx:end_idx])
        start_idx = end_idx
    
    # Generate train/test splits
    splits = []
    
    for i in range(k):
        test_indices = folds[i]
        train_indices = np.concatenate([folds[j] for j in range(k) if j != i])
        splits.append((train_indices.tolist(), test_indices.tolist()))
    
    return splits