diff --git a/pyproject.toml b/pyproject.toml index 398326fc3..99e01ae81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,9 @@ requires-python = ">=3.10" "import-linter~=2.10", "pytest-deadfixtures~=3.1", "taplo~=0.9.3", + "gymnasium~=1.2", ] + rl = ["gymnasium~=1.2"] docs = [ "sphinx~=8.1", "nvidia-sphinx-theme~=0.0.8", diff --git a/src/cloudai/configurator/__init__.py b/src/cloudai/configurator/__init__.py index f05b65c5b..a88432c41 100644 --- a/src/cloudai/configurator/__init__.py +++ b/src/cloudai/configurator/__init__.py @@ -18,11 +18,13 @@ from .base_gym import BaseGym from .cloudai_gym import CloudAIGymEnv, TrajectoryEntry from .grid_search import GridSearchAgent +from .gymnasium_adapter import GymnasiumAdapter __all__ = [ "BaseAgent", "BaseGym", "CloudAIGymEnv", "GridSearchAgent", + "GymnasiumAdapter", "TrajectoryEntry", ] diff --git a/src/cloudai/configurator/cloudai_gym.py b/src/cloudai/configurator/cloudai_gym.py index d1bdba1f1..72f030627 100644 --- a/src/cloudai/configurator/cloudai_gym.py +++ b/src/cloudai/configurator/cloudai_gym.py @@ -76,9 +76,10 @@ def define_observation_space(self) -> list: Define the observation space for the environment. Returns: - list: The observation space. + list: One float slot per agent metric (at least one), giving the correct shape + for adapters that derive ``gymnasium.spaces.Box`` from this output. """ - return [0.0] + return [0.0] * max(len(self.test_run.test.agent_metrics), 1) def reset( self, @@ -100,7 +101,7 @@ def reset( if seed is not None: lazy.np.random.seed(seed) self.test_run.current_iteration = 0 - observation = [0.0] + observation = self.define_observation_space() info = {} return observation, info diff --git a/src/cloudai/configurator/gymnasium_adapter.py b/src/cloudai/configurator/gymnasium_adapter.py new file mode 100644 index 000000000..6b1053a85 --- /dev/null +++ b/src/cloudai/configurator/gymnasium_adapter.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, ClassVar, Optional + +from .base_gym import BaseGym + +_GYMNASIUM_INSTALL_HINT = "gymnasium is required for GymnasiumAdapter. Install it with: pip install gymnasium" + + +def _import_gymnasium(): + """ + Import gymnasium + numpy lazily; raise a clear, actionable error when absent. + + Kept as a single seam so that: + * cloudai installs without ``gymnasium`` continue to work for users that don't + need this adapter (the import is gated behind ``GymnasiumAdapter()``); + * tests can patch this helper to simulate a missing install. + """ + try: + import gymnasium + import numpy as np + from gymnasium import spaces + + return gymnasium, spaces, np + except ImportError as exc: + raise ImportError(_GYMNASIUM_INSTALL_HINT) from exc + + +class GymnasiumAdapter: + """ + Expose a CloudAI :class:`BaseGym` as a standard ``gymnasium.Env``-shaped object. + + The adapter: + + * builds a ``gymnasium.spaces.Dict`` of ``Discrete`` action spaces over the + *tunable* parameters (those with more than one candidate value), and + injects the *fixed* parameters (single candidate) automatically on every + step so agents never see them. + * converts observations to ``float32`` ``numpy`` arrays sized by + ``env.define_observation_space()``. + * returns the gymnasium 5-tuple ``(obs, reward, terminated, truncated, info)`` + from :meth:`step` and :meth:`step_raw`. + * keeps ``env.test_run.step`` in sync (1-based) so artifact paths produced by + ``CloudAIGymEnv`` match those produced by ``handle_dse_job`` (i.e. + ``////`` for every evaluation), which is + required when a custom training loop (e.g. RLlib) front-ends the env. + + ``gymnasium`` and ``numpy`` are optional dependencies; importing this module + is cheap, but instantiating the adapter without them raises ``ImportError``. + """ + + metadata: ClassVar[dict[str, Any]] = {"render_modes": ["human"]} + + def __init__(self, env: BaseGym) -> None: + _, spaces, np = _import_gymnasium() + + self._np = np + self._env = env + self._step_count = 0 + + raw_action_space = env.define_action_space() + self._tunable_params: dict[str, list] = {k: v for k, v in raw_action_space.items() if len(v) > 1} + self._fixed_params: dict[str, Any] = {k: v[0] for k, v in raw_action_space.items() if len(v) == 1} + + self.action_space = spaces.Dict( + {name: spaces.Discrete(len(values)) for name, values in self._tunable_params.items()} + ) + + obs_shape = (len(env.define_observation_space()),) + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=obs_shape, dtype=np.float32) + + @property + def unwrapped(self) -> BaseGym: + return self._env + + def decode_action(self, action: dict[str, int]) -> dict[str, Any]: + """ + Map discrete action indices back to the original parameter values. + + Raises: + ValueError: if ``action`` is missing tunable params, contains unknown keys, + or carries an index outside the discrete range for any tunable param. + """ + self._assert_keys(action.keys(), set(self._tunable_params), "action") + decoded: dict[str, Any] = {} + for name, idx in action.items(): + values = self._tunable_params[name] + if not 0 <= idx < len(values): + raise ValueError(f"Action index out of range for '{name}': {idx} (expected 0..{len(values) - 1})") + decoded[name] = values[idx] + return decoded + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[Any, dict[str, Any]]: + self._step_count = 0 + obs, info = self._env.reset(seed=seed, options=options) + return self._as_obs_array(obs), info + + def step(self, action: dict[str, int]) -> tuple[Any, float, bool, bool, dict[str, Any]]: + params = {**self._fixed_params, **self.decode_action(action)} + return self._step_with_params(params) + + def step_raw(self, params: dict[str, Any]) -> tuple[Any, float, bool, bool, dict[str, Any]]: + """ + Step the env with an already-decoded parameter dict; bypasses index decoding. + + Raises: + ValueError: if ``params`` does not cover exactly the tunable + fixed param keys. + """ + self._assert_keys(params.keys(), set(self._tunable_params) | set(self._fixed_params), "raw params") + return self._step_with_params(params) + + def render(self) -> None: + self._env.render() + + @staticmethod + def _assert_keys(received: Any, expected: set[str], ctx: str) -> None: + received_set = set(received) + if received_set == expected: + return + missing = sorted(expected - received_set) + extra = sorted(received_set - expected) + raise ValueError(f"{ctx} keys mismatch; missing={missing}, extra={extra}") + + def _step_with_params(self, params: dict[str, Any]) -> tuple[Any, float, bool, bool, dict[str, Any]]: + self._sync_underlying_step_counter() + obs, reward, done, info = self._env.step(params) + self._step_count += 1 + return self._as_obs_array(obs), float(reward), bool(done), False, info + + def _sync_underlying_step_counter(self) -> None: + """ + Mirror ``handle_dse_job``'s 1-based ``test_run.step`` so artifact paths match. + + The first step is written under ``…//1/``, matching how + ``handle_dse_job`` numbers steps; this keeps reports and trajectory + analysis consistent regardless of whether the env is driven by the + DSE loop or by an external training loop wrapping the adapter. + """ + test_run = getattr(self._env, "test_run", None) + if test_run is not None: + test_run.step = self._step_count + 1 + + def _as_obs_array(self, obs: Any) -> Any: + return self._np.asarray(obs, dtype=self._np.float32) diff --git a/tests/test_gymnasium_adapter.py b/tests/test_gymnasium_adapter.py new file mode 100644 index 000000000..e6ec62b79 --- /dev/null +++ b/tests/test_gymnasium_adapter.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Optional + +import gymnasium +import numpy as np +import pytest + +from cloudai.configurator import GymnasiumAdapter +from cloudai.configurator.base_gym import BaseGym + + +class _FakeGym(BaseGym): + """Deterministic BaseGym fixture with two tunable params and a 3-dim observation.""" + + def __init__(self) -> None: + self._action_space: dict[str, Any] = {"param_a": [1, 2, 3], "param_b": [10, 20]} + self._observation_space: list[float] = [0.0, 0.0, 0.0] + self.last_action: Optional[dict[str, Any]] = None + super().__init__() + + def define_action_space(self) -> dict[str, Any]: + return self._action_space + + def define_observation_space(self) -> list: + return self._observation_space + + def reset( + self, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[list, dict[str, Any]]: + return [0.0, 0.0, 0.0], {} + + def step(self, action: Any) -> tuple[list, float, bool, dict]: + self.last_action = action + return [1.0, 2.0, 3.0], 0.5, False, {"info": "test"} + + def render(self, mode: str = "human") -> None: + return None + + def seed(self, seed: Optional[int] = None) -> None: + pass + + +class _FixedParamGym(_FakeGym): + """Adds a single-value parameter that the adapter must treat as fixed.""" + + def define_action_space(self) -> dict[str, Any]: + return {"param_a": [1, 2, 3], "param_b": [10, 20], "fixed_param": [42]} + + +class _GymWithTestRun(_FakeGym): + """Carries a CloudAI-like ``test_run`` so we can verify the step-counter sync.""" + + def __init__(self) -> None: + super().__init__() + self.test_run = SimpleNamespace(step=0) + + +@pytest.fixture +def fake_gym() -> _FakeGym: + return _FakeGym() + + +@pytest.fixture +def adapter(fake_gym: _FakeGym) -> GymnasiumAdapter: + return GymnasiumAdapter(fake_gym) + + +def test_action_space_is_dict_of_discrete(adapter: GymnasiumAdapter) -> None: + assert isinstance(adapter.action_space, gymnasium.spaces.Dict) + assert set(adapter.action_space.spaces) == {"param_a", "param_b"} + + sub_a = adapter.action_space.spaces["param_a"] + sub_b = adapter.action_space.spaces["param_b"] + assert isinstance(sub_a, gymnasium.spaces.Discrete) and sub_a.n == 3 + assert isinstance(sub_b, gymnasium.spaces.Discrete) and sub_b.n == 2 + + +def test_observation_space_shape_matches_env(adapter: GymnasiumAdapter) -> None: + assert isinstance(adapter.observation_space, gymnasium.spaces.Box) + assert adapter.observation_space.shape == (3,) + assert adapter.observation_space.dtype == np.float32 + + +def test_reset_returns_float32_array(adapter: GymnasiumAdapter) -> None: + obs, info = adapter.reset() + assert isinstance(obs, np.ndarray) and obs.dtype == np.float32 and obs.shape == (3,) + np.testing.assert_array_equal(obs, [0.0, 0.0, 0.0]) + assert info == {} + + +def test_step_returns_gymnasium_five_tuple(adapter: GymnasiumAdapter) -> None: + adapter.reset() + obs, reward, terminated, truncated, info = adapter.step({"param_a": 0, "param_b": 1}) + + assert isinstance(obs, np.ndarray) and obs.dtype == np.float32 + np.testing.assert_array_equal(obs, [1.0, 2.0, 3.0]) + assert reward == 0.5 + assert terminated is False + assert truncated is False + assert info == {"info": "test"} + + +def test_decode_action_maps_indices_back_to_values(adapter: GymnasiumAdapter) -> None: + assert adapter.decode_action({"param_a": 0, "param_b": 1}) == {"param_a": 1, "param_b": 20} + + +def test_unwrapped_returns_original_env(fake_gym: _FakeGym, adapter: GymnasiumAdapter) -> None: + assert adapter.unwrapped is fake_gym + + +def test_single_value_params_are_excluded_from_action_space() -> None: + adapter = GymnasiumAdapter(_FixedParamGym()) + + assert set(adapter.action_space.spaces) == {"param_a", "param_b"} + assert adapter._fixed_params == {"fixed_param": 42} + + +def test_step_merges_fixed_params_into_underlying_action() -> None: + gym = _FixedParamGym() + adapter = GymnasiumAdapter(gym) + adapter.reset() + + adapter.step({"param_a": 0, "param_b": 1}) + + assert gym.last_action == {"param_a": 1, "param_b": 20, "fixed_param": 42} + + +def test_step_raw_bypasses_decode_and_fixed_injection() -> None: + gym = _FixedParamGym() + adapter = GymnasiumAdapter(gym) + adapter.reset() + raw = {"param_a": 999, "param_b": 888, "fixed_param": 777} + + obs, _reward, terminated, truncated, _info = adapter.step_raw(raw) + + assert gym.last_action == raw + assert isinstance(obs, np.ndarray) + assert terminated is False + assert truncated is False + + +def test_step_assigns_one_based_step_to_test_run() -> None: + gym = _GymWithTestRun() + adapter = GymnasiumAdapter(gym) + adapter.reset() + + adapter.step({"param_a": 0, "param_b": 1}) + assert gym.test_run.step == 1 + + adapter.step({"param_a": 1, "param_b": 0}) + assert gym.test_run.step == 2 + + +def test_step_raw_also_syncs_test_run_step() -> None: + gym = _GymWithTestRun() + adapter = GymnasiumAdapter(gym) + adapter.reset() + + adapter.step_raw({"param_a": 2, "param_b": 1}) + assert gym.test_run.step == 1 + + +def test_reset_restarts_step_counter() -> None: + gym = _GymWithTestRun() + adapter = GymnasiumAdapter(gym) + adapter.reset() + adapter.step({"param_a": 0, "param_b": 1}) + adapter.step({"param_a": 1, "param_b": 0}) + assert gym.test_run.step == 2 + + adapter.reset() + adapter.step({"param_a": 0, "param_b": 0}) + assert gym.test_run.step == 1 + + +def test_missing_gymnasium_raises_clear_error(monkeypatch: pytest.MonkeyPatch) -> None: + import cloudai.configurator.gymnasium_adapter as mod + + def _raise() -> None: + raise ImportError("pip install gymnasium") + + monkeypatch.setattr(mod, "_import_gymnasium", _raise) + + with pytest.raises(ImportError, match="pip install gymnasium"): + GymnasiumAdapter(_FakeGym()) + + +def test_decode_action_rejects_missing_keys(adapter: GymnasiumAdapter) -> None: + with pytest.raises(ValueError, match=r"missing=\['param_b'\]"): + adapter.decode_action({"param_a": 0}) + + +def test_decode_action_rejects_unknown_keys(adapter: GymnasiumAdapter) -> None: + with pytest.raises(ValueError, match=r"extra=\['bogus'\]"): + adapter.decode_action({"param_a": 0, "param_b": 1, "bogus": 0}) + + +def test_decode_action_rejects_out_of_range_index(adapter: GymnasiumAdapter) -> None: + with pytest.raises(ValueError, match=r"out of range for 'param_a'"): + adapter.decode_action({"param_a": 99, "param_b": 0}) + + +def test_step_raw_rejects_missing_fixed_param() -> None: + adapter = GymnasiumAdapter(_FixedParamGym()) + adapter.reset() + with pytest.raises(ValueError, match=r"missing=\['fixed_param'\]"): + adapter.step_raw({"param_a": 1, "param_b": 10}) + + +def test_step_raw_rejects_unknown_keys() -> None: + adapter = GymnasiumAdapter(_FixedParamGym()) + adapter.reset() + with pytest.raises(ValueError, match=r"extra=\['bogus'\]"): + adapter.step_raw({"param_a": 1, "param_b": 10, "fixed_param": 42, "bogus": 0}) diff --git a/uv.lock b/uv.lock index c5c9d6748..7852ea314 100644 --- a/uv.lock +++ b/uv.lock @@ -282,6 +282,7 @@ dependencies = [ [package.optional-dependencies] dev = [ { name = "build" }, + { name = "gymnasium" }, { name = "import-linter" }, { name = "pandas-stubs" }, { name = "pre-commit" }, @@ -312,6 +313,9 @@ docs-cms = [ { name = "sphinx-rtd-theme" }, { name = "sphinxcontrib-mermaid" }, ] +rl = [ + { name = "gymnasium" }, +] [package.metadata] requires-dist = [ @@ -320,6 +324,8 @@ requires-dist = [ { name = "bokeh", specifier = "~=3.8" }, { name = "build", marker = "extra == 'dev'", specifier = "~=1.4" }, { name = "click", specifier = "~=8.3" }, + { name = "gymnasium", marker = "extra == 'dev'", specifier = "~=1.2" }, + { name = "gymnasium", marker = "extra == 'rl'", specifier = "~=1.2" }, { name = "huggingface-hub", specifier = "~=1.4" }, { name = "import-linter", marker = "extra == 'dev'", specifier = "~=2.10" }, { name = "jinja2", specifier = "~=3.1.6" }, @@ -350,7 +356,16 @@ requires-dist = [ { name = "vulture", marker = "extra == 'dev'", specifier = "==2.14" }, { name = "websockets", specifier = "~=16.0" }, ] -provides-extras = ["dev", "docs", "docs-cms"] +provides-extras = ["dev", "rl", "docs", "docs-cms"] + +[[package]] +name = "cloudpickle" +version = "3.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, +] [[package]] name = "colorama" @@ -674,6 +689,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, ] +[[package]] +name = "farama-notifications" +version = "0.0.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/91/14397890dde30adc4bee6462158933806207bc5dd10d7b4d09d5c33845cf/farama_notifications-0.0.6.tar.gz", hash = "sha256:b19acac4bb41d76e59e03394b5dd165f4761c86fa327f56307a35cbee3b60158", size = 2517, upload-time = "2026-04-24T08:43:57.603Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl", hash = "sha256:f84839188efa1ce5bb361c2a84881b2dc2c0d0d7fb661ff00421820170930935", size = 2897, upload-time = "2026-04-24T08:43:56.785Z" }, +] + [[package]] name = "fastapi" version = "0.128.6" @@ -884,6 +908,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/b2/b096ccce418882fbfda4f7496f9357aaa9a5af1896a9a7f60d9f2b275a06/grpcio-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:dce09d6116df20a96acfdbf85e4866258c3758180e8c49845d6ba8248b6d0bbb", size = 4929852, upload-time = "2026-02-06T09:56:45.885Z" }, ] +[[package]] +name = "gymnasium" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cloudpickle" }, + { name = "farama-notifications" }, + { name = "numpy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4d/ff/14b6880d703dfaca204490979d3254ccd280c99550798993319902873658/gymnasium-1.3.0.tar.gz", hash = "sha256:6939e86e835d6b71b6ba6bfd360487420876deafc79bfb7bacba83a7c446bcf3", size = 830646, upload-time = "2026-04-22T13:47:14.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl", hash = "sha256:6b8c159a8540dcbcb221722d7efda24d78ebbcbc3bd2ea1c2611aa2a34471fc2", size = 953904, upload-time = "2026-04-22T13:47:12.13Z" }, +] + [[package]] name = "h11" version = "0.16.0"