Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import signal
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Protocol, TypeGuard, runtime_checkable
from unittest.mock import Mock

import toml
Expand Down Expand Up @@ -118,6 +118,60 @@ def prepare_installation(
return installables, installer


@runtime_checkable
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move this code into base_agent.py. handlers.py is already too long

as for the tests against _run_custom_training_loop: I'm starting to make the tests folder structure replicate the main code structure. so in this case, I'd place all the relevant tests you added into tests/configurator/test_base_agent.py

(not related to tests against handle_dse_job)

class CustomTrainingLoopAgent(Protocol):
"""
Agent that drives its own training loop and skips the ``handle_dse_job`` step loop.

Set ``HAS_CUSTOM_TRAINING_LOOP = True`` on the agent class to opt in. Used by
agents (e.g. RLlib-based) whose training loops are not modelled as a sequence
of independent ``select_action`` / ``env.step`` calls.
"""

HAS_CUSTOM_TRAINING_LOOP: bool

def train(self) -> None: ...


def _has_custom_training_loop(agent: object) -> TypeGuard[CustomTrainingLoopAgent]:
"""
Narrow ``agent`` to :class:`CustomTrainingLoopAgent` when it opts into the dispatch path.

Returning :class:`TypeGuard` (instead of plain ``bool``) lets the type checker
treat this predicate like ``isinstance``: callers inside the truthy branch see
``agent`` as a :class:`CustomTrainingLoopAgent`, so ``agent.train()`` type-checks
without ``getattr`` or ``cast``.
"""
return bool(getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False))


def _run_custom_training_loop(agent: CustomTrainingLoopAgent, agent_type: str) -> int:
"""
Drive an agent's self-contained training loop and return a process-style exit code.

``shutdown()`` runs inside its own ``try/except`` so a faulty teardown cannot
suppress the exit code from ``train()`` nor propagate out of this helper:
``handle_dse_job`` relies on the returned ``rc`` to accumulate ``err |= rc``
and continue with the remaining test runs.
"""
logging.info(f"Agent {agent_type} drives its own training loop; delegating to agent.train().")
rc = 0
try:
agent.train()
except Exception:
logging.exception(f"Custom training loop failed for agent {agent_type}.")
rc = 1
finally:
shutdown = getattr(agent, "shutdown", None)
if callable(shutdown):
try:
shutdown()
except Exception:
logging.exception(f"Shutdown failed for agent {agent_type}.")
rc = 1
return rc


def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
registry = Registry()

Expand Down Expand Up @@ -157,6 +211,10 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:

agent = agent_class(env, agent_config)

if _has_custom_training_loop(agent):
err |= _run_custom_training_loop(agent, agent_type)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we exit (immediate return err) if err is greater than zero? The existing code above doesn't really treat the err well but maybe it's the time to start doing so :D

continue

for step in range(agent.max_steps):
result = agent.select_action()
if result is None:
Expand Down
162 changes: 160 additions & 2 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,22 @@
# limitations under the License.

import argparse
import logging
from pathlib import Path
from typing import Any, ClassVar, Iterator
from typing import Any, ClassVar, Iterator, Optional
from unittest.mock import MagicMock

import pandas as pd
import pytest
from pydantic import Field

