Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StratifiedShuffleSplit requires three copies of a lower class, rather than 2 #28994

Open
jeremycg opened this issue May 10, 2024 · 2 comments
Open

Comments

@jeremycg
Copy link

jeremycg commented May 10, 2024

Describe the bug

When we want to use StratifiedShuffleSplit to train test split across classes, we would expect we need 2 samples of the lowest represented class: 1 for test, one for train. We don't get this: we need 3 samples of the lowest class

sklearn version 1.2.1

Steps/Code to Reproduce

from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np

#50k ones, two zeros
X = np.ones((50000,2))
y = np.ones((50000,1))
y[0] = 0
y[1] = 0

splitter = StratifiedShuffleSplit(n_splits=1, test_size=max(2, int(0.2*X.shape[0])))
for train, test in splitter.split(X,y):
    train_indices = train
    test_indices = test
    
X_train, X_test, y_train, y_test = X[train_indices,:], X[test_indices,:], y[train_indices],  y[test_indices]
np.unique(y_train), np.unique(y_test)
#(array([0., 1.]), array([1.]))
#why no 1s in test?

#same thing, but 3 0s
X = np.ones((50000,2))
y = np.ones((50000,1))
y[0] = 0
y[1] = 0
y[2] = 0

splitter = StratifiedShuffleSplit(n_splits=1, test_size=max(2, int(0.2*X.shape[0])))
for train, test in splitter.split(X,y):
    train_indices = train
    test_indices = test
    
X_train, X_test, y_train, y_test = X[train_indices,:], X[test_indices,:], y[train_indices],  y[test_indices]
np.unique(y_train), np.unique(y_test)
#(array([0., 1.]), array([0., 1.]))
#as expected!

Expected Results

We expect to get a test set and a train set that both contain 1 example of each class when we have 2 representatives.

(array([0., 1.]), array([0., 1.]))

Actual Results

(array([0., 1.]), array([1.]))

Versions

System:
    python: 3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ]
executable: bin/python
   machine: macOS-14.2.1-arm64-arm-64bit

Python dependencies:
      sklearn: 1.2.1
          pip: 23.3.1
   setuptools: 68.2.2
        numpy: 1.26.4
        scipy: 1.10.0
       Cython: 3.0.0
       pandas: 2.2.2
   matplotlib: 3.7.0
       joblib: 1.3.2
threadpoolctl: 3.2.0

Built with OpenMP: True

threadpoolctl info:
       user_api: openmp
   internal_api: openmp
    num_threads: 12
         prefix: libomp
       filepath: lib/python3.11/site-packages/sklearn/.dylibs/libomp.dylib
        version: None

       user_api: blas
   internal_api: openblas
    num_threads: 12
         prefix: libopenblas
       filepath: lib/python3.11/site-packages/numpy/.dylibs/libopenblas64_.0.dylib
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: armv8

       user_api: blas
   internal_api: openblas
    num_threads: 12
         prefix: libopenblas
       filepath: lib/python3.11/site-packages/scipy/.dylibs/libopenblas.0.dylib
        version: 0.3.18
threading_layer: pthreads
   architecture: armv8
@jeremycg jeremycg added Bug Needs Triage Issue requires triage labels May 10, 2024
@glemaitre glemaitre added Documentation and removed Bug Needs Triage Issue requires triage labels May 15, 2024
@glemaitre
Copy link
Member

I think this is not a bug but rather a known implementation detail: looking at the code, we use _approximate_mode that is know to be an approximate estimate that can be off by 1 (the value that you observed). I assume that this approximation is done for some computation reasons.

However, since the behaviour is surprising, I think that we could document it in a "Note" section to mention this corner case.

@jeremycg
Copy link
Author

Fair enough, the code in the function:

        if n_train < n_classes:
            raise ValueError(
                "The train_size = %d should be greater or "
                "equal to the number of classes = %d" % (n_train, n_classes)
            )
        if n_test < n_classes:
            raise ValueError(
                "The test_size = %d should be greater or "
                "equal to the number of classes = %d" % (n_test, n_classes)
            )

does a check I'd expect to catch the issue here: I'd suggest that should be modified too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants