From 10bad34bb1599ef78c8e1a81a910ec3403909794 Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:51:56 +0800 Subject: [PATCH 01/13] =?UTF-8?q?=E2=9C=A8=20(exceptions):=20Add=20MCPTool?= =?UTF-8?q?BlockedError=20for=20MCP=20governance=20blocking?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_assembly/exceptions/__init__.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/agent_assembly/exceptions/__init__.py b/agent_assembly/exceptions/__init__.py index 91b91d1..49b2b5b 100644 --- a/agent_assembly/exceptions/__init__.py +++ b/agent_assembly/exceptions/__init__.py @@ -10,6 +10,7 @@ "ConfigurationError", "AdapterValidationError", "ToolExecutionBlockedError", + "MCPToolBlockedError", "PolicyViolationError", ] @@ -49,6 +50,21 @@ class ToolExecutionBlockedError(AssemblyError): pass +class MCPToolBlockedError(ToolExecutionBlockedError): + """Exception raised when an MCP tool call is blocked by governance.""" + + def __init__( + self, + message: str, + *, + tool_name: str | None = None, + server: str | None = None, + ) -> None: + super().__init__(message) + self.tool_name = tool_name + self.server = server + + class PolicyViolationError(ToolExecutionBlockedError): """Exception raised when policy blocks tool execution.""" pass From 5389907c728c6484a9dcad6a0dc98702929bcd40 Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:52:03 +0800 Subject: [PATCH 02/13] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20(exports):=20Expose?= =?UTF-8?q?=20MCPToolBlockedError=20in=20package=20public=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_assembly/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agent_assembly/__init__.py b/agent_assembly/__init__.py index fcdf256..3a8491e 100644 --- a/agent_assembly/__init__.py +++ b/agent_assembly/__init__.py @@ -8,6 +8,7 @@ AssemblyError, ConfigurationError, GatewayError, + MCPToolBlockedError, PolicyError, ToolExecutionBlockedError, ) @@ -37,6 +38,7 @@ "ConfigurationError", "AdapterValidationError", "ToolExecutionBlockedError", + "MCPToolBlockedError", ] if "RuntimeClient" in globals(): From 53947c1ec589bb50e9d5389b94680c7c5f5efa68 Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:52:22 +0800 Subject: [PATCH 03/13] =?UTF-8?q?=E2=9C=A8=20(mcp):=20Add=20patch=20state?= =?UTF-8?q?=20constants=20and=20process=20agent=20context=20helpers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_assembly/adapters/mcp/patch.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/agent_assembly/adapters/mcp/patch.py b/agent_assembly/adapters/mcp/patch.py index 4923f0b..13a3176 100644 --- a/agent_assembly/adapters/mcp/patch.py +++ b/agent_assembly/adapters/mcp/patch.py @@ -3,23 +3,44 @@ from __future__ import annotations from dataclasses import dataclass +import importlib import importlib.util +import inspect from typing import Any +_ORIGINAL_CALL_TOOL = "_agent_assembly_original_mcp_call_tool" +_PATCHED_FLAG = "_agent_assembly_mcp_clientsession_patched" +_PROCESS_AGENT_ID: str | None = None +_MAX_AUDIT_RESULT_CHARS = 2000 + @dataclass(slots=True) class MCPClientPatch: """Patch placeholder for MCP client interception.""" callback_handler: Any + process_agent_id: str | None = None def apply(self) -> bool: + set_process_agent_id(self.process_agent_id) _ = self.callback_handler return _is_mcp_available() def revert(self) -> None: + set_process_agent_id(None) return None def _is_mcp_available() -> bool: return importlib.util.find_spec("mcp") is not None + + +def set_process_agent_id(agent_id: str | None) -> None: + global _PROCESS_AGENT_ID + _PROCESS_AGENT_ID = agent_id + + +def _get_process_agent_id() -> str | None: + if isinstance(_PROCESS_AGENT_ID, str) and _PROCESS_AGENT_ID: + return _PROCESS_AGENT_ID + return None From 1c5d3d3ed7683827b48e277f86ea4556a2e371eb Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:52:34 +0800 Subject: [PATCH 04/13] =?UTF-8?q?=E2=9C=A8=20(mcp):=20Add=20ClientSession?= =?UTF-8?q?=20loader=20and=20multi-transport=20server=20identifier=20extra?= =?UTF-8?q?ction?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_assembly/adapters/mcp/patch.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/agent_assembly/adapters/mcp/patch.py b/agent_assembly/adapters/mcp/patch.py index 13a3176..a3c3283 100644 --- a/agent_assembly/adapters/mcp/patch.py +++ b/agent_assembly/adapters/mcp/patch.py @@ -44,3 +44,31 @@ def _get_process_agent_id() -> str | None: if isinstance(_PROCESS_AGENT_ID, str) and _PROCESS_AGENT_ID: return _PROCESS_AGENT_ID return None + + +def _load_mcp_client_session_class() -> type[Any] | None: + try: + module = importlib.import_module("mcp") + except ImportError: + return None + + client_session_cls = getattr(module, "ClientSession", None) + if isinstance(client_session_cls, type): + return client_session_cls + return None + + +def _get_server_identifier(session: Any) -> str: + for attr in ("_server_url", "_server_name", "_ws_url"): + value = getattr(session, attr, None) + if isinstance(value, str) and value.strip(): + return value + + transport = getattr(session, "_transport", None) + if transport is not None: + for attr in ("url", "server_url", "server_name", "ws_url", "name"): + value = getattr(transport, attr, None) + if isinstance(value, str) and value.strip(): + return value + + return "mcp-unknown" From 2c1d2acf9ecbbb2f2226e0191c2a7a89900aabfb Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:52:55 +0800 Subject: [PATCH 05/13] =?UTF-8?q?=E2=9C=A8=20(mcp):=20Add=20async=20govern?= =?UTF-8?q?ance=20decision=20helpers=20for=20call=5Ftool=20interception?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_assembly/adapters/mcp/patch.py | 95 +++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 1 deletion(-) diff --git a/agent_assembly/adapters/mcp/patch.py b/agent_assembly/adapters/mcp/patch.py index a3c3283..aa41f9f 100644 --- a/agent_assembly/adapters/mcp/patch.py +++ b/agent_assembly/adapters/mcp/patch.py @@ -6,7 +6,12 @@ import importlib import importlib.util import inspect -from typing import Any +from typing import Any, Literal, Mapping + +from agent_assembly.adapters.crewai.patch import ( + _get_pending_tool_approval_timeout_seconds as _resolve_pending_timeout_seconds, +) +from agent_assembly.adapters.crewai.patch import _normalize_decision as _normalize_governance_decision _ORIGINAL_CALL_TOOL = "_agent_assembly_original_mcp_call_tool" _PATCHED_FLAG = "_agent_assembly_mcp_clientsession_patched" @@ -72,3 +77,91 @@ def _get_server_identifier(session: Any) -> str: return value return "mcp-unknown" + + +def _resolve_governance_target(callback_handler: Any) -> Any: + target = getattr(callback_handler, "_interceptor", None) + if target is not None: + return target + return callback_handler + + +def _extract_tool_call_inputs( + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> tuple[str, dict[str, Any]]: + raw_tool_name = kwargs.get("name") + if not isinstance(raw_tool_name, str): + raw_tool_name = str(args[0]) if args else "mcp-unknown-tool" + + raw_arguments = kwargs.get("arguments") + if raw_arguments is None and len(args) >= 2: + raw_arguments = args[1] + + if isinstance(raw_arguments, Mapping): + return raw_tool_name, dict(raw_arguments) + return raw_tool_name, {} + + +def _normalize_decision( + decision: object, +) -> tuple[Literal["allow", "deny", "pending"], str | None]: + return _normalize_governance_decision(decision) + + +def _get_pending_tool_approval_timeout_seconds(callback_handler: Any) -> int: + return _resolve_pending_timeout_seconds(callback_handler) + + +async def _invoke_async_tool_check( + callback_handler: Any, + *, + tool_name: str, + tool_args: dict[str, Any], + agent_id: str | None, + server_identifier: str, +) -> object: + target = _resolve_governance_target(callback_handler) + method = getattr(target, "check_tool_start", None) + if not callable(method): + return {"status": "allow"} + + result = method( + serialized={"name": tool_name}, + input_str=str(tool_args), + tool_name=tool_name, + args=tool_args, + agent_id=agent_id, + server=server_identifier, + ) + if inspect.isawaitable(result): + return await result + return result + + +async def _wait_for_async_tool_approval( + callback_handler: Any, + *, + tool_name: str, + timeout_seconds: int, + tool_args: dict[str, Any], + agent_id: str | None, + server_identifier: str, +) -> object: + target = _resolve_governance_target(callback_handler) + method = getattr(target, "wait_for_tool_approval", None) + if not callable(method): + return {"status": "deny", "reason": "Approval handler is unavailable."} + + result = method( + serialized={"name": tool_name}, + input_str=str(tool_args), + tool_name=tool_name, + timeout_seconds=timeout_seconds, + args=tool_args, + agent_id=agent_id, + server=server_identifier, + ) + if inspect.isawaitable(result): + return await result + return result From a644b83494ab4679d3e9ebda5f4c557bd20cbe2c Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:53:12 +0800 Subject: [PATCH 06/13] =?UTF-8?q?=E2=9C=A8=20(mcp):=20Add=20blocked-error?= =?UTF-8?q?=20and=20truncated=20audit=20result=20recording=20helpers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_assembly/adapters/mcp/patch.py | 66 ++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/agent_assembly/adapters/mcp/patch.py b/agent_assembly/adapters/mcp/patch.py index aa41f9f..25fd9fe 100644 --- a/agent_assembly/adapters/mcp/patch.py +++ b/agent_assembly/adapters/mcp/patch.py @@ -165,3 +165,69 @@ async def _wait_for_async_tool_approval( if inspect.isawaitable(result): return await result return result + + +def _truncate_result_for_audit(result: object) -> str: + return str(result)[:_MAX_AUDIT_RESULT_CHARS] + + +async def _record_async_tool_result( + callback_handler: Any, + *, + tool_name: str, + result: object, + agent_id: str | None, + server_identifier: str, +) -> None: + target = _resolve_governance_target(callback_handler) + + record_method = getattr(target, "record_result", None) + if callable(record_method): + recorded = record_method( + tool_name=tool_name, + result=_truncate_result_for_audit(result), + agent_id=agent_id, + server=server_identifier, + ) + if inspect.isawaitable(recorded): + await recorded + return None + + tool_end_method = getattr(target, "on_tool_end", None) + if callable(tool_end_method): + recorded = tool_end_method( + output=_truncate_result_for_audit(result), + tool_name=tool_name, + agent_id=agent_id, + server=server_identifier, + ) + if inspect.isawaitable(recorded): + await recorded + + +def _build_blocked_error( + *, + tool_name: str, + server_identifier: str, + reason: str | None, + is_pending_rejection: bool, +) -> Exception: + from agent_assembly.exceptions import MCPToolBlockedError + + reason_text = reason or "No reason provided." + if is_pending_rejection: + message = ( + f"MCP tool '{tool_name}' on server '{server_identifier}' " + f"rejected during approval: {reason_text}" + ) + else: + message = ( + f"MCP tool '{tool_name}' on server '{server_identifier}' " + f"blocked by governance policy: {reason_text}" + ) + + return MCPToolBlockedError( + message, + tool_name=tool_name, + server=server_identifier, + ) From a367279f3248356e95d10a82c17c56eb6ccc2df7 Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:53:42 +0800 Subject: [PATCH 07/13] =?UTF-8?q?=E2=9C=A8=20(mcp):=20Patch=20ClientSessio?= =?UTF-8?q?n.call=5Ftool=20with=20idempotent=20async=20governance=20interc?= =?UTF-8?q?eption?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_assembly/adapters/mcp/patch.py | 84 +++++++++++++++++++++++++++- agent_assembly/core/assembly.py | 7 ++- 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/agent_assembly/adapters/mcp/patch.py b/agent_assembly/adapters/mcp/patch.py index 25fd9fe..cfefe9a 100644 --- a/agent_assembly/adapters/mcp/patch.py +++ b/agent_assembly/adapters/mcp/patch.py @@ -28,10 +28,16 @@ class MCPClientPatch: def apply(self) -> bool: set_process_agent_id(self.process_agent_id) - _ = self.callback_handler - return _is_mcp_available() + client_session_cls = _load_mcp_client_session_class() + if client_session_cls is None: + return False + _apply_client_session_patch(client_session_cls, self.callback_handler) + return True def revert(self) -> None: + client_session_cls = _load_mcp_client_session_class() + if client_session_cls is not None: + _revert_client_session_patch(client_session_cls) set_process_agent_id(None) return None @@ -231,3 +237,77 @@ def _build_blocked_error( tool_name=tool_name, server=server_identifier, ) + + +def _apply_client_session_patch(client_session_cls: type[Any], callback_handler: Any) -> None: + if getattr(client_session_cls, _PATCHED_FLAG, False): + return None + + original_call_tool = getattr(client_session_cls, "call_tool", None) + if not callable(original_call_tool): + return None + + async def patched_call_tool(self: Any, *args: Any, **kwargs: Any) -> Any: + tool_name, tool_args = _extract_tool_call_inputs(args, kwargs) + agent_id = _get_process_agent_id() + server_identifier = _get_server_identifier(self) + + decision = await _invoke_async_tool_check( + callback_handler, + tool_name=tool_name, + tool_args=tool_args, + agent_id=agent_id, + server_identifier=server_identifier, + ) + status, reason = _normalize_decision(decision) + is_pending_flow = False + if status == "pending": + is_pending_flow = True + timeout_seconds = _get_pending_tool_approval_timeout_seconds(callback_handler) + final_decision = await _wait_for_async_tool_approval( + callback_handler, + tool_name=tool_name, + timeout_seconds=timeout_seconds, + tool_args=tool_args, + agent_id=agent_id, + server_identifier=server_identifier, + ) + status, reason = _normalize_decision(final_decision) + + if status == "deny": + raise _build_blocked_error( + tool_name=tool_name, + server_identifier=server_identifier, + reason=reason, + is_pending_rejection=is_pending_flow, + ) + + result = original_call_tool(self, *args, **kwargs) + if inspect.isawaitable(result): + result = await result + await _record_async_tool_result( + callback_handler, + tool_name=tool_name, + result=result, + agent_id=agent_id, + server_identifier=server_identifier, + ) + return result + + setattr(client_session_cls, _ORIGINAL_CALL_TOOL, original_call_tool) + setattr(client_session_cls, "call_tool", patched_call_tool) + setattr(client_session_cls, _PATCHED_FLAG, True) + + +def _revert_client_session_patch(client_session_cls: type[Any]) -> None: + if not getattr(client_session_cls, _PATCHED_FLAG, False): + return None + + original_call_tool = getattr(client_session_cls, _ORIGINAL_CALL_TOOL, None) + if callable(original_call_tool): + setattr(client_session_cls, "call_tool", original_call_tool) + + if hasattr(client_session_cls, _ORIGINAL_CALL_TOOL): + delattr(client_session_cls, _ORIGINAL_CALL_TOOL) + if hasattr(client_session_cls, _PATCHED_FLAG): + delattr(client_session_cls, _PATCHED_FLAG) diff --git a/agent_assembly/core/assembly.py b/agent_assembly/core/assembly.py index 2347681..60b1fb5 100644 --- a/agent_assembly/core/assembly.py +++ b/agent_assembly/core/assembly.py @@ -189,7 +189,12 @@ def _build_patch_plan(client: GatewayClient, process_agent_id: str) -> list[Runt patch_plan.append(OpenAIAgentsPatch(callback_target)) if _is_installed("mcp"): # Keep MCP patch last as fallback for remaining tool dispatch paths. - patch_plan.append(MCPClientPatch(callback_target)) + patch_plan.append( + MCPClientPatch( + callback_handler=callback_target, + process_agent_id=process_agent_id, + ) + ) return patch_plan From 6a481ef3950e2eed3c8914f6d4157d8624d054e8 Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:53:53 +0800 Subject: [PATCH 08/13] =?UTF-8?q?=E2=9C=A8=20(tests):=20Create=20dedicated?= =?UTF-8?q?=20MCP=20adapter=20unit=20and=20integration=20test=20packages?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/integration/mcp/__init__.py | 0 test/unit/adapters/mcp/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/integration/mcp/__init__.py create mode 100644 test/unit/adapters/mcp/__init__.py diff --git a/test/integration/mcp/__init__.py b/test/integration/mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit/adapters/mcp/__init__.py b/test/unit/adapters/mcp/__init__.py new file mode 100644 index 0000000..e69de29 From 7acd897b1354f43d264e6c7fac0acb0664d0fc12 Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:54:04 +0800 Subject: [PATCH 09/13] =?UTF-8?q?=E2=9C=85=20(tests):=20Cover=20MCPToolBlo?= =?UTF-8?q?ckedError=20metadata=20fields=20and=20message=20behavior?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/unit/test_exceptions.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 test/unit/test_exceptions.py diff --git a/test/unit/test_exceptions.py b/test/unit/test_exceptions.py new file mode 100644 index 0000000..b5eac6c --- /dev/null +++ b/test/unit/test_exceptions.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from agent_assembly import MCPToolBlockedError + + +def test_mcp_tool_blocked_error_exposes_tool_and_server_metadata() -> None: + error = MCPToolBlockedError( + "blocked", + tool_name="search_docs", + server="https://mcp.example.test", + ) + + assert str(error) == "blocked" + assert error.tool_name == "search_docs" + assert error.server == "https://mcp.example.test" From 7c5a43652b45345fdfada532b79596118b00ec2f Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:55:09 +0800 Subject: [PATCH 10/13] =?UTF-8?q?=E2=9C=85=20(tests):=20Add=20dedicated=20?= =?UTF-8?q?MCP=20patch=20unit=20coverage=20and=20retire=20placeholder=20op?= =?UTF-8?q?tional=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/unit/adapters/mcp/test_patch.py | 258 ++++++++++++++++++++ test/unit/adapters/test_optional_patches.py | 16 -- 2 files changed, 258 insertions(+), 16 deletions(-) create mode 100644 test/unit/adapters/mcp/test_patch.py diff --git a/test/unit/adapters/mcp/test_patch.py b/test/unit/adapters/mcp/test_patch.py new file mode 100644 index 0000000..31d3f0d --- /dev/null +++ b/test/unit/adapters/mcp/test_patch.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from agent_assembly.adapters.mcp import patch as mcp_patch +from agent_assembly.exceptions import MCPToolBlockedError + + +class _ArgsMapping(dict[str, object]): + pass + + +def _install_fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> type[Any]: + class FakeClientSession: + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> dict[str, Any]: + return { + "name": name, + "arguments": arguments or {}, + } + + fake_mcp_module = SimpleNamespace(ClientSession=FakeClientSession) + + def fake_import_module(module_name: str) -> object: + if module_name == "mcp": + return fake_mcp_module + raise ImportError(module_name) + + monkeypatch.setattr(mcp_patch.importlib, "import_module", fake_import_module) + return FakeClientSession + + +@pytest.mark.asyncio +async def test_apply_patches_clientsession_call_tool_and_is_idempotent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + patcher = mcp_patch.MCPClientPatch(Interceptor(), process_agent_id="agent-1") + assert patcher.apply() is True + first_call_ref = FakeClientSession.call_tool + + assert getattr(FakeClientSession, mcp_patch._PATCHED_FLAG, False) is True + assert mcp_patch._get_process_agent_id() == "agent-1" + + assert patcher.apply() is True + assert FakeClientSession.call_tool is first_call_ref + + +def test_revert_restores_clientsession_call_tool_and_clears_process_agent_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + original_call_tool = FakeClientSession.call_tool + mcp_patch.set_process_agent_id("agent-before-revert") + + patcher = mcp_patch.MCPClientPatch(object()) + assert patcher.apply() is True + assert FakeClientSession.call_tool is not original_call_tool + + patcher.revert() + assert FakeClientSession.call_tool is original_call_tool + assert getattr(FakeClientSession, mcp_patch._PATCHED_FLAG, False) is False + assert mcp_patch._get_process_agent_id() is None + + +def test_loaders_and_apply_false_when_mcp_clientsession_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda _: (_ for _ in ()).throw(ImportError("mcp")), + ) + assert mcp_patch._load_mcp_client_session_class() is None + assert mcp_patch.MCPClientPatch(callback_handler=object()).apply() is False + + fake_mcp_module = SimpleNamespace(ClientSession=object()) + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda name: fake_mcp_module if name == "mcp" else (_ for _ in ()).throw(ImportError(name)), + ) + assert mcp_patch._load_mcp_client_session_class() is None + + +def test_server_identifier_extraction_handles_transport_variants() -> None: + sse_session = SimpleNamespace(_server_url="https://sse.mcp.test") + stdio_session = SimpleNamespace(_server_name="stdio-server") + ws_session = SimpleNamespace(_ws_url="wss://ws.mcp.test") + transport_session = SimpleNamespace(_transport=SimpleNamespace(server_name="transport-server")) + unknown_session = SimpleNamespace() + + assert mcp_patch._get_server_identifier(sse_session) == "https://sse.mcp.test" + assert mcp_patch._get_server_identifier(stdio_session) == "stdio-server" + assert mcp_patch._get_server_identifier(ws_session) == "wss://ws.mcp.test" + assert mcp_patch._get_server_identifier(transport_session) == "transport-server" + assert mcp_patch._get_server_identifier(unknown_session) == "mcp-unknown" + + +@pytest.mark.asyncio +async def test_denied_tool_raises_mcp_tool_blocked_error_with_metadata( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + assert kwargs["tool_name"] == "blocked_tool" + return {"status": "deny", "reason": "policy block"} + + patcher = mcp_patch.MCPClientPatch(Interceptor(), process_agent_id="agent-7") + assert patcher.apply() is True + + session = FakeClientSession() + session._server_name = "stdlib-server" # type: ignore[attr-defined] + with pytest.raises( + MCPToolBlockedError, + match="blocked by governance policy: policy block", + ) as captured: + await session.call_tool("blocked_tool", {"q": "secret"}) + + error = captured.value + assert error.tool_name == "blocked_tool" + assert error.server == "stdlib-server" + + +@pytest.mark.asyncio +async def test_pending_then_approved_runs_original_and_records_result( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + wait_calls: list[dict[str, object]] = [] + recorded_results: list[dict[str, object]] = [] + + class Interceptor: + pending_tool_approval_timeout_seconds = 17 + + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "pending", "reason": "approval required"} + + async def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + wait_calls.append(dict(kwargs)) + return {"status": "allow"} + + async def record_result(self, **kwargs: object) -> None: + recorded_results.append(dict(kwargs)) + + patcher = mcp_patch.MCPClientPatch(Interceptor(), process_agent_id="agent-9") + assert patcher.apply() is True + + session = FakeClientSession() + session._server_url = "https://api.mcp.test" # type: ignore[attr-defined] + result = await session.call_tool("search", _ArgsMapping({"q": "hello"})) + + assert result["name"] == "search" + assert len(wait_calls) == 1 + assert wait_calls[0]["timeout_seconds"] == 17 + assert len(recorded_results) == 1 + assert recorded_results[0]["tool_name"] == "search" + assert recorded_results[0]["agent_id"] == "agent-9" + assert recorded_results[0]["server"] == "https://api.mcp.test" + + +@pytest.mark.asyncio +async def test_pending_then_rejected_raises_mcp_tool_blocked_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "pending", "reason": "approval required"} + + async def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "deny", "reason": "approval rejected"} + + patcher = mcp_patch.MCPClientPatch(Interceptor()) + assert patcher.apply() is True + + session = FakeClientSession() + session._ws_url = "wss://tools.mcp.test" # type: ignore[attr-defined] + with pytest.raises( + MCPToolBlockedError, + match="rejected during approval: approval rejected", + ) as captured: + await session.call_tool("deploy", {"service": "api"}) + + assert captured.value.tool_name == "deploy" + assert captured.value.server == "wss://tools.mcp.test" + + +@pytest.mark.asyncio +async def test_result_recording_truncates_to_2000_chars( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeClientSession: + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> str: + del name, arguments + return "x" * 2500 + + fake_mcp_module = SimpleNamespace(ClientSession=FakeClientSession) + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda name: fake_mcp_module if name == "mcp" else (_ for _ in ()).throw(ImportError(name)), + ) + + observed_results: list[str] = [] + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + async def record_result(self, **kwargs: object) -> None: + observed_results.append(str(kwargs["result"])) + + patcher = mcp_patch.MCPClientPatch(Interceptor()) + assert patcher.apply() is True + + result = await FakeClientSession().call_tool("long-output", {"debug": True}) + + assert isinstance(result, str) + assert len(result) == 2500 + assert len(observed_results) == 1 + assert len(observed_results[0]) == 2000 + + +@pytest.mark.asyncio +async def test_callback_wrapper_with_interceptor_is_supported( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeClientSession = _install_fake_mcp_module(monkeypatch) + seen: list[str] = [] + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + seen.append(str(kwargs.get("tool_name"))) + return {"status": "allow"} + + wrapper = SimpleNamespace(_interceptor=Interceptor()) + patcher = mcp_patch.MCPClientPatch(wrapper) + assert patcher.apply() is True + + result = await FakeClientSession().call_tool(name="from-wrapper", arguments={"ok": True}) + assert result["name"] == "from-wrapper" + assert seen == ["from-wrapper"] diff --git a/test/unit/adapters/test_optional_patches.py b/test/unit/adapters/test_optional_patches.py index 623a207..84fe1bc 100644 --- a/test/unit/adapters/test_optional_patches.py +++ b/test/unit/adapters/test_optional_patches.py @@ -2,25 +2,9 @@ import pytest -from agent_assembly.adapters.mcp.patch import MCPClientPatch -from agent_assembly.adapters.mcp import patch as mcp_patch from agent_assembly.adapters.openai_agents.patch import OpenAIAgentsPatch from agent_assembly.adapters.openai_agents import patch as openai_patch - -def test_mcp_patch_apply_and_revert(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(mcp_patch.importlib.util, "find_spec", lambda package: object()) - patcher = MCPClientPatch(callback_handler=object()) - assert patcher.apply() is True - patcher.revert() - - -def test_mcp_patch_apply_returns_false_when_module_missing(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(mcp_patch.importlib.util, "find_spec", lambda package: None) - patcher = MCPClientPatch(callback_handler=object()) - assert patcher.apply() is False - - def test_openai_agents_patch_apply_and_revert(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(openai_patch.importlib.util, "find_spec", lambda package: object()) patcher = OpenAIAgentsPatch(callback_handler=object()) From 0d5070c8b8917ab8f550d621b790ec875f9ba3a8 Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:55:38 +0800 Subject: [PATCH 11/13] =?UTF-8?q?=E2=9C=85=20(integration):=20Validate=20d?= =?UTF-8?q?irect=20MCP=20ClientSession=20flow=20blocks=20denied=20tools=20?= =?UTF-8?q?and=20continues?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_direct_clientsession_integration.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 test/integration/mcp/test_direct_clientsession_integration.py diff --git a/test/integration/mcp/test_direct_clientsession_integration.py b/test/integration/mcp/test_direct_clientsession_integration.py new file mode 100644 index 0000000..affcdf1 --- /dev/null +++ b/test/integration/mcp/test_direct_clientsession_integration.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from agent_assembly.adapters.mcp import patch as mcp_patch +from agent_assembly.exceptions import MCPToolBlockedError + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_direct_mcp_clientsession_blocks_mid_flow_and_allows_followup( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeClientSession: + def __init__(self) -> None: + self._server_name = "direct-mcp-server" + + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> str: + payload = arguments or {} + return f"ok:{name}:{payload.get('step', 'none')}" + + fake_mcp_module = SimpleNamespace(ClientSession=FakeClientSession) + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda name: fake_mcp_module if name == "mcp" else (_ for _ in ()).throw(ImportError(name)), + ) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + if kwargs.get("tool_name") == "blocked_tool": + return {"status": "deny", "reason": "blocked by policy"} + return {"status": "allow"} + + patcher = mcp_patch.MCPClientPatch(Interceptor(), process_agent_id="agent-custom") + assert patcher.apply() is True + + session = FakeClientSession() + with pytest.raises(MCPToolBlockedError, match="blocked by governance policy: blocked by policy"): + await session.call_tool("blocked_tool", {"step": "one"}) + + safe_result = await session.call_tool("safe_tool", {"step": "two"}) + assert safe_result == "ok:safe_tool:two" From 92be6b7c0cd494912e42f514eba91223a745935c Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:55:59 +0800 Subject: [PATCH 12/13] =?UTF-8?q?=E2=9C=85=20(integration):=20Cover=20MCP?= =?UTF-8?q?=20and=20LangChain=20adapter=20coexistence=20audit=20flow?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...t_langchain_mcp_coexistence_integration.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 test/integration/mcp/test_langchain_mcp_coexistence_integration.py diff --git a/test/integration/mcp/test_langchain_mcp_coexistence_integration.py b/test/integration/mcp/test_langchain_mcp_coexistence_integration.py new file mode 100644 index 0000000..c948c5c --- /dev/null +++ b/test/integration/mcp/test_langchain_mcp_coexistence_integration.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +from agent_assembly.adapters.langchain import AssemblyCallbackHandler +from agent_assembly.adapters.mcp import patch as mcp_patch + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_langchain_and_mcp_layers_both_emit_governance_events( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeClientSession: + def __init__(self) -> None: + self._server_url = "https://mcp.shared.test" + + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> str: + payload = arguments or {} + return f"mcp:{name}:{payload.get('q', '')}" + + fake_mcp_module = SimpleNamespace(ClientSession=FakeClientSession) + monkeypatch.setattr( + mcp_patch.importlib, + "import_module", + lambda name: fake_mcp_module if name == "mcp" else (_ for _ in ()).throw(ImportError(name)), + ) + + checks: list[dict[str, object]] = [] + records: list[dict[str, object]] = [] + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + checks.append(dict(kwargs)) + return {"status": "allow"} + + async def record_result(self, **kwargs: object) -> None: + records.append(dict(kwargs)) + + callback_handler = AssemblyCallbackHandler(Interceptor()) + patcher = mcp_patch.MCPClientPatch(callback_handler, process_agent_id="agent-chain") + assert patcher.apply() is True + + callback_handler.on_tool_start( + serialized={"name": "langchain_tool"}, + input_str="{}", + run_id=uuid4(), + ) + result = await FakeClientSession().call_tool("mcp_tool", {"q": "hello"}) + + assert result == "mcp:mcp_tool:hello" + assert len(checks) == 2 + assert checks[0]["serialized"] == {"name": "langchain_tool"} + assert checks[1]["serialized"] == {"name": "mcp_tool"} + assert checks[1]["server"] == "https://mcp.shared.test" + assert len(records) == 1 + assert records[0]["tool_name"] == "mcp_tool" From 87a221e24a287744a244807bb315ea8c63c5f135 Mon Sep 17 00:00:00 2001 From: Chisanan232 Date: Thu, 30 Apr 2026 09:56:21 +0800 Subject: [PATCH 13/13] =?UTF-8?q?=F0=9F=A9=B9=20(integration):=20Align=20L?= =?UTF-8?q?angChain=20coexistence=20test=20interceptor=20with=20sync=20cal?= =?UTF-8?q?lback=20contract?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../mcp/test_langchain_mcp_coexistence_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/integration/mcp/test_langchain_mcp_coexistence_integration.py b/test/integration/mcp/test_langchain_mcp_coexistence_integration.py index c948c5c..f429e77 100644 --- a/test/integration/mcp/test_langchain_mcp_coexistence_integration.py +++ b/test/integration/mcp/test_langchain_mcp_coexistence_integration.py @@ -34,11 +34,11 @@ async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> records: list[dict[str, object]] = [] class Interceptor: - async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: checks.append(dict(kwargs)) return {"status": "allow"} - async def record_result(self, **kwargs: object) -> None: + def record_result(self, **kwargs: object) -> None: records.append(dict(kwargs)) callback_handler = AssemblyCallbackHandler(Interceptor())