Skip to content

Commit

Permalink
Merge pull request optuna#5098 from adjeiv/terminator_callback
Browse files Browse the repository at this point in the history
Report CV scores from within OptunaSearchCV
  • Loading branch information
Alnusjaponica committed Nov 9, 2023
2 parents 5252052 + b4dd960 commit f966055
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
5 changes: 5 additions & 0 deletions optuna/integration/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from optuna._imports import try_import
from optuna.distributions import _convert_old_distribution_to_new_distribution
from optuna.study import StudyDirection
from optuna.terminator import report_cross_validation_scores
from optuna.trial import FrozenTrial
from optuna.trial import Trial

Expand Down Expand Up @@ -243,6 +244,10 @@ def __call__(self, trial: Trial) -> float:

self._store_scores(trial, scores)

test_scores = scores["test_score"]
scores_list = test_scores if isinstance(test_scores, list) else test_scores.tolist()
report_cross_validation_scores(trial, scores_list)

return trial.user_attrs["mean_test_score"]

def _cross_validate_with_pruning(
Expand Down
21 changes: 21 additions & 0 deletions tests/integration_tests/test_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from unittest.mock import MagicMock
from unittest.mock import patch
import warnings

import numpy as np
Expand All @@ -20,6 +21,7 @@
from optuna import integration
from optuna.samplers import BruteForceSampler
from optuna.study import create_study
from optuna.terminator.erroreval import _CROSS_VALIDATION_SCORES_KEY


pytestmark = pytest.mark.integration
Expand Down Expand Up @@ -409,3 +411,22 @@ def test_callbacks() -> None:
for trial in optuna_search.trials_:
callback.assert_any_call(optuna_search.study_, trial)
assert callback.call_count == n_trials


@pytest.mark.filterwarnings("ignore::UserWarning")
@patch("optuna.integration.sklearn.cross_validate")
def test_terminator_cv_score_reporting(mock: MagicMock) -> None:
scores = {
"fit_time": np.array([2.01, 1.78, 3.22]),
"score_time": np.array([0.33, 0.35, 0.48]),
"test_score": np.array([0.04, 0.80, 0.70]),
}
mock.return_value = scores

X, _ = make_blobs(n_samples=10)
est = PCA()
optuna_search = integration.OptunaSearchCV(est, {}, cv=3, error_score="raise", random_state=0)
optuna_search.fit(X)

for trial in optuna_search.study_.trials:
assert (trial.system_attrs[_CROSS_VALIDATION_SCORES_KEY] == scores["test_score"]).all()

0 comments on commit f966055

Please sign in to comment.