Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KFold(n_samples=n) not equivalent to LeaveOneOut() cv in CalibratedClassifierCV() #29000

Open
ethanresnick opened this issue May 11, 2024 · 4 comments
Labels

Comments

@ethanresnick
Copy link

ethanresnick commented May 11, 2024

Describe the bug

Calling CalibratedClassifierCV() with cv=KFold(n_samples=n) (where n is the number of samples) can give different results than using cv=LeaveOneOut(), but the docs for LeaveOneOut() say these should be equivalent.

In particular, the KFold class has an "n_splits" attribute, which means this branch runs when setting up sigmoid calibration, and then this error can be thrown. With LeaveOneOut(), n_folds is set to None and that error is never hit.

I'm not sure whether that error is correct/desirable in every case (see the code to reproduce for my use case where I think(?) the error may be unnecessary) but, either way, the two different cv values seem like they should behave equivalently.

Steps/Code to Reproduce

from sklearn.pipeline import make_pipeline
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import KFold, LeaveOneOut
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=20, random_state=42)

pipeline = make_pipeline(
    StandardScaler(),
    CalibratedClassifierCV(
        SVC(probability=False),
        ensemble=False,
        cv=LeaveOneOut()
    )
)
pipeline.fit(X, y)

pipeline2 = make_pipeline(
    StandardScaler(),
    CalibratedClassifierCV(
        SVC(probability=False),
        ensemble=False,
        cv=KFold(n_splits=20, shuffle=True)
    )
)
pipeline2.fit(X, y)

Expected Results

pipeline and pipeline2 should function identically. Instead, pipeline.fit() succeeds and pipeline2.fit() throws.

Actual Results

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/python3.11/site-packages/sklearn/base.py", line 1152, in wrapper
    return fit_method(estimator, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/python3.11/site-packages/sklearn/pipeline.py", line 427, in fit
    self._final_estimator.fit(Xt, y, **fit_params_last_step)
  File "/python3.11/site-packages/sklearn/base.py", line 1152, in wrapper
    return fit_method(estimator, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/python3.11/site-packages/sklearn/calibration.py", line 419, in fit
    raise ValueError(
ValueError: Requesting 20-fold cross-validation but provided less than 20 examples for at least one class.

Versions

System:
    python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:34:54) [Clang 16.0.6 ]
   machine: macOS-14.4.1-arm64-arm-64bit

Python dependencies:
      sklearn: 1.3.2
          pip: 24.0
   setuptools: 69.0.2
        numpy: 1.26.2
        scipy: 1.11.4
       Cython: None
       pandas: 2.1.3
   matplotlib: 3.8.2
       joblib: 1.3.2
threadpoolctl: 3.2.0

Built with OpenMP: True

threadpoolctl info:
       user_api: openmp
   internal_api: openmp
    num_threads: 12
         prefix: libomp
        version: None

       user_api: blas
   internal_api: openblas
    num_threads: 12
         prefix: libopenblas
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: armv8

       user_api: blas
   internal_api: openblas
    num_threads: 12
         prefix: libopenblas
        version: 0.3.21.dev
threading_layer: pthreads
   architecture: armv8
@ethanresnick ethanresnick added Bug Needs Triage Issue requires triage labels May 11, 2024
@glemaitre
Copy link
Member

The error in KFold is actually expected. We expect to have at least a sample from each class in each fold. This cannot be achieved with the LeaveOneOut cross-validation. So we should not accept this strategy.

So we could raise early an error for this strategy. However, I can also see some other strategy leading to having a single class present when fitting the calibrator. I assume that it should be safer to raise an error in this case as well otherwise we get a ill-fitted calibrator anyway.

ping @lucyleeow @ogrisel that might have more insight on this part of the calibrator and to know their opinions.

@glemaitre glemaitre removed the Needs Triage Issue requires triage label May 16, 2024
@ogrisel
Copy link
Member

ogrisel commented May 16, 2024

I think i agree on both accounts but did not check the details in the code yet.

@ethanresnick
Copy link
Author

ethanresnick commented May 16, 2024

The error in KFold is actually expected. We expect to have at least a sample from each class in each fold.

Isn't it the case that KFold also doesn't guarantee one sample from each class in each fold (since it doesn't create stratified folds)?

However, I can also see some other strategy leading to having a single class present when fitting the calibrator.

Yeah, exactly. There are lots of ways to end up with poorly-fit calibrators, and I'm not sure the code's current check (even when it does apply) really covers that.

@kyrajeep
Copy link

kyrajeep commented May 28, 2024

LeaveOneOut does not have different groups like k-folds cv (https://www.cs.cmu.edu/~schneide/tut5/node42.html). More accurately, it sets each sample as a 'fold.' It trains on (n-1) training data at a time (where train data size = n) making it computationally expensive but very reliable. K-folds, on the other hand, divides the training data into k groups and trains the model k times, leaving one group at a time. Perhaps this clarification was not the main issue, but I thought it might be helpful :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants