From cc1eb0afa3ef2f8da9cd14102a148eadda1399d4 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Fri, 15 May 2026 16:57:04 -0400 Subject: [PATCH 1/3] feat(cli): support agents with custom training loops in handle_dse_job - Agents that set HAS_CUSTOM_TRAINING_LOOP = True drive their own training loop; handle_dse_job calls agent.train() and skips the per-step env.step loop. - New _run_custom_training_loop helper logs exceptions, returns a process-style exit code, and always invokes agent.shutdown() (when defined) in a finally block so resources are released on both success and failure paths. - CustomTrainingLoopAgent Protocol documents the opt-in contract for type checkers and IDEs. --- src/cloudai/cli/handlers.py | 40 ++++++++++- tests/test_handlers.py | 130 +++++++++++++++++++++++++++++++++++- 2 files changed, 167 insertions(+), 3 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 0284fcd9e..49f750529 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -20,7 +20,7 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Protocol, runtime_checkable from unittest.mock import Mock import toml @@ -118,6 +118,40 @@ def prepare_installation( return installables, installer +@runtime_checkable +class CustomTrainingLoopAgent(Protocol): + """ + Agent that drives its own training loop and skips the ``handle_dse_job`` step loop. + + Set ``HAS_CUSTOM_TRAINING_LOOP = True`` on the agent class to opt in. Used by + agents (e.g. RLlib-based) whose training loops are not modelled as a sequence + of independent ``select_action`` / ``env.step`` calls. + """ + + HAS_CUSTOM_TRAINING_LOOP: bool + + def train(self) -> None: ... + + +def _has_custom_training_loop(agent: object) -> bool: + return bool(getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False)) + + +def _run_custom_training_loop(agent: CustomTrainingLoopAgent, agent_type: str) -> int: + """Drive an agent's self-contained training loop and return a process-style exit code.""" + logging.info(f"Agent {agent_type} drives its own training loop; delegating to agent.train().") + try: + agent.train() + return 0 + except Exception: + logging.exception(f"Custom training loop failed for agent {agent_type}.") + return 1 + finally: + shutdown = getattr(agent, "shutdown", None) + if callable(shutdown): + shutdown() + + def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: registry = Registry() @@ -157,6 +191,10 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: agent = agent_class(env, agent_config) + if _has_custom_training_loop(agent): + err |= _run_custom_training_loop(agent, agent_type) + continue + for step in range(agent.max_steps): result = agent.select_action() if result is None: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 5124186c0..19e4b0eae 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -15,15 +15,22 @@ # limitations under the License. import argparse +import logging from pathlib import Path -from typing import Any, ClassVar, Iterator +from typing import Any, ClassVar, Iterator, Optional from unittest.mock import MagicMock import pandas as pd import pytest from pydantic import Field -from cloudai.cli.handlers import handle_dse_job, verify_system_configs, verify_test_configs, verify_test_scenarios +from cloudai.cli.handlers import ( + _run_custom_training_loop, + handle_dse_job, + verify_system_configs, + verify_test_configs, + verify_test_scenarios, +) from cloudai.core import ( BaseAgent, BaseAgentConfig, @@ -254,3 +261,122 @@ def test_verify_test_scenarios_logs_failure_details(tmp_path: Path, caplog: pyte assert str(broken_scenario) in caplog.text assert "duplicate TOML key 'name'" in caplog.text assert "1 out of 1 test scenarios have issues." in caplog.text + + +class CustomLoopStubAgentConfig(BaseAgentConfig): + pass + + +class CustomLoopStubAgent(BaseAgent): + """Stub agent that opts into the custom-training-loop dispatch path.""" + + HAS_CUSTOM_TRAINING_LOOP: ClassVar[bool] = True + + train_calls: ClassVar[int] = 0 + shutdown_calls: ClassVar[int] = 0 + train_raises: ClassVar[Optional[BaseException]] = None + + def __init__(self, env, config: CustomLoopStubAgentConfig): + self.env = env + self.config = config + self.max_steps = 0 + + @staticmethod + def get_config_class() -> type[CustomLoopStubAgentConfig]: + return CustomLoopStubAgentConfig + + def configure(self, config: dict[str, Any]) -> None: + raise NotImplementedError + + def select_action(self) -> tuple[int, dict[str, Any]]: # pragma: no cover - never called + raise AssertionError("select_action must not be called when HAS_CUSTOM_TRAINING_LOOP is True") + + def update_policy(self, _feedback: dict[str, Any]) -> None: + return + + def train(self) -> None: + CustomLoopStubAgent.train_calls += 1 + if CustomLoopStubAgent.train_raises is not None: + raise CustomLoopStubAgent.train_raises + + def shutdown(self) -> None: + CustomLoopStubAgent.shutdown_calls += 1 + + +@pytest.fixture +def custom_loop_agent_name() -> Iterator[str]: + registry = Registry() + agent_name = "test_handlers_custom_loop_agent" + old_agent = registry.agents_map.get(agent_name) + registry.update_agent(agent_name, CustomLoopStubAgent) + CustomLoopStubAgent.train_calls = 0 + CustomLoopStubAgent.shutdown_calls = 0 + CustomLoopStubAgent.train_raises = None + yield agent_name + CustomLoopStubAgent.train_calls = 0 + CustomLoopStubAgent.shutdown_calls = 0 + CustomLoopStubAgent.train_raises = None + if old_agent is None: + del registry.agents_map[agent_name] + else: + registry.update_agent(agent_name, old_agent) + + +def test_run_custom_training_loop_calls_train_and_shutdown() -> None: + agent = MagicMock() + agent.train = MagicMock() + agent.shutdown = MagicMock() + + assert _run_custom_training_loop(agent, "mock_agent") == 0 + agent.train.assert_called_once_with() + agent.shutdown.assert_called_once_with() + + +def test_run_custom_training_loop_returns_error_and_still_shuts_down( + caplog: pytest.LogCaptureFixture, +) -> None: + agent = MagicMock() + agent.train = MagicMock(side_effect=RuntimeError("boom")) + agent.shutdown = MagicMock() + + with caplog.at_level(logging.ERROR): + assert _run_custom_training_loop(agent, "mock_agent") == 1 + + agent.shutdown.assert_called_once_with() + assert "boom" in caplog.text + + +def test_run_custom_training_loop_tolerates_missing_shutdown() -> None: + agent = MagicMock(spec=["train"]) + agent.train = MagicMock() + + assert _run_custom_training_loop(agent, "mock_agent") == 0 + agent.train.assert_called_once_with() + + +def test_handle_dse_job_dispatches_to_custom_training_loop( + slurm_system: SlurmSystem, + dse_tr: TestRun, + custom_loop_agent_name: str, +) -> None: + dse_tr.test.agent = custom_loop_agent_name + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0 + assert CustomLoopStubAgent.train_calls == 1 + assert CustomLoopStubAgent.shutdown_calls == 1 + + +def test_handle_dse_job_propagates_custom_loop_failure( + slurm_system: SlurmSystem, + dse_tr: TestRun, + custom_loop_agent_name: str, +) -> None: + CustomLoopStubAgent.train_raises = RuntimeError("training blew up") + dse_tr.test.agent = custom_loop_agent_name + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 1 + assert CustomLoopStubAgent.shutdown_calls == 1 From ede2ae5a7af2b9a9d4763cbcaa79f56b7afaa343 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Fri, 15 May 2026 18:30:15 -0400 Subject: [PATCH 2/3] fix(cli): narrow agent type via TypeGuard in custom-loop dispatch Pyright rejected calling _run_custom_training_loop(agent, ...) because the plain bool predicate did not narrow agent's static type from BaseAgent to CustomTrainingLoopAgent. Return TypeGuard[CustomTrainingLoopAgent] from _has_custom_training_loop so the truthy branch in handle_dse_job sees the opted-in shape and the helper can call agent.train() directly. --- src/cloudai/cli/handlers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 49f750529..7d8c33689 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -20,7 +20,7 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Optional, Protocol, runtime_checkable +from typing import Callable, List, Optional, Protocol, TypeGuard, runtime_checkable from unittest.mock import Mock import toml @@ -133,7 +133,15 @@ class CustomTrainingLoopAgent(Protocol): def train(self) -> None: ... -def _has_custom_training_loop(agent: object) -> bool: +def _has_custom_training_loop(agent: object) -> TypeGuard[CustomTrainingLoopAgent]: + """ + Narrow ``agent`` to :class:`CustomTrainingLoopAgent` when it opts into the dispatch path. + + Returning :class:`TypeGuard` (instead of plain ``bool``) lets the type checker + treat this predicate like ``isinstance``: callers inside the truthy branch see + ``agent`` as a :class:`CustomTrainingLoopAgent`, so ``agent.train()`` type-checks + without ``getattr`` or ``cast``. + """ return bool(getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False)) From 9552e5a5ff0bf821136c3a96f00ff2cc4c7faef1 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Mon, 18 May 2026 12:21:47 -0400 Subject: [PATCH 3/3] review: isolate shutdown() failures from the exit-code contract If agent.shutdown() raised from the finally block, Python suppressed the earlier return 0/1 from agent.train() and propagated the exception, breaking the outer test-run loop in handle_dse_job (skipped remaining scenarios, failed to accumulate err |= rc). Wrap shutdown() in its own try/except, log via logging.exception, set rc = 1, and return rc after finally so the helper always honours the (int) -> int contract. Adds tests for shutdown-only failure and combined train+shutdown failure. --- src/cloudai/cli/handlers.py | 20 ++++++++++++++++---- tests/test_handlers.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 7d8c33689..5f862f270 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -146,18 +146,30 @@ def _has_custom_training_loop(agent: object) -> TypeGuard[CustomTrainingLoopAgen def _run_custom_training_loop(agent: CustomTrainingLoopAgent, agent_type: str) -> int: - """Drive an agent's self-contained training loop and return a process-style exit code.""" + """ + Drive an agent's self-contained training loop and return a process-style exit code. + + ``shutdown()`` runs inside its own ``try/except`` so a faulty teardown cannot + suppress the exit code from ``train()`` nor propagate out of this helper: + ``handle_dse_job`` relies on the returned ``rc`` to accumulate ``err |= rc`` + and continue with the remaining test runs. + """ logging.info(f"Agent {agent_type} drives its own training loop; delegating to agent.train().") + rc = 0 try: agent.train() - return 0 except Exception: logging.exception(f"Custom training loop failed for agent {agent_type}.") - return 1 + rc = 1 finally: shutdown = getattr(agent, "shutdown", None) if callable(shutdown): - shutdown() + try: + shutdown() + except Exception: + logging.exception(f"Shutdown failed for agent {agent_type}.") + rc = 1 + return rc def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 19e4b0eae..fec9f2eff 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -354,6 +354,38 @@ def test_run_custom_training_loop_tolerates_missing_shutdown() -> None: agent.train.assert_called_once_with() +def test_run_custom_training_loop_reports_shutdown_failure( + caplog: pytest.LogCaptureFixture, +) -> None: + """shutdown() raising must not suppress the exit code or propagate the exception.""" + agent = MagicMock() + agent.train = MagicMock() + agent.shutdown = MagicMock(side_effect=RuntimeError("teardown blew up")) + + with caplog.at_level(logging.ERROR): + assert _run_custom_training_loop(agent, "mock_agent") == 1 + + agent.train.assert_called_once_with() + agent.shutdown.assert_called_once_with() + assert "teardown blew up" in caplog.text + + +def test_run_custom_training_loop_reports_combined_train_and_shutdown_failures( + caplog: pytest.LogCaptureFixture, +) -> None: + """When both train() and shutdown() raise, the helper still returns 1 and logs both.""" + agent = MagicMock() + agent.train = MagicMock(side_effect=RuntimeError("training boom")) + agent.shutdown = MagicMock(side_effect=RuntimeError("teardown boom")) + + with caplog.at_level(logging.ERROR): + assert _run_custom_training_loop(agent, "mock_agent") == 1 + + agent.shutdown.assert_called_once_with() + assert "training boom" in caplog.text + assert "teardown boom" in caplog.text + + def test_handle_dse_job_dispatches_to_custom_training_loop( slurm_system: SlurmSystem, dse_tr: TestRun,