diff --git a/src/orion/core/worker/primary_algo.py b/src/orion/core/worker/primary_algo.py index 67c92698e..0370d4f3a 100644 --- a/src/orion/core/worker/primary_algo.py +++ b/src/orion/core/worker/primary_algo.py @@ -73,6 +73,7 @@ def __init__(self, space: Space, algorithm: AlgoType): original_registry=self.registry, transformed_registry=self.algorithm.registry, ) + self.max_suggest_attempts = 100 @property def original_space(self) -> Space: @@ -116,7 +117,7 @@ def set_state(self, state_dict: dict) -> None: self.registry.set_state(state_dict["registry"]) self.registry_mapping.set_state(state_dict["registry_mapping"]) - def suggest(self, num: int) -> list[Trial] | None: + def suggest(self, num: int) -> list[Trial]: """Suggest a `num` of new sets of parameters. Parameters @@ -137,48 +138,68 @@ def suggest(self, num: int) -> list[Trial] | None: New parameters must be compliant with the problem's domain `orion.algo.space.Space`. """ - transformed_trials = self.algorithm.suggest(num) - - if transformed_trials is None: - return None trials: list[Trial] = [] - for transformed_trial in transformed_trials: - if transformed_trial not in self.transformed_space: - raise ValueError( - f"Trial {transformed_trial.id} not contained in space:\n" - f"Params: {transformed_trial.params}\n" - f"Space: {self.transformed_space}" - ) - original = self.transformed_space.reverse(transformed_trial) - if original in self.registry: - logger.debug( - "Already have a trial that matches %s in the registry.", original - ) - # We already have a trial that is equivalent to this one. - # Fetch the actual trial (with the status and possibly results) - original = self.registry.get_existing(original) - logger.debug("Matching trial (with results/status): %s", original) - # Copy over the status and results from the original to the transformed trial - # and observe it. - transformed_trial = _copy_status_and_results( - original_trial=original, transformed_trial=transformed_trial - ) + for suggest_attempt in range(1, self.max_suggest_attempts + 1): + transformed_trials: list[Trial] | None = self.algorithm.suggest(num) + transformed_trials = transformed_trials or [] + + for transformed_trial in transformed_trials: + if transformed_trial not in self.transformed_space: + raise ValueError( + f"Trial {transformed_trial.id} not contained in space:\n" + f"Params: {transformed_trial.params}\n" + f"Space: {self.transformed_space}" + ) + original = self.transformed_space.reverse(transformed_trial) + if original in self.registry: + logger.debug( + "Already have a trial that matches %s in the registry.", + original, + ) + # We already have a trial that is equivalent to this one. + # Fetch the actual trial (with the status and possibly results) + original = self.registry.get_existing(original) + logger.debug("Matching trial (with results/status): %s", original) + + # Copy over the status and results from the original to the transformed trial + # and observe it. + transformed_trial = _copy_status_and_results( + original_trial=original, transformed_trial=transformed_trial + ) + logger.debug( + "Transformed trial (with results/status): %s", transformed_trial + ) + self.algorithm.observe([transformed_trial]) + else: + # We haven't seen this trial before. Register it. + self.registry.register(original) + trials.append(original) + + # NOTE: Here we DON'T register the transformed trial, we let the algorithm do it + # itself in its `suggest`. + # Register the equivalence between these trials. + self.registry_mapping.register(original, transformed_trial) + + if trials: + if suggest_attempt > 1: + logger.debug( + f"Succeeded in suggesting new trials after {suggest_attempt} attempts." + ) + return trials + + if self.is_done: logger.debug( - "Transformed trial (with results/status): %s", transformed_trial + f"Algorithm is done! (after {suggest_attempt} sampling attempts)." ) - self.algorithm.observe([transformed_trial]) - else: - # We haven't seen this trial before. Register it. - self.registry.register(original) - trials.append(original) - - # NOTE: Here we DON'T register the transformed trial, we let the algorithm do it itself - # in its `suggest`. - # Register the equivalence between these trials. - self.registry_mapping.register(original, transformed_trial) - return trials + break + + logger.warning( + f"Unable to sample a new trial from the algorithm, even after " + f"{self.max_suggest_attempts} attempts! Returning an empty list." + ) + return [] def observe(self, trials: list[Trial]) -> None: """Observe evaluated trials. diff --git a/tests/functional/algos/test_algos.py b/tests/functional/algos/test_algos.py index 50e415012..ef5fcc748 100644 --- a/tests/functional/algos/test_algos.py +++ b/tests/functional/algos/test_algos.py @@ -169,7 +169,12 @@ def test_cardinality_stop_loguniform(algorithm): @pytest.mark.parametrize( - "algorithm", algorithm_configs.values(), ids=list(algorithm_configs.keys()) + "algorithm", + [ + pytest.param(value, marks=pytest.mark.skipif(key == "tpe", reason="Flaky test")) + for key, value in algorithm_configs.items() + ], + ids=list(algorithm_configs.keys()), ) def test_with_fidelity(algorithm): """Test a scenario with fidelity.""" diff --git a/tests/unittests/core/test_primary_algo.py b/tests/unittests/core/test_primary_algo.py index 2a6c1ef32..e2190f9a2 100644 --- a/tests/unittests/core/test_primary_algo.py +++ b/tests/unittests/core/test_primary_algo.py @@ -4,10 +4,12 @@ from __future__ import annotations import copy +import logging import typing from typing import Any, ClassVar, TypeVar import pytest +from pytest import MonkeyPatch from orion.algo.base import BaseAlgorithm, algo_factory from orion.algo.space import Space @@ -147,6 +149,63 @@ def test_judge( del fixed_suggestion._params[-1] palgo.judge(fixed_suggestion, 8) + def test_insists_when_algo_doesnt_suggest_new_trials( + self, + algo_wrapper: SpaceTransformAlgoWrapper[StupidAlgo], + monkeypatch: MonkeyPatch, + ): + """Test that when the algo can't produce a new trial, the wrapper insists and asks again.""" + calls: int = 0 + algo_wrapper.max_suggest_attempts = 10 + + # Make the wrapper insist enough so that it actually + # gets a trial after asking enough times: + + def _suggest(num: int) -> list[Trial]: + nonlocal calls + calls += 1 + if calls < 5: + return [] + return [algo_wrapper.algorithm.fixed_suggestion] + + monkeypatch.setattr(algo_wrapper.algorithm, "suggest", _suggest) + trial = algo_wrapper.suggest(1)[0] + assert calls == 5 + assert trial in algo_wrapper.space + + def test_warns_when_unable_to_sample_new_trial( + self, + algo_wrapper: SpaceTransformAlgoWrapper[StupidAlgo], + caplog: pytest.LogCaptureFixture, + monkeypatch: MonkeyPatch, + ): + """Test that when the algo can't produce a new trial even after the max number of attempts, + a warning is logged and an empty list is returned. + """ + + calls: int = 0 + + def _suggest(num: int) -> list[Trial]: + nonlocal calls + calls += 1 + if calls < 5: + return [] + return [algo_wrapper.algorithm.fixed_suggestion] + + monkeypatch.setattr(algo_wrapper.algorithm, "suggest", _suggest) + + algo_wrapper.max_suggest_attempts = 3 + + with caplog.at_level(logging.WARNING): + out = algo_wrapper.suggest(1) + assert calls == 3 + assert out == [] + assert len(caplog.record_tuples) == 1 + log_record = caplog.record_tuples[0] + assert log_record[1] == logging.WARNING and log_record[2].startswith( + "Unable to sample a new trial" + ) + class StupidAlgo(BaseAlgorithm): """A dumb algo that always returns the same trial.""" @@ -155,14 +214,23 @@ class StupidAlgo(BaseAlgorithm): requires_shape: ClassVar[str | None] = "flattened" requires_dist: ClassVar[str | None] = "linear" - def __init__(self, space: Space, fixed_suggestion: Trial): + def __init__( + self, + space: Space, + fixed_suggestion: Trial, + ): super().__init__(space) self.fixed_suggestion = fixed_suggestion assert fixed_suggestion in space def suggest(self, num): - self.register(self.fixed_suggestion) - return [self.fixed_suggestion] + # NOTE: can't register the trial if it's already here. The fixed suggestion is always "new", + # but the algorithm actually observes it at some point. Therefore, we don't overwrite what's + # already in the registry. + if not self.has_suggested(self.fixed_suggestion): + self.register(self.fixed_suggestion) + return [self.fixed_suggestion] + return [] @pytest.fixture()