Skip to content

Commit

Permalink
Merge pull request optuna#4631 from not522/cached-storage-sync-owned-…
Browse files Browse the repository at this point in the history
…trials

Sync owned trials when calling `study.ask` and `study.get_trials`
  • Loading branch information
c-bata committed May 12, 2023
2 parents 4ef1a10 + 098d634 commit 466cddf
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 131 deletions.
117 changes: 8 additions & 109 deletions optuna/storages/_cached_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self) -> None:
# Trial number to corresponding FrozenTrial.
self.trials: Dict[int, FrozenTrial] = {}
# A list of trials which do not require storage access to read latest attributes.
self.owned_or_finished_trial_ids: Set[int] = set()
self.finished_trial_ids: Set[int] = set()
# Cache distributions to avoid storage access on distribution consistency check.
self.param_distribution: Dict[str, distributions.BaseDistribution] = {}
self.directions: Optional[List[StudyDirection]] = None
Expand Down Expand Up @@ -59,11 +59,6 @@ class _CachedStorage(BaseStorage, BaseHeartbeat):
the `state` attribute of `T`.
The same applies for `user_attrs', 'system_attrs' and 'intermediate_values` attributes.
The current implementation of :class:`~optuna.storages._CachedStorage` assumes that each
RUNNING trial is only modified from a single process.
When a user modifies a RUNNING trial from multiple processes, the internal state of the storage
may become inconsistent. Consequences are undefined.
**Data persistence**
:class:`~optuna.storages._CachedStorage` does not guarantee that write operations are logged
Expand Down Expand Up @@ -176,15 +171,10 @@ def create_new_trial(self, study_id: int, template_trial: Optional[FrozenTrial]
self._studies[study_id] = _StudyInfo()
study = self._studies[study_id]
self._add_trials_to_cache(study_id, [frozen_trial])
# Running trials can be modified from only one worker.
# If the state is RUNNING, since this worker is an owner of the trial, we do not need
# to access to the storage to get the latest attributes of the trial.
# Since finished trials will not be modified by any worker, we do not
# need storage access for them, too.
# WAITING trials are exception and they can be modified from arbitral worker.
# Thus, we cannot add them to a list of cached trials.
if frozen_trial.state != TrialState.WAITING:
study.owned_or_finished_trial_ids.add(frozen_trial._trial_id)
# need storage access for them.
if frozen_trial.state.is_finished():
study.finished_trial_ids.add(frozen_trial._trial_id)
return trial_id

def set_trial_param(
Expand All @@ -194,39 +184,6 @@ def set_trial_param(
param_value_internal: float,
distribution: distributions.BaseDistribution,
) -> None:
with self._lock:
cached_trial = self._get_cached_trial(trial_id)
if cached_trial is not None:
self._check_trial_is_updatable(cached_trial)

study_id, _ = self._trial_id_to_study_id_and_number[trial_id]
cached_dist = self._studies[study_id].param_distribution.get(param_name, None)
if cached_dist:
distributions.check_distribution_compatibility(cached_dist, distribution)
else:
# On cache miss, check compatibility against previous trials in the database
# and INSERT immediately to prevent other processes from creating incompatible
# ones. By INSERT, it is assumed that no previous entry has been persisted
# already.
self._backend._check_and_set_param_distribution(
study_id, trial_id, param_name, param_value_internal, distribution
)
self._studies[study_id].param_distribution[param_name] = distribution

params = copy.copy(cached_trial.params)
params[param_name] = distribution.to_external_repr(param_value_internal)
cached_trial.params = params

dists = copy.copy(cached_trial.distributions)
dists[param_name] = distribution
cached_trial.distributions = dists

if cached_dist: # Already persisted in case of cache miss so no need to update.
self._backend.set_trial_param(
trial_id, param_name, param_value_internal, distribution
)
return

self._backend.set_trial_param(trial_id, param_name, param_value_internal, distribution)

def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int:
Expand All @@ -243,76 +200,25 @@ def get_best_trial(self, study_id: int) -> FrozenTrial:
def set_trial_state_values(
self, trial_id: int, state: TrialState, values: Optional[Sequence[float]] = None
) -> bool:
with self._lock:
cached_trial = self._get_cached_trial(trial_id)
if cached_trial is not None:
# When a waiting trial is updated to running, its `datetime_start` must be
# updated. However, a waiting trials is never cached so we do not have to account
# for this case.
assert cached_trial.state != TrialState.WAITING

self._check_trial_is_updatable(cached_trial)
ret = self._backend.set_trial_state_values(trial_id, state=state, values=values)

if values is not None:
cached_trial.values = values
cached_trial.state = state
if cached_trial.state.is_finished():
backend_trial = self._backend.get_trial(trial_id)
cached_trial.datetime_complete = backend_trial.datetime_complete
return ret

ret = self._backend.set_trial_state_values(trial_id, state=state, values=values)
if (
ret
and state == TrialState.RUNNING
and trial_id in self._trial_id_to_study_id_and_number
):
# Cache when the local thread pop WAITING trial and start evaluation.
with self._lock:
study_id, _ = self._trial_id_to_study_id_and_number[trial_id]
self._add_trials_to_cache(study_id, [self._backend.get_trial(trial_id)])
self._studies[study_id].owned_or_finished_trial_ids.add(trial_id)
return ret
return self._backend.set_trial_state_values(trial_id, state=state, values=values)

def set_trial_intermediate_value(
self, trial_id: int, step: int, intermediate_value: float
) -> None:
with self._lock:
cached_trial = self._get_cached_trial(trial_id)
if cached_trial is not None:
self._check_trial_is_updatable(cached_trial)
intermediate_values = copy.copy(cached_trial.intermediate_values)
intermediate_values[step] = intermediate_value
cached_trial.intermediate_values = intermediate_values
self._backend.set_trial_intermediate_value(trial_id, step, intermediate_value)

def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None:
with self._lock:
cached_trial = self._get_cached_trial(trial_id)
if cached_trial is not None:
self._check_trial_is_updatable(cached_trial)
attrs = copy.copy(cached_trial.user_attrs)
attrs[key] = value
cached_trial.user_attrs = attrs
self._backend.set_trial_user_attr(trial_id, key=key, value=value)

def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None:
with self._lock:
cached_trial = self._get_cached_trial(trial_id)
if cached_trial is not None:
self._check_trial_is_updatable(cached_trial)
attrs = copy.copy(cached_trial.system_attrs)
attrs[key] = value
cached_trial.system_attrs = attrs
self._backend.set_trial_system_attr(trial_id, key=key, value=value)

def _get_cached_trial(self, trial_id: int) -> Optional[FrozenTrial]:
if trial_id not in self._trial_id_to_study_id_and_number:
return None
study_id, number = self._trial_id_to_study_id_and_number[trial_id]
study = self._studies[study_id]
return study.trials[number] if trial_id in study.owned_or_finished_trial_ids else None
return study.trials[number] if trial_id in study.finished_trial_ids else None

def get_trial(self, trial_id: int) -> FrozenTrial:
with self._lock:
Expand Down Expand Up @@ -351,13 +257,13 @@ def read_trials_from_remote_storage(self, study_id: int) -> None:
self._studies[study_id] = _StudyInfo()
study = self._studies[study_id]
trials = self._backend._get_trials(
study_id, states=None, excluded_trial_ids=study.owned_or_finished_trial_ids
study_id, states=None, excluded_trial_ids=study.finished_trial_ids
)
if trials:
self._add_trials_to_cache(study_id, trials)
for trial in trials:
if trial.state.is_finished():
study.owned_or_finished_trial_ids.add(trial._trial_id)
study.finished_trial_ids.add(trial._trial_id)

def _add_trials_to_cache(self, study_id: int, trials: List[FrozenTrial]) -> None:
study = self._studies[study_id]
Expand All @@ -369,13 +275,6 @@ def _add_trials_to_cache(self, study_id: int, trials: List[FrozenTrial]) -> None
self._study_id_and_number_to_trial_id[(study_id, trial.number)] = trial._trial_id
study.trials[trial.number] = trial

@staticmethod
def _check_trial_is_updatable(trial: FrozenTrial) -> None:
if trial.state.is_finished():
raise RuntimeError(
"Trial#{} has already finished and can not be updated.".format(trial.number)
)

def record_heartbeat(self, trial_id: int) -> None:
self._backend.record_heartbeat(trial_id)

Expand Down
2 changes: 2 additions & 0 deletions optuna/study/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,8 @@ def objective(trial):
study_summaries = []

for s in frozen_studies:
if isinstance(storage, _CachedStorage):
storage.read_trials_from_remote_storage(s._study_id)
all_trials = storage.get_all_trials(s._study_id)
completed_trials = [t for t in all_trials if t.state == TrialState.COMPLETE]

Expand Down
41 changes: 29 additions & 12 deletions optuna/trial/_trial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

from collections import UserDict
import copy
import datetime
from typing import Any
Expand Down Expand Up @@ -59,6 +62,7 @@ def __init__(self, study: "optuna.study.Study", trial_id: int) -> None:
study, self._cached_frozen_trial
)
self._relative_params: Optional[Dict[str, Any]] = None
self._fixed_params = self._cached_frozen_trial.system_attrs.get("fixed_params", {})

@property
def relative_params(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -528,7 +532,7 @@ def should_prune(self) -> bool:
"Trial.should_prune is not supported for multi-objective optimization."
)

trial = copy.deepcopy(self._get_latest_trial())
trial = self._get_latest_trial()
return self.study.pruner.prune(self.study, trial)

def set_user_attr(self, key: str, value: Any) -> None:
Expand Down Expand Up @@ -618,7 +622,7 @@ def _suggest(self, name: str, distribution: BaseDistribution) -> Any:
param_value = trial.params[name]
else:
if self._is_fixed_param(name, distribution):
param_value = trial.system_attrs["fixed_params"][name]
param_value = self._fixed_params[name]
elif distribution.single():
param_value = distributions._get_single_value(distribution)
elif self._is_relative_param(name, distribution):
Expand All @@ -638,14 +642,10 @@ def _suggest(self, name: str, distribution: BaseDistribution) -> Any:
return param_value

def _is_fixed_param(self, name: str, distribution: BaseDistribution) -> bool:
system_attrs = self._cached_frozen_trial.system_attrs
if "fixed_params" not in system_attrs:
return False

if name not in system_attrs["fixed_params"]:
if name not in self._fixed_params:
return False

param_value = system_attrs["fixed_params"][name]
param_value = self._fixed_params[name]
param_value_in_internal_repr = distribution.to_internal_repr(param_value)

contained = distribution._contains(param_value_in_internal_repr)
Expand Down Expand Up @@ -688,10 +688,12 @@ def _check_distribution(self, name: str, distribution: BaseDistribution) -> None
)

