diff --git a/README.md b/README.md index 4070f4c..d7c95cc 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,8 @@ pip install adcp ```python from adcp import ADCPMultiAgentClient, AgentConfig, GetProductsRequest -# Configure agents and handlers -client = ADCPMultiAgentClient( +# Configure agents and handlers (context manager ensures proper cleanup) +async with ADCPMultiAgentClient( agents=[ AgentConfig( id="agent_x", @@ -54,21 +54,21 @@ client = ADCPMultiAgentClient( if metadata.status == "completed" else None ) } -) - -# Execute operation - library handles operation IDs, webhook URLs, context management -agent = client.agent("agent_x") -request = GetProductsRequest(brief="Coffee brands") -result = await agent.get_products(request) +) as client: + # Execute operation - library handles operation IDs, webhook URLs, context management + agent = client.agent("agent_x") + request = GetProductsRequest(brief="Coffee brands") + result = await agent.get_products(request) -# Check result -if result.status == "completed": - # Agent completed synchronously! - print(f"✅ Sync completion: {len(result.data.products)} products") + # Check result + if result.status == "completed": + # Agent completed synchronously! + print(f"✅ Sync completion: {len(result.data.products)} products") -if result.status == "submitted": - # Agent will send webhook when complete - print(f"⏳ Async - webhook registered at: {result.submitted.webhook_url}") + if result.status == "submitted": + # Agent will send webhook when complete + print(f"⏳ Async - webhook registered at: {result.submitted.webhook_url}") +# Connections automatically cleaned up here ``` ## Features @@ -173,6 +173,51 @@ Or use the CLI: uvx adcp --debug myagent get_products '{"brief":"TV ads"}' ``` +### Resource Management + +**Why use async context managers?** +- Ensures HTTP connections are properly closed, preventing resource leaks +- Handles cleanup even when exceptions occur +- Required for production applications with connection pooling +- Prevents issues with async task group cleanup in MCP protocol + +The recommended pattern uses async context managers: + +```python +from adcp import ADCPClient, AgentConfig, GetProductsRequest + +# Recommended: Automatic cleanup with context manager +config = AgentConfig(id="agent_x", agent_uri="https://...", protocol="a2a") +async with ADCPClient(config) as client: + request = GetProductsRequest(brief="Coffee brands") + result = await client.get_products(request) + # Connection automatically closed on exit + +# Multi-agent client also supports context managers +async with ADCPMultiAgentClient(agents) as client: + # Execute across all agents in parallel + results = await client.get_products(request) + # All agent connections closed automatically (even if some failed) +``` + +Manual cleanup is available for special cases (e.g., managing client lifecycle manually): + +```python +# Use manual cleanup when you need fine-grained control over lifecycle +client = ADCPClient(config) +try: + result = await client.get_products(request) +finally: + await client.close() # Explicit cleanup +``` + +**When to use manual cleanup:** +- Managing client lifecycle across multiple functions +- Testing scenarios requiring explicit control +- Integration with frameworks that manage resources differently + +In most cases, prefer the context manager pattern. + ### Error Handling The library provides a comprehensive exception hierarchy with helpful error messages: diff --git a/examples/basic_usage.py b/examples/basic_usage.py index e82df13..fedb97d 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -23,32 +23,33 @@ async def main(): auth_token="your-token-here", # Optional ) - # Create client - client = ADCPClient( + # Use context manager for automatic resource cleanup + async with ADCPClient( config, webhook_url_template="https://myapp.com/webhook/{task_type}/{agent_id}/{operation_id}", on_activity=lambda activity: print(f"[{activity.type}] {activity.task_type}"), - ) + ) as client: + # Call get_products + print("Fetching products...") + result = await client.get_products(brief="Coffee brands targeting millennials") - # Call get_products - print("Fetching products...") - result = await client.get_products(brief="Coffee brands targeting millennials") + # Handle result + if result.status == "completed": + print(f"✅ Sync completion: Got {len(result.data.get('products', []))} products") + for product in result.data.get("products", []): + print(f" - {product.get('name')}: {product.get('description')}") - # Handle result - if result.status == "completed": - print(f"✅ Sync completion: Got {len(result.data.get('products', []))} products") - for product in result.data.get("products", []): - print(f" - {product.get('name')}: {product.get('description')}") + elif result.status == "submitted": + print(f"⏳ Async: Webhook will be sent to {result.submitted.webhook_url}") + print(f" Operation ID: {result.submitted.operation_id}") - elif result.status == "submitted": - print(f"⏳ Async: Webhook will be sent to {result.submitted.webhook_url}") - print(f" Operation ID: {result.submitted.operation_id}") + elif result.status == "needs_input": + print(f"❓ Agent needs clarification: {result.needs_input.message}") - elif result.status == "needs_input": - print(f"❓ Agent needs clarification: {result.needs_input.message}") + elif result.status == "failed": + print(f"❌ Error: {result.error}") - elif result.status == "failed": - print(f"❌ Error: {result.error}") + # Connection automatically closed here if __name__ == "__main__": diff --git a/examples/multi_agent.py b/examples/multi_agent.py index 3805b3e..f4f96af 100644 --- a/examples/multi_agent.py +++ b/examples/multi_agent.py @@ -34,8 +34,8 @@ async def main(): ), ] - # Create multi-agent client - client = ADCPMultiAgentClient( + # Use context manager for automatic resource cleanup + async with ADCPMultiAgentClient( agents=agents, webhook_url_template="https://myapp.com/webhook/{task_type}/{agent_id}/{operation_id}", on_activity=lambda activity: print( @@ -44,29 +44,30 @@ async def main(): handlers={ "on_get_products_status_change": handle_products_result, }, - ) + ) as client: + # Execute across all agents in parallel + print(f"Querying {len(agents)} agents in parallel...") + results = await client.get_products(brief="Coffee brands") - # Execute across all agents in parallel - print(f"Querying {len(agents)} agents in parallel...") - results = await client.get_products(brief="Coffee brands") + # Process results + sync_count = sum(1 for r in results if r.status == "completed") + async_count = sum(1 for r in results if r.status == "submitted") - # Process results - sync_count = sum(1 for r in results if r.status == "completed") - async_count = sum(1 for r in results if r.status == "submitted") + print(f"\n📊 Results:") + print(f" ✅ Sync completions: {sync_count}") + print(f" ⏳ Async (webhooks pending): {async_count}") - print(f"\n📊 Results:") - print(f" ✅ Sync completions: {sync_count}") - print(f" ⏳ Async (webhooks pending): {async_count}") + for i, result in enumerate(results): + agent_id = client.agent_ids[i] - for i, result in enumerate(results): - agent_id = client.agent_ids[i] + if result.status == "completed": + products = result.data.get("products", []) + print(f"\n{agent_id}: {len(products)} products (sync)") - if result.status == "completed": - products = result.data.get("products", []) - print(f"\n{agent_id}: {len(products)} products (sync)") + elif result.status == "submitted": + print(f"\n{agent_id}: webhook to {result.submitted.webhook_url}") - elif result.status == "submitted": - print(f"\n{agent_id}: webhook to {result.submitted.webhook_url}") + # All agent connections automatically closed here def handle_products_result(response, metadata): diff --git a/src/adcp/protocols/mcp.py b/src/adcp/protocols/mcp.py index 7c4698f..c863b34 100644 --- a/src/adcp/protocols/mcp.py +++ b/src/adcp/protocols/mcp.py @@ -40,6 +40,39 @@ def __init__(self, *args: Any, **kwargs: Any): self._session: Any = None self._exit_stack: Any = None + async def _cleanup_failed_connection(self, context: str) -> None: + """ + Clean up resources after a failed connection attempt. + + This method handles cleanup without raising exceptions to avoid + masking the original connection error. + + Args: + context: Description of the context for logging (e.g., "during connection attempt") + """ + if self._exit_stack is not None: + old_stack = self._exit_stack + self._exit_stack = None + self._session = None + try: + await old_stack.aclose() + except asyncio.CancelledError: + logger.debug(f"MCP session cleanup cancelled {context}") + except RuntimeError as cleanup_error: + # Known anyio task group cleanup issue + error_msg = str(cleanup_error).lower() + if "cancel scope" in error_msg or "async context" in error_msg: + logger.debug(f"Ignoring anyio cleanup error {context}: {cleanup_error}") + else: + logger.warning( + f"Unexpected RuntimeError during cleanup {context}: {cleanup_error}" + ) + except Exception as cleanup_error: + # Log unexpected cleanup errors but don't raise to preserve original error + logger.warning( + f"Unexpected error during cleanup {context}: {cleanup_error}", exc_info=True + ) + async def _get_session(self) -> ClientSession: """ Get or create MCP client session with URL fallback handling. @@ -115,35 +148,8 @@ async def _get_session(self) -> ClientSession: return self._session # type: ignore[no-any-return] except Exception as e: last_error = e - # Clean up the exit stack on failure to avoid async scope issues - if self._exit_stack is not None: - old_stack = self._exit_stack - self._exit_stack = None # Clear immediately to prevent reuse - self._session = None - try: - await old_stack.aclose() - except asyncio.CancelledError: - # Expected during shutdown - pass - except RuntimeError as cleanup_error: - # Known MCP SDK async cleanup issue - if ( - "async context" in str(cleanup_error).lower() - or "cancel scope" in str(cleanup_error).lower() - ): - logger.debug( - "Ignoring MCP SDK async context error during cleanup: " - f"{cleanup_error}" - ) - else: - logger.warning( - f"Unexpected RuntimeError during cleanup: {cleanup_error}" - ) - except Exception as cleanup_error: - # Unexpected cleanup errors should be logged - logger.warning( - f"Unexpected error during cleanup: {cleanup_error}", exc_info=True - ) + # Clean up the exit stack on failure to avoid resource leaks + await self._cleanup_failed_connection("during connection attempt") # If this isn't the last URL to try, create a new exit stack and continue if url != urls_to_try[-1]: @@ -352,15 +358,5 @@ async def list_tools(self) -> list[str]: return [tool.name for tool in result.tools] async def close(self) -> None: - """Close the MCP session.""" - if self._exit_stack is not None: - old_stack = self._exit_stack - self._exit_stack = None - self._session = None - try: - await old_stack.aclose() - except (asyncio.CancelledError, RuntimeError): - # Cleanup errors during shutdown are expected - pass - except Exception as e: - logger.debug(f"Error during MCP session cleanup: {e}") + """Close the MCP session and clean up resources.""" + await self._cleanup_failed_connection("during close") diff --git a/src/adcp/types/core.py b/src/adcp/types/core.py index 319a6b4..be7ad39 100644 --- a/src/adcp/types/core.py +++ b/src/adcp/types/core.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Generic, Literal, TypeVar -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator class Protocol(str, Enum): @@ -125,6 +125,8 @@ class DebugInfo(BaseModel): class TaskResult(BaseModel, Generic[T]): """Result from task execution.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + status: TaskStatus data: T | None = None message: str | None = None # Human-readable message from agent (e.g., MCP content text) @@ -135,9 +137,6 @@ class TaskResult(BaseModel, Generic[T]): metadata: dict[str, Any] | None = None debug_info: DebugInfo | None = None - class Config: - arbitrary_types_allowed = True - class ActivityType(str, Enum): """Types of activity events.""" diff --git a/tests/test_client.py b/tests/test_client.py index 33d6f0d..0dbde5c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -380,3 +380,110 @@ async def test_list_creative_formats_handles_invalid_response(): assert result.success is False assert result.status == TaskStatus.FAILED assert "Failed to parse response" in result.error + + +@pytest.mark.asyncio +async def test_client_context_manager(): + """Test that ADCPClient works as an async context manager.""" + from unittest.mock import AsyncMock, patch + + config = AgentConfig( + id="test_agent", + agent_uri="https://test.example.com", + protocol=Protocol.MCP, + ) + + # Mock the close method to verify it gets called + with patch.object(ADCPClient, "close", new_callable=AsyncMock) as mock_close: + async with ADCPClient(config) as client: + assert client.agent_config == config + + # Verify close was called on context exit + mock_close.assert_called_once() + + +@pytest.mark.asyncio +async def test_multi_agent_context_manager(): + """Test that ADCPMultiAgentClient works as an async context manager.""" + from unittest.mock import AsyncMock, patch + + agents = [ + AgentConfig( + id="agent1", + agent_uri="https://agent1.example.com", + protocol=Protocol.A2A, + ), + AgentConfig( + id="agent2", + agent_uri="https://agent2.example.com", + protocol=Protocol.MCP, + ), + ] + + # Mock the close method to verify it gets called + with patch.object(ADCPMultiAgentClient, "close", new_callable=AsyncMock) as mock_close: + async with ADCPMultiAgentClient(agents) as client: + assert len(client.agents) == 2 + + # Verify close was called on context exit + mock_close.assert_called_once() + + +@pytest.mark.asyncio +async def test_client_context_manager_with_exception(): + """Test that ADCPClient properly closes even when an exception occurs.""" + from unittest.mock import AsyncMock, patch + + config = AgentConfig( + id="test_agent", + agent_uri="https://test.example.com", + protocol=Protocol.MCP, + ) + + # Mock the close method to verify it gets called + with patch.object(ADCPClient, "close", new_callable=AsyncMock) as mock_close: + try: + async with ADCPClient(config) as client: + assert client.agent_config == config + raise ValueError("Test exception") + except ValueError: + pass # Expected + + # Verify close was called even after exception + mock_close.assert_called_once() + + +@pytest.mark.asyncio +async def test_multi_agent_close_handles_adapter_failures(): + """Test that multi-agent close handles individual adapter failures gracefully.""" + from unittest.mock import AsyncMock, patch + + agents = [ + AgentConfig( + id="agent1", + agent_uri="https://agent1.example.com", + protocol=Protocol.A2A, + ), + AgentConfig( + id="agent2", + agent_uri="https://agent2.example.com", + protocol=Protocol.MCP, + ), + ] + + client = ADCPMultiAgentClient(agents) + + # Mock one adapter to fail during close + mock_close_success = AsyncMock() + mock_close_failure = AsyncMock(side_effect=RuntimeError("Cleanup error")) + + with ( + patch.object(client.agents["agent1"].adapter, "close", mock_close_success), + patch.object(client.agents["agent2"].adapter, "close", mock_close_failure), + ): + # Should not raise despite one adapter failing + await client.close() + + # Verify both adapters had close called + mock_close_success.assert_called_once() + mock_close_failure.assert_called_once() diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 7f1fc4e..15a3f41 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -341,3 +341,116 @@ class MockTextContent(BaseModel): assert result[0] == {"type": "text", "text": "Plain dict"} assert result[1] == {"type": "text", "text": "Pydantic object"} assert all(isinstance(item, dict) for item in result) + + @pytest.mark.asyncio + async def test_connection_failure_cleanup(self, mcp_config): + """Test that connection failures clean up resources properly.""" + from contextlib import AsyncExitStack + from unittest.mock import MagicMock + + import httpcore + + adapter = MCPAdapter(mcp_config) + + # Mock the exit stack to simulate connection failure + mock_exit_stack = AsyncMock(spec=AsyncExitStack) + mock_exit_stack.enter_async_context = AsyncMock( + side_effect=httpcore.ConnectError("Connection refused") + ) + # Simulate the anyio cleanup error that occurs in production + mock_exit_stack.aclose = AsyncMock( + side_effect=RuntimeError("Attempted to exit cancel scope in a different task") + ) + + with patch("adcp.protocols.mcp.AsyncExitStack", return_value=mock_exit_stack): + # Try to get session - should fail but cleanup gracefully + try: + await adapter._get_session() + except Exception: + pass # Expected to fail + + # Verify cleanup was attempted + mock_exit_stack.aclose.assert_called() + + # Verify adapter state is clean after failed connection + assert adapter._exit_stack is None + assert adapter._session is None + + @pytest.mark.asyncio + async def test_close_with_runtime_error(self, mcp_config): + """Test that close() handles RuntimeError from anyio cleanup gracefully.""" + from contextlib import AsyncExitStack + + adapter = MCPAdapter(mcp_config) + + # Set up a mock exit stack that raises RuntimeError on cleanup + mock_exit_stack = AsyncMock(spec=AsyncExitStack) + mock_exit_stack.aclose = AsyncMock( + side_effect=RuntimeError("Attempted to exit cancel scope in a different task") + ) + adapter._exit_stack = mock_exit_stack + + # close() should not raise despite the RuntimeError + await adapter.close() + + # Verify cleanup was attempted and state is clean + mock_exit_stack.aclose.assert_called_once() + assert adapter._exit_stack is None + assert adapter._session is None + + @pytest.mark.asyncio + async def test_close_with_cancellation(self, mcp_config): + """Test that close() handles CancelledError during cleanup.""" + import asyncio + from contextlib import AsyncExitStack + + adapter = MCPAdapter(mcp_config) + + # Set up a mock exit stack that raises CancelledError + mock_exit_stack = AsyncMock(spec=AsyncExitStack) + mock_exit_stack.aclose = AsyncMock(side_effect=asyncio.CancelledError()) + adapter._exit_stack = mock_exit_stack + + # close() should not raise despite the CancelledError + await adapter.close() + + # Verify cleanup was attempted and state is clean + mock_exit_stack.aclose.assert_called_once() + assert adapter._exit_stack is None + assert adapter._session is None + + @pytest.mark.asyncio + async def test_multiple_connection_attempts_with_cleanup_failures(self, mcp_config): + """Test that multiple connection attempts handle cleanup failures properly.""" + from contextlib import AsyncExitStack + + adapter = MCPAdapter(mcp_config) + + # Mock exit stack creation and cleanup + call_count = 0 + + def create_mock_exit_stack(): + nonlocal call_count + call_count += 1 + mock_stack = AsyncMock(spec=AsyncExitStack) + mock_stack.enter_async_context = AsyncMock( + side_effect=ConnectionError(f"Connection attempt {call_count} failed") + ) + mock_stack.aclose = AsyncMock( + side_effect=RuntimeError("Cancel scope error") if call_count == 1 else None + ) + return mock_stack + + with patch("adcp.protocols.mcp.AsyncExitStack", side_effect=create_mock_exit_stack): + # Try to get session - should fail after trying all URLs + try: + await adapter._get_session() + except Exception: + pass # Expected to fail + + # Verify multiple connection attempts were made (original URL + /mcp suffix) + assert call_count >= 1 + + # Verify adapter state is clean after all failed attempts + assert adapter._exit_stack is None + assert adapter._session is None