Skip to content

Commit

Permalink
Add number of suggest() attempts in Algo wrapper (#883)
Browse files Browse the repository at this point in the history
* Add number of suggest() attempts in Algo wrapper

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Fix pylint issue

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Remove TODO note in test_primary_algo.py

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Remove note from primary_algo.py

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Add `max_suggest_attempts` as an attribute

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* skip flaky test

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Fix bug in primary_algo.py

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Fix type annotation on algo wrapper suggest

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Add tests for the new feature of the algo wrapper

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
  • Loading branch information
lebrice committed Apr 26, 2022
1 parent 9b6eeb8 commit e13326d
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 42 deletions.
97 changes: 59 additions & 38 deletions src/orion/core/worker/primary_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, space: Space, algorithm: AlgoType):
original_registry=self.registry,
transformed_registry=self.algorithm.registry,
)
self.max_suggest_attempts = 100

@property
def original_space(self) -> Space:
Expand Down Expand Up @@ -116,7 +117,7 @@ def set_state(self, state_dict: dict) -> None:
self.registry.set_state(state_dict["registry"])
self.registry_mapping.set_state(state_dict["registry_mapping"])

def suggest(self, num: int) -> list[Trial] | None:
def suggest(self, num: int) -> list[Trial]:
"""Suggest a `num` of new sets of parameters.
Parameters
Expand All @@ -137,48 +138,68 @@ def suggest(self, num: int) -> list[Trial] | None:
New parameters must be compliant with the problem's domain `orion.algo.space.Space`.
"""
transformed_trials = self.algorithm.suggest(num)

if transformed_trials is None:
return None

trials: list[Trial] = []
for transformed_trial in transformed_trials:
if transformed_trial not in self.transformed_space:
raise ValueError(
f"Trial {transformed_trial.id} not contained in space:\n"
f"Params: {transformed_trial.params}\n"
f"Space: {self.transformed_space}"
)
original = self.transformed_space.reverse(transformed_trial)
if original in self.registry:
logger.debug(
"Already have a trial that matches %s in the registry.", original
)
# We already have a trial that is equivalent to this one.
# Fetch the actual trial (with the status and possibly results)
original = self.registry.get_existing(original)
logger.debug("Matching trial (with results/status): %s", original)

# Copy over the status and results from the original to the transformed trial
# and observe it.
transformed_trial = _copy_status_and_results(
original_trial=original, transformed_trial=transformed_trial
)
for suggest_attempt in range(1, self.max_suggest_attempts + 1):
transformed_trials: list[Trial] | None = self.algorithm.suggest(num)
transformed_trials = transformed_trials or []

for transformed_trial in transformed_trials:
if transformed_trial not in self.transformed_space:
raise ValueError(
f"Trial {transformed_trial.id} not contained in space:\n"
f"Params: {transformed_trial.params}\n"
f"Space: {self.transformed_space}"
)
original = self.transformed_space.reverse(transformed_trial)
if original in self.registry:
logger.debug(
"Already have a trial that matches %s in the registry.",
original,
)
# We already have a trial that is equivalent to this one.
# Fetch the actual trial (with the status and possibly results)
original = self.registry.get_existing(original)
logger.debug("Matching trial (with results/status): %s", original)

# Copy over the status and results from the original to the transformed trial
# and observe it.
transformed_trial = _copy_status_and_results(
original_trial=original, transformed_trial=transformed_trial
)
logger.debug(
"Transformed trial (with results/status): %s", transformed_trial
)
self.algorithm.observe([transformed_trial])
else:
# We haven't seen this trial before. Register it.
self.registry.register(original)
trials.append(original)

# NOTE: Here we DON'T register the transformed trial, we let the algorithm do it
# itself in its `suggest`.
# Register the equivalence between these trials.
self.registry_mapping.register(original, transformed_trial)

if trials:
if suggest_attempt > 1:
logger.debug(
f"Succeeded in suggesting new trials after {suggest_attempt} attempts."
)
return trials

if self.is_done:
logger.debug(
"Transformed trial (with results/status): %s", transformed_trial
f"Algorithm is done! (after {suggest_attempt} sampling attempts)."
)
self.algorithm.observe([transformed_trial])
else:
# We haven't seen this trial before. Register it.
self.registry.register(original)
trials.append(original)

