Skip to content

Commit

Permalink
Merge pull request optuna#4747 from Alnusjaponica/fix-BruteForceSampler
Browse files Browse the repository at this point in the history
Make BruteForceSampler consider failed trials
  • Loading branch information
not522 committed Jun 22, 2023
2 parents 67d976e + 820a523 commit 0af57a1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
2 changes: 2 additions & 0 deletions optuna/samplers/_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def sample_independent(
TrialState.COMPLETE,
TrialState.PRUNED,
TrialState.RUNNING,
TrialState.FAIL,
),
)
tree = self._build_tree((t for t in trials if t.number != trial.number), trial.params)
Expand All @@ -210,6 +211,7 @@ def after_trial(
TrialState.COMPLETE,
TrialState.PRUNED,
TrialState.RUNNING,
TrialState.FAIL,
),
)
tree = self._build_tree(
Expand Down
15 changes: 15 additions & 0 deletions tests/samplers_tests/test_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,18 @@ def objective_decreasing_variable(trial: Trial) -> float:
)
with pytest.raises(ValueError):
study.optimize(objective_decreasing_variable, n_trials=10)


def test_study_optimize_with_failed_trials() -> None:
def objective(trial: Trial) -> float:
x = trial.suggest_int("x", 0, 99) # NOQA[F811]
return np.nan

study = optuna.create_study(sampler=samplers.BruteForceSampler())
study.optimize(objective, n_trials=100)

expected_suggested_values = [{"x": i} for i in range(100)]
all_suggested_values = [t.params for t in study.trials]
assert len(all_suggested_values) == len(expected_suggested_values)
for a in expected_suggested_values:
assert a in all_suggested_values

0 comments on commit 0af57a1

Please sign in to comment.