diff --git a/pyproject.toml b/pyproject.toml index 4e14aa151..8b6906293 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,9 @@ requires-python = ">=3.10" "import-linter~=2.10", "pytest-deadfixtures~=3.1", "taplo~=0.9.3", + "gymnasium~=1.2", ] + rl = ["gymnasium~=1.2"] docs = [ "sphinx~=8.1", "nvidia-sphinx-theme~=0.0.8", diff --git a/src/cloudai/_core/registry.py b/src/cloudai/_core/registry.py index 2e2adf6b7..17f73e8c3 100644 --- a/src/cloudai/_core/registry.py +++ b/src/cloudai/_core/registry.py @@ -57,6 +57,7 @@ class Registry(metaclass=Singleton): scenario_reports: ClassVar[dict[str, type[Reporter]]] = {} report_configs: ClassVar[dict[str, ReportConfig]] = {} reward_functions_map: ClassVar[dict[str, RewardFunction]] = {} + env_factories_map: ClassVar[dict[str, Callable]] = {} command_gen_strategies_map: ClassVar[dict[tuple[Type[System], Type[TestDefinition]], Type[CommandGenStrategy]]] = {} json_gen_strategies_map: ClassVar[dict[tuple[Type[System], Type[TestDefinition]], Type[JsonGenStrategy]]] = {} grading_strategies_map: ClassVar[dict[Tuple[Type[System], Type[TestDefinition]], Type[GradingStrategy]]] = {} @@ -249,6 +250,19 @@ def get_reward_function(self, name: str) -> RewardFunction: ) return self.reward_functions_map[name] + def add_env_factory(self, name: str, factory: Callable) -> None: + if name in self.env_factories_map: + raise ValueError(f"Duplicating implementation for '{name}', use 'update()' for replacement.") + self.update_env_factory(name, factory) + + def update_env_factory(self, name: str, factory: Callable) -> None: + self.env_factories_map[name] = factory + + def get_env_factory(self, name: str) -> Callable: + if name not in self.env_factories_map: + raise KeyError(f"Env factory '{name}' not found. Available: {list(self.env_factories_map.keys())}") + return self.env_factories_map[name] + def add_command_gen_strategy( self, system_type: Type[System], tdef_type: Type[TestDefinition], value: Type[CommandGenStrategy] ) -> None: diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 1337976c2..143f740aa 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -116,6 +116,20 @@ def prepare_installation( return installables, installer +def _run_custom_training_loop(agent: object, agent_type: str) -> int: + """Delegate to an agent's own training loop (e.g. RLlib PPO).""" + logging.info(f"Agent {agent_type} uses a custom training loop, delegating to agent.train()") + try: + agent.train() # type: ignore[union-attr] + return 0 + except Exception as e: + logging.error(f"Agent training failed for {agent_type}: {e}", exc_info=True) + return 1 + finally: + if hasattr(agent, "shutdown"): + agent.shutdown() # type: ignore[union-attr] + + def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: registry = Registry() @@ -132,6 +146,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: err = 0 for tr in runner.runner.test_scenario.test_runs: test_run = copy.deepcopy(tr) + test_run.output_path = runner.runner.get_job_output_path(test_run) agent_type = test_run.test.agent agent_class = registry.agents_map.get(agent_type) @@ -151,15 +166,29 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: agent = agent_class(env, agent_config) + if getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False): + err |= _run_custom_training_loop(agent, agent_type) + continue + + observation, _ = env.reset() + for step in range(agent.max_steps): - result = agent.select_action() + result = agent.select_action(observation=observation) if result is None: break step, action = result env.test_run.step = step logging.info(f"Running step {step} (of {agent.max_steps}) with action {action}") - observation, reward, *_ = env.step(action) - feedback = {"trial_index": step, "value": reward} + prev_obs = observation + observation, reward, done, *_ = env.step(action) + feedback = { + "trial_index": step, + "value": reward, + "observation": observation, + "prev_observation": prev_obs, + "action": action, + "done": done, + } agent.update_policy(feedback) logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}") @@ -304,11 +333,12 @@ def handle_dry_run_and_run(args: argparse.Namespace) -> int: logging.info(f"Scenario results will be stored at: {runner.runner.scenario_root}") has_dse = any(tr.is_dse_job for tr in test_scenario.test_runs) - if args.single_sbatch or not has_dse: # in this mode cases are unrolled using grid search + has_live_rl = any(getattr(tr.test.cmd_args, "live_rl_mode", False) for tr in test_scenario.test_runs) + if args.single_sbatch or (not has_dse and not has_live_rl): handle_non_dse_job(runner, args) return 0 - if all(tr.is_dse_job for tr in test_scenario.test_runs): + if all(tr.is_dse_job or getattr(tr.test.cmd_args, "live_rl_mode", False) for tr in test_scenario.test_runs): return handle_dse_job(runner, args) logging.error("Mixing DSE and non-DSE jobs is not allowed.") diff --git a/src/cloudai/configurator/__init__.py b/src/cloudai/configurator/__init__.py index f05b65c5b..1734e6b2c 100644 --- a/src/cloudai/configurator/__init__.py +++ b/src/cloudai/configurator/__init__.py @@ -16,13 +16,16 @@ from .base_agent import BaseAgent from .base_gym import BaseGym -from .cloudai_gym import CloudAIGymEnv, TrajectoryEntry +from .cloudai_gym import CloudAIGymEnv, GymServer, TrajectoryEntry from .grid_search import GridSearchAgent +from .gymnasium_adapter import GymnasiumAdapter __all__ = [ "BaseAgent", "BaseGym", "CloudAIGymEnv", "GridSearchAgent", + "GymServer", + "GymnasiumAdapter", "TrajectoryEntry", ] diff --git a/src/cloudai/configurator/base_agent.py b/src/cloudai/configurator/base_agent.py index 4f5813199..8b63059c1 100644 --- a/src/cloudai/configurator/base_agent.py +++ b/src/cloudai/configurator/base_agent.py @@ -15,7 +15,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Literal +from typing import Any, Dict, Literal, Optional from pydantic import BaseModel, ConfigDict @@ -68,10 +68,13 @@ def configure(self, config: dict[str, Any]) -> None: pass @abstractmethod - def select_action(self) -> tuple[int, dict[str, Any]]: + def select_action(self, observation: Optional[list] = None) -> tuple[int, dict[str, Any]]: """ Select an action from the action space. + Args: + observation: Optional environment observation from the previous step. + Returns: Tuple[int, Dict[str, Any]]: The current step index and a dictionary mapping action keys to selected values. """ diff --git a/src/cloudai/configurator/cloudai_gym.py b/src/cloudai/configurator/cloudai_gym.py index a7643992c..ccf1435bd 100644 --- a/src/cloudai/configurator/cloudai_gym.py +++ b/src/cloudai/configurator/cloudai_gym.py @@ -17,9 +17,11 @@ import copy import csv import dataclasses +import importlib import logging +import random as stdlib_random from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Protocol, Tuple from cloudai.core import METRIC_ERROR, BaseRunner, Registry, TestRun from cloudai.util.lazy_imports import lazy @@ -35,238 +37,324 @@ class TrajectoryEntry: action: dict[str, Any] reward: float observation: list + info: dict[str, Any] = dataclasses.field(default_factory=dict) + + +class GymServer(Protocol): + """Protocol for gym server objects used in online mode.""" + + def reset(self) -> Tuple[List[float], Dict[str, Any]]: ... + def step(self, action: Dict[str, Any]) -> Tuple[List[float], float, bool, Dict[str, Any]]: ... + def get_action_space(self) -> Dict[str, Any]: ... + def get_observation_space(self) -> List[float]: ... + + +class _StepBackend(Protocol): + """Internal protocol for execution backends.""" + + def get_action_space(self) -> Dict[str, Any]: ... + def get_observation_space(self) -> list: ... + def reset(self, seed: Optional[int] = None) -> Tuple[list, dict[str, Any]]: ... + def step(self, action: Any) -> Tuple[list, bool, dict[str, Any]]: ... + + +class _RunnerBackend: + """Backend that launches real workloads via the CloudAI runner.""" + + def __init__(self, test_run: TestRun, runner: BaseRunner) -> None: + self._test_run = test_run + self._original_test_run = copy.deepcopy(test_run) + self._runner = runner + self._trajectory_cache: dict[int, list[TrajectoryEntry]] = {} + + @property + def test_run(self) -> TestRun: + return self._test_run + + @test_run.setter + def test_run(self, value: TestRun) -> None: + self._test_run = value + + def get_action_space(self) -> Dict[str, Any]: + return self._test_run.param_space + + def get_observation_space(self) -> list: + n_metrics = max(len(self._test_run.test.agent_metrics), 1) + return [0.0] * n_metrics + + def reset(self, seed: Optional[int] = None) -> Tuple[list, dict[str, Any]]: + if seed is not None: + lazy.np.random.seed(seed) + self._test_run.current_iteration = 0 + return self.get_observation_space(), {} + + def step(self, action: Any) -> Tuple[list, bool, dict[str, Any]]: + self._test_run = self._test_run.apply_params_set(action) + + cached = self._get_cached_result(action) + if cached is not None: + logging.info("Retrieved cached result with reward %s. Skipping step.", cached.reward) + return cached.observation, False, cached.info + + if not self._test_run.test.constraint_check(self._test_run, self._runner.system): + logging.info("Constraint check failed. Skipping step.") + return [-1.0], True, {"reason": "constraint_check_failed"} + + new_tr = copy.deepcopy(self._test_run) + new_tr.output_path = self._runner.get_job_output_path(new_tr) + self._runner.test_scenario.test_runs = [new_tr] + + self._runner.shutting_down = False + self._runner.jobs.clear() + self._runner.testrun_to_job_map.clear() + + try: + self._runner.run() + except Exception as e: + logging.error(f"Error running step {self._test_run.step}: {e}") + + if self._runner.test_scenario.test_runs and self._runner.test_scenario.test_runs[0].output_path.exists(): + self._test_run = self._runner.test_scenario.test_runs[0] + else: + self._test_run = copy.deepcopy(self._original_test_run) + self._test_run.step = new_tr.step + self._test_run.output_path = new_tr.output_path + + observation = self._get_observation(action) + return observation, False, {} + + def get_observation(self, action: Any) -> list: + return self._get_observation(action) + + def _get_observation(self, action: Any) -> list: + all_metrics = self._test_run.test.agent_metrics + if not all_metrics: + raise ValueError("No agent metrics defined for the test run") + + observation = [] + for metric in all_metrics: + v = self._test_run.get_metric_value(self._runner.system, metric) + if v == METRIC_ERROR: + v = -1.0 + observation.append(v) + return observation + + def cache_trajectory(self, entry: TrajectoryEntry) -> None: + self._trajectory_cache.setdefault(self._test_run.current_iteration, []).append(entry) + + def _get_cached_result(self, action: Any) -> Optional[TrajectoryEntry]: + for entry in self._trajectory_cache.get(self._test_run.current_iteration, []): + if _values_match_exact(entry.action, action): + return entry + return None + + +class _GymServerBackend: + """Backend that delegates to an in-process GymServer for fast, stateful interaction.""" + + def __init__(self, server: Any) -> None: + self._server = server + self._step_count = 0 + + def get_action_space(self) -> Dict[str, Any]: + return self._server.get_action_space() + + def get_observation_space(self) -> list: + return self._server.get_observation_space() + + def reset(self, seed: Optional[int] = None) -> Tuple[list, dict[str, Any]]: + self._step_count = 0 + return self._server.reset() + + def step(self, action: Any) -> Tuple[list, bool, dict[str, Any]]: + self._step_count += 1 + observation, _raw_reward, done, info = self._server.step(action) + return observation, done, info + + +def _create_gym_server(test_run: TestRun) -> Any: + """Instantiate a GymServer from the env_class path in cmd_args.""" + from cloudai.util import flatten_dict + + cmd_args = test_run.test.cmd_args + args_dict = flatten_dict(cmd_args.model_dump()) + + env_class_path = args_dict.pop("env_class", None) + if not env_class_path: + raise ValueError("online mode requires 'env_class' in cmd_args pointing to a GymServer class") + + for key in ("live_rl_mode", "docker_image_url"): + args_dict.pop(key, None) + + module_path, class_name = env_class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + server_cls = getattr(module, class_name) + + import inspect + sig = inspect.signature(server_cls.__init__) + valid_params = set(sig.parameters.keys()) - {"self"} + if valid_params and not any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()): + args_dict = {k: v for k, v in args_dict.items() if k in valid_params} + + return server_cls(**args_dict) class CloudAIGymEnv(BaseGym): - """ - Custom Gym environment for CloudAI integration. + """Unified Gym environment for CloudAI. - Uses the TestRun object and actual runner methods to execute jobs. - """ + Supports two execution modes selected automatically: - def __init__(self, test_run: TestRun, runner: BaseRunner): - """ - Initialize the Gym environment using the TestRun object. + - **Runner mode** (default): launches real workloads via the CloudAI runner, + reads metrics from job output. Used for standard DSE. + - **Online mode** (``live_rl_mode=true`` in cmd_args): delegates to an + in-process GymServer for fast, stateful interaction. Used for + online RL / simulation-based optimization (e.g. kvpilot). + + Agents interact with the same interface regardless of mode. + """ - Args: - test_run (TestRun): A test run object that encapsulates cmd_args, extra_cmd_args, etc. - runner (BaseRunner): The runner object to execute jobs. - """ + def __init__( + self, + test_run: TestRun, + runner: BaseRunner, + gym_server: Optional[Any] = None, + ): self.test_run = test_run - self.original_test_run = copy.deepcopy(test_run) # Preserve clean state for DSE self.runner = runner self.max_steps = test_run.test.agent_steps self.reward_function = Registry().get_reward_function(test_run.test.agent_reward_function) - self.trajectory: dict[int, list[TrajectoryEntry]] = {} - super().__init__() + self._step_count = 0 + self._rng = stdlib_random.Random(42) + self._trajectory: list[TrajectoryEntry] = [] + self._trajectory_by_iteration: dict[int, list[TrajectoryEntry]] = {} + + if gym_server is not None: + self._backend: _StepBackend = _GymServerBackend(gym_server) + elif getattr(test_run.test.cmd_args, "live_rl_mode", False): + server = _create_gym_server(test_run) + self._backend = _GymServerBackend(server) + else: + self._backend = _RunnerBackend(test_run, runner) - def define_action_space(self) -> Dict[str, list[Any]]: - return self.test_run.param_space + super().__init__() @property - def first_sweep(self) -> dict[str, Any]: - """Builds a sweep using first elements of each explorable parameter.""" - return {k: v[0] for k, v in self.define_action_space().items()} + def _is_online(self) -> bool: + return isinstance(self._backend, _GymServerBackend) + + def define_action_space(self) -> Dict[str, Any]: + return self._backend.get_action_space() def define_observation_space(self) -> list: - """ - Define the observation space for the environment. + return self._backend.get_observation_space() - Returns: - list: The observation space. - """ - return [0.0] + @property + def first_sweep(self) -> Any: + space = self.define_action_space() + if isinstance(space, dict) and space.get("type") == "continuous": + shape = int(space.get("shape", 1)) + low = float(space.get("low", -1.0)) + return [low] * shape + return {k: v[0] if isinstance(v, list) else v for k, v in space.items()} def reset( self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None, # noqa: Vulture ) -> Tuple[list, dict[str, Any]]: - """ - Reset the environment and reinitialize the TestRun. - - Args: - seed (Optional[int]): Seed for the environment's random number generator. - options (Optional[dict]): Additional options for reset. - - Returns: - Tuple: A tuple containing: - - observation (list): Initial observation. - - info (dict): Additional info for debugging. - """ if seed is not None: - lazy.np.random.seed(seed) - self.test_run.current_iteration = 0 - observation = [0.0] - info = {} - return observation, info + self._rng = stdlib_random.Random(seed) + self._step_count = 0 + return self._backend.reset(seed) def step(self, action: Any) -> Tuple[list, float, bool, dict]: - """ - Execute one step in the environment. - - Args: - action (Any): Action chosen by the agent. - - Returns: - Tuple: A tuple containing: - - observation (list): Updated system state. - - reward (float): Reward for the action taken. - - done (bool): Whether the episode is done. - - info (dict): Additional info for debugging. - """ - self.test_run = self.test_run.apply_params_set(action) - - cached_result = self.get_cached_trajectory_result(action) - if cached_result is not None: - logging.info( - "Retrieved cached result from trajectory with reward %s. Skipping step.", - cached_result.reward, - ) - return cached_result.observation, cached_result.reward, False, {} - - if not self.test_run.test.constraint_check(self.test_run, self.runner.system): - logging.info("Constraint check failed. Skipping step.") - return [-1.0], -1.0, True, {} - - new_tr = copy.deepcopy(self.test_run) - new_tr.output_path = self.runner.get_job_output_path(new_tr) - self.runner.test_scenario.test_runs = [new_tr] - - self.runner.shutting_down = False - self.runner.jobs.clear() - self.runner.testrun_to_job_map.clear() - - try: - self.runner.run() - except Exception as e: - logging.error(f"Error running step {self.test_run.step}: {e}") - - if self.runner.test_scenario.test_runs and self.runner.test_scenario.test_runs[0].output_path.exists(): - self.test_run = self.runner.test_scenario.test_runs[0] - else: - self.test_run = copy.deepcopy(self.original_test_run) - self.test_run.step = new_tr.step - self.test_run.output_path = new_tr.output_path - - observation = self.get_observation(action) - reward = self.compute_reward(observation) - - self.write_trajectory( - TrajectoryEntry( - step=self.test_run.step, - action=action, - reward=reward, - observation=observation, - ) + self._step_count += 1 + observation, done, info = self._backend.step(action) + reward = self.reward_function(observation) + + entry = TrajectoryEntry( + step=self._step_count, + action=action, + reward=reward, + observation=observation, + info=info, ) + self._write_trajectory(entry) - return observation, reward, False, {} + if isinstance(self._backend, _RunnerBackend): + self._backend.cache_trajectory(entry) - def render(self, mode: str = "human"): - """ - Render the current state of the TestRun. + return observation, reward, done, info - Args: - mode (str): The mode to render with. Default is "human". - """ - print(f"Step {self.test_run.current_iteration}: Parameters {self.test_run.test.cmd_args}") + def render(self, mode: str = "human"): + if self._is_online: + logging.info(f"CloudAIGymEnv [online] step {self._step_count}") + else: + print(f"Step {self.test_run.current_iteration}: Parameters {self.test_run.test.cmd_args}") def seed(self, seed: Optional[int] = None): - """ - Set the seed for the environment's random number generator. - - Args: - seed (Optional[int]): Seed for the environment's random number generator. - """ if seed is not None: + self._rng = stdlib_random.Random(seed) lazy.np.random.seed(seed) def compute_reward(self, observation: list) -> float: - """ - Compute a reward based on the TestRun result. - - Args: - observation (list): The observation list containing the average value. - - Returns: - float: Reward value. - """ return self.reward_function(observation) def get_observation(self, action: Any) -> list: - """ - Get the observation from the TestRun object. - - Args: - action (Any): Action taken by the agent. + if isinstance(self._backend, _RunnerBackend): + return self._backend.get_observation(action) + return self._backend.get_observation_space() - Returns: - list: The observation. - """ - all_metrics = self.test_run.test.agent_metrics - if not all_metrics: - raise ValueError("No agent metrics defined for the test run") - - observation = [] - for metric in all_metrics: - v = self.test_run.get_metric_value(self.runner.system, metric) - if v == METRIC_ERROR: - v = -1.0 - observation.append(v) - return observation + _MAX_OBS_CSV_ELEMENTS = 1024 - def write_trajectory(self, entry: TrajectoryEntry): - """Append the trajectory to the CSV file and to the local attribute.""" + def _write_trajectory(self, entry: TrajectoryEntry) -> None: + self._trajectory.append(entry) self.current_trajectory.append(entry) file_exists = self.trajectory_file_path.exists() logging.debug(f"Writing trajectory into {self.trajectory_file_path} (exists: {file_exists})") self.trajectory_file_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.trajectory_file_path, mode="a", newline="") as file: - writer = csv.writer(file) + with open(self.trajectory_file_path, mode="a", newline="") as f: + writer = csv.writer(f) if not file_exists: - writer.writerow(["step", "action", "reward", "observation"]) - writer.writerow([entry.step, entry.action, entry.reward, entry.observation]) + writer.writerow(["step", "action", "reward", "observation", "info"]) + obs = entry.observation + if isinstance(obs, list) and len(obs) > self._MAX_OBS_CSV_ELEMENTS: + obs = f"[truncated len={len(obs)}]" + writer.writerow([entry.step, entry.action, entry.reward, obs, entry.info]) + + def write_trajectory(self, entry: TrajectoryEntry) -> None: + """Public method for external callers (e.g. single_sbatch_runner).""" + self._write_trajectory(entry) @property def trajectory_file_path(self) -> Path: + if self._is_online: + return self.test_run.output_path / "trajectory.csv" return self.runner.scenario_root / self.test_run.name / f"{self.test_run.current_iteration}" / "trajectory.csv" @property def current_trajectory(self) -> list[TrajectoryEntry]: - return self.trajectory.setdefault(self.test_run.current_iteration, []) + return self._trajectory_by_iteration.setdefault(self.test_run.current_iteration, []) - def get_cached_trajectory_result(self, action: Any) -> TrajectoryEntry | None: + def get_cached_trajectory_result(self, action: Any) -> Optional[TrajectoryEntry]: for entry in self.current_trajectory: - if self._values_match_exact(entry.action, action): + if _values_match_exact(entry.action, action): return entry - return None - @classmethod - def _values_match_exact(cls, left: Any, right: Any) -> bool: - if type(left) is not type(right): - return False - - elif isinstance(left, dict): - left_keys = set(left.keys()) - right_keys = set(right.keys()) - if left_keys != right_keys: - return False - - return all(cls._values_match_exact(left[key], right[key]) for key in left_keys) - elif isinstance(left, (list, tuple)): - if len(left) != len(right): - return False - - for left_item, right_item in zip(left, right, strict=True): - if not cls._values_match_exact(left_item, right_item): - return False - - return True - - else: - return left == right +def _values_match_exact(left: Any, right: Any) -> bool: + if type(left) is not type(right): + return False + elif isinstance(left, dict): + if set(left.keys()) != set(right.keys()): + return False + return all(_values_match_exact(left[key], right[key]) for key in left) + elif isinstance(left, (list, tuple)): + if len(left) != len(right): + return False + return all(_values_match_exact(l, r) for l, r in zip(left, right, strict=True)) + else: + return left == right diff --git a/src/cloudai/configurator/grid_search.py b/src/cloudai/configurator/grid_search.py index 631660ca4..5a0af3484 100644 --- a/src/cloudai/configurator/grid_search.py +++ b/src/cloudai/configurator/grid_search.py @@ -15,7 +15,7 @@ # limitations under the License. import itertools -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from .base_agent import BaseAgent, BaseAgentConfig from .cloudai_gym import CloudAIGymEnv @@ -71,7 +71,7 @@ def get_all_combinations(self) -> List[Dict[str, Any]]: keys = list(self.action_space.keys()) return [dict(zip(keys, combination, strict=True)) for combination in self.action_combinations] - def select_action(self) -> Tuple[int, Dict[str, Any]]: + def select_action(self, observation: Optional[list] = None) -> Tuple[int, Dict[str, Any]]: """ Select the next action from the grid. diff --git a/src/cloudai/configurator/gymnasium_adapter.py b/src/cloudai/configurator/gymnasium_adapter.py new file mode 100644 index 000000000..83ffa3ead --- /dev/null +++ b/src/cloudai/configurator/gymnasium_adapter.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, ClassVar + +from .base_gym import BaseGym + + +def _import_gymnasium(): + """Import gymnasium lazily; raise a clear error when it is absent.""" + try: + import gymnasium + from gymnasium import spaces + + return gymnasium, spaces + except ImportError as exc: + raise ImportError("gymnasium is required for GymnasiumAdapter. Install it with: pip install gymnasium") from exc + + +class GymnasiumAdapter: + """ + Wrap a CloudAI BaseGym environment as a standard gymnasium.Env. + + gymnasium is imported lazily so it remains an optional dependency. + """ + + metadata: ClassVar[dict[str, Any]] = {"render_modes": ["human"]} + + def __init__(self, env: BaseGym) -> None: + import numpy as np + + gymnasium, spaces = _import_gymnasium() + + gymnasium.Env.__init__(self) + + self._np = np + self._env = env + self._continuous = False + + raw_action_space = env.define_action_space() + + if isinstance(raw_action_space, dict) and raw_action_space.get("type") == "continuous": + self._continuous = True + shape = int(raw_action_space["shape"]) + low = float(raw_action_space.get("low", -1e6)) + high = float(raw_action_space.get("high", 1e6)) + self.action_space = spaces.Box(low=low, high=high, shape=(shape,), dtype=np.float32) + self._tunable_params: dict[str, list] = {} + self._fixed_params: dict[str, Any] = {} + else: + self._tunable_params = {k: v for k, v in raw_action_space.items() if len(v) > 1} + self._fixed_params = {k: v[0] for k, v in raw_action_space.items() if len(v) == 1} + self.action_space = spaces.Dict({k: spaces.Discrete(len(v)) for k, v in self._tunable_params.items()}) + + obs = env.define_observation_space() + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(len(obs),), dtype=np.float32) + + def reset( + self, + *, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[Any, dict[str, Any]]: + """Reset the environment and return (observation, info).""" + obs, info = self._env.reset(seed=seed, options=options) + return self._np.asarray(obs, dtype=self._np.float32), info + + def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: + """Execute one step and return the gymnasium 5-tuple.""" + if self._continuous: + decoded = self._np.asarray(action, dtype=self._np.float32).tolist() + else: + decoded = {**self._fixed_params, **self.decode_action(action)} + obs, reward, done, info = self._env.step(decoded) + return self._np.asarray(obs, dtype=self._np.float32), float(reward), bool(done), False, info + + def step_raw(self, param_dict: dict[str, Any]) -> tuple[Any, float, bool, bool, dict[str, Any]]: + """Execute one step with a pre-decoded parameter dictionary.""" + obs, reward, done, info = self._env.step(param_dict) + return self._np.asarray(obs, dtype=self._np.float32), float(reward), bool(done), False, info + + def decode_action(self, action: dict[str, int]) -> dict[str, Any]: + """Map discrete indices back to the original parameter values.""" + return {k: self._tunable_params[k][idx] for k, idx in action.items()} + + def render(self) -> None: + """Render the underlying environment.""" + self._env.render() + + @property + def unwrapped(self) -> BaseGym: + """Return the wrapped CloudAI BaseGym instance.""" + return self._env diff --git a/src/cloudai/core.py b/src/cloudai/core.py index d5fe0f4b9..f7bfab753 100644 --- a/src/cloudai/core.py +++ b/src/cloudai/core.py @@ -41,7 +41,7 @@ from ._core.system import System from ._core.test_scenario import METRIC_ERROR, TestDependency, TestRun, TestScenario from .configurator.base_agent import BaseAgent, BaseAgentConfig -from .configurator.cloudai_gym import CloudAIGymEnv +from .configurator.cloudai_gym import CloudAIGymEnv, GymServer from .configurator.grid_search import GridSearchAgent from .models.workload import CmdArgs, NsysConfiguration, PredictorConfig, TestDefinition from .parser import Parser @@ -66,6 +66,7 @@ "Grader", "GradingStrategy", "GridSearchAgent", + "GymServer", "HFModel", "InstallStatusResult", "Installable", diff --git a/tests/test_gymnasium_adapter.py b/tests/test_gymnasium_adapter.py new file mode 100644 index 000000000..c08527be0 --- /dev/null +++ b/tests/test_gymnasium_adapter.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Optional +from unittest.mock import patch + +import gymnasium +import numpy as np +import pytest + +from cloudai.configurator.base_gym import BaseGym +from cloudai.configurator.gymnasium_adapter import GymnasiumAdapter + + +class _FakeGym(BaseGym): + """Minimal BaseGym implementation with a known, deterministic interface.""" + + def __init__(self) -> None: + self._action_space: dict[str, Any] = { + "param_a": [1, 2, 3], + "param_b": [10, 20], + } + self._observation_space: list[float] = [0.0, 0.0, 0.0] + super().__init__() + + def define_action_space(self) -> dict[str, Any]: + return self._action_space + + def define_observation_space(self) -> list: + return self._observation_space + + def reset( + self, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[list, dict[str, Any]]: + return [0.0, 0.0, 0.0], {} + + def step(self, action: Any) -> tuple[list, float, bool, dict]: + return [1.0, 2.0, 3.0], 0.5, False, {"info": "test"} + + def render(self, mode: str = "human") -> None: + return None + + def seed(self, seed: Optional[int] = None) -> None: + pass + + +@pytest.fixture +def fake_gym() -> _FakeGym: + return _FakeGym() + + +@pytest.fixture +def adapter(fake_gym: _FakeGym) -> GymnasiumAdapter: + return GymnasiumAdapter(fake_gym) + + +def test_adapter_action_space_structure(adapter: GymnasiumAdapter) -> None: + assert isinstance(adapter.action_space, gymnasium.spaces.Dict) + + assert "param_a" in adapter.action_space.spaces + assert "param_b" in adapter.action_space.spaces + + space_a = adapter.action_space.spaces["param_a"] + space_b = adapter.action_space.spaces["param_b"] + + assert isinstance(space_a, gymnasium.spaces.Discrete) + assert isinstance(space_b, gymnasium.spaces.Discrete) + + assert space_a.n == 3 + assert space_b.n == 2 + + +def test_adapter_observation_space_structure(adapter: GymnasiumAdapter) -> None: + assert isinstance(adapter.observation_space, gymnasium.spaces.Box) + assert adapter.observation_space.shape == (3,) + assert adapter.observation_space.dtype == np.float32 + + +def test_adapter_reset_returns_numpy_array(adapter: GymnasiumAdapter) -> None: + obs, info = adapter.reset() + + assert isinstance(obs, np.ndarray) + assert obs.dtype == np.float32 + assert obs.shape == (3,) + np.testing.assert_array_equal(obs, [0.0, 0.0, 0.0]) + assert isinstance(info, dict) + + +def test_adapter_step_returns_five_tuple(adapter: GymnasiumAdapter) -> None: + adapter.reset() + result = adapter.step({"param_a": 0, "param_b": 1}) + + assert len(result) == 5 + obs, reward, terminated, truncated, info = result + + assert isinstance(obs, np.ndarray) + assert obs.dtype == np.float32 + np.testing.assert_array_equal(obs, [1.0, 2.0, 3.0]) + + assert isinstance(reward, float) + assert reward == 0.5 + + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) + + assert isinstance(info, dict) + + +def test_adapter_decode_action(adapter: GymnasiumAdapter) -> None: + decoded = adapter.decode_action({"param_a": 0, "param_b": 1}) + assert decoded == {"param_a": 1, "param_b": 20} + + +def test_adapter_single_value_params_excluded() -> None: + """Params with a single value offer no choice and should be excluded.""" + + class _SingleValueGym(_FakeGym): + def define_action_space(self) -> dict[str, Any]: + return { + "param_a": [1, 2, 3], + "param_b": [10, 20], + "fixed_param": [42], + } + + adapter = GymnasiumAdapter(_SingleValueGym()) + + assert "fixed_param" not in adapter.action_space.spaces + assert "param_a" in adapter.action_space.spaces + assert "param_b" in adapter.action_space.spaces + + +def test_adapter_import_error_without_gymnasium() -> None: + with ( + patch( + "cloudai.configurator.gymnasium_adapter._import_gymnasium", + side_effect=ImportError( + "gymnasium is required for GymnasiumAdapter. Install it with: pip install gymnasium" + ), + ), + pytest.raises(ImportError, match="pip install gymnasium"), + ): + GymnasiumAdapter(_FakeGym()) + + +def test_adapter_unwrapped_returns_original(fake_gym: _FakeGym, adapter: GymnasiumAdapter) -> None: + assert adapter.unwrapped is fake_gym + + +class _FixedParamGym(_FakeGym): + def define_action_space(self) -> dict[str, Any]: + return { + "param_a": [1, 2, 3], + "param_b": [10, 20], + "fixed_param": [42], + } + + def step(self, action: Any) -> tuple[list, float, bool, dict]: + self.last_action = action + return [1.0, 2.0, 3.0], 0.5, False, {"info": "test"} + + +def test_step_merges_fixed_params() -> None: + gym = _FixedParamGym() + adapter = GymnasiumAdapter(gym) + adapter.reset() + adapter.step({"param_a": 0, "param_b": 1}) + + assert "fixed_param" in gym.last_action + assert gym.last_action["fixed_param"] == 42 + assert gym.last_action["param_a"] == 1 + assert gym.last_action["param_b"] == 20 + + +def test_step_raw_bypasses_decode() -> None: + gym = _FixedParamGym() + adapter = GymnasiumAdapter(gym) + adapter.reset() + + raw_params = {"param_a": 999, "param_b": 888, "fixed_param": 777} + result = adapter.step_raw(raw_params) + + assert len(result) == 5 + _obs, _reward, terminated, truncated, _info = result + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) + assert gym.last_action == raw_params + + +def test_fixed_params_stored_correctly() -> None: + gym = _FixedParamGym() + adapter = GymnasiumAdapter(gym) + assert adapter._fixed_params == {"fixed_param": 42} diff --git a/tests/test_handlers.py b/tests/test_handlers.py index e495162da..a2fd4b91c 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -15,14 +15,15 @@ # limitations under the License. import argparse -from typing import Any, ClassVar, Iterator +import logging +from typing import Any, ClassVar, Iterator, Optional from unittest.mock import MagicMock import pandas as pd import pytest from pydantic import Field -from cloudai.cli.handlers import handle_dse_job +from cloudai.cli.handlers import _run_custom_training_loop, handle_dse_job from cloudai.core import ( BaseAgent, BaseAgentConfig, @@ -58,7 +59,7 @@ def get_config_class() -> type[StubAgentConfig]: def configure(self, config: dict[str, Any]) -> None: raise NotImplementedError - def select_action(self) -> tuple[int, dict[str, Any]]: + def select_action(self, observation: Optional[list] = None) -> tuple[int, dict[str, Any]]: raise NotImplementedError def update_policy(self, _feedback: dict[str, Any]) -> None: @@ -207,3 +208,95 @@ def _job_output_path(tr: TestRun, create: bool = True): pd.testing.assert_frame_equal(actual_trajectory, expected_trajectory) assert [tr.step for tr in reporter.trs] == [1, 3] + + +class CustomLoopStubAgentConfig(BaseAgentConfig): + pass + + +class CustomLoopStubAgent(BaseAgent): + HAS_CUSTOM_TRAINING_LOOP: ClassVar[bool] = True + train_called: ClassVar[bool] = False + shutdown_called: ClassVar[bool] = False + + def __init__(self, env, config: CustomLoopStubAgentConfig): + self.env = env + self.config = config + self.max_steps = 0 + CustomLoopStubAgent.train_called = False + CustomLoopStubAgent.shutdown_called = False + + @staticmethod + def get_config_class() -> type[CustomLoopStubAgentConfig]: + return CustomLoopStubAgentConfig + + def configure(self, config: dict[str, Any]) -> None: + raise NotImplementedError + + def select_action(self, observation: Optional[list] = None) -> tuple[int, dict[str, Any]]: + raise NotImplementedError + + def update_policy(self, _feedback: dict[str, Any]) -> None: + return + + def train(self) -> None: + CustomLoopStubAgent.train_called = True + + def shutdown(self) -> None: + CustomLoopStubAgent.shutdown_called = True + + +@pytest.fixture +def custom_loop_agent_name() -> Iterator[str]: + registry = Registry() + agent_name = "test_handlers_custom_loop_agent" + old_agent = registry.agents_map.get(agent_name) + registry.update_agent(agent_name, CustomLoopStubAgent) + CustomLoopStubAgent.train_called = False + CustomLoopStubAgent.shutdown_called = False + yield agent_name + CustomLoopStubAgent.train_called = False + CustomLoopStubAgent.shutdown_called = False + if old_agent is None: + del registry.agents_map[agent_name] + else: + registry.update_agent(agent_name, old_agent) + + +def test_custom_training_loop_success() -> None: + agent = MagicMock() + agent.train = MagicMock() + agent.shutdown = MagicMock() + + result = _run_custom_training_loop(agent, "mock_agent") + + assert result == 0 + agent.train.assert_called_once() + agent.shutdown.assert_called_once() + + +def test_custom_training_loop_failure(caplog: pytest.LogCaptureFixture) -> None: + agent = MagicMock() + agent.train = MagicMock(side_effect=RuntimeError("boom")) + agent.shutdown = MagicMock() + + with caplog.at_level(logging.ERROR): + result = _run_custom_training_loop(agent, "mock_agent") + + assert result == 1 + agent.shutdown.assert_called_once() + assert "boom" in caplog.text + + +def test_dse_delegates_to_custom_loop( + slurm_system: SlurmSystem, + dse_tr: TestRun, + custom_loop_agent_name: str, +) -> None: + dse_tr.test.agent = custom_loop_agent_name + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0 + assert CustomLoopStubAgent.train_called is True + assert CustomLoopStubAgent.shutdown_called is True diff --git a/uv.lock b/uv.lock index 580571b97..dd7770e30 100644 --- a/uv.lock +++ b/uv.lock @@ -311,6 +311,7 @@ dependencies = [ [package.optional-dependencies] dev = [ { name = "build" }, + { name = "gymnasium" }, { name = "import-linter" }, { name = "pandas-stubs" }, { name = "pre-commit" }, @@ -341,6 +342,9 @@ docs-cms = [ { name = "sphinx-rtd-theme" }, { name = "sphinxcontrib-mermaid" }, ] +rl = [ + { name = "gymnasium" }, +] [package.metadata] requires-dist = [ @@ -350,6 +354,8 @@ requires-dist = [ { name = "bokeh", specifier = "~=3.8" }, { name = "build", marker = "extra == 'dev'", specifier = "~=1.4" }, { name = "click", specifier = "~=8.3" }, + { name = "gymnasium", marker = "extra == 'dev'", specifier = "~=1.2" }, + { name = "gymnasium", marker = "extra == 'rl'", specifier = "~=1.2" }, { name = "huggingface-hub", specifier = "~=1.4" }, { name = "import-linter", marker = "extra == 'dev'", specifier = "~=2.10" }, { name = "jinja2", specifier = "~=3.1.6" }, @@ -380,7 +386,16 @@ requires-dist = [ { name = "vulture", marker = "extra == 'dev'", specifier = "==2.14" }, { name = "websockets", specifier = "~=16.0" }, ] -provides-extras = ["dev", "docs", "docs-cms"] +provides-extras = ["dev", "rl", "docs", "docs-cms"] + +[[package]] +name = "cloudpickle" +version = "3.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, +] [[package]] name = "colorama" @@ -713,6 +728,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, ] +[[package]] +name = "farama-notifications" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/2c/8384832b7a6b1fd6ba95bbdcae26e7137bb3eedc955c42fd5cdcc086cfbf/Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18", size = 2131, upload-time = "2023-02-27T18:28:41.047Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/2c/ffc08c54c05cdce6fbed2aeebc46348dbe180c6d2c541c7af7ba0aa5f5f8/Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae", size = 2511, upload-time = "2023-02-27T18:28:39.447Z" }, +] + [[package]] name = "fastapi" version = "0.128.6" @@ -980,6 +1004,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/b2/b096ccce418882fbfda4f7496f9357aaa9a5af1896a9a7f60d9f2b275a06/grpcio-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:dce09d6116df20a96acfdbf85e4866258c3758180e8c49845d6ba8248b6d0bbb", size = 4929852, upload-time = "2026-02-06T09:56:45.885Z" }, ] +[[package]] +name = "gymnasium" +version = "1.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cloudpickle" }, + { name = "farama-notifications" }, + { name = "numpy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/59/653a9417d98ed3e29ef9734ba52c3495f6c6823b8d5c0c75369f25111708/gymnasium-1.2.3.tar.gz", hash = "sha256:2b2cb5b5fbbbdf3afb9f38ca952cc48aa6aa3e26561400d940747fda3ad42509", size = 829230, upload-time = "2025-12-18T16:51:10.234Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/d3/ea5f088e3638dbab12e5c20d6559d5b3bdaeaa1f2af74e526e6815836285/gymnasium-1.2.3-py3-none-any.whl", hash = "sha256:e6314bba8f549c7fdcc8677f7cd786b64908af6e79b57ddaa5ce1825bffb5373", size = 952113, upload-time = "2025-12-18T16:51:08.445Z" }, +] + [[package]] name = "h11" version = "0.16.0"