# NOTE: Here we DON'T register the transformed trial, we let the algorithm do it itself
# in its `suggest`.
# Register the equivalence between these trials.
self.registry_mapping.register(original, transformed_trial)
return trials
break

logger.warning(
f"Unable to sample a new trial from the algorithm, even after "
f"{self.max_suggest_attempts} attempts! Returning an empty list."
)
return []

def observe(self, trials: list[Trial]) -> None:
"""Observe evaluated trials.
Expand Down
7 changes: 6 additions & 1 deletion tests/functional/algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,12 @@ def test_cardinality_stop_loguniform(algorithm):


@pytest.mark.parametrize(
"algorithm", algorithm_configs.values(), ids=list(algorithm_configs.keys())
"algorithm",
[
pytest.param(value, marks=pytest.mark.skipif(key == "tpe", reason="Flaky test"))
for key, value in algorithm_configs.items()
],
ids=list(algorithm_configs.keys()),
)
def test_with_fidelity(algorithm):
"""Test a scenario with fidelity."""
Expand Down
74 changes: 71 additions & 3 deletions tests/unittests/core/test_primary_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from __future__ import annotations

import copy
import logging
import typing
from typing import Any, ClassVar, TypeVar

import pytest
from pytest import MonkeyPatch

from orion.algo.base import BaseAlgorithm, algo_factory
from orion.algo.space import Space
Expand Down Expand Up @@ -147,6 +149,63 @@ def test_judge(
del fixed_suggestion._params[-1]
palgo.judge(fixed_suggestion, 8)

def test_insists_when_algo_doesnt_suggest_new_trials(
self,
algo_wrapper: SpaceTransformAlgoWrapper[StupidAlgo],
monkeypatch: MonkeyPatch,
):
"""Test that when the algo can't produce a new trial, the wrapper insists and asks again."""
calls: int = 0
algo_wrapper.max_suggest_attempts = 10

# Make the wrapper insist enough so that it actually
# gets a trial after asking enough times:

def _suggest(num: int) -> list[Trial]:
nonlocal calls
calls += 1
if calls < 5:
return []
return [algo_wrapper.algorithm.fixed_suggestion]

monkeypatch.setattr(algo_wrapper.algorithm, "suggest", _suggest)
trial = algo_wrapper.suggest(1)[0]
assert calls == 5
assert trial in algo_wrapper.space

def test_warns_when_unable_to_sample_new_trial(
self,
algo_wrapper: SpaceTransformAlgoWrapper[StupidAlgo],
caplog: pytest.LogCaptureFixture,
monkeypatch: MonkeyPatch,
):
"""Test that when the algo can't produce a new trial even after the max number of attempts,
a warning is logged and an empty list is returned.
"""

calls: int = 0

def _suggest(num: int) -> list[Trial]:
nonlocal calls
calls += 1
if calls < 5:
return []
return [algo_wrapper.algorithm.fixed_suggestion]

monkeypatch.setattr(algo_wrapper.algorithm, "suggest", _suggest)

algo_wrapper.max_suggest_attempts = 3

with caplog.at_level(logging.WARNING):
out = algo_wrapper.suggest(1)
assert calls == 3
assert out == []
assert len(caplog.record_tuples) == 1
log_record = caplog.record_tuples[0]
assert log_record[1] == logging.WARNING and log_record[2].startswith(
"Unable to sample a new trial"
)


class StupidAlgo(BaseAlgorithm):
"""A dumb algo that always returns the same trial."""
Expand All @@ -155,14 +214,23 @@ class StupidAlgo(BaseAlgorithm):
requires_shape: ClassVar[str | None] = "flattened"
requires_dist: ClassVar[str | None] = "linear"

def __init__(self, space: Space, fixed_suggestion: Trial):
def __init__(
self,
space: Space,
fixed_suggestion: Trial,
):
super().__init__(space)
self.fixed_suggestion = fixed_suggestion
assert fixed_suggestion in space

def suggest(self, num):
self.register(self.fixed_suggestion)
return [self.fixed_suggestion]
# NOTE: can't register the trial if it's already here. The fixed suggestion is always "new",
# but the algorithm actually observes it at some point. Therefore, we don't overwrite what's
# already in the registry.
if not self.has_suggested(self.fixed_suggestion):
self.register(self.fixed_suggestion)
return [self.fixed_suggestion]
return []


@pytest.fixture()
Expand Down

0 comments on commit e13326d

Please sign in to comment.