Skip to content

Commit

Permalink
Several fixes, and extended test
Browse files Browse the repository at this point in the history
  • Loading branch information
mseeger committed Jan 20, 2023
1 parent 90864c8 commit 4afa556
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 101 deletions.
11 changes: 11 additions & 0 deletions syne_tune/optimizer/schedulers/searchers/bore/bore.py
Expand Up @@ -296,5 +296,16 @@ def _update(self, trial_id: str, config: dict, result: dict):
self.inputs.append(self._hp_ranges.to_ndarray(config))
self.targets.append(result[self._metric])

def register_pending(
self,
trial_id: str,
config: Optional[dict] = None,
milestone: Optional[int] = None,
):
FilterDuplicatesMixin.register_pending(self, trial_id, config, milestone)

def evaluation_failed(self, trial_id: str):
FilterDuplicatesMixin.evaluation_failed(self, trial_id)

def clone_from_state(self, state: dict):
raise NotImplementedError
8 changes: 4 additions & 4 deletions syne_tune/optimizer/schedulers/searchers/gp_fifo_searcher.py
Expand Up @@ -339,7 +339,9 @@ def _get_config_not_modelbased(
model-based search. If False is returned, model-based search must be
called.
:param exclusion_candidates: Configs to be avoided
:param exclusion_candidates: Configs to be avoided, even if
``allow_duplicates == True`` (in this case, we avoid configs of
failed or pending trials)
:return: ``(config, use_get_config_modelbased)``
"""
self._assign_random_searcher()
Expand All @@ -357,9 +359,7 @@ def _get_config_not_modelbased(
# If ``RandomSearcher`` returns no config at all, the
# search space is exhausted
break
if self._allow_duplicates or (
not exclusion_candidates.contains(_config)
):
if not exclusion_candidates.contains(_config):
config = _config
break
if self.do_profile:
Expand Down
14 changes: 11 additions & 3 deletions syne_tune/optimizer/schedulers/searchers/kde/kde_searcher.py
Expand Up @@ -25,9 +25,6 @@
from syne_tune.optimizer.schedulers.searchers.bayesopt.utils.debug_log import (
DebugLogPrinter,
)
from syne_tune.optimizer.schedulers.searchers.utils.hp_ranges_factory import (
make_hyperparameter_ranges,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.common import (
ExclusionList,
)
Expand Down Expand Up @@ -408,5 +405,16 @@ def _train_kde(self, train_data, train_targets):

return bad_kde, good_kde

def register_pending(
self,
trial_id: str,
config: Optional[dict] = None,
milestone: Optional[int] = None,
):
FilterDuplicatesMixin.register_pending(self, trial_id, config, milestone)

def evaluation_failed(self, trial_id: str):
FilterDuplicatesMixin.evaluation_failed(self, trial_id)

def clone_from_state(self, state: dict):
raise NotImplementedError
11 changes: 11 additions & 0 deletions syne_tune/optimizer/schedulers/searchers/random_grid_searcher.py
Expand Up @@ -140,6 +140,17 @@ def _update(self, trial_id: str, config: dict, result: dict):
msg = f"Update for trial_id {trial_id}: metric = {metric_val:.3f}"
logger.info(msg)

def register_pending(
self,
trial_id: str,
config: Optional[dict] = None,
milestone: Optional[int] = None,
):
FilterDuplicatesMixin.register_pending(self, trial_id, config, milestone)

def evaluation_failed(self, trial_id: str):
FilterDuplicatesMixin.evaluation_failed(self, trial_id)

def get_state(self) -> dict:
state = SearcherWithRandomSeed.get_state(self)
state.update(FilterDuplicatesMixin.get_state(self))
Expand Down
27 changes: 18 additions & 9 deletions syne_tune/optimizer/schedulers/synchronous/dehb.py
Expand Up @@ -437,18 +437,23 @@ def _encoded_config_from_searcher(self, trial_id: int) -> np.ndarray:

def _encoded_config_by_promotion(
self, ext_slot: ExtendedSlotInRung
) -> (np.ndarray, int):
) -> (Optional[np.ndarray], Optional[int]):
parent_trial_id = self.bracket_manager.top_of_previous_rung(
bracket_id=ext_slot.bracket_id, pos=ext_slot.slot_index
)
trial_info = self._trial_info[parent_trial_id]
assert trial_info.metric_val is not None # Sanity check
if self._debug_log is not None:
logger.info(
f"Promote config from trial_id {parent_trial_id}"
f", level {trial_info.level}"
)
return trial_info.encoded_config, parent_trial_id
if parent_trial_id is not None:
trial_info = self._trial_info[parent_trial_id]
assert trial_info.metric_val is not None # Sanity check
if self._debug_log is not None:
logger.info(
f"Promote config from trial_id {parent_trial_id}"
f", level {trial_info.level}"
)
encoded_config = trial_info.encoded_config
else:
# This can happen when all trials in the previous rung failed
encoded_config = None
return encoded_config, parent_trial_id

def _extended_config_by_mutation_crossover(
self, ext_slot: ExtendedSlotInRung
Expand Down Expand Up @@ -498,6 +503,8 @@ def _register_new_config_and_make_suggestion(
self._debug_log.set_final_config(config)
self._debug_log.write_block()
config = cast_config_values(config, self.config_space)
if self.searcher is not None:
self.searcher.register_pending(trial_id=str(trial_id), config=config)
if self.max_resource_attr is not None:
config = dict(config, **{self.max_resource_attr: ext_slot.level})
return TrialSuggestion.start_suggestion(config=config)
Expand Down Expand Up @@ -733,6 +740,8 @@ def on_trial_error(self, trial: Trial):
if trial_id in self._trial_to_pending_slot:
ext_slot = self._trial_to_pending_slot[trial_id]
self._report_as_failed(ext_slot)
# A failed trial is not pending anymore
del self._trial_to_pending_slot[trial_id]
else:
logger.warning(
f"Trial trial_id {trial_id} not registered at pending: "
Expand Down
3 changes: 3 additions & 0 deletions syne_tune/optimizer/schedulers/synchronous/hyperband.py
Expand Up @@ -290,6 +290,7 @@ def _suggest(self, trial_id: int) -> Optional[TrialSuggestion]:
config = self.searcher.get_config(trial_id=str(trial_id))
if config is not None:
config = cast_config_values(config, self.config_space)
self.searcher.register_pending(trial_id=str(trial_id), config=config)
if self.max_resource_attr is not None:
config[self.max_resource_attr] = slot_in_rung.level
self._trial_to_config[trial_id] = config
Expand Down Expand Up @@ -393,6 +394,8 @@ def on_trial_error(self, trial: Trial):
if trial_id in self._trial_to_pending_slot:
bracket_id, slot_in_rung = self._trial_to_pending_slot[trial_id]
self._report_as_failed(bracket_id, slot_in_rung)
# A failed trial is not pending anymore
del self._trial_to_pending_slot[trial_id]
else:
logger.warning(
f"Trial trial_id {trial_id} not registered at pending: "
Expand Down
184 changes: 99 additions & 85 deletions tst/schedulers/test_searchers.py
Expand Up @@ -12,6 +12,7 @@
# permissions and limitations under the License.
import pytest
from datetime import datetime
import itertools
import numpy as np

from syne_tune.optimizer.baselines import (
Expand All @@ -36,26 +37,39 @@
from syne_tune.optimizer.scheduler import SchedulerDecision


SCHEDULERS = [
(GridSearch, False),
(RandomSearch, False),
(BayesianOptimization, False),
(ASHA, True),
(HyperTune, True),
(SyncHyperband, True),
(DEHB, True),
(BOHB, True),
(SyncBOHB, True),
(BORE, False),
(KDE, False),
]


DUPLICATES_AND_FAIL = [(False, False), (True, False), (True, True)]


COMBINATIONS = list(itertools.product(SCHEDULERS, DUPLICATES_AND_FAIL[:-1])) + list(
itertools.product(SCHEDULERS[1:], DUPLICATES_AND_FAIL[-1:])
)


# Does not contain ASHABORE, because >10 secs on CI
# TODO: Dig more, why is ASHABORE more expensive than BORE here?
@pytest.mark.timeout(10)
@pytest.mark.parametrize(
"scheduler_cls, multifid",
[
(RandomSearch, False),
(GridSearch, False),
(BayesianOptimization, False),
(ASHA, True),
(HyperTune, True),
(SyncHyperband, True),
(DEHB, True),
(BOHB, True),
(SyncBOHB, True),
(BORE, False),
(KDE, False),
],
)
def test_allow_duplicates_or_not(scheduler_cls, multifid):
@pytest.mark.parametrize("tpl1, tpl2", COMBINATIONS)
def test_allow_duplicates_or_not(tpl1, tpl2):
# If ``trials_fail == True``, we let all trials fail. In that case, corresponding
# configs are filtered out, even if ``allow_duplicates == True`` (for all searchers
# except ``GridSearcher``).
scheduler_cls, multifid = tpl1
allow_duplicates, trials_fail = tpl2
random_seed = 31415927
np.random.seed(random_seed)

Expand All @@ -72,73 +86,73 @@ def test_allow_duplicates_or_not(scheduler_cls, multifid):
cs_size = config_space_size(config_space)
assert cs_size == 3 * 3

# for allow_duplicates, trials_fail in [(False, False), (True, False), (True, True)]:
for allow_duplicates, trials_fail in [(False, False), (True, False)]:
if multifid:
kwargs = dict(resource_attr=resource_attr)
else:
kwargs = dict()
scheduler = scheduler_cls(
config_space,
metric=metric,
mode=mode,
max_resource_attr=max_resource_attr,
search_options={
"allow_duplicates": allow_duplicates,
"debug_log": False,
},
**kwargs,
if multifid:
kwargs = dict(resource_attr=resource_attr)
else:
kwargs = dict()
scheduler = scheduler_cls(
config_space,
metric=metric,
mode=mode,
max_resource_attr=max_resource_attr,
search_options={
"allow_duplicates": allow_duplicates,
"debug_log": False,
},
**kwargs,
)
trial_id = 0
trials = dict()
while trial_id <= cs_size:
err_msg = (
f"trial_id {trial_id}, cs_size {cs_size}, allow_duplicates "
f"{allow_duplicates}, trials_fail {trials_fail}"
)
trial_id = 0
trials = dict()
while trial_id <= cs_size:
err_msg = (
f"trial_id {trial_id}, cs_size {cs_size}, allow_duplicates "
f"{allow_duplicates}, trials_fail {trials_fail}"
)
suggestion = scheduler.suggest(trial_id)
if trial_id < cs_size or allow_duplicates:
assert suggestion is not None, err_msg
if suggestion.spawn_new_trial_id:
# Start new trial
assert suggestion.checkpoint_trial_id is None, err_msg
trial = Trial(
trial_id=trial_id,
config=suggestion.config,
creation_time=datetime.now(),
)
scheduler.on_trial_add(trial)
trials[trial_id] = trial
trial_id += 1
else:
# Resume existing trial
# Note that we do not implement checkpointing, so resume
# means start from scratch
assert suggestion.checkpoint_trial_id is not None, err_msg
trial = trials[suggestion.checkpoint_trial_id]
if not trials_fail:
# Return results
result = None
it = None
for it in range(max_resource_val):
result = {
metric: np.random.rand(),
resource_attr: it + 1,
}
decision = scheduler.on_trial_result(trial=trial, result=result)
if decision != SchedulerDecision.CONTINUE:
break
if it >= max_resource_val - 1:
scheduler.on_trial_complete(trial=trial, result=result)
else:
# Trial fails
scheduler.on_trial_error(trial=trial)
else:
# Maybe trials are being resumed?
while (
suggestion is not None
and suggestion.checkpoint_trial_id is not None
):
suggestion = scheduler.suggest(trial_id)
assert suggestion is None, err_msg
print(f"suggest: trial_id = {trial_id}")
suggestion = scheduler.suggest(trial_id)
if trial_id < cs_size or (allow_duplicates and not trials_fail):
assert suggestion is not None, err_msg
if suggestion.spawn_new_trial_id:
# Start new trial
print(f"Start new trial: {suggestion.config}")
assert suggestion.checkpoint_trial_id is None, err_msg
trial = Trial(
trial_id=trial_id,
config=suggestion.config,
creation_time=datetime.now(),
)
scheduler.on_trial_add(trial)
trials[trial_id] = trial
trial_id += 1
else:
# Resume existing trial
# Note that we do not implement checkpointing, so resume
# means start from scratch
print(f"Resume trial: {suggestion.checkpoint_trial_id}")
assert suggestion.checkpoint_trial_id is not None, err_msg
trial = trials[suggestion.checkpoint_trial_id]
if not trials_fail:
# Return results
result = None
it = None
for it in range(max_resource_val):
result = {
metric: np.random.rand(),
resource_attr: it + 1,
}
decision = scheduler.on_trial_result(trial=trial, result=result)
if decision != SchedulerDecision.CONTINUE:
break
if it >= max_resource_val - 1:
scheduler.on_trial_complete(trial=trial, result=result)
else:
# Trial fails
scheduler.on_trial_error(trial=trial)
else:
# Maybe trials are being resumed?
print(f"OK, I am here. trial_id = {trial_id}")
while suggestion is not None and suggestion.checkpoint_trial_id is not None:
print(f"suggest: trial_id = {trial_id}")
suggestion = scheduler.suggest(trial_id)
assert suggestion is None, err_msg
trial_id += 1

0 comments on commit 4afa556

Please sign in to comment.