From 603a8b6a8f0f3c342bbea08c87d5544164c8bb29 Mon Sep 17 00:00:00 2001 From: HideakiImamura Date: Mon, 22 May 2023 23:57:14 +0900 Subject: [PATCH 1/2] Add tests and fix some bugs --- .../visualization/_terminator_improvement.py | 9 +- .../test_terminator_improvement.py | 124 ++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 tests/visualization_tests/test_terminator_improvement.py diff --git a/optuna/visualization/_terminator_improvement.py b/optuna/visualization/_terminator_improvement.py index e3cbaf1ea0..76b9013ca7 100644 --- a/optuna/visualization/_terminator_improvement.py +++ b/optuna/visualization/_terminator_improvement.py @@ -52,6 +52,9 @@ def _get_improvement_info( improvement_evaluator: BaseImprovementEvaluator | None = None, error_evaluator: BaseErrorEvaluator | None = None, ) -> _ImprovementInfo: + if study._is_multi_objective(): + raise ValueError("This function does not support multi-objective optimization `study`.") + if improvement_evaluator is None: improvement_evaluator = RegretBoundEvaluator() if error_evaluator is None: @@ -63,10 +66,14 @@ def _get_improvement_info( errors = [] for trial in tqdm.tqdm(study.trials): - trial_numbers.append(trial.number) if trial.state == optuna.trial.TrialState.COMPLETE: completed_trials.append(trial) + if len(completed_trials) == 0: + continue + + trial_numbers.append(trial.number) + improvement = improvement_evaluator.evaluate( trials=completed_trials, study_direction=study.direction ) diff --git a/tests/visualization_tests/test_terminator_improvement.py b/tests/visualization_tests/test_terminator_improvement.py new file mode 100644 index 0000000000..293e58242b --- /dev/null +++ b/tests/visualization_tests/test_terminator_improvement.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from io import BytesIO +from typing import Any +from typing import Callable + +import pytest + +from optuna.distributions import FloatDistribution +from optuna.study import create_study +from optuna.study import Study +from optuna.terminator import BaseErrorEvaluator +from optuna.terminator import BaseImprovementEvaluator +from optuna.terminator import CrossValidationErrorEvaluator +from optuna.terminator import RegretBoundEvaluator +from optuna.terminator import report_cross_validation_scores +from optuna.terminator import StaticErrorEvaluator +from optuna.testing.objectives import fail_objective +from optuna.testing.visualization import prepare_study_with_trials +from optuna.visualization import plot_terminator_improvement as plotly_plot_terminator_improvement +from optuna.visualization._terminator_improvement import _get_improvement_info +from optuna.visualization._terminator_improvement import _get_y_range +from optuna.visualization._terminator_improvement import _ImprovementInfo + + +parametrize_plot_terminator_improvement = pytest.mark.parametrize( + "plot_terminator_improvement", [plotly_plot_terminator_improvement] +) + + +def _create_study_with_failed_trial() -> Study: + study = create_study() + study.optimize(fail_objective, n_trials=1, catch=(ValueError,)) + return study + + +def test_study_is_multi_obj() -> None: + study = create_study(directions=["minimize", "minimize"]) + with pytest.raises(ValueError): + _get_improvement_info(study=study) + + +@parametrize_plot_terminator_improvement +@pytest.mark.parametrize( + "specific_create_study", + [ + create_study, + _create_study_with_failed_trial, + prepare_study_with_trials, + ], +) +@pytest.mark.parametrize("plot_error", [False, True]) +def test_plot_terminator_improvement( + plot_terminator_improvement: Callable[..., Any], + specific_create_study: Callable[[], Study], + plot_error: bool, +) -> None: + study = specific_create_study() + figure = plot_terminator_improvement(study) + figure.write_image(BytesIO()) + + +@pytest.mark.parametrize( + "specific_create_study", + [create_study, _create_study_with_failed_trial], +) +@pytest.mark.parametrize("plot_error", [False, True]) +def test_get_terminator_improvement_info_empty( + specific_create_study: Callable[[], Study], plot_error: bool +) -> None: + study = specific_create_study() + info = _get_improvement_info(study, plot_error) + assert info == _ImprovementInfo(trial_numbers=[], improvements=[], errors=None) + + +@pytest.mark.parametrize("get_error", [False, True]) +@pytest.mark.parametrize("improvement_evaluator_class", [lambda: RegretBoundEvaluator(), None]) +@pytest.mark.parametrize( + "error_evaluator_class", + [ + lambda: CrossValidationErrorEvaluator(), + lambda: StaticErrorEvaluator(0), + None, + ], +) +def test_get_improvement_info( + get_error: bool, + improvement_evaluator_class: Callable[[], BaseImprovementEvaluator] | None, + error_evaluator_class: Callable[[], BaseErrorEvaluator] | None, +) -> None: + study = create_study() + for _ in range(3): + trial = study.ask({"x": FloatDistribution(0, 1)}) + report_cross_validation_scores(trial, [1.0, 2.0]) + study.tell(trial, 0) + + improvement_evaluator = ( + None if improvement_evaluator_class is None else improvement_evaluator_class() + ) + error_evaluator = None if error_evaluator_class is None else error_evaluator_class() + info = _get_improvement_info(study, get_error, improvement_evaluator, error_evaluator) + assert info.trial_numbers == [0, 1, 2] + assert len(info.improvements) == 3 + if get_error: + assert info.errors is not None + assert len(info.errors) == 3 + assert info.errors[0] == info.errors[1] == info.errors[2] + else: + assert info.errors is None + + +@pytest.mark.parametrize( + "info", + [ + _ImprovementInfo(trial_numbers=[0], improvements=[0], errors=None), + _ImprovementInfo(trial_numbers=[0], improvements=[0], errors=[0]), + _ImprovementInfo(trial_numbers=[0, 1], improvements=[0, 1], errors=[0, 1]), + ], +) +@pytest.mark.parametrize("min_n_trials", [1, 2]) +def test_get_y_range(info: _ImprovementInfo, min_n_trials: int) -> None: + y_range = _get_y_range(info, min_n_trials) + assert len(y_range) == 2 + assert y_range[0] <= y_range[1] From 13f223549ef74274bb88efb52a6e4de5434099a2 Mon Sep 17 00:00:00 2001 From: HideakiImamura Date: Wed, 24 May 2023 11:33:58 +0900 Subject: [PATCH 2/2] Follow review comments --- .../visualization/_terminator_improvement.py | 2 +- .../test_terminator_improvement.py | 61 +++++++++++++------ 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/optuna/visualization/_terminator_improvement.py b/optuna/visualization/_terminator_improvement.py index 76b9013ca7..f9b6edfa18 100644 --- a/optuna/visualization/_terminator_improvement.py +++ b/optuna/visualization/_terminator_improvement.py @@ -53,7 +53,7 @@ def _get_improvement_info( error_evaluator: BaseErrorEvaluator | None = None, ) -> _ImprovementInfo: if study._is_multi_objective(): - raise ValueError("This function does not support multi-objective optimization `study`.") + raise ValueError("This function does not support multi-objective optimization study.") if improvement_evaluator is None: improvement_evaluator = RegretBoundEvaluator() diff --git a/tests/visualization_tests/test_terminator_improvement.py b/tests/visualization_tests/test_terminator_improvement.py index 293e58242b..25d62ba07f 100644 --- a/tests/visualization_tests/test_terminator_improvement.py +++ b/tests/visualization_tests/test_terminator_improvement.py @@ -17,6 +17,8 @@ from optuna.terminator import StaticErrorEvaluator from optuna.testing.objectives import fail_objective from optuna.testing.visualization import prepare_study_with_trials +from optuna.trial import create_trial +from optuna.trial import TrialState from optuna.visualization import plot_terminator_improvement as plotly_plot_terminator_improvement from optuna.visualization._terminator_improvement import _get_improvement_info from optuna.visualization._terminator_improvement import _get_y_range @@ -34,7 +36,16 @@ def _create_study_with_failed_trial() -> Study: return study -def test_study_is_multi_obj() -> None: +def _prepare_study_with_cross_validation_scores() -> Study: + study = create_study() + for _ in range(3): + trial = study.ask({"x": FloatDistribution(0, 1)}) + report_cross_validation_scores(trial, [1.0, 2.0]) + study.tell(trial, 0) + return study + + +def test_study_is_multi_objective() -> None: study = create_study(directions=["minimize", "minimize"]) with pytest.raises(ValueError): _get_improvement_info(study=study) @@ -42,21 +53,22 @@ def test_study_is_multi_obj() -> None: @parametrize_plot_terminator_improvement @pytest.mark.parametrize( - "specific_create_study", + "specific_create_study, plot_error", [ - create_study, - _create_study_with_failed_trial, - prepare_study_with_trials, + (create_study, False), + (_create_study_with_failed_trial, False), + (prepare_study_with_trials, False), + (_prepare_study_with_cross_validation_scores, False), + (_prepare_study_with_cross_validation_scores, True), ], ) -@pytest.mark.parametrize("plot_error", [False, True]) def test_plot_terminator_improvement( plot_terminator_improvement: Callable[..., Any], specific_create_study: Callable[[], Study], plot_error: bool, ) -> None: study = specific_create_study() - figure = plot_terminator_improvement(study) + figure = plot_terminator_improvement(study, plot_error) figure.write_image(BytesIO()) @@ -74,31 +86,27 @@ def test_get_terminator_improvement_info_empty( @pytest.mark.parametrize("get_error", [False, True]) -@pytest.mark.parametrize("improvement_evaluator_class", [lambda: RegretBoundEvaluator(), None]) +@pytest.mark.parametrize( + "improvement_evaluator_class", [lambda: RegretBoundEvaluator(), lambda: None] +) @pytest.mark.parametrize( "error_evaluator_class", [ lambda: CrossValidationErrorEvaluator(), lambda: StaticErrorEvaluator(0), - None, + lambda: None, ], ) def test_get_improvement_info( get_error: bool, - improvement_evaluator_class: Callable[[], BaseImprovementEvaluator] | None, - error_evaluator_class: Callable[[], BaseErrorEvaluator] | None, + improvement_evaluator_class: Callable[[], BaseImprovementEvaluator | None], + error_evaluator_class: Callable[[], BaseErrorEvaluator | None], ) -> None: - study = create_study() - for _ in range(3): - trial = study.ask({"x": FloatDistribution(0, 1)}) - report_cross_validation_scores(trial, [1.0, 2.0]) - study.tell(trial, 0) + study = _prepare_study_with_cross_validation_scores() - improvement_evaluator = ( - None if improvement_evaluator_class is None else improvement_evaluator_class() + info = _get_improvement_info( + study, get_error, improvement_evaluator_class(), error_evaluator_class() ) - error_evaluator = None if error_evaluator_class is None else error_evaluator_class() - info = _get_improvement_info(study, get_error, improvement_evaluator, error_evaluator) assert info.trial_numbers == [0, 1, 2] assert len(info.improvements) == 3 if get_error: @@ -109,6 +117,19 @@ def test_get_improvement_info( assert info.errors is None +def test_get_improvement_info_started_with_failed_trials() -> None: + study = create_study() + for _ in range(3): + study.add_trial(create_trial(state=TrialState.FAIL)) + trial = study.ask({"x": FloatDistribution(0, 1)}) + study.tell(trial, 0) + + info = _get_improvement_info(study) + assert info.trial_numbers == [3] + assert len(info.improvements) == 1 + assert info.errors is None + + @pytest.mark.parametrize( "info", [