diff --git a/src/orion/algo/asha.py b/src/orion/algo/asha.py index 163c98882..1e7b274e0 100644 --- a/src/orion/algo/asha.py +++ b/src/orion/algo/asha.py @@ -232,9 +232,12 @@ def get_candidates(self, rung_id): rung = list( sorted( - (objective, trial) - for objective, trial in rung.values() - if objective is not None + ( + (objective, trial) + for objective, trial in rung.values() + if objective is not None + ), + key=lambda item: item[0], ) ) k = len(rung) // self.hyperband.reduction_factor diff --git a/src/orion/algo/hyperband.py b/src/orion/algo/hyperband.py index 39ee75010..dbc3cdc92 100644 --- a/src/orion/algo/hyperband.py +++ b/src/orion/algo/hyperband.py @@ -341,7 +341,8 @@ def suggest(self, num): ) else: logger.warning( - f"{self.__class__.__name__} cannot suggest new samples, exit." + f"{self.__class__.__name__} cannot suggest new samples and must wait " + "for trials to complete." ) return [] diff --git a/tests/unittests/algo/test_asha.py b/tests/unittests/algo/test_asha.py index 74ec08ecf..f86dec829 100644 --- a/tests/unittests/algo/test_asha.py +++ b/tests/unittests/algo/test_asha.py @@ -574,6 +574,36 @@ def test_suggest_promote_many_plus_random( == 20 - 2 - 3 * 3 ) + def test_suggest_promote_identic_objectives( + self, asha, bracket, big_rung_0, big_rung_1 + ): + """Test that identic objectives are handled properly""" + asha.brackets = [bracket] + bracket.asha = asha + + n_trials = 9 + resources = 1 + + results = {} + for param in np.linspace(0, 8, 9): + trial = create_trial_for_hb((resources, param), objective=0) + trial_hash = trial.compute_trial_hash( + trial, + ignore_fidelity=True, + ignore_experiment=True, + ) + results[trial_hash] = (trial.objective.value, trial) + + bracket.rungs[0] = dict(n_trials=n_trials, resources=resources, results=results) + + candidates = asha.suggest(2) + + assert len(candidates) == 2 + assert ( + sum(1 for trial in candidates if trial.params[asha.fidelity_index] == 3) + == 2 + ) + class TestGenericASHA(BaseAlgoTests): algo_name = "asha" diff --git a/tests/unittests/algo/test_hyperband.py b/tests/unittests/algo/test_hyperband.py index dad44b9b8..58da7da18 100644 --- a/tests/unittests/algo/test_hyperband.py +++ b/tests/unittests/algo/test_hyperband.py @@ -584,6 +584,36 @@ def test_suggest_promote(self, hyperband, bracket, rung_0): assert points[1].params == {"epoch": 3, "lr": 1} assert points[2].params == {"epoch": 3, "lr": 2} + def test_suggest_promote_identic_objectives(self, hyperband, bracket): + """Test that identic objectives are handled properly""" + hyperband.brackets = [bracket] + bracket.hyperband = hyperband + + n_trials = 9 + resources = 1 + + results = {} + for param in np.linspace(0, 8, 9): + trial = create_trial_for_hb((resources, param), objective=0) + trial_hash = trial.compute_trial_hash( + trial, + ignore_fidelity=True, + ignore_experiment=True, + ) + results[trial_hash] = (trial.objective.value, trial) + + bracket.rungs[0] = dict(n_trials=n_trials, resources=resources, results=results) + + candidates = hyperband.suggest(2) + + assert len(candidates) == 2 + assert ( + sum( + 1 for trial in candidates if trial.params[hyperband.fidelity_index] == 3 + ) + == 2 + ) + def test_is_filled(self, hyperband, bracket, rung_0, rung_1, rung_2): """Test that Hyperband bracket detects when rung is filled.""" hyperband.brackets = [bracket]