from cloudai.cli.handlers import handle_dse_job, verify_system_configs, verify_test_configs, verify_test_scenarios
from cloudai.cli.handlers import (
_run_custom_training_loop,
handle_dse_job,
verify_system_configs,
verify_test_configs,
verify_test_scenarios,
)
from cloudai.core import (
BaseAgent,
BaseAgentConfig,
Expand Down Expand Up @@ -254,3 +261,154 @@ def test_verify_test_scenarios_logs_failure_details(tmp_path: Path, caplog: pyte
assert str(broken_scenario) in caplog.text
assert "duplicate TOML key 'name'" in caplog.text
assert "1 out of 1 test scenarios have issues." in caplog.text


class CustomLoopStubAgentConfig(BaseAgentConfig):
pass


class CustomLoopStubAgent(BaseAgent):
"""Stub agent that opts into the custom-training-loop dispatch path."""

HAS_CUSTOM_TRAINING_LOOP: ClassVar[bool] = True

train_calls: ClassVar[int] = 0
shutdown_calls: ClassVar[int] = 0
train_raises: ClassVar[Optional[BaseException]] = None

def __init__(self, env, config: CustomLoopStubAgentConfig):
self.env = env
self.config = config
self.max_steps = 0

@staticmethod
def get_config_class() -> type[CustomLoopStubAgentConfig]:
return CustomLoopStubAgentConfig

def configure(self, config: dict[str, Any]) -> None:
raise NotImplementedError

def select_action(self) -> tuple[int, dict[str, Any]]: # pragma: no cover - never called
raise AssertionError("select_action must not be called when HAS_CUSTOM_TRAINING_LOOP is True")

def update_policy(self, _feedback: dict[str, Any]) -> None:
return

def train(self) -> None:
CustomLoopStubAgent.train_calls += 1
if CustomLoopStubAgent.train_raises is not None:
raise CustomLoopStubAgent.train_raises

def shutdown(self) -> None:
CustomLoopStubAgent.shutdown_calls += 1


@pytest.fixture
def custom_loop_agent_name() -> Iterator[str]:
registry = Registry()
agent_name = "test_handlers_custom_loop_agent"
old_agent = registry.agents_map.get(agent_name)
registry.update_agent(agent_name, CustomLoopStubAgent)
CustomLoopStubAgent.train_calls = 0
CustomLoopStubAgent.shutdown_calls = 0
CustomLoopStubAgent.train_raises = None
yield agent_name
CustomLoopStubAgent.train_calls = 0
CustomLoopStubAgent.shutdown_calls = 0
CustomLoopStubAgent.train_raises = None
if old_agent is None:
del registry.agents_map[agent_name]
else:
registry.update_agent(agent_name, old_agent)


def test_run_custom_training_loop_calls_train_and_shutdown() -> None:
agent = MagicMock()
agent.train = MagicMock()
agent.shutdown = MagicMock()

assert _run_custom_training_loop(agent, "mock_agent") == 0
agent.train.assert_called_once_with()
agent.shutdown.assert_called_once_with()


def test_run_custom_training_loop_returns_error_and_still_shuts_down(
caplog: pytest.LogCaptureFixture,
) -> None:
agent = MagicMock()
agent.train = MagicMock(side_effect=RuntimeError("boom"))
agent.shutdown = MagicMock()

with caplog.at_level(logging.ERROR):
assert _run_custom_training_loop(agent, "mock_agent") == 1

agent.shutdown.assert_called_once_with()
assert "boom" in caplog.text


def test_run_custom_training_loop_tolerates_missing_shutdown() -> None:
agent = MagicMock(spec=["train"])
agent.train = MagicMock()

assert _run_custom_training_loop(agent, "mock_agent") == 0
agent.train.assert_called_once_with()


def test_run_custom_training_loop_reports_shutdown_failure(
caplog: pytest.LogCaptureFixture,
) -> None:
"""shutdown() raising must not suppress the exit code or propagate the exception."""
agent = MagicMock()
agent.train = MagicMock()
agent.shutdown = MagicMock(side_effect=RuntimeError("teardown blew up"))

with caplog.at_level(logging.ERROR):
assert _run_custom_training_loop(agent, "mock_agent") == 1

agent.train.assert_called_once_with()
agent.shutdown.assert_called_once_with()
assert "teardown blew up" in caplog.text


def test_run_custom_training_loop_reports_combined_train_and_shutdown_failures(
caplog: pytest.LogCaptureFixture,
) -> None:
"""When both train() and shutdown() raise, the helper still returns 1 and logs both."""
agent = MagicMock()
agent.train = MagicMock(side_effect=RuntimeError("training boom"))
agent.shutdown = MagicMock(side_effect=RuntimeError("teardown boom"))

with caplog.at_level(logging.ERROR):
assert _run_custom_training_loop(agent, "mock_agent") == 1

agent.shutdown.assert_called_once_with()
assert "training boom" in caplog.text
assert "teardown boom" in caplog.text


def test_handle_dse_job_dispatches_to_custom_training_loop(
slurm_system: SlurmSystem,
dse_tr: TestRun,
custom_loop_agent_name: str,
) -> None:
dse_tr.test.agent = custom_loop_agent_name
test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr])
runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario)

assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0
assert CustomLoopStubAgent.train_calls == 1
assert CustomLoopStubAgent.shutdown_calls == 1


def test_handle_dse_job_propagates_custom_loop_failure(
slurm_system: SlurmSystem,
dse_tr: TestRun,
custom_loop_agent_name: str,
) -> None:
CustomLoopStubAgent.train_raises = RuntimeError("training blew up")
dse_tr.test.agent = custom_loop_agent_name
test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr])
runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario)

assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 1
assert CustomLoopStubAgent.shutdown_calls == 1
Loading