Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions sdks/python/src/agent_control/control_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,46 @@ async def _evaluate(
return result_dict


def _unexpected_control_failure_message(stage: str, error: Exception) -> str:
"""Build a user-facing message for unexpected control evaluation failures."""
if stage == "pre":
return (
"Control check failed unexpectedly before execution. "
f"Execution blocked for safety. Error: {error}"
)
return (
"Control check failed unexpectedly after execution. "
f"Result blocked for safety. Error: {error}"
)


async def _run_control_check(
ctx: ControlContext,
stage: str,
payload: dict[str, Any],
controls: list[dict[str, Any]] | None,
) -> None:
"""Run one control stage and enforce fail-closed behavior on unexpected errors."""
try:
result = await _evaluate(
ctx.agent_name,
payload,
stage,
ctx.server_url,
ctx.trace_id,
ctx.span_id,
controls=controls,
event_agent_name=ctx.agent_name,
)
ctx.process_result(result, stage)
except (ControlViolationError, ControlSteerError):
raise
except Exception as e:
stage_name = "Pre" if stage == "pre" else "Post"
logger.error("%s-execution control check failed: %s", stage_name, e, exc_info=True)
raise RuntimeError(_unexpected_control_failure_message(stage, e)) from e


def _extract_input_from_args(func: Callable, args: tuple, kwargs: dict) -> str:
"""
Extract input data from function arguments.
Expand Down Expand Up @@ -644,6 +684,8 @@ async def _execute_with_control(

Raises:
ControlViolationError: If any control triggers with "deny" action
ControlSteerError: If any control triggers with "steer" action
RuntimeError: If control evaluation fails unexpectedly
"""
agent = _get_current_agent()
if agent is None:
Expand Down Expand Up @@ -682,22 +724,7 @@ async def _execute_with_control(

try:
# PRE-EXECUTION: Check controls with check_stage="pre"
try:
result = await _evaluate(
ctx.agent_name, ctx.pre_payload(), "pre",
ctx.server_url, ctx.trace_id, ctx.span_id,
controls=controls,
event_agent_name=ctx.agent_name,
)
ctx.process_result(result, "pre")
except (ControlViolationError, ControlSteerError):
raise
except Exception as e:
# FAIL-SAFE: If control check fails, DO NOT execute the function
logger.error(f"Pre-execution control check failed: {e}")
raise RuntimeError(
f"Control check failed unexpectedly. Execution blocked for safety. Error: {e}"
) from e
await _run_control_check(ctx, "pre", ctx.pre_payload(), controls)

# Execute the function
if is_async:
Expand All @@ -706,18 +733,7 @@ async def _execute_with_control(
output = func(*args, **kwargs)

# POST-EXECUTION: Check controls with check_stage="post"
try:
result = await _evaluate(
ctx.agent_name, ctx.post_payload(output), "post",
ctx.server_url, ctx.trace_id, ctx.span_id,
controls=controls,
event_agent_name=ctx.agent_name,
)
ctx.process_result(result, "post")
except (ControlViolationError, ControlSteerError):
raise
except Exception as e:
logger.error(f"Post-execution control check failed: {e}")
await _run_control_check(ctx, "post", ctx.post_payload(output), controls)

return output
finally:
Expand All @@ -742,6 +758,8 @@ def control(policy: str | None = None, step_name: str | None = None) -> Callable

Raises:
ControlViolationError: If any control triggers with "deny" action
ControlSteerError: If any control triggers with "steer" action
RuntimeError: If control evaluation fails unexpectedly

How it works:
1. Before function execution: Calls server with stage="pre"
Expand Down
21 changes: 14 additions & 7 deletions sdks/python/tests/test_control_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,9 +854,10 @@ async def test_func():
assert "Unexpected error" in str(exc_info.value)

@pytest.mark.asyncio
async def test_other_exceptions_logged_in_post_execution(self, mock_agent, mock_safe_response):
"""Test that non-control exceptions are logged (not raised) in post-execution."""
async def test_other_exceptions_wrapped_in_post_execution(self, mock_agent, mock_safe_response):
"""Test that non-control exceptions fail closed in post-execution."""
call_count = [0]
executed = {"value": False}

def mock_evaluate_side_effect(*args, **kwargs):
call_count[0] += 1
Expand All @@ -870,12 +871,18 @@ def mock_evaluate_side_effect(*args, **kwargs):

@control()
async def test_func():
executed["value"] = True
return "executed successfully"

# Function should still complete despite post-execution error
result = await test_func()
assert result == "executed successfully"
# Function still executes, but the result is withheld for safety.
with pytest.raises(RuntimeError) as exc_info:
await test_func()

assert executed["value"] is True
assert "Control check failed unexpectedly after execution" in str(exc_info.value)
assert "Post-execution error" in str(exc_info.value)

# Error should be logged
mock_logger.error.assert_called_once()
assert "Post-execution control check failed" in mock_logger.error.call_args[0][0]
assert mock_logger.error.call_args[0][0] == "%s-execution control check failed: %s"
assert mock_logger.error.call_args[0][1] == "Post"
assert str(mock_logger.error.call_args[0][2]) == "Post-execution error"
Loading