Skip to content

Commit

Permalink
Follow review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
HideakiImamura committed May 24, 2023
1 parent 603a8b6 commit 13f2235
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 21 deletions.
2 changes: 1 addition & 1 deletion optuna/visualization/_terminator_improvement.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_improvement_info(
error_evaluator: BaseErrorEvaluator | None = None,
) -> _ImprovementInfo:
if study._is_multi_objective():
raise ValueError("This function does not support multi-objective optimization `study`.")
raise ValueError("This function does not support multi-objective optimization study.")

if improvement_evaluator is None:
improvement_evaluator = RegretBoundEvaluator()
Expand Down
61 changes: 41 additions & 20 deletions tests/visualization_tests/test_terminator_improvement.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from optuna.terminator import StaticErrorEvaluator
from optuna.testing.objectives import fail_objective
from optuna.testing.visualization import prepare_study_with_trials
from optuna.trial import create_trial
from optuna.trial import TrialState
from optuna.visualization import plot_terminator_improvement as plotly_plot_terminator_improvement
from optuna.visualization._terminator_improvement import _get_improvement_info
from optuna.visualization._terminator_improvement import _get_y_range
Expand All @@ -34,29 +36,39 @@ def _create_study_with_failed_trial() -> Study:
return study


def test_study_is_multi_obj() -> None:
def _prepare_study_with_cross_validation_scores() -> Study:
study = create_study()
for _ in range(3):
trial = study.ask({"x": FloatDistribution(0, 1)})
report_cross_validation_scores(trial, [1.0, 2.0])
study.tell(trial, 0)
return study


def test_study_is_multi_objective() -> None:
study = create_study(directions=["minimize", "minimize"])
with pytest.raises(ValueError):
_get_improvement_info(study=study)


@parametrize_plot_terminator_improvement
@pytest.mark.parametrize(
"specific_create_study",
"specific_create_study, plot_error",
[
create_study,
_create_study_with_failed_trial,
prepare_study_with_trials,
(create_study, False),
(_create_study_with_failed_trial, False),
(prepare_study_with_trials, False),
(_prepare_study_with_cross_validation_scores, False),
(_prepare_study_with_cross_validation_scores, True),
],
)
@pytest.mark.parametrize("plot_error", [False, True])
def test_plot_terminator_improvement(
plot_terminator_improvement: Callable[..., Any],
specific_create_study: Callable[[], Study],
plot_error: bool,
) -> None:
study = specific_create_study()
figure = plot_terminator_improvement(study)
figure = plot_terminator_improvement(study, plot_error)
figure.write_image(BytesIO())


Expand All @@ -74,31 +86,27 @@ def test_get_terminator_improvement_info_empty(


@pytest.mark.parametrize("get_error", [False, True])
@pytest.mark.parametrize("improvement_evaluator_class", [lambda: RegretBoundEvaluator(), None])
@pytest.mark.parametrize(
"improvement_evaluator_class", [lambda: RegretBoundEvaluator(), lambda: None]
)
@pytest.mark.parametrize(
"error_evaluator_class",
[
lambda: CrossValidationErrorEvaluator(),
lambda: StaticErrorEvaluator(0),
None,
lambda: None,
],
)
def test_get_improvement_info(
get_error: bool,
improvement_evaluator_class: Callable[[], BaseImprovementEvaluator] | None,
error_evaluator_class: Callable[[], BaseErrorEvaluator] | None,
improvement_evaluator_class: Callable[[], BaseImprovementEvaluator | None],
error_evaluator_class: Callable[[], BaseErrorEvaluator | None],
) -> None:
study = create_study()
for _ in range(3):
trial = study.ask({"x": FloatDistribution(0, 1)})
report_cross_validation_scores(trial, [1.0, 2.0])
study.tell(trial, 0)
study = _prepare_study_with_cross_validation_scores()

improvement_evaluator = (
None if improvement_evaluator_class is None else improvement_evaluator_class()
info = _get_improvement_info(
study, get_error, improvement_evaluator_class(), error_evaluator_class()
)
error_evaluator = None if error_evaluator_class is None else error_evaluator_class()
info = _get_improvement_info(study, get_error, improvement_evaluator, error_evaluator)
assert info.trial_numbers == [0, 1, 2]
assert len(info.improvements) == 3
if get_error:
Expand All @@ -109,6 +117,19 @@ def test_get_improvement_info(
assert info.errors is None


def test_get_improvement_info_started_with_failed_trials() -> None:
study = create_study()
for _ in range(3):
study.add_trial(create_trial(state=TrialState.FAIL))
trial = study.ask({"x": FloatDistribution(0, 1)})
study.tell(trial, 0)

info = _get_improvement_info(study)
assert info.trial_numbers == [3]
assert len(info.improvements) == 1
assert info.errors is None


@pytest.mark.parametrize(
"info",
[
Expand Down

0 comments on commit 13f2235

Please sign in to comment.