In [2]:
import numpy as np
from sklearn.model_selection import train_test_split as sklearn_train_test_split
from model_selection.train_test_split import train_test_split  # Adjust import path as needed

def test_train_test_split():
    # Generate some data to split
    X = np.arange(100).reshape(50, 2)  # 50 samples, 2 features each
    y = np.tile(np.array([0, 1]), 25)  # 50 targets alternating between 0 and 1

    # Custom train_test_split with stratification
    X_train_custom, X_test_custom, y_train_custom, y_test_custom = train_test_split(
        X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y)

    # Sklearn train_test_split with stratification
    X_train_sklearn, X_test_sklearn, y_train_sklearn, y_test_sklearn = sklearn_train_test_split(
        X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y)

    # Assert the shapes are identical
    assert X_train_custom.shape == X_train_sklearn.shape, "Mismatch in train set size"
    assert X_test_custom.shape == X_test_sklearn.shape, "Mismatch in test set size"

    # Assert the distribution of classes is similar
    assert np.all(np.bincount(y_train_custom) == np.bincount(y_train_sklearn)), "Mismatch in class distribution in train sets"
    assert np.all(np.bincount(y_test_custom) == np.bincount(y_test_sklearn)), "Mismatch in class distribution in test sets"
    print(X_train_custom)
    print(X_test_sklearn)
    print("Custom train_test_split with stratification passed all tests successfully!")

test_train_test_split()


[[32 33]
 [64 65]
 [ 0  1]
 [92 93]
 [44 45]
 [36 37]
 [52 53]
 [ 4  5]
 [88 89]
 [20 21]
 [ 8  9]
 [48 49]
 [60 61]
 [12 13]
 [16 17]
 [80 81]
 [68 69]
 [84 85]
 [72 73]
 [96 97]
 [54 55]
 [42 43]
 [22 23]
 [30 31]
 [70 71]
 [ 6  7]
 [82 83]
 [ 2  3]
 [94 95]
 [78 79]
 [50 51]
 [18 19]
 [10 11]
 [34 35]
 [14 15]
 [26 27]
 [90 91]
 [74 75]
 [58 59]
 [62 63]]
[[40 41]
 [24 25]
 [56 57]
 [76 77]
 [38 39]
 [98 99]
 [66 67]
 [46 47]
 [28 29]
 [86 87]]
Custom train_test_split with stratification passed all tests successfully!
