diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 0284fcd9e..5f862f270 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -20,7 +20,7 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Protocol, TypeGuard, runtime_checkable from unittest.mock import Mock import toml @@ -118,6 +118,60 @@ def prepare_installation( return installables, installer +@runtime_checkable +class CustomTrainingLoopAgent(Protocol): + """ + Agent that drives its own training loop and skips the ``handle_dse_job`` step loop. + + Set ``HAS_CUSTOM_TRAINING_LOOP = True`` on the agent class to opt in. Used by + agents (e.g. RLlib-based) whose training loops are not modelled as a sequence + of independent ``select_action`` / ``env.step`` calls. + """ + + HAS_CUSTOM_TRAINING_LOOP: bool + + def train(self) -> None: ... + + +def _has_custom_training_loop(agent: object) -> TypeGuard[CustomTrainingLoopAgent]: + """ + Narrow ``agent`` to :class:`CustomTrainingLoopAgent` when it opts into the dispatch path. + + Returning :class:`TypeGuard` (instead of plain ``bool``) lets the type checker + treat this predicate like ``isinstance``: callers inside the truthy branch see + ``agent`` as a :class:`CustomTrainingLoopAgent`, so ``agent.train()`` type-checks + without ``getattr`` or ``cast``. + """ + return bool(getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False)) + + +def _run_custom_training_loop(agent: CustomTrainingLoopAgent, agent_type: str) -> int: + """ + Drive an agent's self-contained training loop and return a process-style exit code. + + ``shutdown()`` runs inside its own ``try/except`` so a faulty teardown cannot + suppress the exit code from ``train()`` nor propagate out of this helper: + ``handle_dse_job`` relies on the returned ``rc`` to accumulate ``err |= rc`` + and continue with the remaining test runs. + """ + logging.info(f"Agent {agent_type} drives its own training loop; delegating to agent.train().") + rc = 0 + try: + agent.train() + except Exception: + logging.exception(f"Custom training loop failed for agent {agent_type}.") + rc = 1 + finally: + shutdown = getattr(agent, "shutdown", None) + if callable(shutdown): + try: + shutdown() + except Exception: + logging.exception(f"Shutdown failed for agent {agent_type}.") + rc = 1 + return rc + + def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: registry = Registry() @@ -157,6 +211,10 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: agent = agent_class(env, agent_config) + if _has_custom_training_loop(agent): + err |= _run_custom_training_loop(agent, agent_type) + continue + for step in range(agent.max_steps): result = agent.select_action() if result is None: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 5124186c0..fec9f2eff 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -15,15 +15,22 @@ # limitations under the License. import argparse +import logging from pathlib import Path -from typing import Any, ClassVar, Iterator +from typing import Any, ClassVar, Iterator, Optional from unittest.mock import MagicMock import pandas as pd import pytest from pydantic import Field -from cloudai.cli.handlers import handle_dse_job, verify_system_configs, verify_test_configs, verify_test_scenarios +from cloudai.cli.handlers import ( + _run_custom_training_loop, + handle_dse_job, + verify_system_configs, + verify_test_configs, + verify_test_scenarios, +) from cloudai.core import ( BaseAgent, BaseAgentConfig, @@ -254,3 +261,154 @@ def test_verify_test_scenarios_logs_failure_details(tmp_path: Path, caplog: pyte assert str(broken_scenario) in caplog.text assert "duplicate TOML key 'name'" in caplog.text assert "1 out of 1 test scenarios have issues." in caplog.text + + +class CustomLoopStubAgentConfig(BaseAgentConfig): + pass + + +class CustomLoopStubAgent(BaseAgent): + """Stub agent that opts into the custom-training-loop dispatch path.""" + + HAS_CUSTOM_TRAINING_LOOP: ClassVar[bool] = True + + train_calls: ClassVar[int] = 0 + shutdown_calls: ClassVar[int] = 0 + train_raises: ClassVar[Optional[BaseException]] = None + + def __init__(self, env, config: CustomLoopStubAgentConfig): + self.env = env + self.config = config + self.max_steps = 0 + + @staticmethod + def get_config_class() -> type[CustomLoopStubAgentConfig]: + return CustomLoopStubAgentConfig + + def configure(self, config: dict[str, Any]) -> None: + raise NotImplementedError + + def select_action(self) -> tuple[int, dict[str, Any]]: # pragma: no cover - never called + raise AssertionError("select_action must not be called when HAS_CUSTOM_TRAINING_LOOP is True") + + def update_policy(self, _feedback: dict[str, Any]) -> None: + return + + def train(self) -> None: + CustomLoopStubAgent.train_calls += 1 + if CustomLoopStubAgent.train_raises is not None: + raise CustomLoopStubAgent.train_raises + + def shutdown(self) -> None: + CustomLoopStubAgent.shutdown_calls += 1 + + +@pytest.fixture +def custom_loop_agent_name() -> Iterator[str]: + registry = Registry() + agent_name = "test_handlers_custom_loop_agent" + old_agent = registry.agents_map.get(agent_name) + registry.update_agent(agent_name, CustomLoopStubAgent) + CustomLoopStubAgent.train_calls = 0 + CustomLoopStubAgent.shutdown_calls = 0 + CustomLoopStubAgent.train_raises = None + yield agent_name + CustomLoopStubAgent.train_calls = 0 + CustomLoopStubAgent.shutdown_calls = 0 + CustomLoopStubAgent.train_raises = None + if old_agent is None: + del registry.agents_map[agent_name] + else: + registry.update_agent(agent_name, old_agent) + + +def test_run_custom_training_loop_calls_train_and_shutdown() -> None: + agent = MagicMock() + agent.train = MagicMock() + agent.shutdown = MagicMock() + + assert _run_custom_training_loop(agent, "mock_agent") == 0 + agent.train.assert_called_once_with() + agent.shutdown.assert_called_once_with() + + +def test_run_custom_training_loop_returns_error_and_still_shuts_down( + caplog: pytest.LogCaptureFixture, +) -> None: + agent = MagicMock() + agent.train = MagicMock(side_effect=RuntimeError("boom")) + agent.shutdown = MagicMock() + + with caplog.at_level(logging.ERROR): + assert _run_custom_training_loop(agent, "mock_agent") == 1 + + agent.shutdown.assert_called_once_with() + assert "boom" in caplog.text + + +def test_run_custom_training_loop_tolerates_missing_shutdown() -> None: + agent = MagicMock(spec=["train"]) + agent.train = MagicMock() + + assert _run_custom_training_loop(agent, "mock_agent") == 0 + agent.train.assert_called_once_with() + + +def test_run_custom_training_loop_reports_shutdown_failure( + caplog: pytest.LogCaptureFixture, +) -> None: + """shutdown() raising must not suppress the exit code or propagate the exception.""" + agent = MagicMock() + agent.train = MagicMock() + agent.shutdown = MagicMock(side_effect=RuntimeError("teardown blew up")) + + with caplog.at_level(logging.ERROR): + assert _run_custom_training_loop(agent, "mock_agent") == 1 + + agent.train.assert_called_once_with() + agent.shutdown.assert_called_once_with() + assert "teardown blew up" in caplog.text + + +def test_run_custom_training_loop_reports_combined_train_and_shutdown_failures( + caplog: pytest.LogCaptureFixture, +) -> None: + """When both train() and shutdown() raise, the helper still returns 1 and logs both.""" + agent = MagicMock() + agent.train = MagicMock(side_effect=RuntimeError("training boom")) + agent.shutdown = MagicMock(side_effect=RuntimeError("teardown boom")) + + with caplog.at_level(logging.ERROR): + assert _run_custom_training_loop(agent, "mock_agent") == 1 + + agent.shutdown.assert_called_once_with() + assert "training boom" in caplog.text + assert "teardown boom" in caplog.text + + +def test_handle_dse_job_dispatches_to_custom_training_loop( + slurm_system: SlurmSystem, + dse_tr: TestRun, + custom_loop_agent_name: str, +) -> None: + dse_tr.test.agent = custom_loop_agent_name + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0 + assert CustomLoopStubAgent.train_calls == 1 + assert CustomLoopStubAgent.shutdown_calls == 1 + + +def test_handle_dse_job_propagates_custom_loop_failure( + slurm_system: SlurmSystem, + dse_tr: TestRun, + custom_loop_agent_name: str, +) -> None: + CustomLoopStubAgent.train_raises = RuntimeError("training blew up") + dse_tr.test.agent = custom_loop_agent_name + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 1 + assert CustomLoopStubAgent.shutdown_calls == 1