diff --git a/syne_tune/optimizer/schedulers/searchers/bore/bore.py b/syne_tune/optimizer/schedulers/searchers/bore/bore.py index ceb5eb2a6..545e208c9 100644 --- a/syne_tune/optimizer/schedulers/searchers/bore/bore.py +++ b/syne_tune/optimizer/schedulers/searchers/bore/bore.py @@ -296,5 +296,16 @@ def _update(self, trial_id: str, config: dict, result: dict): self.inputs.append(self._hp_ranges.to_ndarray(config)) self.targets.append(result[self._metric]) + def register_pending( + self, + trial_id: str, + config: Optional[dict] = None, + milestone: Optional[int] = None, + ): + FilterDuplicatesMixin.register_pending(self, trial_id, config, milestone) + + def evaluation_failed(self, trial_id: str): + FilterDuplicatesMixin.evaluation_failed(self, trial_id) + def clone_from_state(self, state: dict): raise NotImplementedError diff --git a/syne_tune/optimizer/schedulers/searchers/gp_fifo_searcher.py b/syne_tune/optimizer/schedulers/searchers/gp_fifo_searcher.py index 31d414afa..8c626cd43 100644 --- a/syne_tune/optimizer/schedulers/searchers/gp_fifo_searcher.py +++ b/syne_tune/optimizer/schedulers/searchers/gp_fifo_searcher.py @@ -339,7 +339,9 @@ def _get_config_not_modelbased( model-based search. If False is returned, model-based search must be called. - :param exclusion_candidates: Configs to be avoided + :param exclusion_candidates: Configs to be avoided, even if + ``allow_duplicates == True`` (in this case, we avoid configs of + failed or pending trials) :return: ``(config, use_get_config_modelbased)`` """ self._assign_random_searcher() @@ -357,9 +359,7 @@ def _get_config_not_modelbased( # If ``RandomSearcher`` returns no config at all, the # search space is exhausted break - if self._allow_duplicates or ( - not exclusion_candidates.contains(_config) - ): + if not exclusion_candidates.contains(_config): config = _config break if self.do_profile: diff --git a/syne_tune/optimizer/schedulers/searchers/kde/kde_searcher.py b/syne_tune/optimizer/schedulers/searchers/kde/kde_searcher.py index adb8e97b4..c954126a9 100644 --- a/syne_tune/optimizer/schedulers/searchers/kde/kde_searcher.py +++ b/syne_tune/optimizer/schedulers/searchers/kde/kde_searcher.py @@ -25,9 +25,6 @@ from syne_tune.optimizer.schedulers.searchers.bayesopt.utils.debug_log import ( DebugLogPrinter, ) -from syne_tune.optimizer.schedulers.searchers.utils.hp_ranges_factory import ( - make_hyperparameter_ranges, -) from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.common import ( ExclusionList, ) @@ -408,5 +405,16 @@ def _train_kde(self, train_data, train_targets): return bad_kde, good_kde + def register_pending( + self, + trial_id: str, + config: Optional[dict] = None, + milestone: Optional[int] = None, + ): + FilterDuplicatesMixin.register_pending(self, trial_id, config, milestone) + + def evaluation_failed(self, trial_id: str): + FilterDuplicatesMixin.evaluation_failed(self, trial_id) + def clone_from_state(self, state: dict): raise NotImplementedError diff --git a/syne_tune/optimizer/schedulers/searchers/random_grid_searcher.py b/syne_tune/optimizer/schedulers/searchers/random_grid_searcher.py index ce363b665..b0f6160cf 100644 --- a/syne_tune/optimizer/schedulers/searchers/random_grid_searcher.py +++ b/syne_tune/optimizer/schedulers/searchers/random_grid_searcher.py @@ -140,6 +140,17 @@ def _update(self, trial_id: str, config: dict, result: dict): msg = f"Update for trial_id {trial_id}: metric = {metric_val:.3f}" logger.info(msg) + def register_pending( + self, + trial_id: str, + config: Optional[dict] = None, + milestone: Optional[int] = None, + ): + FilterDuplicatesMixin.register_pending(self, trial_id, config, milestone) + + def evaluation_failed(self, trial_id: str): + FilterDuplicatesMixin.evaluation_failed(self, trial_id) + def get_state(self) -> dict: state = SearcherWithRandomSeed.get_state(self) state.update(FilterDuplicatesMixin.get_state(self)) diff --git a/syne_tune/optimizer/schedulers/synchronous/dehb.py b/syne_tune/optimizer/schedulers/synchronous/dehb.py index b1b550f8f..880870795 100644 --- a/syne_tune/optimizer/schedulers/synchronous/dehb.py +++ b/syne_tune/optimizer/schedulers/synchronous/dehb.py @@ -437,18 +437,23 @@ def _encoded_config_from_searcher(self, trial_id: int) -> np.ndarray: def _encoded_config_by_promotion( self, ext_slot: ExtendedSlotInRung - ) -> (np.ndarray, int): + ) -> (Optional[np.ndarray], Optional[int]): parent_trial_id = self.bracket_manager.top_of_previous_rung( bracket_id=ext_slot.bracket_id, pos=ext_slot.slot_index ) - trial_info = self._trial_info[parent_trial_id] - assert trial_info.metric_val is not None # Sanity check - if self._debug_log is not None: - logger.info( - f"Promote config from trial_id {parent_trial_id}" - f", level {trial_info.level}" - ) - return trial_info.encoded_config, parent_trial_id + if parent_trial_id is not None: + trial_info = self._trial_info[parent_trial_id] + assert trial_info.metric_val is not None # Sanity check + if self._debug_log is not None: + logger.info( + f"Promote config from trial_id {parent_trial_id}" + f", level {trial_info.level}" + ) + encoded_config = trial_info.encoded_config + else: + # This can happen when all trials in the previous rung failed + encoded_config = None + return encoded_config, parent_trial_id def _extended_config_by_mutation_crossover( self, ext_slot: ExtendedSlotInRung @@ -498,6 +503,8 @@ def _register_new_config_and_make_suggestion( self._debug_log.set_final_config(config) self._debug_log.write_block() config = cast_config_values(config, self.config_space) + if self.searcher is not None: + self.searcher.register_pending(trial_id=str(trial_id), config=config) if self.max_resource_attr is not None: config = dict(config, **{self.max_resource_attr: ext_slot.level}) return TrialSuggestion.start_suggestion(config=config) @@ -733,6 +740,8 @@ def on_trial_error(self, trial: Trial): if trial_id in self._trial_to_pending_slot: ext_slot = self._trial_to_pending_slot[trial_id] self._report_as_failed(ext_slot) + # A failed trial is not pending anymore + del self._trial_to_pending_slot[trial_id] else: logger.warning( f"Trial trial_id {trial_id} not registered at pending: " diff --git a/syne_tune/optimizer/schedulers/synchronous/hyperband.py b/syne_tune/optimizer/schedulers/synchronous/hyperband.py index 414bf4858..cd488f158 100644 --- a/syne_tune/optimizer/schedulers/synchronous/hyperband.py +++ b/syne_tune/optimizer/schedulers/synchronous/hyperband.py @@ -290,6 +290,7 @@ def _suggest(self, trial_id: int) -> Optional[TrialSuggestion]: config = self.searcher.get_config(trial_id=str(trial_id)) if config is not None: config = cast_config_values(config, self.config_space) + self.searcher.register_pending(trial_id=str(trial_id), config=config) if self.max_resource_attr is not None: config[self.max_resource_attr] = slot_in_rung.level self._trial_to_config[trial_id] = config @@ -393,6 +394,8 @@ def on_trial_error(self, trial: Trial): if trial_id in self._trial_to_pending_slot: bracket_id, slot_in_rung = self._trial_to_pending_slot[trial_id] self._report_as_failed(bracket_id, slot_in_rung) + # A failed trial is not pending anymore + del self._trial_to_pending_slot[trial_id] else: logger.warning( f"Trial trial_id {trial_id} not registered at pending: " diff --git a/tst/schedulers/test_searchers.py b/tst/schedulers/test_searchers.py index 2e5aee754..d52fb15b8 100644 --- a/tst/schedulers/test_searchers.py +++ b/tst/schedulers/test_searchers.py @@ -12,6 +12,7 @@ # permissions and limitations under the License. import pytest from datetime import datetime +import itertools import numpy as np from syne_tune.optimizer.baselines import ( @@ -36,26 +37,39 @@ from syne_tune.optimizer.scheduler import SchedulerDecision +SCHEDULERS = [ + (GridSearch, False), + (RandomSearch, False), + (BayesianOptimization, False), + (ASHA, True), + (HyperTune, True), + (SyncHyperband, True), + (DEHB, True), + (BOHB, True), + (SyncBOHB, True), + (BORE, False), + (KDE, False), +] + + +DUPLICATES_AND_FAIL = [(False, False), (True, False), (True, True)] + + +COMBINATIONS = list(itertools.product(SCHEDULERS, DUPLICATES_AND_FAIL[:-1])) + list( + itertools.product(SCHEDULERS[1:], DUPLICATES_AND_FAIL[-1:]) +) + + # Does not contain ASHABORE, because >10 secs on CI # TODO: Dig more, why is ASHABORE more expensive than BORE here? @pytest.mark.timeout(10) -@pytest.mark.parametrize( - "scheduler_cls, multifid", - [ - (RandomSearch, False), - (GridSearch, False), - (BayesianOptimization, False), - (ASHA, True), - (HyperTune, True), - (SyncHyperband, True), - (DEHB, True), - (BOHB, True), - (SyncBOHB, True), - (BORE, False), - (KDE, False), - ], -) -def test_allow_duplicates_or_not(scheduler_cls, multifid): +@pytest.mark.parametrize("tpl1, tpl2", COMBINATIONS) +def test_allow_duplicates_or_not(tpl1, tpl2): + # If ``trials_fail == True``, we let all trials fail. In that case, corresponding + # configs are filtered out, even if ``allow_duplicates == True`` (for all searchers + # except ``GridSearcher``). + scheduler_cls, multifid = tpl1 + allow_duplicates, trials_fail = tpl2 random_seed = 31415927 np.random.seed(random_seed) @@ -72,73 +86,73 @@ def test_allow_duplicates_or_not(scheduler_cls, multifid): cs_size = config_space_size(config_space) assert cs_size == 3 * 3 - # for allow_duplicates, trials_fail in [(False, False), (True, False), (True, True)]: - for allow_duplicates, trials_fail in [(False, False), (True, False)]: - if multifid: - kwargs = dict(resource_attr=resource_attr) - else: - kwargs = dict() - scheduler = scheduler_cls( - config_space, - metric=metric, - mode=mode, - max_resource_attr=max_resource_attr, - search_options={ - "allow_duplicates": allow_duplicates, - "debug_log": False, - }, - **kwargs, + if multifid: + kwargs = dict(resource_attr=resource_attr) + else: + kwargs = dict() + scheduler = scheduler_cls( + config_space, + metric=metric, + mode=mode, + max_resource_attr=max_resource_attr, + search_options={ + "allow_duplicates": allow_duplicates, + "debug_log": False, + }, + **kwargs, + ) + trial_id = 0 + trials = dict() + while trial_id <= cs_size: + err_msg = ( + f"trial_id {trial_id}, cs_size {cs_size}, allow_duplicates " + f"{allow_duplicates}, trials_fail {trials_fail}" ) - trial_id = 0 - trials = dict() - while trial_id <= cs_size: - err_msg = ( - f"trial_id {trial_id}, cs_size {cs_size}, allow_duplicates " - f"{allow_duplicates}, trials_fail {trials_fail}" - ) - suggestion = scheduler.suggest(trial_id) - if trial_id < cs_size or allow_duplicates: - assert suggestion is not None, err_msg - if suggestion.spawn_new_trial_id: - # Start new trial - assert suggestion.checkpoint_trial_id is None, err_msg - trial = Trial( - trial_id=trial_id, - config=suggestion.config, - creation_time=datetime.now(), - ) - scheduler.on_trial_add(trial) - trials[trial_id] = trial - trial_id += 1 - else: - # Resume existing trial - # Note that we do not implement checkpointing, so resume - # means start from scratch - assert suggestion.checkpoint_trial_id is not None, err_msg - trial = trials[suggestion.checkpoint_trial_id] - if not trials_fail: - # Return results - result = None - it = None - for it in range(max_resource_val): - result = { - metric: np.random.rand(), - resource_attr: it + 1, - } - decision = scheduler.on_trial_result(trial=trial, result=result) - if decision != SchedulerDecision.CONTINUE: - break - if it >= max_resource_val - 1: - scheduler.on_trial_complete(trial=trial, result=result) - else: - # Trial fails - scheduler.on_trial_error(trial=trial) - else: - # Maybe trials are being resumed? - while ( - suggestion is not None - and suggestion.checkpoint_trial_id is not None - ): - suggestion = scheduler.suggest(trial_id) - assert suggestion is None, err_msg + print(f"suggest: trial_id = {trial_id}") + suggestion = scheduler.suggest(trial_id) + if trial_id < cs_size or (allow_duplicates and not trials_fail): + assert suggestion is not None, err_msg + if suggestion.spawn_new_trial_id: + # Start new trial + print(f"Start new trial: {suggestion.config}") + assert suggestion.checkpoint_trial_id is None, err_msg + trial = Trial( + trial_id=trial_id, + config=suggestion.config, + creation_time=datetime.now(), + ) + scheduler.on_trial_add(trial) + trials[trial_id] = trial trial_id += 1 + else: + # Resume existing trial + # Note that we do not implement checkpointing, so resume + # means start from scratch + print(f"Resume trial: {suggestion.checkpoint_trial_id}") + assert suggestion.checkpoint_trial_id is not None, err_msg + trial = trials[suggestion.checkpoint_trial_id] + if not trials_fail: + # Return results + result = None + it = None + for it in range(max_resource_val): + result = { + metric: np.random.rand(), + resource_attr: it + 1, + } + decision = scheduler.on_trial_result(trial=trial, result=result) + if decision != SchedulerDecision.CONTINUE: + break + if it >= max_resource_val - 1: + scheduler.on_trial_complete(trial=trial, result=result) + else: + # Trial fails + scheduler.on_trial_error(trial=trial) + else: + # Maybe trials are being resumed? + print(f"OK, I am here. trial_id = {trial_id}") + while suggestion is not None and suggestion.checkpoint_trial_id is not None: + print(f"suggest: trial_id = {trial_id}") + suggestion = scheduler.suggest(trial_id) + assert suggestion is None, err_msg + trial_id += 1