diff --git a/optuna/integration/sklearn.py b/optuna/integration/sklearn.py index f5a4b7a329..83a0d17db3 100644 --- a/optuna/integration/sklearn.py +++ b/optuna/integration/sklearn.py @@ -24,6 +24,7 @@ from optuna._imports import try_import from optuna.distributions import _convert_old_distribution_to_new_distribution from optuna.study import StudyDirection +from optuna.terminator import report_cross_validation_scores from optuna.trial import FrozenTrial from optuna.trial import Trial @@ -243,6 +244,10 @@ def __call__(self, trial: Trial) -> float: self._store_scores(trial, scores) + test_scores = scores["test_score"] + scores_list = test_scores if isinstance(test_scores, list) else test_scores.tolist() + report_cross_validation_scores(trial, scores_list) + return trial.user_attrs["mean_test_score"] def _cross_validate_with_pruning( diff --git a/tests/integration_tests/test_sklearn.py b/tests/integration_tests/test_sklearn.py index def1a222cd..d52bf89803 100644 --- a/tests/integration_tests/test_sklearn.py +++ b/tests/integration_tests/test_sklearn.py @@ -1,6 +1,7 @@ from __future__ import annotations from unittest.mock import MagicMock +from unittest.mock import patch import warnings import numpy as np @@ -20,6 +21,7 @@ from optuna import integration from optuna.samplers import BruteForceSampler from optuna.study import create_study +from optuna.terminator.erroreval import _CROSS_VALIDATION_SCORES_KEY pytestmark = pytest.mark.integration @@ -409,3 +411,22 @@ def test_callbacks() -> None: for trial in optuna_search.trials_: callback.assert_any_call(optuna_search.study_, trial) assert callback.call_count == n_trials + + +@pytest.mark.filterwarnings("ignore::UserWarning") +@patch("optuna.integration.sklearn.cross_validate") +def test_terminator_cv_score_reporting(mock: MagicMock) -> None: + scores = { + "fit_time": np.array([2.01, 1.78, 3.22]), + "score_time": np.array([0.33, 0.35, 0.48]), + "test_score": np.array([0.04, 0.80, 0.70]), + } + mock.return_value = scores + + X, _ = make_blobs(n_samples=10) + est = PCA() + optuna_search = integration.OptunaSearchCV(est, {}, cv=3, error_score="raise", random_state=0) + optuna_search.fit(X) + + for trial in optuna_search.study_.trials: + assert (trial.system_attrs[_CROSS_VALIDATION_SCORES_KEY] == scores["test_score"]).all()