Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implementation of expected hypervolume improvement #825

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
55 changes: 55 additions & 0 deletions syne_tune/optimizer/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,61 @@ def __init__(
)
)

try:
from syne_tune.optimizer.schedulers.multiobjective.expected_hyper_volume_improvement import (
ExpectedHyperVolumeImprovement,
)

class EHVI(FIFOScheduler):
"""
Implements the Expected Hypervolume Improvement method.

See :class:`~syne_tune.optimizer.schedulers.searchers.RandomSearcher`
for ``kwargs["search_options"]`` parameters.

:param config_space: Configuration space for evaluation function
:param metric: Name of metric to optimize
:param population_size: See
:class:`~syne_tune.optimizer.schedulers.searchers.RegularizedEvolution`.
Defaults to 100
:param sample_size: See
:class:`~syne_tune.optimizer.schedulers.searchers.RegularizedEvolution`.
Defaults to 10
:param random_seed: Random seed, optional
:param kwargs: Additional arguments to
:class:`~syne_tune.optimizer.schedulers.FIFOScheduler`
"""

def __init__(
self,
config_space: Dict[str, Any],
metric: List[str],
mode: Union[List[str], str] = "min",
random_seed: Optional[int] = None,
**kwargs,
):
searcher_kwargs = _create_searcher_kwargs(
config_space, metric, random_seed, kwargs
)
searcher_kwargs["mode"] = mode
print(mode)

super(EHVI, self).__init__(
config_space=config_space,
metric=metric,
mode=mode,
searcher=ExpectedHyperVolumeImprovement(**searcher_kwargs),
random_seed=random_seed,
**kwargs,
)

except ImportError:
logging.info(
_try_import_message(
message_text="EHVI is not imported (not contained in extra)",
tag="ehvi",
)
)
# Dictionary that allows to also list baselines who don't need a wrapper class
# such as :class:`PopulationBasedTraining`
baselines_dict = {
Expand Down
2 changes: 1 addition & 1 deletion syne_tune/optimizer/schedulers/fifo.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _check_metric_mode(
else:
len_mode = 1
if len_mode == 1:
mode = [mode * num_objectives]
mode = [mode] * num_objectives
allowed_values = {"min", "max"}
assert all(
x in allowed_values for x in mode
Expand Down