From 94dcd58126daf10e711ff382b43ac2f615aa79b0 Mon Sep 17 00:00:00 2001 From: c-bata Date: Fri, 17 Feb 2023 13:17:02 +0900 Subject: [PATCH 01/11] Sync owned trials when calling study.ask and study.get_trials --- optuna/storages/_cached_storage.py | 25 +++++++++++++-------- optuna/study/study.py | 4 ++-- tests/storages_tests/test_cached_storage.py | 17 ++++++++++++-- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/optuna/storages/_cached_storage.py b/optuna/storages/_cached_storage.py index a6cb6ce5d5..96eb35976a 100644 --- a/optuna/storages/_cached_storage.py +++ b/optuna/storages/_cached_storage.py @@ -27,12 +27,17 @@ def __init__(self) -> None: # Trial number to corresponding FrozenTrial. self.trials: Dict[int, FrozenTrial] = {} # A list of trials which do not require storage access to read latest attributes. - self.owned_or_finished_trial_ids: Set[int] = set() + self.owned_trial_ids: Set[int] = set() + self.finished_trial_ids: Set[int] = set() # Cache distributions to avoid storage access on distribution consistency check. self.param_distribution: Dict[str, distributions.BaseDistribution] = {} self.directions: Optional[List[StudyDirection]] = None self.name: Optional[str] = None + @property + def owned_or_finished_trial_ids(self) -> Set[int]: + return self.owned_trial_ids | self.finished_trial_ids + class _CachedStorage(BaseStorage, BaseHeartbeat): """A wrapper class of storage backends. @@ -183,7 +188,7 @@ def create_new_trial(self, study_id: int, template_trial: Optional[FrozenTrial] # WAITING trials are exception and they can be modified from arbitral worker. # Thus, we cannot add them to a list of cached trials. if frozen_trial.state != TrialState.WAITING: - study.owned_or_finished_trial_ids.add(frozen_trial._trial_id) + study.owned_trial_ids.add(frozen_trial._trial_id) return trial_id def set_trial_param( @@ -271,7 +276,7 @@ def set_trial_state_values( with self._lock: study_id, _ = self._trial_id_to_study_id_and_number[trial_id] self._add_trials_to_cache(study_id, [self._backend.get_trial(trial_id)]) - self._studies[study_id].owned_or_finished_trial_ids.add(trial_id) + self._studies[study_id].owned_trial_ids.add(trial_id) return ret def set_trial_intermediate_value( @@ -328,7 +333,7 @@ def get_all_trials( states: Optional[Container[TrialState]] = None, ) -> List[FrozenTrial]: if study_id not in self._studies: - self.read_trials_from_remote_storage(study_id) + self.read_trials_from_remote_storage(study_id, sync_owned_trials=False) with self._lock: study = self._studies[study_id] @@ -344,19 +349,21 @@ def get_all_trials( trials = list(sorted(trials.values(), key=lambda t: t.number)) return copy.deepcopy(trials) if deepcopy else trials - def read_trials_from_remote_storage(self, study_id: int) -> None: + def read_trials_from_remote_storage(self, study_id: int, sync_owned_trials: bool) -> None: with self._lock: if study_id not in self._studies: self._studies[study_id] = _StudyInfo() study = self._studies[study_id] - trials = self._backend._get_trials( - study_id, states=None, excluded_trial_ids=study.owned_or_finished_trial_ids - ) + if sync_owned_trials: + excluded = study.finished_trial_ids + else: + excluded = study.owned_or_finished_trial_ids + trials = self._backend._get_trials(study_id, states=None, excluded_trial_ids=excluded) if trials: self._add_trials_to_cache(study_id, trials) for trial in trials: if trial.state.is_finished(): - study.owned_or_finished_trial_ids.add(trial._trial_id) + study.finished_trial_ids.add(trial._trial_id) def _add_trials_to_cache(self, study_id: int, trials: List[FrozenTrial]) -> None: study = self._studies[study_id] diff --git a/optuna/study/study.py b/optuna/study/study.py index 8ada7f7c75..1fccc90b1e 100644 --- a/optuna/study/study.py +++ b/optuna/study/study.py @@ -261,7 +261,7 @@ def objective(trial): A list of :class:`~optuna.trial.FrozenTrial` objects. """ if isinstance(self._storage, _CachedStorage): - self._storage.read_trials_from_remote_storage(self._study_id) + self._storage.read_trials_from_remote_storage(self._study_id, sync_owned_trials=True) return self._storage.get_all_trials(self._study_id, deepcopy=deepcopy, states=states) @@ -508,7 +508,7 @@ def ask( # Sync storage once every trial. if isinstance(self._storage, _CachedStorage): - self._storage.read_trials_from_remote_storage(self._study_id) + self._storage.read_trials_from_remote_storage(self._study_id, sync_owned_trials=True) trial_id = self._pop_waiting_trial_id() if trial_id is None: diff --git a/tests/storages_tests/test_cached_storage.py b/tests/storages_tests/test_cached_storage.py index 9f7b71148d..d2d933d686 100644 --- a/tests/storages_tests/test_cached_storage.py +++ b/tests/storages_tests/test_cached_storage.py @@ -112,8 +112,21 @@ def test_read_trials_from_remote_storage() -> None: directions=[StudyDirection.MINIMIZE], study_name="test-study" ) - storage.read_trials_from_remote_storage(study_id) + storage.read_trials_from_remote_storage(study_id, sync_owned_trials=False) # Non-existent study. with pytest.raises(KeyError): - storage.read_trials_from_remote_storage(study_id + 1) + storage.read_trials_from_remote_storage(study_id + 1, sync_owned_trials=False) + + # Create a trial via CachedStorage and update it via backend storage directly. + trial_id = storage.create_new_trial(study_id) + base_storage.set_trial_param( + trial_id, "paramA", 1.2, optuna.distributions.FloatDistribution(-0.2, 2.3) + ) + base_storage.set_trial_state_values(trial_id, TrialState.COMPLETE, values=[0.0]) + + storage.read_trials_from_remote_storage(study_id, sync_owned_trials=False) + assert storage.get_trial(trial_id).state == TrialState.RUNNING + + storage.read_trials_from_remote_storage(study_id, sync_owned_trials=True) + assert storage.get_trial(trial_id).state == TrialState.COMPLETE From b7dcb02eb761437bca5e475a95d94d26edc70f65 Mon Sep 17 00:00:00 2001 From: c-bata Date: Fri, 17 Feb 2023 14:45:25 +0900 Subject: [PATCH 02/11] Fix a broken test --- tests/study_tests/test_study.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/study_tests/test_study.py b/tests/study_tests/test_study.py index 1bf0d7373b..573501b1a1 100644 --- a/tests/study_tests/test_study.py +++ b/tests/study_tests/test_study.py @@ -1502,7 +1502,7 @@ def test_tell_from_another_process() -> None: assert study.best_value == 1.2 # Should fail because the trial0 is already finished. - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): pool.starmap(_process_tell, [(study, trial0, 1.2)]) From ea25ad6926f162484cd77bdbd0218d35eaec0589 Mon Sep 17 00:00:00 2001 From: c-bata Date: Fri, 24 Mar 2023 14:59:32 +0900 Subject: [PATCH 03/11] Remove sync_owned_trials option and always sync them --- optuna/storages/_cached_storage.py | 35 ++++++--------------- optuna/study/study.py | 4 +-- tests/storages_tests/test_cached_storage.py | 10 ++---- 3 files changed, 14 insertions(+), 35 deletions(-) diff --git a/optuna/storages/_cached_storage.py b/optuna/storages/_cached_storage.py index 96eb35976a..2f391d2493 100644 --- a/optuna/storages/_cached_storage.py +++ b/optuna/storages/_cached_storage.py @@ -34,10 +34,6 @@ def __init__(self) -> None: self.directions: Optional[List[StudyDirection]] = None self.name: Optional[str] = None - @property - def owned_or_finished_trial_ids(self) -> Set[int]: - return self.owned_trial_ids | self.finished_trial_ids - class _CachedStorage(BaseStorage, BaseHeartbeat): """A wrapper class of storage backends. @@ -63,11 +59,6 @@ class _CachedStorage(BaseStorage, BaseHeartbeat): the `state` attribute of `T`. The same applies for `user_attrs', 'system_attrs' and 'intermediate_values` attributes. - The current implementation of :class:`~optuna.storages._CachedStorage` assumes that each - RUNNING trial is only modified from a single process. - When a user modifies a RUNNING trial from multiple processes, the internal state of the storage - may become inconsistent. Consequences are undefined. - **Data persistence** :class:`~optuna.storages._CachedStorage` does not guarantee that write operations are logged @@ -180,15 +171,10 @@ def create_new_trial(self, study_id: int, template_trial: Optional[FrozenTrial] self._studies[study_id] = _StudyInfo() study = self._studies[study_id] self._add_trials_to_cache(study_id, [frozen_trial]) - # Running trials can be modified from only one worker. - # If the state is RUNNING, since this worker is an owner of the trial, we do not need - # to access to the storage to get the latest attributes of the trial. # Since finished trials will not be modified by any worker, we do not - # need storage access for them, too. - # WAITING trials are exception and they can be modified from arbitral worker. - # Thus, we cannot add them to a list of cached trials. - if frozen_trial.state != TrialState.WAITING: - study.owned_trial_ids.add(frozen_trial._trial_id) + # need storage access for them. + if frozen_trial.state.is_finished(): + study.finished_trial_ids.add(frozen_trial._trial_id) return trial_id def set_trial_param( @@ -276,7 +262,6 @@ def set_trial_state_values( with self._lock: study_id, _ = self._trial_id_to_study_id_and_number[trial_id] self._add_trials_to_cache(study_id, [self._backend.get_trial(trial_id)]) - self._studies[study_id].owned_trial_ids.add(trial_id) return ret def set_trial_intermediate_value( @@ -316,7 +301,7 @@ def _get_cached_trial(self, trial_id: int) -> Optional[FrozenTrial]: return None study_id, number = self._trial_id_to_study_id_and_number[trial_id] study = self._studies[study_id] - return study.trials[number] if trial_id in study.owned_or_finished_trial_ids else None + return study.trials[number] if trial_id in study.finished_trial_ids else None def get_trial(self, trial_id: int) -> FrozenTrial: with self._lock: @@ -333,7 +318,7 @@ def get_all_trials( states: Optional[Container[TrialState]] = None, ) -> List[FrozenTrial]: if study_id not in self._studies: - self.read_trials_from_remote_storage(study_id, sync_owned_trials=False) + self.read_trials_from_remote_storage(study_id) with self._lock: study = self._studies[study_id] @@ -349,16 +334,14 @@ def get_all_trials( trials = list(sorted(trials.values(), key=lambda t: t.number)) return copy.deepcopy(trials) if deepcopy else trials - def read_trials_from_remote_storage(self, study_id: int, sync_owned_trials: bool) -> None: + def read_trials_from_remote_storage(self, study_id: int) -> None: with self._lock: if study_id not in self._studies: self._studies[study_id] = _StudyInfo() study = self._studies[study_id] - if sync_owned_trials: - excluded = study.finished_trial_ids - else: - excluded = study.owned_or_finished_trial_ids - trials = self._backend._get_trials(study_id, states=None, excluded_trial_ids=excluded) + trials = self._backend._get_trials( + study_id, states=None, excluded_trial_ids=study.finished_trial_ids + ) if trials: self._add_trials_to_cache(study_id, trials) for trial in trials: diff --git a/optuna/study/study.py b/optuna/study/study.py index 1fccc90b1e..8ada7f7c75 100644 --- a/optuna/study/study.py +++ b/optuna/study/study.py @@ -261,7 +261,7 @@ def objective(trial): A list of :class:`~optuna.trial.FrozenTrial` objects. """ if isinstance(self._storage, _CachedStorage): - self._storage.read_trials_from_remote_storage(self._study_id, sync_owned_trials=True) + self._storage.read_trials_from_remote_storage(self._study_id) return self._storage.get_all_trials(self._study_id, deepcopy=deepcopy, states=states) @@ -508,7 +508,7 @@ def ask( # Sync storage once every trial. if isinstance(self._storage, _CachedStorage): - self._storage.read_trials_from_remote_storage(self._study_id, sync_owned_trials=True) + self._storage.read_trials_from_remote_storage(self._study_id) trial_id = self._pop_waiting_trial_id() if trial_id is None: diff --git a/tests/storages_tests/test_cached_storage.py b/tests/storages_tests/test_cached_storage.py index d2d933d686..1d88ca953f 100644 --- a/tests/storages_tests/test_cached_storage.py +++ b/tests/storages_tests/test_cached_storage.py @@ -112,11 +112,11 @@ def test_read_trials_from_remote_storage() -> None: directions=[StudyDirection.MINIMIZE], study_name="test-study" ) - storage.read_trials_from_remote_storage(study_id, sync_owned_trials=False) + storage.read_trials_from_remote_storage(study_id) # Non-existent study. with pytest.raises(KeyError): - storage.read_trials_from_remote_storage(study_id + 1, sync_owned_trials=False) + storage.read_trials_from_remote_storage(study_id + 1) # Create a trial via CachedStorage and update it via backend storage directly. trial_id = storage.create_new_trial(study_id) @@ -124,9 +124,5 @@ def test_read_trials_from_remote_storage() -> None: trial_id, "paramA", 1.2, optuna.distributions.FloatDistribution(-0.2, 2.3) ) base_storage.set_trial_state_values(trial_id, TrialState.COMPLETE, values=[0.0]) - - storage.read_trials_from_remote_storage(study_id, sync_owned_trials=False) - assert storage.get_trial(trial_id).state == TrialState.RUNNING - - storage.read_trials_from_remote_storage(study_id, sync_owned_trials=True) + storage.read_trials_from_remote_storage(study_id) assert storage.get_trial(trial_id).state == TrialState.COMPLETE From ad013110494def2163744a48dd515d1ed25aa2b3 Mon Sep 17 00:00:00 2001 From: c-bata Date: Fri, 24 Mar 2023 15:00:26 +0900 Subject: [PATCH 04/11] Remove unused attributes from CachedStorage --- optuna/storages/_cached_storage.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optuna/storages/_cached_storage.py b/optuna/storages/_cached_storage.py index 52e5183b23..c91a8b5a0c 100644 --- a/optuna/storages/_cached_storage.py +++ b/optuna/storages/_cached_storage.py @@ -28,7 +28,6 @@ def __init__(self) -> None: # Trial number to corresponding FrozenTrial. self.trials: Dict[int, FrozenTrial] = {} # A list of trials which do not require storage access to read latest attributes. - self.owned_trial_ids: Set[int] = set() self.finished_trial_ids: Set[int] = set() # Cache distributions to avoid storage access on distribution consistency check. self.param_distribution: Dict[str, distributions.BaseDistribution] = {} From 9aca86c6a2a4536cb7c5485da5c7587d1450723d Mon Sep 17 00:00:00 2001 From: c-bata Date: Fri, 24 Mar 2023 16:04:46 +0900 Subject: [PATCH 05/11] Remove the logic to update owned trials cache --- optuna/storages/_cached_storage.py | 45 +-------------------- tests/storages_tests/test_cached_storage.py | 9 ----- 2 files changed, 1 insertion(+), 53 deletions(-) diff --git a/optuna/storages/_cached_storage.py b/optuna/storages/_cached_storage.py index c91a8b5a0c..6bfec46b00 100644 --- a/optuna/storages/_cached_storage.py +++ b/optuna/storages/_cached_storage.py @@ -184,39 +184,6 @@ def set_trial_param( param_value_internal: float, distribution: distributions.BaseDistribution, ) -> None: - with self._lock: - cached_trial = self._get_cached_trial(trial_id) - if cached_trial is not None: - self._check_trial_is_updatable(cached_trial) - - study_id, _ = self._trial_id_to_study_id_and_number[trial_id] - cached_dist = self._studies[study_id].param_distribution.get(param_name, None) - if cached_dist: - distributions.check_distribution_compatibility(cached_dist, distribution) - else: - # On cache miss, check compatibility against previous trials in the database - # and INSERT immediately to prevent other processes from creating incompatible - # ones. By INSERT, it is assumed that no previous entry has been persisted - # already. - self._backend._check_and_set_param_distribution( - study_id, trial_id, param_name, param_value_internal, distribution - ) - self._studies[study_id].param_distribution[param_name] = distribution - - params = copy.copy(cached_trial.params) - params[param_name] = distribution.to_external_repr(param_value_internal) - cached_trial.params = params - - dists = copy.copy(cached_trial.distributions) - dists[param_name] = distribution - cached_trial.distributions = dists - - if cached_dist: # Already persisted in case of cache miss so no need to update. - self._backend.set_trial_param( - trial_id, param_name, param_value_internal, distribution - ) - return - self._backend.set_trial_param(trial_id, param_name, param_value_internal, distribution) def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int: @@ -252,17 +219,7 @@ def set_trial_state_values( cached_trial.datetime_complete = backend_trial.datetime_complete return ret - ret = self._backend.set_trial_state_values(trial_id, state=state, values=values) - if ( - ret - and state == TrialState.RUNNING - and trial_id in self._trial_id_to_study_id_and_number - ): - # Cache when the local thread pop WAITING trial and start evaluation. - with self._lock: - study_id, _ = self._trial_id_to_study_id_and_number[trial_id] - self._add_trials_to_cache(study_id, [self._backend.get_trial(trial_id)]) - return ret + return self._backend.set_trial_state_values(trial_id, state=state, values=values) def set_trial_intermediate_value( self, trial_id: int, step: int, intermediate_value: float diff --git a/tests/storages_tests/test_cached_storage.py b/tests/storages_tests/test_cached_storage.py index 1d88ca953f..a6ccb5e0d0 100644 --- a/tests/storages_tests/test_cached_storage.py +++ b/tests/storages_tests/test_cached_storage.py @@ -67,15 +67,6 @@ def test_uncached_set() -> None: storage.set_trial_state_values(trial_id, state=trial.state, values=(0.3,)) assert set_mock.call_count == 1 - trial_id = storage.create_new_trial(study_id) - with patch.object( - base_storage, "_check_and_set_param_distribution", return_value=True - ) as set_mock: - storage.set_trial_param( - trial_id, "paramA", 1.2, optuna.distributions.FloatDistribution(-0.2, 2.3) - ) - assert set_mock.call_count == 1 - trial_id = storage.create_new_trial(study_id) with patch.object(base_storage, "set_trial_param", return_value=True) as set_mock: storage.set_trial_param( From 41e69da81ac42ca8409d6f65a7e965d9fab4557e Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Wed, 12 Apr 2023 17:12:47 +0900 Subject: [PATCH 06/11] Get system_attrs lazily --- optuna/trial/_trial.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/optuna/trial/_trial.py b/optuna/trial/_trial.py index 531ac2afc8..4f3bb8a09a 100644 --- a/optuna/trial/_trial.py +++ b/optuna/trial/_trial.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from collections import UserDict import copy import datetime from typing import Any @@ -59,6 +62,7 @@ def __init__(self, study: "optuna.study.Study", trial_id: int) -> None: study, self._cached_frozen_trial ) self._relative_params: Optional[Dict[str, Any]] = None + self._fixed_params = self._cached_frozen_trial.system_attrs.get("fixed_params", []) @property def relative_params(self) -> Dict[str, Any]: @@ -528,7 +532,7 @@ def should_prune(self) -> bool: "Trial.should_prune is not supported for multi-objective optimization." ) - trial = copy.deepcopy(self._get_latest_trial()) + trial = self._get_latest_trial() return self.study.pruner.prune(self.study, trial) def set_user_attr(self, key: str, value: Any) -> None: @@ -618,7 +622,7 @@ def _suggest(self, name: str, distribution: BaseDistribution) -> Any: param_value = trial.params[name] else: if self._is_fixed_param(name, distribution): - param_value = trial.system_attrs["fixed_params"][name] + param_value = self._fixed_params[name] elif distribution.single(): param_value = distributions._get_single_value(distribution) elif self._is_relative_param(name, distribution): @@ -638,14 +642,10 @@ def _suggest(self, name: str, distribution: BaseDistribution) -> Any: return param_value def _is_fixed_param(self, name: str, distribution: BaseDistribution) -> bool: - system_attrs = self._cached_frozen_trial.system_attrs - if "fixed_params" not in system_attrs: - return False - - if name not in system_attrs["fixed_params"]: + if name not in self._fixed_params: return False - param_value = system_attrs["fixed_params"][name] + param_value = self._fixed_params[name] param_value_in_internal_repr = distribution.to_internal_repr(param_value) contained = distribution._contains(param_value_in_internal_repr) @@ -688,10 +688,12 @@ def _check_distribution(self, name: str, distribution: BaseDistribution) -> None ) def _get_latest_trial(self) -> FrozenTrial: - # TODO(eukaryo): Remove this method after `system_attrs` property is deprecated. - system_attrs = copy.deepcopy(self.storage.get_trial_system_attrs(self._trial_id)) - self._cached_frozen_trial.system_attrs = system_attrs - return self._cached_frozen_trial + # TODO(eukaryo): Remove this method after `system_attrs` property is removed. + latest_trial = copy.deepcopy(self._cached_frozen_trial) + latest_trial.system_attrs = _LazyTrialSystemAttrs( # type: ignore[assignment] + self._trial_id, self.storage + ) + return latest_trial @property def params(self) -> Dict[str, Any]: @@ -752,3 +754,18 @@ def number(self) -> int: """ return self._cached_frozen_trial.number + + +class _LazyTrialSystemAttrs(UserDict): + def __init__(self, trial_id: int, storage: optuna.storages.BaseStorage) -> None: + super().__init__() + self._trial_id = trial_id + self._storage = storage + self._initialized = False + + def __getattribute__(self, key: str) -> Any: + if key == "data": + if not self._initialized: + self._initialized = True + super().update(self._storage.get_trial_system_attrs(self._trial_id)) + return super().__getattribute__(key) From 1c97089e623820b9e7fa059d03f63f4d2e2df31c Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Mon, 17 Apr 2023 13:31:11 +0900 Subject: [PATCH 07/11] Add test for _LazyTrialSystemAttrs --- tests/trial_tests/test_trial.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/trial_tests/test_trial.py b/tests/trial_tests/test_trial.py index 81a8b9aba0..1e4e3c8d94 100644 --- a/tests/trial_tests/test_trial.py +++ b/tests/trial_tests/test_trial.py @@ -678,3 +678,35 @@ def test_persisted_param() -> None: study = load_study(storage=storage, study_name=study_name) assert all("x" in t.params for t in study.trials) + + +@pytest.mark.parametrize("storage_mode", STORAGE_MODES) +def test_lazy_trial_system_attrs(storage_mode: str) -> None: + with StorageSupplier(storage_mode) as storage: + study = optuna.create_study(storage=storage) + trial = study.ask() + storage.set_trial_system_attr(trial._trial_id, "int", 0) + storage.set_trial_system_attr(trial._trial_id, "str", "A") + + # _LazyTrialSystemAttrs gets attrs the first time it is needed. + # Then, we create the instance for each method, and test the first and second use. + + system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + assert system_attrs == {"int": 0, "str": "A"} # type: ignore[comparison-overlap] + assert system_attrs == {"int": 0, "str": "A"} # type: ignore[comparison-overlap] + + system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + assert len(system_attrs) == 2 + assert len(system_attrs) == 2 + + system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + assert set(system_attrs.keys()) == {"int", "str"} + assert set(system_attrs.keys()) == {"int", "str"} + + system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + assert set(system_attrs.values()) == {0, "A"} + assert set(system_attrs.values()) == {0, "A"} + + system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + assert set(system_attrs.items()) == {("int", 0), ("str", "A")} + assert set(system_attrs.items()) == {("int", 0), ("str", "A")} From ae2d8a0dd628a61834c322a6384c97bebc52ea82 Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Mon, 24 Apr 2023 13:32:02 +0900 Subject: [PATCH 08/11] Fix to sync with remote storage --- optuna/study/study.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optuna/study/study.py b/optuna/study/study.py index b118add6a4..3e21658966 100644 --- a/optuna/study/study.py +++ b/optuna/study/study.py @@ -1574,6 +1574,8 @@ def objective(trial): study_summaries = [] for s in frozen_studies: + if isinstance(storage, _CachedStorage): + storage.read_trials_from_remote_storage(s._study_id) all_trials = storage.get_all_trials(s._study_id) completed_trials = [t for t in all_trials if t.state == TrialState.COMPLETE] From 7c52f7469551e8a68c43f847a52ac3b734a66179 Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Mon, 24 Apr 2023 15:30:25 +0900 Subject: [PATCH 09/11] Remove dead code --- optuna/storages/_cached_storage.py | 47 ------------------------------ 1 file changed, 47 deletions(-) diff --git a/optuna/storages/_cached_storage.py b/optuna/storages/_cached_storage.py index 6bfec46b00..5846c600b3 100644 --- a/optuna/storages/_cached_storage.py +++ b/optuna/storages/_cached_storage.py @@ -200,57 +200,17 @@ def get_best_trial(self, study_id: int) -> FrozenTrial: def set_trial_state_values( self, trial_id: int, state: TrialState, values: Optional[Sequence[float]] = None ) -> bool: - with self._lock: - cached_trial = self._get_cached_trial(trial_id) - if cached_trial is not None: - # When a waiting trial is updated to running, its `datetime_start` must be - # updated. However, a waiting trials is never cached so we do not have to account - # for this case. - assert cached_trial.state != TrialState.WAITING - - self._check_trial_is_updatable(cached_trial) - ret = self._backend.set_trial_state_values(trial_id, state=state, values=values) - - if values is not None: - cached_trial.values = values - cached_trial.state = state - if cached_trial.state.is_finished(): - backend_trial = self._backend.get_trial(trial_id) - cached_trial.datetime_complete = backend_trial.datetime_complete - return ret - return self._backend.set_trial_state_values(trial_id, state=state, values=values) def set_trial_intermediate_value( self, trial_id: int, step: int, intermediate_value: float ) -> None: - with self._lock: - cached_trial = self._get_cached_trial(trial_id) - if cached_trial is not None: - self._check_trial_is_updatable(cached_trial) - intermediate_values = copy.copy(cached_trial.intermediate_values) - intermediate_values[step] = intermediate_value - cached_trial.intermediate_values = intermediate_values self._backend.set_trial_intermediate_value(trial_id, step, intermediate_value) def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None: - with self._lock: - cached_trial = self._get_cached_trial(trial_id) - if cached_trial is not None: - self._check_trial_is_updatable(cached_trial) - attrs = copy.copy(cached_trial.user_attrs) - attrs[key] = value - cached_trial.user_attrs = attrs self._backend.set_trial_user_attr(trial_id, key=key, value=value) def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None: - with self._lock: - cached_trial = self._get_cached_trial(trial_id) - if cached_trial is not None: - self._check_trial_is_updatable(cached_trial) - attrs = copy.copy(cached_trial.system_attrs) - attrs[key] = value - cached_trial.system_attrs = attrs self._backend.set_trial_system_attr(trial_id, key=key, value=value) def _get_cached_trial(self, trial_id: int) -> Optional[FrozenTrial]: @@ -315,13 +275,6 @@ def _add_trials_to_cache(self, study_id: int, trials: List[FrozenTrial]) -> None self._study_id_and_number_to_trial_id[(study_id, trial.number)] = trial._trial_id study.trials[trial.number] = trial - @staticmethod - def _check_trial_is_updatable(trial: FrozenTrial) -> None: - if trial.state.is_finished(): - raise RuntimeError( - "Trial#{} has already finished and can not be updated.".format(trial.number) - ) - def record_heartbeat(self, trial_id: int) -> None: self._backend.record_heartbeat(trial_id) From 48e5ad68f4dba0e41d8833049c2980c92937d793 Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Mon, 8 May 2023 13:02:42 +0900 Subject: [PATCH 10/11] Update optuna/trial/_trial.py Co-authored-by: Gen <54583542+gen740@users.noreply.github.com> --- optuna/trial/_trial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optuna/trial/_trial.py b/optuna/trial/_trial.py index bfb4a58a5c..f8fde28140 100644 --- a/optuna/trial/_trial.py +++ b/optuna/trial/_trial.py @@ -62,7 +62,7 @@ def __init__(self, study: "optuna.study.Study", trial_id: int) -> None: study, self._cached_frozen_trial ) self._relative_params: Optional[Dict[str, Any]] = None - self._fixed_params = self._cached_frozen_trial.system_attrs.get("fixed_params", []) + self._fixed_params = self._cached_frozen_trial.system_attrs.get("fixed_params", {}) @property def relative_params(self) -> Dict[str, Any]: From 098d6342285072f2ad2a1a70a0af549066e07fb6 Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Tue, 9 May 2023 08:07:22 +0900 Subject: [PATCH 11/11] Import _LazyTrialSystemAttrs directly --- tests/trial_tests/test_trial.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/trial_tests/test_trial.py b/tests/trial_tests/test_trial.py index bf72192e60..b0564dfd8e 100644 --- a/tests/trial_tests/test_trial.py +++ b/tests/trial_tests/test_trial.py @@ -27,6 +27,7 @@ from optuna.testing.storages import StorageSupplier from optuna.testing.tempfile_pool import NamedTemporaryFilePool from optuna.trial import Trial +from optuna.trial._trial import _LazyTrialSystemAttrs @pytest.mark.filterwarnings("ignore::FutureWarning") @@ -691,22 +692,22 @@ def test_lazy_trial_system_attrs(storage_mode: str) -> None: # _LazyTrialSystemAttrs gets attrs the first time it is needed. # Then, we create the instance for each method, and test the first and second use. - system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage) assert system_attrs == {"int": 0, "str": "A"} # type: ignore[comparison-overlap] assert system_attrs == {"int": 0, "str": "A"} # type: ignore[comparison-overlap] - system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage) assert len(system_attrs) == 2 assert len(system_attrs) == 2 - system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage) assert set(system_attrs.keys()) == {"int", "str"} assert set(system_attrs.keys()) == {"int", "str"} - system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage) assert set(system_attrs.values()) == {0, "A"} assert set(system_attrs.values()) == {0, "A"} - system_attrs = optuna.trial._trial._LazyTrialSystemAttrs(trial._trial_id, storage) + system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage) assert set(system_attrs.items()) == {("int", 0), ("str", "A")} assert set(system_attrs.items()) == {("int", 0), ("str", "A")}