Skip to content

Commit

Permalink
Unified common code, as requested
Browse files Browse the repository at this point in the history
  • Loading branch information
mseeger committed Jan 17, 2023
1 parent 6c9c4d3 commit bc54f77
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 122 deletions.
43 changes: 8 additions & 35 deletions syne_tune/optimizer/schedulers/searchers/bore/bore.py
Expand Up @@ -23,9 +23,7 @@
from syne_tune.optimizer.schedulers.searchers.searcher import (
SearcherWithRandomSeed,
sample_random_configuration,
)
from syne_tune.optimizer.schedulers.searchers.utils.hp_ranges_factory import (
make_hyperparameter_ranges,
FilterDuplicatesMixin,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.common import (
ExclusionList,
Expand All @@ -37,7 +35,7 @@
logger = logging.getLogger(__name__)


class Bore(SearcherWithRandomSeed):
class Bore(SearcherWithRandomSeed, FilterDuplicatesMixin):
"""
Implements "Bayesian optimization by Density Ratio Estimation" as described
in the following paper:
Expand Down Expand Up @@ -90,12 +88,16 @@ def __init__(
allow_duplicates: Optional[bool] = None,
**kwargs,
):
super().__init__(
SearcherWithRandomSeed.__init__(
self,
config_space=config_space,
metric=metric,
points_to_evaluate=points_to_evaluate,
**kwargs,
)
if allow_duplicates is None:
allow_duplicates = False
FilterDuplicatesMixin.__init__(self, config_space, allow_duplicates)
if mode is None:
mode = "min"
if gamma is None:
Expand All @@ -112,8 +114,6 @@ def __init__(
random_prob = 0.0
if init_random is None:
init_random = 6
if allow_duplicates is None:
allow_duplicates = False

self.calibrate = calibrate
self.gamma = gamma
Expand All @@ -125,14 +125,6 @@ def __init__(
self.random_prob = random_prob
self.mode = mode

self._hp_ranges = make_hyperparameter_ranges(self.config_space)
self._excl_list = ExclusionList.empty_list(self._hp_ranges)
self._allow_duplicates = allow_duplicates
# Maps ``trial_id`` to configuration. This is used to blacklist
# configurations whose trial has failed (only if
# `allow_duplicates == True``)
self._config_for_trial_id = dict() if allow_duplicates else None

if classifier_kwargs is None:
classifier_kwargs = dict()
if self.classifier == "xgboost":
Expand Down Expand Up @@ -256,7 +248,7 @@ def wrapper(x):
f"config={config}] "
f"optimization time : {opt_time}"
)
if not self._allow_duplicates:
if not self.allow_duplicates:
self._excl_list.add(config) # Should not be suggested again

return config
Expand Down Expand Up @@ -304,24 +296,5 @@ def _update(self, trial_id: str, config: dict, result: dict):
self.inputs.append(self._hp_ranges.to_ndarray(config))
self.targets.append(result[self._metric])

def register_pending(
self,
trial_id: str,
config: Optional[dict] = None,
milestone: Optional[int] = None,
):
if self._allow_duplicates and trial_id not in self._config_for_trial_id:
if config is not None:
self._config_for_trial_id[trial_id] = config
else:
logger.warning(
f"register_pending called for trial_id {trial_id} without passing config"
)

def evaluation_failed(self, trial_id: str):
if self._allow_duplicates and trial_id in self._config_for_trial_id:
# Blacklist this configuration
self._excl_list.add(self._config_for_trial_id[trial_id])

def clone_from_state(self, state: dict):
raise NotImplementedError
40 changes: 8 additions & 32 deletions syne_tune/optimizer/schedulers/searchers/kde/kde_searcher.py
Expand Up @@ -19,6 +19,7 @@
from syne_tune.optimizer.schedulers.searchers.searcher import (
SearcherWithRandomSeed,
sample_random_configuration,
FilterDuplicatesMixin,
)
import syne_tune.config_space as sp
from syne_tune.optimizer.schedulers.searchers.bayesopt.utils.debug_log import (
Expand All @@ -34,7 +35,7 @@
logger = logging.getLogger(__name__)


class KernelDensityEstimator(SearcherWithRandomSeed):
class KernelDensityEstimator(SearcherWithRandomSeed, FilterDuplicatesMixin):
"""
Fits two kernel density estimators (KDE) to model the density of the top N
configurations as well as the density of the configurations that are not
Expand Down Expand Up @@ -101,13 +102,17 @@ def __init__(
allow_duplicates: Optional[bool] = None,
**kwargs,
):
super().__init__(
SearcherWithRandomSeed.__init__(
self,
config_space=config_space,
metric=metric,
points_to_evaluate=points_to_evaluate,
mode="min" if mode is None else mode,
**kwargs,
)
if allow_duplicates is None:
allow_duplicates = False
FilterDuplicatesMixin.__init__(self, config_space, allow_duplicates)
if top_n_percent is None:
top_n_percent = 15
if min_bandwidth is None:
Expand All @@ -118,8 +123,6 @@ def __init__(
bandwidth_factor = 3
if random_fraction is None:
random_fraction = 0.33
if allow_duplicates is None:
allow_duplicates = False
self.num_evaluations = 0
self.min_bandwidth = min_bandwidth
self.random_fraction = random_fraction
Expand Down Expand Up @@ -163,14 +166,6 @@ def __init__(
self.vartypes
), f"num_min_data_points = {num_min_data_points}, must be >= {len(self.vartypes)}"
self._resource_attr = kwargs.get("resource_attr")
# Used for sampling initial random configs, and to avoid duplicates
self._hp_ranges = make_hyperparameter_ranges(self.config_space)
self._excl_list = ExclusionList.empty_list(self._hp_ranges)
self._allow_duplicates = allow_duplicates
# Maps ``trial_id`` to configuration. This is used to blacklist
# configurations whose trial has failed (only if
# `allow_duplicates == True``)
self._config_for_trial_id = dict() if allow_duplicates else None
# Debug log printing (switched on by default)
debug_log = kwargs.get("debug_log", True)
if isinstance(debug_log, bool):
Expand Down Expand Up @@ -373,7 +368,7 @@ def acquisition_function(x):
)
suggestion = self._get_random_config()

if suggestion is not None and not self._allow_duplicates:
if suggestion is not None and not self.allow_duplicates:
self._excl_list.add(suggestion) # Should not be suggested again
return suggestion

Expand Down Expand Up @@ -413,24 +408,5 @@ def _train_kde(self, train_data, train_targets):

return bad_kde, good_kde

def register_pending(
self,
trial_id: str,
config: Optional[dict] = None,
milestone: Optional[int] = None,
):
if self._allow_duplicates and trial_id not in self._config_for_trial_id:
if config is not None:
self._config_for_trial_id[trial_id] = config
else:
logger.warning(
f"register_pending called for trial_id {trial_id} without passing config"
)

def evaluation_failed(self, trial_id: str):
if self._allow_duplicates and trial_id in self._config_for_trial_id:
# Blacklist this configuration
self._excl_list.add(self._config_for_trial_id[trial_id])

def clone_from_state(self, state: dict):
raise NotImplementedError
50 changes: 10 additions & 40 deletions syne_tune/optimizer/schedulers/searchers/random_grid_searcher.py
Expand Up @@ -28,6 +28,7 @@
from syne_tune.optimizer.schedulers.searchers.searcher import (
SearcherWithRandomSeed,
sample_random_configuration,
FilterDuplicatesMixin,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.common import (
ExclusionList,
Expand All @@ -40,7 +41,7 @@
logger = logging.getLogger(__name__)


class RandomSearcher(SearcherWithRandomSeed):
class RandomSearcher(SearcherWithRandomSeed, FilterDuplicatesMixin):
"""
Searcher which randomly samples configurations to try next.
Expand All @@ -65,19 +66,11 @@ def __init__(
allow_duplicates: bool = False,
**kwargs,
):
super().__init__(
config_space, metric, points_to_evaluate=points_to_evaluate, **kwargs
SearcherWithRandomSeed.__init__(
self, config_space, metric, points_to_evaluate=points_to_evaluate, **kwargs
)
self._hp_ranges = make_hyperparameter_ranges(config_space)
FilterDuplicatesMixin.__init__(self, config_space, allow_duplicates)
self._resource_attr = resource_attr
self._allow_duplicates = allow_duplicates
# Used to avoid returning the same config more than once. If
# ``allow_duplicates == True``, this is used to block failed trials
self._excl_list = ExclusionList.empty_list(self._hp_ranges)
# Maps ``trial_id`` to configuration. This is used to blacklist
# configurations whose trial has failed (only if
# `allow_duplicates == True``)
self._config_for_trial_id = dict() if allow_duplicates else None
# Debug log printing (switched off by default)
if isinstance(debug_log, bool):
if debug_log:
Expand Down Expand Up @@ -118,7 +111,7 @@ def get_config(self, **kwargs) -> Optional[dict]:
exclusion_list=self._excl_list,
)
if new_config is not None:
if not self._allow_duplicates:
if not self.allow_duplicates:
self._excl_list.add(new_config) # Should not be suggested again
if self._debug_log is not None:
trial_id = kwargs.get("trial_id")
Expand Down Expand Up @@ -147,37 +140,14 @@ def _update(self, trial_id: str, config: dict, result: dict):
msg = f"Update for trial_id {trial_id}: metric = {metric_val:.3f}"
logger.info(msg)

def register_pending(
self,
trial_id: str,
config: Optional[dict] = None,
milestone: Optional[int] = None,
):
if self._allow_duplicates and trial_id not in self._config_for_trial_id:
if config is not None:
self._config_for_trial_id[trial_id] = config
else:
logger.warning(
f"register_pending called for trial_id {trial_id} without passing config"
)

def evaluation_failed(self, trial_id: str):
if self._allow_duplicates and trial_id in self._config_for_trial_id:
# Blacklist this configuration
self._excl_list.add(self._config_for_trial_id[trial_id])

def get_state(self) -> dict:
state = dict(super().get_state(), excl_list=self._excl_list.get_state())
if self._allow_duplicates:
state["config_for_trial_id"] = self._config_for_trial_id
state = SearcherWithRandomSeed.get_state(self)
state.update(FilterDuplicatesMixin.get_state(self))
return state

def _restore_from_state(self, state: dict):
super()._restore_from_state(state)
self._excl_list = ExclusionList.empty_list(self._hp_ranges)
self._excl_list.clone_from_state(state["excl_list"])
if self._allow_duplicates:
self._config_for_trial_id = state["config_for_trial_id"]
SearcherWithRandomSeed._restore_from_state(self, state)
FilterDuplicatesMixin._restore_from_state(self, state)

def clone_from_state(self, state: dict):
new_searcher = RandomSearcher(
Expand Down
69 changes: 69 additions & 0 deletions syne_tune/optimizer/schedulers/searchers/searcher.py
Expand Up @@ -33,6 +33,7 @@
)
from syne_tune.optimizer.schedulers.searchers.utils import (
HyperparameterRanges,
make_hyperparameter_ranges,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -450,3 +451,71 @@ def _restore_from_state(self, state: dict):

def set_random_state(self, random_state: np.random.RandomState):
self.random_state = random_state


class FilterDuplicatesMixin:
"""
Mixin for :class:`BaseSearcher`. Maintains an exclusion list to filter out
duplicates in
:meth:`~syne_tune.optimizer.schedulers.searchers.BaseSearcher.get_config` if
``allows_duplicates == False`. If this is ``True``, duplicates are not filtered,
and the exclusion list is used only to avoid configurations of failed trials.
In order to make use of these features:
* Reject configurations in :meth:`get_config` if they are in the exclusion list.
If the configuration is drawn at random, please use
:func:`~syne_tune.optimizer.schedulers.searchers.searcher.sample_random_configuration`
and pass the exclusion list
* Before returning a configuration in :meth:`get_config`, add it to the exclusion
list iff ``allow_duplicates == False`
* If you implment :meth:`get_state` and :meth:`_restore_from_state`, call the
methods here as well
Note: Not all searchers which filter duplicates make use of this class.
"""
def __init__(self, config_space: dict, allow_duplicates: bool = False):
self._hp_ranges = make_hyperparameter_ranges(config_space)
self._allow_duplicates = allow_duplicates
# Used to avoid returning the same config more than once. If
# ``allow_duplicates == True``, this is used to block failed trials
self._excl_list = ExclusionList.empty_list(self._hp_ranges)
# Maps ``trial_id`` to configuration. This is used to blacklist
# configurations whose trial has failed (only if
# `allow_duplicates == True``)
self._config_for_trial_id = dict() if allow_duplicates else None

@property
def allow_duplicates(self) -> bool:
return self._allow_duplicates

def register_pending(
self,
trial_id: str,
config: Optional[dict] = None,
milestone: Optional[int] = None,
):
if self._allow_duplicates and trial_id not in self._config_for_trial_id:
if config is not None:
self._config_for_trial_id[trial_id] = config
else:
logger.warning(
f"register_pending called for trial_id {trial_id} without passing config"
)

def evaluation_failed(self, trial_id: str):
if self._allow_duplicates and trial_id in self._config_for_trial_id:
# Blacklist this configuration
self._excl_list.add(self._config_for_trial_id[trial_id])

def get_state(self) -> dict:
state = {"excl_list": self._excl_list.get_state()}
if self._allow_duplicates:
state["config_for_trial_id"] = self._config_for_trial_id
return state

def _restore_from_state(self, state: dict):
self._excl_list = ExclusionList.empty_list(self._hp_ranges)
self._excl_list.clone_from_state(state["excl_list"])
if self._allow_duplicates:
self._config_for_trial_id = state["config_for_trial_id"]

0 comments on commit bc54f77

Please sign in to comment.