def _get_latest_trial(self) -> FrozenTrial:
# TODO(eukaryo): Remove this method after `system_attrs` property is deprecated.
system_attrs = copy.deepcopy(self.storage.get_trial_system_attrs(self._trial_id))
self._cached_frozen_trial.system_attrs = system_attrs
return self._cached_frozen_trial
# TODO(eukaryo): Remove this method after `system_attrs` property is removed.
latest_trial = copy.deepcopy(self._cached_frozen_trial)
latest_trial.system_attrs = _LazyTrialSystemAttrs( # type: ignore[assignment]
self._trial_id, self.storage
)
return latest_trial

@property
def params(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -752,3 +754,18 @@ def number(self) -> int:
"""

return self._cached_frozen_trial.number


class _LazyTrialSystemAttrs(UserDict):
def __init__(self, trial_id: int, storage: optuna.storages.BaseStorage) -> None:
super().__init__()
self._trial_id = trial_id
self._storage = storage
self._initialized = False

def __getattribute__(self, key: str) -> Any:
if key == "data":
if not self._initialized:
self._initialized = True
super().update(self._storage.get_trial_system_attrs(self._trial_id))
return super().__getattribute__(key)
18 changes: 9 additions & 9 deletions tests/storages_tests/test_cached_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,6 @@ def test_uncached_set() -> None:
storage.set_trial_state_values(trial_id, state=trial.state, values=(0.3,))
assert set_mock.call_count == 1

trial_id = storage.create_new_trial(study_id)
with patch.object(
base_storage, "_check_and_set_param_distribution", return_value=True
) as set_mock:
storage.set_trial_param(
trial_id, "paramA", 1.2, optuna.distributions.FloatDistribution(-0.2, 2.3)
)
assert set_mock.call_count == 1

trial_id = storage.create_new_trial(study_id)
with patch.object(base_storage, "set_trial_param", return_value=True) as set_mock:
storage.set_trial_param(
Expand Down Expand Up @@ -117,3 +108,12 @@ def test_read_trials_from_remote_storage() -> None:
# Non-existent study.
with pytest.raises(KeyError):
storage.read_trials_from_remote_storage(study_id + 1)

# Create a trial via CachedStorage and update it via backend storage directly.
trial_id = storage.create_new_trial(study_id)
base_storage.set_trial_param(
trial_id, "paramA", 1.2, optuna.distributions.FloatDistribution(-0.2, 2.3)
)
base_storage.set_trial_state_values(trial_id, TrialState.COMPLETE, values=[0.0])
storage.read_trials_from_remote_storage(study_id)
assert storage.get_trial(trial_id).state == TrialState.COMPLETE
2 changes: 1 addition & 1 deletion tests/study_tests/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ def test_tell_from_another_process() -> None:
assert study.best_value == 1.2

# Should fail because the trial0 is already finished.
with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
pool.starmap(_process_tell, [(study, trial0, 1.2)])


Expand Down
33 changes: 33 additions & 0 deletions tests/trial_tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from optuna.testing.storages import StorageSupplier
from optuna.testing.tempfile_pool import NamedTemporaryFilePool
from optuna.trial import Trial
from optuna.trial._trial import _LazyTrialSystemAttrs


@pytest.mark.filterwarnings("ignore::FutureWarning")
Expand Down Expand Up @@ -678,3 +679,35 @@ def test_persisted_param() -> None:
study = load_study(storage=storage, study_name=study_name)

assert all("x" in t.params for t in study.trials)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_lazy_trial_system_attrs(storage_mode: str) -> None:
with StorageSupplier(storage_mode) as storage:
study = optuna.create_study(storage=storage)
trial = study.ask()
storage.set_trial_system_attr(trial._trial_id, "int", 0)
storage.set_trial_system_attr(trial._trial_id, "str", "A")

# _LazyTrialSystemAttrs gets attrs the first time it is needed.
# Then, we create the instance for each method, and test the first and second use.

system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage)
assert system_attrs == {"int": 0, "str": "A"} # type: ignore[comparison-overlap]
assert system_attrs == {"int": 0, "str": "A"} # type: ignore[comparison-overlap]

system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage)
assert len(system_attrs) == 2
assert len(system_attrs) == 2

system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage)
assert set(system_attrs.keys()) == {"int", "str"}
assert set(system_attrs.keys()) == {"int", "str"}

system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage)
assert set(system_attrs.values()) == {0, "A"}
assert set(system_attrs.values()) == {0, "A"}

system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage)
assert set(system_attrs.items()) == {("int", 0), ("str", "A")}
assert set(system_attrs.items()) == {("int", 0), ("str", "A")}

0 comments on commit 466cddf

Please sign in to comment.