diff --git a/agent_assembly/__init__.py b/agent_assembly/__init__.py index fcdf256..3a8491e 100644 --- a/agent_assembly/__init__.py +++ b/agent_assembly/__init__.py @@ -8,6 +8,7 @@ AssemblyError, ConfigurationError, GatewayError, + MCPToolBlockedError, PolicyError, ToolExecutionBlockedError, ) @@ -37,6 +38,7 @@ "ConfigurationError", "AdapterValidationError", "ToolExecutionBlockedError", + "MCPToolBlockedError", ] if "RuntimeClient" in globals(): diff --git a/agent_assembly/adapters/mcp/patch.py b/agent_assembly/adapters/mcp/patch.py index 4923f0b..cfefe9a 100644 --- a/agent_assembly/adapters/mcp/patch.py +++ b/agent_assembly/adapters/mcp/patch.py @@ -3,8 +3,20 @@ from __future__ import annotations from dataclasses import dataclass +import importlib import importlib.util -from typing import Any +import inspect +from typing import Any, Literal, Mapping + +from agent_assembly.adapters.crewai.patch import ( + _get_pending_tool_approval_timeout_seconds as _resolve_pending_timeout_seconds, +) +from agent_assembly.adapters.crewai.patch import _normalize_decision as _normalize_governance_decision + +_ORIGINAL_CALL_TOOL = "_agent_assembly_original_mcp_call_tool" +_PATCHED_FLAG = "_agent_assembly_mcp_clientsession_patched" +_PROCESS_AGENT_ID: str | None = None +_MAX_AUDIT_RESULT_CHARS = 2000 @dataclass(slots=True) @@ -12,14 +24,290 @@ class MCPClientPatch: """Patch placeholder for MCP client interception.""" callback_handler: Any + process_agent_id: str | None = None def apply(self) -> bool: - _ = self.callback_handler - return _is_mcp_available() + set_process_agent_id(self.process_agent_id) + client_session_cls = _load_mcp_client_session_class() + if client_session_cls is None: + return False + _apply_client_session_patch(client_session_cls, self.callback_handler) + return True def revert(self) -> None: + client_session_cls = _load_mcp_client_session_class() + if client_session_cls is not None: + _revert_client_session_patch(client_session_cls) + set_process_agent_id(None) return None def _is_mcp_available() -> bool: return importlib.util.find_spec("mcp") is not None + + +def set_process_agent_id(agent_id: str | None) -> None: + global _PROCESS_AGENT_ID + _PROCESS_AGENT_ID = agent_id + + +def _get_process_agent_id() -> str | None: + if isinstance(_PROCESS_AGENT_ID, str) and _PROCESS_AGENT_ID: + return _PROCESS_AGENT_ID + return None + + +def _load_mcp_client_session_class() -> type[Any] | None: + try: + module = importlib.import_module("mcp") + except ImportError: + return None + + client_session_cls = getattr(module, "ClientSession", None) + if isinstance(client_session_cls, type): + return client_session_cls + return None + + +def _get_server_identifier(session: Any) -> str: + for attr in ("_server_url", "_server_name", "_ws_url"): + value = getattr(session, attr, None) + if isinstance(value, str) and value.strip(): + return value + + transport = getattr(session, "_transport", None) + if transport is not None: + for attr in ("url", "server_url", "server_name", "ws_url", "name"): + value = getattr(transport, attr, None) + if isinstance(value, str) and value.strip(): + return value + + return "mcp-unknown" + + +def _resolve_governance_target(callback_handler: Any) -> Any: + target = getattr(callback_handler, "_interceptor", None) + if target is not None: + return target + return callback_handler + + +def _extract_tool_call_inputs( + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> tuple[str, dict[str, Any]]: + raw_tool_name = kwargs.get("name") + if not isinstance(raw_tool_name, str): + raw_tool_name = str(args[0]) if args else "mcp-unknown-tool" + + raw_arguments = kwargs.get("arguments") + if raw_arguments is None and len(args) >= 2: + raw_arguments = args[1] + + if isinstance(raw_arguments, Mapping): + return raw_tool_name, dict(raw_arguments) + return raw_tool_name, {} + + +def _normalize_decision( + decision: object, +) -> tuple[Literal["allow", "deny", "pending"], str | None]: + return _normalize_governance_decision(decision) + + +def _get_pending_tool_approval_timeout_seconds(callback_handler: Any) -> int: + return _resolve_pending_timeout_seconds(callback_handler) + + +async def _invoke_async_tool_check( + callback_handler: Any, + *, + tool_name: str, + tool_args: dict[str, Any], + agent_id: str | None, + server_identifier: str, +) -> object: + target = _resolve_governance_target(callback_handler) + method = getattr(target, "check_tool_start", None) + if not callable(method): + return {"status": "allow"} + + result = method( + serialized={"name": tool_name}, + input_str=str(tool_args), + tool_name=tool_name, + args=tool_args, + agent_id=agent_id, + server=server_identifier, + ) + if inspect.isawaitable(result): + return await result + return result + + +async def _wait_for_async_tool_approval( + callback_handler: Any, + *, + tool_name: str, + timeout_seconds: int, + tool_args: dict[str, Any], + agent_id: str | None, + server_identifier: str, +) -> object: + target = _resolve_governance_target(callback_handler) + method = getattr(target, "wait_for_tool_approval", None) + if not callable(method): + return {"status": "deny", "reason": "Approval handler is unavailable."} + + result = method( + serialized={"name": tool_name}, + input_str=str(tool_args), + tool_name=tool_name, + timeout_seconds=timeout_seconds, + args=tool_args, + agent_id=agent_id, + server=server_identifier, + ) + if inspect.isawaitable(result): + return await result + return result + + +def _truncate_result_for_audit(result: object) -> str: + return str(result)[:_MAX_AUDIT_RESULT_CHARS] + + +async def _record_async_tool_result( + callback_handler: Any, + *, + tool_name: str, + result: object, + agent_id: str | None, + server_identifier: str, +) -> None: + target = _resolve_governance_target(callback_handler) + + record_method = getattr(target, "record_result", None) + if callable(record_method): + recorded = record_method( + tool_name=tool_name, + result=_truncate_result_for_audit(result), + agent_id=agent_id, + server=server_identifier, + ) + if inspect.isawaitable(recorded): + await recorded + return None + + tool_end_method = getattr(target, "on_tool_end", None) + if callable(tool_end_method): + recorded = tool_end_method( + output=_truncate_result_for_audit(result), + tool_name=tool_name, + agent_id=agent_id, + server=server_identifier, + ) + if inspect.isawaitable(recorded): + await recorded + + +def _build_blocked_error( + *, + tool_name: str, + server_identifier: str, + reason: str | None, + is_pending_rejection: bool, +) -> Exception: + from agent_assembly.exceptions import MCPToolBlockedError + + reason_text = reason or "No reason provided." + if is_pending_rejection: + message = ( + f"MCP tool '{tool_name}' on server '{server_identifier}' " + f"rejected during approval: {reason_text}" + ) + else: + message = ( + f"MCP tool '{tool_name}' on server '{server_identifier}' " + f"blocked by governance policy: {reason_text}" + ) + + return MCPToolBlockedError( + message, + tool_name=tool_name, + server=server_identifier, + ) + + +def _apply_client_session_patch(client_session_cls: type[Any], callback_handler: Any) -> None: + if getattr(client_session_cls, _PATCHED_FLAG, False): + return None + + original_call_tool = getattr(client_session_cls, "call_tool", None) + if not callable(original_call_tool): + return None + + async def patched_call_tool(self: Any, *args: Any, **kwargs: Any) -> Any: + tool_name, tool_args = _extract_tool_call_inputs(args, kwargs) + agent_id = _get_process_agent_id() + server_identifier = _get_server_identifier(self) + + decision = await _invoke_async_tool_check( + callback_handler, + tool_name=tool_name, + tool_args=tool_args, + agent_id=agent_id, + server_identifier=server_identifier, + ) + status, reason = _normalize_decision(decision) + is_pending_flow = False + if status == "pending": + is_pending_flow = True + timeout_seconds = _get_pending_tool_approval_timeout_seconds(callback_handler) + final_decision = await _wait_for_async_tool_approval( + callback_handler, + tool_name=tool_name, + timeout_seconds=timeout_seconds, + tool_args=tool_args, + agent_id=agent_id, + server_identifier=server_identifier, + ) + status, reason = _normalize_decision(final_decision) + + if status == "deny": + raise _build_blocked_error( + tool_name=tool_name, + server_identifier=server_identifier, + reason=reason, + is_pending_rejection=is_pending_flow, + ) + + result = original_call_tool(self, *args, **kwargs) + if inspect.isawaitable(result): + result = await result + await _record_async_tool_result( + callback_handler, + tool_name=tool_name, + result=result, + agent_id=agent_id, + server_identifier=server_identifier, + ) + return result + + setattr(client_session_cls, _ORIGINAL_CALL_TOOL, original_call_tool) + setattr(client_session_cls, "call_tool", patched_call_tool) + setattr(client_session_cls, _PATCHED_FLAG, True) + + +def _revert_client_session_patch(client_session_cls: type[Any]) -> None: + if not getattr(client_session_cls, _PATCHED_FLAG, False): + return None + + original_call_tool = getattr(client_session_cls, _ORIGINAL_CALL_TOOL, None) + if callable(original_call_tool): + setattr(client_session_cls, "call_tool", original_call_tool) + + if hasattr(client_session_cls, _ORIGINAL_CALL_TOOL): + delattr(client_session_cls, _ORIGINAL_CALL_TOOL) + if hasattr(client_session_cls, _PATCHED_FLAG): + delattr(client_session_cls, _PATCHED_FLAG) diff --git a/agent_assembly/core/assembly.py b/agent_assembly/core/assembly.py index 2347681..60b1fb5 100644 --- a/agent_assembly/core/assembly.py +++ b/agent_assembly/core/assembly.py @@ -189,7 +189,12 @@ def _build_patch_plan(client: GatewayClient, process_agent_id: str) -> list[Runt patch_plan.append(OpenAIAgentsPatch(callback_target)) if _is_installed("mcp"): # Keep MCP patch last as fallback for remaining tool dispatch paths. - patch_plan.append(MCPClientPatch(callback_target)) + patch_plan.append( + MCPClientPatch( + callback_handler=callback_target, + process_agent_id=process_agent_id, + ) + ) return patch_plan diff --git a/agent_assembly/exceptions/__init__.py b/agent_assembly/exceptions/__init__.py index 91b91d1..49b2b5b 100644 --- a/agent_assembly/exceptions/__init__.py +++ b/agent_assembly/exceptions/__init__.py @@ -10,6 +10,7 @@ "ConfigurationError", "AdapterValidationError", "ToolExecutionBlockedError", + "MCPToolBlockedError", "PolicyViolationError", ] @@ -49,6 +50,21 @@ class ToolExecutionBlockedError(AssemblyError): pass +class MCPToolBlockedError(ToolExecutionBlockedError): + """Exception raised when an MCP tool call is blocked by governance.""" + + def __init__( + self, + message: str, + *, + tool_name: str | None = None, + server: str | None = None, + ) -> None: + super().__init__(message) + self.tool_name = tool_name + self.server = server + + class PolicyViolationError(ToolExecutionBlockedError): """Exception raised when policy blocks tool execution.""" pass diff --git a/test/integration/mcp/__init__.py b/test/integration/mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/integration/mcp/test_direct_clientsession_integration.py b/test/integration/mcp/test_direct_clientsession_integration.py new file mode 100644 index 0000000..affcdf1 --- /dev/null +++ b/test/integration/mcp/test_direct_clientsession_integration.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from agent_assembly.adapters.mcp import patch as mcp_patch +from agent_assembly.exceptions import MCPToolBlockedError + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_direct_mcp_clientsession_blocks_mid_flow_and_allows_followup( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeClientSession: + def __init__(self) -> None: + self._server_name = "direct-mcp-server" + + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> str: + payload = arguments or {} + return f"ok:{name}:{payload.get('step', 'none')}" + + fake_mcp_module = SimpleNamespace(ClientSession=FakeClientSession) + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda name: fake_mcp_module if name == "mcp" else (_ for _ in ()).throw(ImportError(name)), + ) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + if kwargs.get("tool_name") == "blocked_tool": + return {"status": "deny", "reason": "blocked by policy"} + return {"status": "allow"} + + patcher = mcp_patch.MCPClientPatch(Interceptor(), process_agent_id="agent-custom") + assert patcher.apply() is True + + session = FakeClientSession() + with pytest.raises(MCPToolBlockedError, match="blocked by governance policy: blocked by policy"): + await session.call_tool("blocked_tool", {"step": "one"}) + + safe_result = await session.call_tool("safe_tool", {"step": "two"}) + assert safe_result == "ok:safe_tool:two" diff --git a/test/integration/mcp/test_langchain_mcp_coexistence_integration.py b/test/integration/mcp/test_langchain_mcp_coexistence_integration.py new file mode 100644 index 0000000..f429e77 --- /dev/null +++ b/test/integration/mcp/test_langchain_mcp_coexistence_integration.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +from agent_assembly.adapters.langchain import AssemblyCallbackHandler +from agent_assembly.adapters.mcp import patch as mcp_patch + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_langchain_and_mcp_layers_both_emit_governance_events( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeClientSession: + def __init__(self) -> None: + self._server_url = "https://mcp.shared.test" + + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> str: + payload = arguments or {} + return f"mcp:{name}:{payload.get('q', '')}" + + fake_mcp_module = SimpleNamespace(ClientSession=FakeClientSession) + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda name: fake_mcp_module if name == "mcp" else (_ for _ in ()).throw(ImportError(name)), + ) + + checks: list[dict[str, object]] = [] + records: list[dict[str, object]] = [] + + class Interceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + checks.append(dict(kwargs)) + return {"status": "allow"} + + def record_result(self, **kwargs: object) -> None: + records.append(dict(kwargs)) + + callback_handler = AssemblyCallbackHandler(Interceptor()) + patcher = mcp_patch.MCPClientPatch(callback_handler, process_agent_id="agent-chain") + assert patcher.apply() is True + + callback_handler.on_tool_start( + serialized={"name": "langchain_tool"}, + input_str="{}", + run_id=uuid4(), + ) + result = await FakeClientSession().call_tool("mcp_tool", {"q": "hello"}) + + assert result == "mcp:mcp_tool:hello" + assert len(checks) == 2 + assert checks[0]["serialized"] == {"name": "langchain_tool"} + assert checks[1]["serialized"] == {"name": "mcp_tool"} + assert checks[1]["server"] == "https://mcp.shared.test" + assert len(records) == 1 + assert records[0]["tool_name"] == "mcp_tool" diff --git a/test/unit/adapters/mcp/__init__.py b/test/unit/adapters/mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit/adapters/mcp/test_patch.py b/test/unit/adapters/mcp/test_patch.py new file mode 100644 index 0000000..31d3f0d --- /dev/null +++ b/test/unit/adapters/mcp/test_patch.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from agent_assembly.adapters.mcp import patch as mcp_patch +from agent_assembly.exceptions import MCPToolBlockedError + + +class _ArgsMapping(dict[str, object]): + pass + + +def _install_fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> type[Any]: + class FakeClientSession: + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> dict[str, Any]: + return { + "name": name, + "arguments": arguments or {}, + } + + fake_mcp_module = SimpleNamespace(ClientSession=FakeClientSession) + + def fake_import_module(module_name: str) -> object: + if module_name == "mcp": + return fake_mcp_module + raise ImportError(module_name) + + monkeypatch.setattr(mcp_patch.importlib, "import_module", fake_import_module) + return FakeClientSession + + +@pytest.mark.asyncio +async def test_apply_patches_clientsession_call_tool_and_is_idempotent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + patcher = mcp_patch.MCPClientPatch(Interceptor(), process_agent_id="agent-1") + assert patcher.apply() is True + first_call_ref = FakeClientSession.call_tool + + assert getattr(FakeClientSession, mcp_patch._PATCHED_FLAG, False) is True + assert mcp_patch._get_process_agent_id() == "agent-1" + + assert patcher.apply() is True + assert FakeClientSession.call_tool is first_call_ref + + +def test_revert_restores_clientsession_call_tool_and_clears_process_agent_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + original_call_tool = FakeClientSession.call_tool + mcp_patch.set_process_agent_id("agent-before-revert") + + patcher = mcp_patch.MCPClientPatch(object()) + assert patcher.apply() is True + assert FakeClientSession.call_tool is not original_call_tool + + patcher.revert() + assert FakeClientSession.call_tool is original_call_tool + assert getattr(FakeClientSession, mcp_patch._PATCHED_FLAG, False) is False + assert mcp_patch._get_process_agent_id() is None + + +def test_loaders_and_apply_false_when_mcp_clientsession_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda _: (_ for _ in ()).throw(ImportError("mcp")), + ) + assert mcp_patch._load_mcp_client_session_class() is None + assert mcp_patch.MCPClientPatch(callback_handler=object()).apply() is False + + fake_mcp_module = SimpleNamespace(ClientSession=object()) + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda name: fake_mcp_module if name == "mcp" else (_ for _ in ()).throw(ImportError(name)), + ) + assert mcp_patch._load_mcp_client_session_class() is None + + +def test_server_identifier_extraction_handles_transport_variants() -> None: + sse_session = SimpleNamespace(_server_url="https://sse.mcp.test") + stdio_session = SimpleNamespace(_server_name="stdio-server") + ws_session = SimpleNamespace(_ws_url="wss://ws.mcp.test") + transport_session = SimpleNamespace(_transport=SimpleNamespace(server_name="transport-server")) + unknown_session = SimpleNamespace() + + assert mcp_patch._get_server_identifier(sse_session) == "https://sse.mcp.test" + assert mcp_patch._get_server_identifier(stdio_session) == "stdio-server" + assert mcp_patch._get_server_identifier(ws_session) == "wss://ws.mcp.test" + assert mcp_patch._get_server_identifier(transport_session) == "transport-server" + assert mcp_patch._get_server_identifier(unknown_session) == "mcp-unknown" + + +@pytest.mark.asyncio +async def test_denied_tool_raises_mcp_tool_blocked_error_with_metadata( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + assert kwargs["tool_name"] == "blocked_tool" + return {"status": "deny", "reason": "policy block"} + + patcher = mcp_patch.MCPClientPatch(Interceptor(), process_agent_id="agent-7") + assert patcher.apply() is True + + session = FakeClientSession() + session._server_name = "stdlib-server" # type: ignore[attr-defined] + with pytest.raises( + MCPToolBlockedError, + match="blocked by governance policy: policy block", + ) as captured: + await session.call_tool("blocked_tool", {"q": "secret"}) + + error = captured.value + assert error.tool_name == "blocked_tool" + assert error.server == "stdlib-server" + + +@pytest.mark.asyncio +async def test_pending_then_approved_runs_original_and_records_result( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + wait_calls: list[dict[str, object]] = [] + recorded_results: list[dict[str, object]] = [] + + class Interceptor: + pending_tool_approval_timeout_seconds = 17 + + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "pending", "reason": "approval required"} + + async def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + wait_calls.append(dict(kwargs)) + return {"status": "allow"} + + async def record_result(self, **kwargs: object) -> None: + recorded_results.append(dict(kwargs)) + + patcher = mcp_patch.MCPClientPatch(Interceptor(), process_agent_id="agent-9") + assert patcher.apply() is True + + session = FakeClientSession() + session._server_url = "https://api.mcp.test" # type: ignore[attr-defined] + result = await session.call_tool("search", _ArgsMapping({"q": "hello"})) + + assert result["name"] == "search" + assert len(wait_calls) == 1 + assert wait_calls[0]["timeout_seconds"] == 17 + assert len(recorded_results) == 1 + assert recorded_results[0]["tool_name"] == "search" + assert recorded_results[0]["agent_id"] == "agent-9" + assert recorded_results[0]["server"] == "https://api.mcp.test" + + +@pytest.mark.asyncio +async def test_pending_then_rejected_raises_mcp_tool_blocked_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "pending", "reason": "approval required"} + + async def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "deny", "reason": "approval rejected"} + + patcher = mcp_patch.MCPClientPatch(Interceptor()) + assert patcher.apply() is True + + session = FakeClientSession() + session._ws_url = "wss://tools.mcp.test" # type: ignore[attr-defined] + with pytest.raises( + MCPToolBlockedError, + match="rejected during approval: approval rejected", + ) as captured: + await session.call_tool("deploy", {"service": "api"}) + + assert captured.value.tool_name == "deploy" + assert captured.value.server == "wss://tools.mcp.test" + + +@pytest.mark.asyncio +async def test_result_recording_truncates_to_2000_chars( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeClientSession: + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> str: + del name, arguments + return "x" * 2500 + + fake_mcp_module = SimpleNamespace(ClientSession=FakeClientSession) + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda name: fake_mcp_module if name == "mcp" else (_ for _ in ()).throw(ImportError(name)), + ) + + observed_results: list[str] = [] + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + async def record_result(self, **kwargs: object) -> None: + observed_results.append(str(kwargs["result"])) + + patcher = mcp_patch.MCPClientPatch(Interceptor()) + assert patcher.apply() is True + + result = await FakeClientSession().call_tool("long-output", {"debug": True}) + + assert isinstance(result, str) + assert len(result) == 2500 + assert len(observed_results) == 1 + assert len(observed_results[0]) == 2000 + + +@pytest.mark.asyncio +async def test_callback_wrapper_with_interceptor_is_supported( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + seen: list[str] = [] + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + seen.append(str(kwargs.get("tool_name"))) + return {"status": "allow"} + + wrapper = SimpleNamespace(_interceptor=Interceptor()) + patcher = mcp_patch.MCPClientPatch(wrapper) + assert patcher.apply() is True + + result = await FakeClientSession().call_tool(name="from-wrapper", arguments={"ok": True}) + assert result["name"] == "from-wrapper" + assert seen == ["from-wrapper"] diff --git a/test/unit/adapters/test_optional_patches.py b/test/unit/adapters/test_optional_patches.py index 623a207..84fe1bc 100644 --- a/test/unit/adapters/test_optional_patches.py +++ b/test/unit/adapters/test_optional_patches.py @@ -2,25 +2,9 @@ import pytest -from agent_assembly.adapters.mcp.patch import MCPClientPatch -from agent_assembly.adapters.mcp import patch as mcp_patch from agent_assembly.adapters.openai_agents.patch import OpenAIAgentsPatch from agent_assembly.adapters.openai_agents import patch as openai_patch - -def test_mcp_patch_apply_and_revert(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(mcp_patch.importlib.util, "find_spec", lambda package: object()) - patcher = MCPClientPatch(callback_handler=object()) - assert patcher.apply() is True - patcher.revert() - - -def test_mcp_patch_apply_returns_false_when_module_missing(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(mcp_patch.importlib.util, "find_spec", lambda package: None) - patcher = MCPClientPatch(callback_handler=object()) - assert patcher.apply() is False - - def test_openai_agents_patch_apply_and_revert(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(openai_patch.importlib.util, "find_spec", lambda package: object()) patcher = OpenAIAgentsPatch(callback_handler=object()) diff --git a/test/unit/test_exceptions.py b/test/unit/test_exceptions.py new file mode 100644 index 0000000..b5eac6c --- /dev/null +++ b/test/unit/test_exceptions.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from agent_assembly import MCPToolBlockedError + + +def test_mcp_tool_blocked_error_exposes_tool_and_server_metadata() -> None: + error = MCPToolBlockedError( + "blocked", + tool_name="search_docs", + server="https://mcp.example.test", + ) + + assert str(error) == "blocked" + assert error.tool_name == "search_docs" + assert error.server == "https://mcp.example.test"