diff --git a/docs/migration-guide-0.7.md b/docs/migration-guide-0.7.md new file mode 100644 index 0000000..63aa906 --- /dev/null +++ b/docs/migration-guide-0.7.md @@ -0,0 +1,109 @@ +# Migrating to ACP Python SDK 0.7 + +ACP 0.7 reshapes the public surface so that Python-facing names, runtime helpers, and schema models line up with the evolving Agent Client Protocol schema. This guide covers the major changes in 0.7.0 and calls out the mechanical steps you need to apply in downstream agents, clients, and transports. + +## 1. `acp.schema` models now expose `snake_case` fields + +- Every generated model in `acp.schema` (see `src/acp/schema.py`) now uses Pythonic attribute names such as `session_id`, `stop_reason`, and `field_meta`. The JSON aliases (e.g., `alias="sessionId"`) stay intact so over-the-wire payloads remain camelCase. +- Instantiating a model or accessing response values must now use the `snake_case` form: + +```python +# Before (0.6 and earlier) +PromptResponse(stopReason="end_turn") +params.sessionId + +# After (0.7 and later) +PromptResponse(stop_reason="end_turn") +params.session_id +``` + +- If you relied on `model_dump()` to emit camelCase keys automatically, switch to `model_dump(by_alias=True)` (or use helpers such as `text_block`, `start_tool_call`, etc.) so responses continue to match the protocol. +- `field_meta` stays available for extension data. Any extra keys that were nested under `_meta` should now be provided via keyword arguments when constructing the schema models (see section 3). + +## 2. `acp.run_agent` and `acp.connect_to_agent` replace manual connection wiring + +`AgentSideConnection` and `ClientSideConnection` still exist internally, but the top-level entry points now prefer the helper functions implemented in `src/acp/core.py`. + +### Updating agents + +- Old pattern: + +```python +conn = AgentSideConnection(lambda conn: Agent(), writer, reader) +await asyncio.Event().wait() # keep running +``` + +- New pattern: + +```python +await run_agent(MyAgent(), input_stream=writer, output_stream=reader) +``` + +- When your agent just runs over stdio, call `await run_agent(MyAgent())` and the helper will acquire asyncio streams via `stdio_streams()` for you. + +### Updating clients and tests + +- Old pattern: + +```python +conn = ClientSideConnection(lambda conn: MyClient(), proc.stdin, proc.stdout) +``` + +- New pattern: + +```python +conn = connect_to_agent(MyClient(), proc.stdin, proc.stdout) +``` + +- `spawn_agent_process` / `spawn_client_process` now accept concrete `Agent`/`Client` instances instead of factories that received the connection. Instantiate your implementation first and pass it in. +- Importing the legacy connection classes via `acp.AgentSideConnection` / `acp.ClientSideConnection` issues a `DeprecationWarning` (see `src/acp/__init__.py:82-96`). Update your imports to `run_agent` and `connect_to_agent` to silence the warning. + +## 3. `Agent` and `Client` interface methods take explicit parameters + +Both interfaces in `src/acp/interfaces.py` now look like idiomatic Python protocols: methods use `snake_case` names and receive the individual schema fields rather than a single request model. + +### What changed + +- Method names follow `snake_case` (`request_permission`, `session_update`, `new_session`, `set_session_model`, etc.). +- Parameters represent the schema fields, so there is no need to unpack `params` manually. +- Each method is decorated with `@param_model(...)`. Combined with the `compatible_class` helper (see `src/acp/utils.py`), this keeps the camelCase wrappers alive for callers that still pass a full Pydantic request object—but those wrappers now emit `DeprecationWarning`s to encourage migration. + +### How to update your implementations + +1. Rename your method overrides to their `snake_case` equivalents. +2. Replace `params: Model` arguments with the concrete fields plus `**kwargs` to collect future `_meta` keys. +3. Access schema data directly via those parameters. + +Example migration for an agent: + +```python +# Before +class EchoAgent: + async def prompt(self, params: PromptRequest) -> PromptResponse: + text = params.prompt[0].text + return PromptResponse(stopReason="end_turn") + +# After +class EchoAgent: + async def prompt(self, prompt, session_id, **kwargs) -> PromptResponse: + text = prompt[0].text + return PromptResponse(stop_reason="end_turn") +``` + +Similarly, a client method such as `requestPermission` becomes: + +```python +class RecordingClient(Client): + async def request_permission(self, options, session_id, tool_call, **kwargs): + ... +``` + +### Additional notes + +- The connection layers automatically assemble the right request/response models using the `param_model` metadata, so callers do not need to build Pydantic objects manually anymore. +- For extension points (`field_meta`), pass keyword arguments from the connection into your handler signature: they arrive inside `**kwargs`. + +### Backward compatibility + +- The change should be 100% backward compatible as long as you update your method names and signatures. The `compatible_class` wrapper ensures that existing callers passing full request models continue to work. The old style API will remain functional before the next major release(1.0). +- Because camelCase wrappers remain for now, you can migrate file-by-file while still running against ACP 0.7. Just watch for the new deprecation warnings in your logs/tests. diff --git a/docs/quickstart.md b/docs/quickstart.md index cd15df8..ba5b80c 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -90,7 +90,7 @@ class SimpleClient(Client): async def main() -> None: script = Path("examples/echo_agent.py") - async with spawn_agent_process(lambda _agent: SimpleClient(), sys.executable, str(script)) as (conn, _proc): + async with spawn_agent_process(SimpleClient(), sys.executable, str(script)) as (conn, _proc): await conn.initialize(protocol_version=1) session = await conn.new_session(cwd=str(script.parent), mcp_servers=[]) await conn.prompt( @@ -119,7 +119,7 @@ class MyAgent(Agent): return PromptResponse(stop_reason="end_turn") ``` -Hook it up with `AgentSideConnection` inside an async entrypoint and wire it to your client. Refer to: +Run it with `run_agent()` inside an async entrypoint and wire it to your client. Refer to: - [`examples/echo_agent.py`](https://github.com/agentclientprotocol/python-sdk/blob/main/examples/echo_agent.py) for the smallest streaming agent - [`examples/agent.py`](https://github.com/agentclientprotocol/python-sdk/blob/main/examples/agent.py) for an implementation that negotiates capabilities and streams richer updates diff --git a/examples/agent.py b/examples/agent.py index 7f66216..3366817 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -11,11 +11,12 @@ NewSessionResponse, PromptResponse, SetSessionModeResponse, - stdio_streams, + run_agent, text_block, update_agent_message, PROTOCOL_VERSION, ) +from acp.interfaces import Client from acp.schema import ( AgentCapabilities, AgentMessageChunk, @@ -33,11 +34,15 @@ class ExampleAgent(Agent): - def __init__(self, conn: AgentSideConnection) -> None: - self._conn = conn + _conn: Client + + def __init__(self) -> None: self._next_session_id = 0 self._sessions: set[str] = set() + def on_connect(self, conn: Client) -> None: + self._conn = conn + async def _send_agent_message(self, session_id: str, content: Any) -> None: update = content if isinstance(content, AgentMessageChunk) else update_agent_message(content) await self._conn.session_update(session_id, update) @@ -114,9 +119,7 @@ async def ext_notification(self, method: str, params: dict[str, Any]) -> None: async def main() -> None: logging.basicConfig(level=logging.INFO) - reader, writer = await stdio_streams() - AgentSideConnection(ExampleAgent, writer, reader) - await asyncio.Event().wait() + await run_agent(ExampleAgent()) if __name__ == "__main__": diff --git a/examples/client.py b/examples/client.py index 6da8121..2cba0d3 100644 --- a/examples/client.py +++ b/examples/client.py @@ -10,9 +10,7 @@ from acp import ( Client, ClientSideConnection, - InitializeRequest, - NewSessionRequest, - PromptRequest, + connect_to_agent, RequestError, text_block, PROTOCOL_VERSION, @@ -190,7 +188,7 @@ async def main(argv: list[str]) -> int: return 1 client_impl = ExampleClient() - conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout) + conn = connect_to_agent(client_impl, proc.stdin, proc.stdout) await conn.initialize( protocol_version=PROTOCOL_VERSION, diff --git a/examples/duet.py b/examples/duet.py index e9c5e2f..f2c2871 100644 --- a/examples/duet.py +++ b/examples/duet.py @@ -30,7 +30,7 @@ async def main() -> int: client_module = _load_client_module(root / "client.py") client = client_module.ExampleClient() - async with spawn_agent_process(lambda _agent: client, sys.executable, str(agent_path), env=env) as ( + async with spawn_agent_process(client, sys.executable, str(agent_path), env=env) as ( conn, process, ): diff --git a/examples/echo_agent.py b/examples/echo_agent.py index f99c539..282a205 100644 --- a/examples/echo_agent.py +++ b/examples/echo_agent.py @@ -8,10 +8,11 @@ InitializeResponse, NewSessionResponse, PromptResponse, - stdio_streams, + run_agent, text_block, update_agent_message, ) +from acp.interfaces import Client from acp.schema import ( AudioContentBlock, ClientCapabilities, @@ -27,7 +28,9 @@ class EchoAgent(Agent): - def __init__(self, conn: AgentSideConnection) -> None: + _conn: Client + + def on_connect(self, conn: Client) -> None: self._conn = conn async def initialize( @@ -67,9 +70,7 @@ async def prompt( async def main() -> None: - reader, writer = await stdio_streams() - AgentSideConnection(EchoAgent, writer, reader) - await asyncio.Event().wait() + await run_agent(EchoAgent()) if __name__ == "__main__": diff --git a/examples/gemini.py b/examples/gemini.py index d1214fb..e9ec79c 100644 --- a/examples/gemini.py +++ b/examples/gemini.py @@ -13,11 +13,12 @@ from acp import ( Client, - ClientSideConnection, + connect_to_agent, PROTOCOL_VERSION, RequestError, text_block, ) +from acp.core import ClientSideConnection from acp.schema import ( AgentMessageChunk, AgentPlanUpdate, @@ -316,7 +317,7 @@ async def run(argv: list[str]) -> int: return 1 client_impl = GeminiClient(auto_approve=args.yolo) - conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout) + conn = connect_to_agent(client_impl, proc.stdin, proc.stdout) try: init_resp = await conn.initialize( diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index 700fad5..d340426 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -6,6 +6,7 @@ import re import subprocess import sys +import textwrap from collections.abc import Callable from dataclasses import dataclass from pathlib import Path @@ -327,11 +328,20 @@ def _ensure_custom_base_model(content: str) -> str: if not has_config: new_imports.append("ConfigDict") lines[idx] = "from pydantic import " + ", ".join(new_imports) + to_insert = textwrap.dedent("""\ + class BaseModel(_BaseModel): + model_config = ConfigDict(populate_by_name=True) + + def __getattr__(self, item: str) -> Any: + if item.lower() != item: + snake_cased = "".join("_" + c.lower() if c.isupper() and i > 0 else c.lower() for i, c in enumerate(item)) + return getattr(self, snake_cased) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + """) insert_idx = idx + 1 lines.insert(insert_idx, "") - lines.insert(insert_idx + 1, "class BaseModel(_BaseModel):") - lines.insert(insert_idx + 2, " model_config = ConfigDict(populate_by_name=True)") - lines.insert(insert_idx + 3, "") + for offset, line in enumerate(to_insert.splitlines(), 1): + lines.insert(insert_idx + offset, line) break return "\n".join(lines) + "\n" diff --git a/src/acp/__init__.py b/src/acp/__init__.py index 3f5e72f..a25e664 100644 --- a/src/acp/__init__.py +++ b/src/acp/__init__.py @@ -1,10 +1,12 @@ +from typing import Any + from .core import ( Agent, - AgentSideConnection, Client, - ClientSideConnection, RequestError, TerminalHandle, + connect_to_agent, + run_agent, ) from .helpers import ( audio_block, @@ -73,6 +75,19 @@ from .stdio import spawn_agent_process, spawn_client_process, spawn_stdio_connection, stdio_streams from .transports import default_environment, spawn_stdio_transport +_DEPRECATED_NAMES = [ + ( + "AgentSideConnection", + "acp.core:AgentSideConnection", + "Using `AgentSideConnection` directly is deprecated, please use `acp.run_agent` instead.", + ), + ( + "ClientSideConnection", + "acp.core:ClientSideConnection", + "Using `ClientSideConnection` directly is deprecated, please use `acp.connect_to_agent` instead.", + ), +] + __all__ = [ # noqa: RUF022 # constants "PROTOCOL_VERSION", @@ -113,8 +128,8 @@ "ReleaseTerminalRequest", "ReleaseTerminalResponse", # core - "AgentSideConnection", - "ClientSideConnection", + "run_agent", + "connect_to_agent", "RequestError", "Agent", "Client", @@ -151,3 +166,16 @@ "start_edit_tool_call", "update_tool_call", ] + + +def __getattr__(name: str) -> Any: + import warnings + from importlib import import_module + + for deprecated_name, new_path, warning in _DEPRECATED_NAMES: + if name == deprecated_name: + warnings.warn(warning, DeprecationWarning, stacklevel=2) + module_name, attr_name = new_path.split(":") + module = import_module(module_name) + return getattr(module, attr_name) + raise AttributeError(f"module {__name__} has no attribute {name}") # noqa: TRY003 diff --git a/src/acp/agent/connection.py b/src/acp/agent/connection.py index 0b4d542..18a82dd 100644 --- a/src/acp/agent/connection.py +++ b/src/acp/agent/connection.py @@ -2,10 +2,10 @@ import asyncio from collections.abc import Callable -from typing import Any +from typing import Any, cast, final from ..connection import Connection -from ..interfaces import Agent +from ..interfaces import Agent, Client from ..meta import CLIENT_METHODS from ..schema import ( AgentMessageChunk, @@ -38,28 +38,37 @@ WriteTextFileResponse, ) from ..terminal import TerminalHandle -from ..utils import notify_model, param_model, request_model, request_optional_model +from ..utils import compatible_class, notify_model, param_model, request_model, request_optional_model from .router import build_agent_router __all__ = ["AgentSideConnection"] _AGENT_CONNECTION_ERROR = "AgentSideConnection requires asyncio StreamWriter/StreamReader" +@final +@compatible_class class AgentSideConnection: """Agent-side connection wrapper that dispatches JSON-RPC messages to a Client implementation.""" def __init__( self, - to_agent: Callable[[AgentSideConnection], Agent], + to_agent: Callable[[Client], Agent] | Agent, input_stream: Any, output_stream: Any, + listening: bool = False, **connection_kwargs: Any, ) -> None: - agent = to_agent(self) - handler = build_agent_router(agent) + agent = to_agent(cast(Client, self)) if callable(to_agent) else to_agent if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_AGENT_CONNECTION_ERROR) - self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) + handler = build_agent_router(cast(Agent, agent)) + self._conn = Connection(handler, input_stream, output_stream, listening=listening, **connection_kwargs) + if on_connect := getattr(agent, "on_connect", None): + on_connect(self) + + async def listen(self) -> None: + """Start listening for incoming messages.""" + await self._conn.main_loop() @param_model(SessionNotification) async def session_update( diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index 857e7d8..292755a 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -2,7 +2,7 @@ import asyncio from collections.abc import Callable -from typing import Any +from typing import Any, cast, final from ..connection import Connection from ..interfaces import Agent, Client @@ -34,24 +34,32 @@ StdioMcpServer, TextContentBlock, ) -from ..utils import notify_model, param_model, request_model, request_model_from_dict +from ..utils import compatible_class, notify_model, param_model, request_model, request_model_from_dict from .router import build_client_router __all__ = ["ClientSideConnection"] _CLIENT_CONNECTION_ERROR = "ClientSideConnection requires asyncio StreamWriter/StreamReader" +@final +@compatible_class class ClientSideConnection: """Client-side connection wrapper that dispatches JSON-RPC messages to an Agent implementation.""" def __init__( - self, to_client: Callable[[Agent], Client], input_stream: Any, output_stream: Any, **connection_kwargs: Any + self, + to_client: Callable[[Agent], Client] | Client, + input_stream: Any, + output_stream: Any, + **connection_kwargs: Any, ) -> None: if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_CLIENT_CONNECTION_ERROR) - client = to_client(self) - handler = build_client_router(client) + client = to_client(cast(Agent, self)) if callable(to_client) else to_client + handler = build_client_router(cast(Client, client)) self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) + if on_connect := getattr(client, "on_connect", None): + on_connect(self) @param_model(InitializeRequest) async def initialize( diff --git a/src/acp/connection.py b/src/acp/connection.py index 34142d7..aca1c19 100644 --- a/src/acp/connection.py +++ b/src/acp/connection.py @@ -72,6 +72,7 @@ def __init__( dispatcher_factory: DispatcherFactory | None = None, sender_factory: SenderFactory | None = None, observers: list[StreamObserver] | None = None, + listening: bool = True, ) -> None: self._handler = handler self._writer = writer @@ -83,11 +84,14 @@ def __init__( self._queue = queue or InMemoryMessageQueue() self._closed = False self._sender = (sender_factory or self._default_sender_factory)(self._writer, self._tasks) - self._recv_task = self._tasks.create( - self._receive_loop(), - name="acp.Connection.receive", - on_error=self._on_receive_error, - ) + if listening: + self._recv_task = self._tasks.create( + self._receive_loop(), + name="acp.Connection.receive", + on_error=self._on_receive_error, + ) + else: + self._recv_task = None dispatcher_factory = dispatcher_factory or self._default_dispatcher_factory self._dispatcher = dispatcher_factory( self._queue, @@ -109,6 +113,14 @@ async def close(self) -> None: await self._tasks.shutdown() self._state.reject_all_outgoing(ConnectionError("Connection closed")) + async def main_loop(self) -> None: + try: + await self._receive_loop() + except Exception as exc: + logging.exception("Connection main loop failed", exc_info=exc) + self._on_receive_error(None, exc) # type: ignore[arg-type] + raise + async def __aenter__(self) -> Connection: return self diff --git a/src/acp/core.py b/src/acp/core.py index 8afa468..dd6342e 100644 --- a/src/acp/core.py +++ b/src/acp/core.py @@ -7,6 +7,8 @@ from __future__ import annotations +from typing import Any + from .agent.connection import AgentSideConnection from .client.connection import ClientSideConnection from .connection import Connection, JsonValue, MethodHandler @@ -24,4 +26,47 @@ "MethodHandler", "RequestError", "TerminalHandle", + "connect_to_agent", + "run_agent", ] + + +async def run_agent( + agent: Agent, input_stream: Any = None, output_stream: Any = None, **connection_kwargs: Any +) -> None: + """Run an ACP agent over the given input/output streams. + + This is a convenience function that creates an :class:`AgentSideConnection` + and starts listening for incoming messages. + + Args: + agent: The agent implementation to run. + input_stream: The (client) input stream to write to (defaults: ``sys.stdin``). + output_stream: The (client) output stream to read from (defaults: ``sys.stdout``). + **connection_kwargs: Additional keyword arguments to pass to the + :class:`AgentSideConnection` constructor. + """ + from .stdio import stdio_streams + + if input_stream is None and output_stream is None: + output_stream, input_stream = await stdio_streams() + conn = AgentSideConnection(agent, input_stream, output_stream, **connection_kwargs) + await conn.listen() + + +def connect_to_agent( + client: Client, input_stream: Any, output_stream: Any, **connection_kwargs: Any +) -> ClientSideConnection: + """Create a ClientSideConnection to an ACP agent over the given input/output streams. + + Args: + client: The client implementation to use. + input_stream: The (agent) input stream to write to (default: ``sys.stdin``). + output_stream: The (agent) output stream to read from (default: ``sys.stdout``). + **connection_kwargs: Additional keyword arguments to pass to the + :class:`ClientSideConnection` constructor. + + Returns: + A :class:`ClientSideConnection` instance connected to the agent. + """ + return ClientSideConnection(client, input_stream, output_stream, **connection_kwargs) diff --git a/src/acp/interfaces.py b/src/acp/interfaces.py index 61c4f27..26021cb 100644 --- a/src/acp/interfaces.py +++ b/src/acp/interfaces.py @@ -127,6 +127,8 @@ async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any async def ext_notification(self, method: str, params: dict[str, Any]) -> None: ... + def on_connect(self, conn: Agent) -> None: ... + class Agent(Protocol): @param_model(InitializeRequest) @@ -179,3 +181,5 @@ async def cancel(self, session_id: str, **kwargs: Any) -> None: ... async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: ... async def ext_notification(self, method: str, params: dict[str, Any]) -> None: ... + + def on_connect(self, conn: Client) -> None: ... diff --git a/src/acp/router.py b/src/acp/router.py index 1d53473..e3e0929 100644 --- a/src/acp/router.py +++ b/src/acp/router.py @@ -1,11 +1,15 @@ from __future__ import annotations +import inspect +import warnings from collections.abc import Awaitable, Callable from dataclasses import dataclass from typing import Any, Literal, TypeVar from pydantic import BaseModel +from acp.utils import to_camel_case + from .exceptions import RequestError __all__ = ["MessageRouter", "Route"] @@ -50,12 +54,34 @@ def add_route(self, route: Route) -> None: self._notifications[route.method] = route def _make_func(self, model: type[BaseModel], obj: Any, attr: str) -> AsyncHandler | None: + legacy_api = False func = getattr(obj, attr, None) + if func is None and "_" in attr: + attr = to_camel_case(attr) + func = getattr(obj, attr, None) + legacy_api = True + elif callable(func) and "_" not in attr: + original_func = func + if hasattr(func, "__func__"): + original_func = func.__func__ + parameters = inspect.signature(original_func).parameters + if len(parameters) == 2 and "params" in parameters: + legacy_api = True + if func is None or not callable(func): return None async def wrapper(params: Any) -> Any: + if legacy_api: + warnings.warn( + f"The old style method {type(obj).__name__}.{attr} is deprecated, " + "please update to the snake-cased form.", + DeprecationWarning, + stacklevel=3, + ) model_obj = model.model_validate(params) + if legacy_api: + return await func(model_obj) # type: ignore[arg-type] params = {k: getattr(model_obj, k) for k in model.model_fields if k != "field_meta"} if meta := getattr(model_obj, "field_meta", None): params.update(meta) diff --git a/src/acp/schema.py b/src/acp/schema.py index 1d8b91e..3d8ae81 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -19,6 +19,12 @@ class BaseModel(_BaseModel): model_config = ConfigDict(populate_by_name=True) + def __getattr__(self, item: str) -> Any: + if item.lower() != item: + snake_cased = "".join("_" + c.lower() if c.isupper() and i > 0 else c.lower() for i, c in enumerate(item)) + return getattr(self, snake_cased) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + class Jsonrpc(Enum): field_2_0 = "2.0" diff --git a/src/acp/stdio.py b/src/acp/stdio.py index 88917c9..d58644e 100644 --- a/src/acp/stdio.py +++ b/src/acp/stdio.py @@ -160,7 +160,7 @@ async def spawn_stdio_connection( @asynccontextmanager async def spawn_agent_process( - to_client: Callable[[Agent], Client], + to_client: Callable[[Agent], Client] | Client, command: str, *args: str, env: Mapping[str, str] | None = None, @@ -185,7 +185,7 @@ async def spawn_agent_process( @asynccontextmanager async def spawn_client_process( - to_agent: Callable[[AgentSideConnection], Agent], + to_agent: Callable[[Client], Agent] | Agent, command: str, *args: str, env: Mapping[str, str] | None = None, diff --git a/src/acp/utils.py b/src/acp/utils.py index e32f467..1be9c19 100644 --- a/src/acp/utils.py +++ b/src/acp/utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +import functools +import warnings from collections.abc import Callable from typing import Any, TypeVar @@ -23,6 +25,7 @@ ModelT = TypeVar("ModelT", bound=BaseModel) MethodT = TypeVar("MethodT", bound=Callable) ClassT = TypeVar("ClassT", bound=type) +T = TypeVar("T") def serialize_params(params: BaseModel) -> dict[str, Any]: @@ -105,6 +108,67 @@ def param_model(param_cls: type[BaseModel]) -> Callable[[MethodT], MethodT]: """ def decorator(func: MethodT) -> MethodT: + func.__param_model__ = param_cls # type: ignore[attr-defined] return func return decorator + + +def to_camel_case(snake_str: str) -> str: + """Convert snake_case strings to camelCase.""" + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def _make_legacy_func(func: Callable[..., T], model: type[BaseModel]) -> Callable[[Any, BaseModel], T]: + @functools.wraps(func) + def wrapped(self, params: BaseModel) -> T: + warnings.warn( + f"Calling {func.__name__} with {model.__name__} parameter is " # type: ignore[attr-defined] + "deprecated, please update to the new API style.", + DeprecationWarning, + stacklevel=3, + ) + kwargs = {k: getattr(params, k) for k in model.model_fields if k != "field_meta"} + if meta := getattr(params, "field_meta", None): + kwargs.update(meta) + return func(self, **kwargs) # type: ignore[arg-type] + + return wrapped + + +def _make_compatible_func(func: Callable[..., T], model: type[BaseModel]) -> Callable[..., T]: + @functools.wraps(func) + def wrapped(self, *args: Any, **kwargs: Any) -> T: + param = None + if not kwargs and len(args) == 1: + param = args[0] + elif not args and len(kwargs) == 1: + param = kwargs.get("params") + if isinstance(param, model): + warnings.warn( + f"Calling {func.__name__} with {model.__name__} parameter " # type: ignore[attr-defined] + "is deprecated, please update to the new API style.", + DeprecationWarning, + stacklevel=3, + ) + kwargs = {k: getattr(param, k) for k in model.model_fields if k != "field_meta"} + if meta := getattr(param, "field_meta", None): + kwargs.update(meta) + return func(self, **kwargs) # type: ignore[arg-type] + return func(self, *args, **kwargs) + + return wrapped + + +def compatible_class(cls: ClassT) -> ClassT: + """Mark a class as backward compatible with old API style.""" + for attr in dir(cls): + func = getattr(cls, attr) + if not callable(func) or (model := getattr(func, "__param_model__", None)) is None: + continue + if "_" in attr: + setattr(cls, to_camel_case(attr), _make_legacy_func(func, model)) + else: + setattr(cls, attr, _make_compatible_func(func, model)) + return cls diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5f5cd8e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,308 @@ +import asyncio +import contextlib +from collections.abc import AsyncGenerator, Callable +from typing import Any + +import pytest +import pytest_asyncio + +from acp import ( + AuthenticateResponse, + CreateTerminalResponse, + InitializeResponse, + KillTerminalCommandResponse, + LoadSessionResponse, + NewSessionResponse, + PromptRequest, + PromptResponse, + ReadTextFileResponse, + ReleaseTerminalResponse, + RequestError, + RequestPermissionResponse, + SessionNotification, + SetSessionModelResponse, + SetSessionModeResponse, + TerminalOutputResponse, + WaitForTerminalExitResponse, + WriteTextFileResponse, +) +from acp.core import AgentSideConnection, ClientSideConnection +from acp.schema import ( + AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, + AllowedOutcome, + AudioContentBlock, + AvailableCommandsUpdate, + ClientCapabilities, + CurrentModeUpdate, + DeniedOutcome, + EmbeddedResourceContentBlock, + EnvVariable, + HttpMcpServer, + ImageContentBlock, + Implementation, + PermissionOption, + ResourceContentBlock, + SseMcpServer, + StdioMcpServer, + TextContentBlock, + ToolCall, + ToolCallProgress, + ToolCallStart, + UserMessageChunk, +) + + +class _Server: + def __init__(self) -> None: + self._server: asyncio.AbstractServer | None = None + self._server_reader: asyncio.StreamReader | None = None + self._server_writer: asyncio.StreamWriter | None = None + self._client_reader: asyncio.StreamReader | None = None + self._client_writer: asyncio.StreamWriter | None = None + + async def __aenter__(self): + async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + self._server_reader = reader + self._server_writer = writer + + self._server = await asyncio.start_server(handle, host="127.0.0.1", port=0) + host, port = self._server.sockets[0].getsockname()[:2] + self._client_reader, self._client_writer = await asyncio.open_connection(host, port) + + # wait until server side is set + for _ in range(100): + if self._server_reader and self._server_writer: + break + await asyncio.sleep(0.01) + assert self._server_reader and self._server_writer + assert self._client_reader and self._client_writer + return self + + async def __aexit__(self, exc_type, exc, tb): + if self._client_writer: + self._client_writer.close() + with contextlib.suppress(Exception): + await self._client_writer.wait_closed() + if self._server_writer: + self._server_writer.close() + with contextlib.suppress(Exception): + await self._server_writer.wait_closed() + if self._server: + self._server.close() + await self._server.wait_closed() + + @property + def server_writer(self) -> asyncio.StreamWriter: + assert self._server_writer is not None + return self._server_writer + + @property + def server_reader(self) -> asyncio.StreamReader: + assert self._server_reader is not None + return self._server_reader + + @property + def client_writer(self) -> asyncio.StreamWriter: + assert self._client_writer is not None + return self._client_writer + + @property + def client_reader(self) -> asyncio.StreamReader: + assert self._client_reader is not None + return self._client_reader + + +@pytest_asyncio.fixture +async def server() -> AsyncGenerator[_Server, None]: + """Provides a server-client connection pair for testing.""" + async with _Server() as server_instance: + yield server_instance + + +class TestClient: + __test__ = False # prevent pytest from collecting this class + + def __init__(self) -> None: + self.permission_outcomes: list[RequestPermissionResponse] = [] + self.files: dict[str, str] = {} + self.notifications: list[SessionNotification] = [] + self.ext_calls: list[tuple[str, dict]] = [] + self.ext_notes: list[tuple[str, dict]] = [] + + def queue_permission_cancelled(self) -> None: + self.permission_outcomes.append(RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled"))) + + def queue_permission_selected(self, option_id: str) -> None: + self.permission_outcomes.append( + RequestPermissionResponse(outcome=AllowedOutcome(option_id=option_id, outcome="selected")) + ) + + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any + ) -> RequestPermissionResponse: + if self.permission_outcomes: + return self.permission_outcomes.pop() + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) + + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: + self.files[str(path)] = content + return WriteTextFileResponse() + + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: + content = self.files.get(str(path), "default content") + return ReadTextFileResponse(content=content) + + async def session_update( + self, + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: + self.notifications.append(SessionNotification(session_id=session_id, update=update, field_meta=kwargs or None)) + + # Optional terminal methods (not implemented in this test client) + async def create_terminal( + self, + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> CreateTerminalResponse: + raise NotImplementedError + + async def terminal_output( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> TerminalOutputResponse: # pragma: no cover - placeholder + raise NotImplementedError + + async def release_terminal( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> ReleaseTerminalResponse | None: + raise NotImplementedError + + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> WaitForTerminalExitResponse: + raise NotImplementedError + + async def kill_terminal( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> KillTerminalCommandResponse | None: + raise NotImplementedError + + async def ext_method(self, method: str, params: dict) -> dict: + self.ext_calls.append((method, params)) + if method == "example.com/ping": + return {"response": "pong", "params": params} + raise RequestError.method_not_found(method) + + async def ext_notification(self, method: str, params: dict) -> None: + self.ext_notes.append((method, params)) + + +class TestAgent: + __test__ = False # prevent pytest from collecting this class + + def __init__(self) -> None: + self.prompts: list[PromptRequest] = [] + self.cancellations: list[str] = [] + self.ext_calls: list[tuple[str, dict]] = [] + self.ext_notes: list[tuple[str, dict]] = [] + + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: + # Avoid serializer warnings by omitting defaults + return InitializeResponse(protocol_version=protocol_version, agent_capabilities=None, auth_methods=[]) + + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any + ) -> NewSessionResponse: + return NewSessionResponse(session_id="test-session-123") + + async def load_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any + ) -> LoadSessionResponse | None: + return LoadSessionResponse() + + async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None: + return AuthenticateResponse() + + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + self.prompts.append(PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None)) + return PromptResponse(stop_reason="end_turn") + + async def cancel(self, session_id: str, **kwargs: Any) -> None: + self.cancellations.append(session_id) + + async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None: + return SetSessionModeResponse() + + async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any) -> SetSessionModelResponse | None: + return SetSessionModelResponse() + + async def ext_method(self, method: str, params: dict) -> dict: + self.ext_calls.append((method, params)) + if method == "example.com/echo": + return {"echo": params} + raise RequestError.method_not_found(method) + + async def ext_notification(self, method: str, params: dict) -> None: + self.ext_notes.append((method, params)) + + +@pytest.fixture(name="agent") +def agent_fixture() -> TestAgent: + return TestAgent() + + +@pytest.fixture(name="client") +def client_fixture() -> TestClient: + return TestClient() + + +@pytest.fixture(name="connect") +def connect_func(server, agent, client) -> Callable[[bool, bool], tuple[AgentSideConnection, ClientSideConnection]]: + def _connect( + connect_agent: bool = True, connect_client: bool = True + ) -> tuple[AgentSideConnection, ClientSideConnection]: + agent_conn = None + client_conn = None + if connect_agent: + agent_conn = AgentSideConnection(agent, server.server_writer, server.server_reader, listening=True) + if connect_client: + client_conn = ClientSideConnection(client, server.client_writer, server.client_reader) + return agent_conn, client_conn # type: ignore[return-value] + + return _connect diff --git a/tests/real_user/test_cancel_prompt_flow.py b/tests/real_user/test_cancel_prompt_flow.py index 64b76e5..bdd20eb 100644 --- a/tests/real_user/test_cancel_prompt_flow.py +++ b/tests/real_user/test_cancel_prompt_flow.py @@ -3,16 +3,16 @@ import pytest -from acp import AgentSideConnection, ClientSideConnection, PromptResponse from acp.schema import ( AudioContentBlock, EmbeddedResourceContentBlock, ImageContentBlock, PromptRequest, + PromptResponse, ResourceContentBlock, TextContentBlock, ) -from tests.test_rpc import TestAgent, TestClient, _Server +from tests.conftest import TestAgent # Regression from a real user session where cancel needed to interrupt a long-running prompt. @@ -52,27 +52,24 @@ async def cancel(self, session_id: str, **kwargs: Any) -> None: @pytest.mark.asyncio -async def test_cancel_reaches_agent_during_prompt() -> None: - async with _Server() as server: - agent = LongRunningAgent() - client = TestClient() - agent_conn = ClientSideConnection(lambda _conn: client, server._client_writer, server._client_reader) - _client_conn = AgentSideConnection(lambda _conn: agent, server._server_writer, server._server_reader) +@pytest.mark.parametrize("agent", [LongRunningAgent()]) +async def test_cancel_reaches_agent_during_prompt(connect, agent) -> None: + _, agent_conn = connect() - prompt_task = asyncio.create_task( - agent_conn.prompt( - session_id="sess-xyz", - prompt=[TextContentBlock(type="text", text="hello")], - ) + prompt_task = asyncio.create_task( + agent_conn.prompt( + session_id="sess-xyz", + prompt=[TextContentBlock(type="text", text="hello")], ) + ) - await agent.prompt_started.wait() - assert not prompt_task.done(), "Prompt finished before cancel was sent" + await agent.prompt_started.wait() + assert not prompt_task.done(), "Prompt finished before cancel was sent" - await agent_conn.cancel(session_id="sess-xyz") + await agent_conn.cancel(session_id="sess-xyz") - await asyncio.wait_for(agent.cancel_received.wait(), timeout=1.0) + await asyncio.wait_for(agent.cancel_received.wait(), timeout=1.0) - response = await asyncio.wait_for(prompt_task, timeout=1.0) - assert response.stop_reason == "cancelled" - assert agent.cancellations == ["sess-xyz"] + response = await asyncio.wait_for(prompt_task, timeout=1.0) + assert response.stop_reason == "cancelled" + assert agent.cancellations == ["sess-xyz"] diff --git a/tests/real_user/test_permission_flow.py b/tests/real_user/test_permission_flow.py index 6c6d3ca..c2a2494 100644 --- a/tests/real_user/test_permission_flow.py +++ b/tests/real_user/test_permission_flow.py @@ -3,7 +3,8 @@ import pytest -from acp import AgentSideConnection, ClientSideConnection, PromptResponse +from acp import PromptResponse +from acp.core import AgentSideConnection, ClientSideConnection from acp.schema import ( AudioContentBlock, EmbeddedResourceContentBlock, @@ -13,7 +14,7 @@ TextContentBlock, ToolCall, ) -from tests.test_rpc import TestAgent, TestClient, _Server +from tests.conftest import TestAgent, TestClient # Regression from real-world runs where agents paused prompts to obtain user permission. @@ -51,32 +52,32 @@ async def prompt( @pytest.mark.asyncio -async def test_agent_request_permission_roundtrip() -> None: - async with _Server() as server: - client = TestClient() - client.queue_permission_selected("allow") +async def test_agent_request_permission_roundtrip(server) -> None: + client = TestClient() + client.queue_permission_selected("allow") - captured_agent = [] + captured_agent = [] - agent_conn = ClientSideConnection(lambda _conn: client, server._client_writer, server._client_reader) - _agent_conn = AgentSideConnection( - lambda conn: captured_agent.append(PermissionRequestAgent(conn)) or captured_agent[-1], - server._server_writer, - server._server_reader, - ) + agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type] + _agent_conn = AgentSideConnection( + lambda conn: captured_agent.append(PermissionRequestAgent(conn)) or captured_agent[-1], + server._server_writer, + server._server_reader, + listening=True, + ) - response = await asyncio.wait_for( - agent_conn.prompt( - session_id="sess-perm", - prompt=[TextContentBlock(type="text", text="needs approval")], - ), - timeout=1.0, - ) - assert response.stop_reason == "end_turn" + response = await asyncio.wait_for( + agent_conn.prompt( + session_id="sess-perm", + prompt=[TextContentBlock(type="text", text="needs approval")], + ), + timeout=1.0, + ) + assert response.stop_reason == "end_turn" - assert captured_agent, "Agent was not constructed" - [agent] = captured_agent - assert agent.permission_responses, "Agent did not receive permission response" - permission_response = agent.permission_responses[0] - assert permission_response.outcome.outcome == "selected" - assert permission_response.outcome.option_id == "allow" + assert captured_agent, "Agent was not constructed" + [agent] = captured_agent + assert agent.permission_responses, "Agent did not receive permission response" + permission_response = agent.permission_responses[0] + assert permission_response.outcome.outcome == "selected" + assert permission_response.outcome.option_id == "allow" diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py new file mode 100644 index 0000000..013427e --- /dev/null +++ b/tests/test_compatibility.py @@ -0,0 +1,169 @@ +import pytest + +from acp import ( + AuthenticateResponse, + InitializeResponse, + LoadSessionResponse, + NewSessionResponse, + PromptRequest, + PromptResponse, + ReadTextFileResponse, + RequestError, + RequestPermissionResponse, + SessionNotification, + SetSessionModelResponse, + SetSessionModeResponse, + WriteTextFileResponse, +) +from acp.schema import ( + AllowedOutcome, + AuthenticateRequest, + CancelNotification, + DeniedOutcome, + InitializeRequest, + LoadSessionRequest, + NewSessionRequest, + ReadTextFileRequest, + RequestPermissionRequest, + SetSessionModelRequest, + SetSessionModeRequest, + WriteTextFileRequest, +) + + +class LegacyAgent: + def __init__(self) -> None: + self.prompts: list[PromptRequest] = [] + self.cancellations: list[str] = [] + self.ext_calls: list[tuple[str, dict]] = [] + self.ext_notes: list[tuple[str, dict]] = [] + + async def initialize(self, params: InitializeRequest) -> InitializeResponse: + # Avoid serializer warnings by omitting defaults + return InitializeResponse(protocol_version=params.protocol_version, agent_capabilities=None, auth_methods=[]) + + async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: + return NewSessionResponse(session_id="test-session-123") + + async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: + return LoadSessionResponse() + + async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: + return AuthenticateResponse() + + async def prompt(self, params: PromptRequest) -> PromptResponse: + self.prompts.append(params) + return PromptResponse(stop_reason="end_turn") + + async def cancel(self, params: CancelNotification) -> None: + self.cancellations.append(params.session_id) + + async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: + return SetSessionModeResponse() + + async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse | None: + return SetSessionModelResponse() + + async def extMethod(self, method: str, params: dict) -> dict: + self.ext_calls.append((method, params)) + if method == "example.com/echo": + return {"echo": params} + raise RequestError.method_not_found(method) + + async def extNotification(self, method: str, params: dict) -> None: + self.ext_notes.append((method, params)) + + +class LegacyClient: + __test__ = False # prevent pytest from collecting this class + + def __init__(self) -> None: + self.permission_outcomes: list[RequestPermissionResponse] = [] + self.files: dict[str, str] = {} + self.notifications: list[SessionNotification] = [] + self.ext_calls: list[tuple[str, dict]] = [] + self.ext_notes: list[tuple[str, dict]] = [] + + def queue_permission_cancelled(self) -> None: + self.permission_outcomes.append(RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled"))) + + def queue_permission_selected(self, option_id: str) -> None: + self.permission_outcomes.append( + RequestPermissionResponse(outcome=AllowedOutcome(option_id=option_id, outcome="selected")) + ) + + async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: + if self.permission_outcomes: + return self.permission_outcomes.pop() + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) + + async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse: + self.files[str(params.path)] = params.content + return WriteTextFileResponse() + + async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse: + content = self.files.get(str(params.path), "default content") + return ReadTextFileResponse(content=content) + + async def sessionUpdate(self, params: SessionNotification) -> None: + self.notifications.append(params) + + # Optional terminal methods (not implemented in this test client) + async def createTerminal(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def terminalOutput(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def releaseTerminal(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def waitForTerminalExit(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def killTerminal(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def extMethod(self, method: str, params: dict) -> dict: + self.ext_calls.append((method, params)) + if method == "example.com/ping": + return {"response": "pong", "params": params} + raise RequestError.method_not_found(method) + + async def extNotification(self, method: str, params: dict) -> None: + self.ext_notes.append((method, params)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("agent,client", [(LegacyAgent(), LegacyClient())]) +async def test_initialize_and_new_session_compat(connect, client): + client_conn, agent_conn = connect() + + with pytest.warns(DeprecationWarning) as record: + resp = await agent_conn.newSession(NewSessionRequest(cwd="/home/tmp", mcp_servers=[])) + + assert len(record) == 2 + assert "Calling new_session with NewSessionRequest parameter is deprecated" in str(record[0].message) + assert "The old style method LegacyAgent.newSession is deprecated" in str(record[1].message) + + assert isinstance(resp, NewSessionResponse) + assert resp.session_id == "test-session-123" + + with pytest.warns(DeprecationWarning) as record: + resp = await agent_conn.new_session(cwd="/home/tmp", mcp_servers=[]) + assert len(record) == 1 + assert "The old style method LegacyAgent.newSession is deprecated" in str(record[0].message) + + with pytest.warns(DeprecationWarning) as record: + await client_conn.writeTextFile( + WriteTextFileRequest(path="test.txt", content="Hello, World!", session_id="test-session-123") + ) + + assert len(record) == 2 + assert client.files["test.txt"] == "Hello, World!" + + with pytest.warns(DeprecationWarning) as record: + resp = await client_conn.read_text_file(path="test.txt", session_id="test-session-123") + + assert len(record) == 1 + assert resp.content == "Hello, World!" diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 112e0c3..e45a881 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -1,5 +1,4 @@ import asyncio -import contextlib import json import sys from pathlib import Path @@ -9,27 +8,17 @@ from acp import ( Agent, - AgentSideConnection, AuthenticateResponse, Client, - ClientSideConnection, - CreateTerminalResponse, InitializeResponse, - KillTerminalCommandResponse, LoadSessionResponse, NewSessionResponse, PromptRequest, PromptResponse, - ReadTextFileResponse, - ReleaseTerminalResponse, - RequestError, RequestPermissionRequest, RequestPermissionResponse, - SessionNotification, SetSessionModelResponse, SetSessionModeResponse, - TerminalOutputResponse, - WaitForTerminalExitResponse, WriteTextFileResponse, spawn_agent_process, start_tool_call, @@ -38,16 +27,11 @@ ) from acp.schema import ( AgentMessageChunk, - AgentPlanUpdate, - AgentThoughtChunk, AllowedOutcome, AudioContentBlock, - AvailableCommandsUpdate, ClientCapabilities, - CurrentModeUpdate, DeniedOutcome, EmbeddedResourceContentBlock, - EnvVariable, HttpMcpServer, ImageContentBlock, Implementation, @@ -62,456 +46,196 @@ ToolCallStart, UserMessageChunk, ) - -# --------------------- Test Utilities --------------------- - - -class _Server: - def __init__(self) -> None: - self._server: asyncio.AbstractServer | None = None - self._server_reader: asyncio.StreamReader | None = None - self._server_writer: asyncio.StreamWriter | None = None - self._client_reader: asyncio.StreamReader | None = None - self._client_writer: asyncio.StreamWriter | None = None - - async def __aenter__(self): - async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - self._server_reader = reader - self._server_writer = writer - - self._server = await asyncio.start_server(handle, host="127.0.0.1", port=0) - host, port = self._server.sockets[0].getsockname()[:2] - self._client_reader, self._client_writer = await asyncio.open_connection(host, port) - - # wait until server side is set - for _ in range(100): - if self._server_reader and self._server_writer: - break - await asyncio.sleep(0.01) - assert self._server_reader and self._server_writer - assert self._client_reader and self._client_writer - return self - - async def __aexit__(self, exc_type, exc, tb): - if self._client_writer: - self._client_writer.close() - with contextlib.suppress(Exception): - await self._client_writer.wait_closed() - if self._server_writer: - self._server_writer.close() - with contextlib.suppress(Exception): - await self._server_writer.wait_closed() - if self._server: - self._server.close() - await self._server.wait_closed() - - @property - def server_writer(self) -> asyncio.StreamWriter: - assert self._server_writer is not None - return self._server_writer - - @property - def server_reader(self) -> asyncio.StreamReader: - assert self._server_reader is not None - return self._server_reader - - @property - def client_writer(self) -> asyncio.StreamWriter: - assert self._client_writer is not None - return self._client_writer - - @property - def client_reader(self) -> asyncio.StreamReader: - assert self._client_reader is not None - return self._client_reader - - -# --------------------- Test Doubles ----------------------- - - -class TestClient(Client): - __test__ = False # prevent pytest from collecting this class - - def __init__(self) -> None: - self.permission_outcomes: list[RequestPermissionResponse] = [] - self.files: dict[str, str] = {} - self.notifications: list[SessionNotification] = [] - self.ext_calls: list[tuple[str, dict]] = [] - self.ext_notes: list[tuple[str, dict]] = [] - - def queue_permission_cancelled(self) -> None: - self.permission_outcomes.append(RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled"))) - - def queue_permission_selected(self, option_id: str) -> None: - self.permission_outcomes.append( - RequestPermissionResponse(outcome=AllowedOutcome(option_id=option_id, outcome="selected")) - ) - - async def request_permission( - self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any - ) -> RequestPermissionResponse: - if self.permission_outcomes: - return self.permission_outcomes.pop() - return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) - - async def write_text_file( - self, content: str, path: str, session_id: str, **kwargs: Any - ) -> WriteTextFileResponse | None: - self.files[str(path)] = content - return WriteTextFileResponse() - - async def read_text_file( - self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any - ) -> ReadTextFileResponse: - content = self.files.get(str(path), "default content") - return ReadTextFileResponse(content=content) - - async def session_update( - self, - session_id: str, - update: UserMessageChunk - | AgentMessageChunk - | AgentThoughtChunk - | ToolCallStart - | ToolCallProgress - | AgentPlanUpdate - | AvailableCommandsUpdate - | CurrentModeUpdate, - **kwargs: Any, - ) -> None: - self.notifications.append(SessionNotification(session_id=session_id, update=update, field_meta=kwargs or None)) - - # Optional terminal methods (not implemented in this test client) - async def create_terminal( - self, - command: str, - session_id: str, - args: list[str] | None = None, - cwd: str | None = None, - env: list[EnvVariable] | None = None, - output_byte_limit: int | None = None, - **kwargs: Any, - ) -> CreateTerminalResponse: - raise NotImplementedError - - async def terminal_output( - self, session_id: str, terminal_id: str | None = None, **kwargs: Any - ) -> TerminalOutputResponse: # pragma: no cover - placeholder - raise NotImplementedError - - async def release_terminal( - self, session_id: str, terminal_id: str | None = None, **kwargs: Any - ) -> ReleaseTerminalResponse | None: - raise NotImplementedError - - async def wait_for_terminal_exit( - self, session_id: str, terminal_id: str | None = None, **kwargs: Any - ) -> WaitForTerminalExitResponse: - raise NotImplementedError - - async def kill_terminal( - self, session_id: str, terminal_id: str | None = None, **kwargs: Any - ) -> KillTerminalCommandResponse | None: - raise NotImplementedError - - async def ext_method(self, method: str, params: dict) -> dict: - self.ext_calls.append((method, params)) - if method == "example.com/ping": - return {"response": "pong", "params": params} - raise RequestError.method_not_found(method) - - async def ext_notification(self, method: str, params: dict) -> None: - self.ext_notes.append((method, params)) - - -class TestAgent(Agent): - __test__ = False # prevent pytest from collecting this class - - def __init__(self) -> None: - self.prompts: list[PromptRequest] = [] - self.cancellations: list[str] = [] - self.ext_calls: list[tuple[str, dict]] = [] - self.ext_notes: list[tuple[str, dict]] = [] - - async def initialize( - self, - protocol_version: int, - client_capabilities: ClientCapabilities | None = None, - client_info: Implementation | None = None, - **kwargs: Any, - ) -> InitializeResponse: - # Avoid serializer warnings by omitting defaults - return InitializeResponse(protocol_version=protocol_version, agent_capabilities=None, auth_methods=[]) - - async def new_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any - ) -> NewSessionResponse: - return NewSessionResponse(session_id="test-session-123") - - async def load_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any - ) -> LoadSessionResponse | None: - return LoadSessionResponse() - - async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None: - return AuthenticateResponse() - - async def prompt( - self, - prompt: list[ - TextContentBlock - | ImageContentBlock - | AudioContentBlock - | ResourceContentBlock - | EmbeddedResourceContentBlock - ], - session_id: str, - **kwargs: Any, - ) -> PromptResponse: - self.prompts.append(PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None)) - return PromptResponse(stop_reason="end_turn") - - async def cancel(self, session_id: str, **kwargs: Any) -> None: - self.cancellations.append(session_id) - - async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None: - return SetSessionModeResponse() - - async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any) -> SetSessionModelResponse | None: - return SetSessionModelResponse() - - async def ext_method(self, method: str, params: dict) -> dict: - self.ext_calls.append((method, params)) - if method == "example.com/echo": - return {"echo": params} - raise RequestError.method_not_found(method) - - async def ext_notification(self, method: str, params: dict) -> None: - self.ext_notes.append((method, params)) - +from tests.conftest import TestClient # ------------------------ Tests -------------------------- @pytest.mark.asyncio -async def test_initialize_and_new_session(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - # server side is agent; client side is client - agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) - AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) +async def test_initialize_and_new_session(connect): + _, agent_conn = connect() - resp = await agent_conn.initialize(protocol_version=1) - assert isinstance(resp, InitializeResponse) - assert resp.protocol_version == 1 + resp = await agent_conn.initialize(protocol_version=1) + assert isinstance(resp, InitializeResponse) + assert resp.protocol_version == 1 - new_sess = await agent_conn.new_session(mcp_servers=[], cwd="/test") - assert new_sess.session_id == "test-session-123" + new_sess = await agent_conn.new_session(mcp_servers=[], cwd="/test") + assert new_sess.session_id == "test-session-123" - load_resp = await agent_conn.load_session(session_id=new_sess.session_id, cwd="/test", mcp_servers=[]) - assert isinstance(load_resp, LoadSessionResponse) + load_resp = await agent_conn.load_session(session_id=new_sess.session_id, cwd="/test", mcp_servers=[]) + assert isinstance(load_resp, LoadSessionResponse) - auth_resp = await agent_conn.authenticate(method_id="password") - assert isinstance(auth_resp, AuthenticateResponse) + auth_resp = await agent_conn.authenticate(method_id="password") + assert isinstance(auth_resp, AuthenticateResponse) - mode_resp = await agent_conn.set_session_mode(session_id=new_sess.session_id, mode_id="ask") - assert isinstance(mode_resp, SetSessionModeResponse) + mode_resp = await agent_conn.set_session_mode(session_id=new_sess.session_id, mode_id="ask") + assert isinstance(mode_resp, SetSessionModeResponse) - model_resp = await agent_conn.set_session_model(session_id=new_sess.session_id, model_id="gpt-4o") - assert isinstance(model_resp, SetSessionModelResponse) + model_resp = await agent_conn.set_session_model(session_id=new_sess.session_id, model_id="gpt-4o") + assert isinstance(model_resp, SetSessionModelResponse) @pytest.mark.asyncio -async def test_bidirectional_file_ops(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - client.files["/test/file.txt"] = "Hello, World!" - _agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) +async def test_bidirectional_file_ops(client, connect): + client.files["/test/file.txt"] = "Hello, World!" + client_conn, _ = connect() - # Agent asks client to read - res = await client_conn.read_text_file(session_id="sess", path="/test/file.txt") - assert res.content == "Hello, World!" + # Agent asks client to read + res = await client_conn.read_text_file(session_id="sess", path="/test/file.txt") + assert res.content == "Hello, World!" - # Agent asks client to write - write_result = await client_conn.write_text_file(session_id="sess", path="/test/file.txt", content="Updated") - assert isinstance(write_result, WriteTextFileResponse) - assert client.files["/test/file.txt"] == "Updated" + # Agent asks client to write + write_result = await client_conn.write_text_file(session_id="sess", path="/test/file.txt", content="Updated") + assert isinstance(write_result, WriteTextFileResponse) + assert client.files["/test/file.txt"] == "Updated" @pytest.mark.asyncio -async def test_cancel_notification_and_capture_wire(): - async with _Server() as s: - # Build only agent-side (server) connection. Client side: raw reader to inspect wire - agent = TestAgent() - client = TestClient() - agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) - _client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) - - # Send cancel notification from client-side connection to agent - await agent_conn.cancel(session_id="test-123") - - # Read raw line from server peer (it will be consumed by agent receive loop quickly). - # Instead, wait a brief moment and assert agent recorded it. - for _ in range(50): - if agent.cancellations: - break - await asyncio.sleep(0.01) - assert agent.cancellations == ["test-123"] +async def test_cancel_notification_and_capture_wire(connect, agent): + _, agent_conn = connect() + # Send cancel notification from client-side connection to agent + await agent_conn.cancel(session_id="test-123") + # Read raw line from server peer (it will be consumed by agent receive loop quickly). + # Instead, wait a brief moment and assert agent recorded it. + for _ in range(50): + if agent.cancellations: + break + await asyncio.sleep(0.01) + assert agent.cancellations == ["test-123"] -@pytest.mark.asyncio -async def test_session_notifications_flow(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - _agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) - - # Agent -> Client notifications - await client_conn.session_update( - session_id="sess", - update=AgentMessageChunk( - session_update="agent_message_chunk", - content=TextContentBlock(type="text", text="Hello"), - ), - ) - await client_conn.session_update( - session_id="sess", - update=UserMessageChunk( - session_update="user_message_chunk", - content=TextContentBlock(type="text", text="World"), - ), - ) - # Wait for async dispatch - for _ in range(50): - if len(client.notifications) >= 2: - break - await asyncio.sleep(0.01) - assert len(client.notifications) >= 2 - assert client.notifications[0].session_id == "sess" +@pytest.mark.asyncio +async def test_session_notifications_flow(connect, client): + client_conn, _ = connect() + + # Agent -> Client notifications + await client_conn.session_update( + session_id="sess", + update=AgentMessageChunk( + session_update="agent_message_chunk", + content=TextContentBlock(type="text", text="Hello"), + ), + ) + await client_conn.session_update( + session_id="sess", + update=UserMessageChunk( + session_update="user_message_chunk", + content=TextContentBlock(type="text", text="World"), + ), + ) + + # Wait for async dispatch + for _ in range(50): + if len(client.notifications) >= 2: + break + await asyncio.sleep(0.01) + assert len(client.notifications) >= 2 + assert client.notifications[0].session_id == "sess" @pytest.mark.asyncio -async def test_concurrent_reads(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - for i in range(5): - client.files[f"/test/file{i}.txt"] = f"Content {i}" - _agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) +async def test_concurrent_reads(connect, client): + for i in range(5): + client.files[f"/test/file{i}.txt"] = f"Content {i}" + client_conn, _ = connect() - async def read_one(i: int): - return await client_conn.read_text_file(session_id="sess", path=f"/test/file{i}.txt") + async def read_one(i: int): + return await client_conn.read_text_file(session_id="sess", path=f"/test/file{i}.txt") - results = await asyncio.gather(*(read_one(i) for i in range(5))) - for i, res in enumerate(results): - assert res.content == f"Content {i}" + results = await asyncio.gather(*(read_one(i) for i in range(5))) + for i, res in enumerate(results): + assert res.content == f"Content {i}" @pytest.mark.asyncio -async def test_invalid_params_results_in_error_response(): - async with _Server() as s: - # Only start agent-side (server) so we can inject raw request from client socket - agent = TestAgent() - _server_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) +async def test_invalid_params_results_in_error_response(connect, server): + # Only start agent-side (server) so we can inject raw request from client socket + connect(connect_agent=True, connect_client=False) - # Send initialize with wrong param type (protocolVersion should be int) - req = {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "oops"}} - s.client_writer.write((json.dumps(req) + "\n").encode()) - await s.client_writer.drain() + # Send initialize with wrong param type (protocolVersion should be int) + req = {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "oops"}} + server.client_writer.write((json.dumps(req) + "\n").encode()) + await server.client_writer.drain() - # Read response - line = await asyncio.wait_for(s.client_reader.readline(), timeout=1) - resp = json.loads(line) - assert resp["id"] == 1 - assert "error" in resp - assert resp["error"]["code"] == -32602 # invalid params + # Read response + line = await asyncio.wait_for(server.client_reader.readline(), timeout=1) + resp = json.loads(line) + assert resp["id"] == 1 + assert "error" in resp + assert resp["error"]["code"] == -32602 # invalid params @pytest.mark.asyncio -async def test_method_not_found_results_in_error_response(): - async with _Server() as s: - agent = TestAgent() - _server_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) +async def test_method_not_found_results_in_error_response(connect, server): + connect(connect_agent=True, connect_client=False) - req = {"jsonrpc": "2.0", "id": 2, "method": "unknown/method", "params": {}} - s.client_writer.write((json.dumps(req) + "\n").encode()) - await s.client_writer.drain() + req = {"jsonrpc": "2.0", "id": 2, "method": "unknown/method", "params": {}} + server.client_writer.write((json.dumps(req) + "\n").encode()) + await server.client_writer.drain() - line = await asyncio.wait_for(s.client_reader.readline(), timeout=1) - resp = json.loads(line) - assert resp["id"] == 2 - assert resp["error"]["code"] == -32601 # method not found + line = await asyncio.wait_for(server.client_reader.readline(), timeout=1) + resp = json.loads(line) + assert resp["id"] == 2 + assert resp["error"]["code"] == -32601 # method not found @pytest.mark.asyncio -async def test_set_session_mode_and_extensions(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) +async def test_set_session_mode_and_extensions(connect, agent, client): + client_conn, agent_conn = connect() - # setSessionMode - resp = await agent_conn.set_session_mode(session_id="sess", mode_id="yolo") - assert isinstance(resp, SetSessionModeResponse) + # setSessionMode + resp = await agent_conn.set_session_mode(session_id="sess", mode_id="yolo") + assert isinstance(resp, SetSessionModeResponse) - model_resp = await agent_conn.set_session_model(session_id="sess", model_id="gpt-4o-mini") - assert isinstance(model_resp, SetSessionModelResponse) + model_resp = await agent_conn.set_session_model(session_id="sess", model_id="gpt-4o-mini") + assert isinstance(model_resp, SetSessionModelResponse) - # extMethod - echo = await agent_conn.ext_method("example.com/echo", {"x": 1}) - assert echo == {"echo": {"x": 1}} + # extMethod + echo = await agent_conn.ext_method("example.com/echo", {"x": 1}) + assert echo == {"echo": {"x": 1}} - # extNotification - await agent_conn.ext_notification("note", {"y": 2}) - # allow dispatch - await asyncio.sleep(0.05) - assert agent.ext_notes and agent.ext_notes[-1][0] == "note" + # extNotification + await agent_conn.ext_notification("note", {"y": 2}) + # allow dispatch + await asyncio.sleep(0.05) + assert agent.ext_notes and agent.ext_notes[-1][0] == "note" - # client extension method - ping = await client_conn.ext_method("example.com/ping", {"k": 3}) - assert ping == {"response": "pong", "params": {"k": 3}} - assert client.ext_calls and client.ext_calls[-1] == ("example.com/ping", {"k": 3}) + # client extension method + ping = await client_conn.ext_method("example.com/ping", {"k": 3}) + assert ping == {"response": "pong", "params": {"k": 3}} + assert client.ext_calls and client.ext_calls[-1] == ("example.com/ping", {"k": 3}) @pytest.mark.asyncio -async def test_ignore_invalid_messages(): - async with _Server() as s: - agent = TestAgent() - _server_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) +async def test_ignore_invalid_messages(connect, server): + connect(connect_agent=True, connect_client=False) - # Message without id and method - msg1 = {"jsonrpc": "2.0"} - s.client_writer.write((json.dumps(msg1) + "\n").encode()) - await s.client_writer.drain() + # Message without id and method + msg1 = {"jsonrpc": "2.0"} + server.client_writer.write((json.dumps(msg1) + "\n").encode()) + await server.client_writer.drain() - # Message without jsonrpc and without id/method - msg2 = {"foo": "bar"} - s.client_writer.write((json.dumps(msg2) + "\n").encode()) - await s.client_writer.drain() + # Message without jsonrpc and without id/method + msg2 = {"foo": "bar"} + server.client_writer.write((json.dumps(msg2) + "\n").encode()) + await server.client_writer.drain() - # Should not receive any response lines - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(s.client_reader.readline(), timeout=0.1) + # Should not receive any response lines + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(server.client_reader.readline(), timeout=0.1) class _ExampleAgent(Agent): __test__ = False def __init__(self) -> None: - self._conn: AgentSideConnection | None = None + self._conn: Client | None = None self.permission_response: RequestPermissionResponse | None = None self.prompt_requests: list[PromptRequest] = [] - def bind(self, conn: AgentSideConnection) -> "_ExampleAgent": + def on_connect(self, conn: Client) -> None: self._conn = conn - return self async def initialize( self, @@ -626,61 +350,57 @@ async def request_permission( @pytest.mark.asyncio -async def test_example_agent_permission_flow(): - async with _Server() as s: - agent = _ExampleAgent() - client = _ExampleClient() - - agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) - AgentSideConnection(lambda conn: agent.bind(conn), s._server_writer, s._server_reader) - - init = await agent_conn.initialize(protocol_version=1) - assert init.protocol_version == 1 - - session = await agent_conn.new_session(mcp_servers=[], cwd="/workspace") - assert session.session_id == "sess_demo" - - resp = await agent_conn.prompt( - session_id=session.session_id, - prompt=[TextContentBlock(type="text", text="Please edit config")], - ) - assert resp.stop_reason == "end_turn" - for _ in range(50): - if len(client.notifications) >= 4: - break - await asyncio.sleep(0.02) - - assert len(client.notifications) >= 4 - session_updates = [getattr(note.update, "session_update", None) for note in client.notifications] - assert session_updates[:4] == ["agent_message_chunk", "tool_call", "tool_call_update", "agent_message_chunk"] - - first_message = client.notifications[0].update - assert isinstance(first_message, AgentMessageChunk) - assert isinstance(first_message.content, TextContentBlock) - assert first_message.content.text == "I'll help you with that." - - tool_call = client.notifications[1].update - assert isinstance(tool_call, ToolCallStart) - assert tool_call.title == "Modifying configuration" - assert tool_call.status == "pending" - - tool_update = client.notifications[2].update - assert isinstance(tool_update, ToolCallProgress) - assert tool_update.status == "completed" - assert tool_update.raw_output == {"success": True} - - final_message = client.notifications[3].update - assert isinstance(final_message, AgentMessageChunk) - assert isinstance(final_message.content, TextContentBlock) - assert final_message.content.text == "Done." - - assert len(client.permission_requests) == 1 - options = client.permission_requests[0].options - assert [opt.option_id for opt in options] == ["allow", "reject"] - - assert agent.permission_response is not None - assert isinstance(agent.permission_response.outcome, AllowedOutcome) - assert agent.permission_response.outcome.option_id == "allow" +@pytest.mark.parametrize("agent,client", [(_ExampleAgent(), _ExampleClient())]) +async def test_example_agent_permission_flow(connect, client, agent): + _, agent_conn = connect() + + init = await agent_conn.initialize(protocol_version=1) + assert init.protocol_version == 1 + + session = await agent_conn.new_session(mcp_servers=[], cwd="/workspace") + assert session.session_id == "sess_demo" + + resp = await agent_conn.prompt( + session_id=session.session_id, + prompt=[TextContentBlock(type="text", text="Please edit config")], + ) + assert resp.stop_reason == "end_turn" + for _ in range(50): + if len(client.notifications) >= 4: + break + await asyncio.sleep(0.02) + + assert len(client.notifications) >= 4 + session_updates = [getattr(note.update, "session_update", None) for note in client.notifications] + assert session_updates[:4] == ["agent_message_chunk", "tool_call", "tool_call_update", "agent_message_chunk"] + + first_message = client.notifications[0].update + assert isinstance(first_message, AgentMessageChunk) + assert isinstance(first_message.content, TextContentBlock) + assert first_message.content.text == "I'll help you with that." + + tool_call = client.notifications[1].update + assert isinstance(tool_call, ToolCallStart) + assert tool_call.title == "Modifying configuration" + assert tool_call.status == "pending" + + tool_update = client.notifications[2].update + assert isinstance(tool_update, ToolCallProgress) + assert tool_update.status == "completed" + assert tool_update.raw_output == {"success": True} + + final_message = client.notifications[3].update + assert isinstance(final_message, AgentMessageChunk) + assert isinstance(final_message.content, TextContentBlock) + assert final_message.content.text == "Done." + + assert len(client.permission_requests) == 1 + options = client.permission_requests[0].options + assert [opt.option_id for opt in options] == ["allow", "reject"] + + assert agent.permission_response is not None + assert isinstance(agent.permission_response.outcome, AllowedOutcome) + assert agent.permission_response.outcome.option_id == "allow" @pytest.mark.asyncio @@ -690,7 +410,7 @@ async def test_spawn_agent_process_roundtrip(tmp_path): test_client = TestClient() - async with spawn_agent_process(lambda _agent: test_client, sys.executable, str(script)) as (client_conn, process): + async with spawn_agent_process(test_client, sys.executable, str(script)) as (client_conn, process): init = await client_conn.initialize(protocol_version=1) assert isinstance(init, InitializeResponse) session = await client_conn.new_session(mcp_servers=[], cwd=str(tmp_path)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4ac2a75..47706d9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +import pytest + from acp.schema import AgentMessageChunk, TextContentBlock from acp.utils import serialize_params @@ -36,3 +38,18 @@ def test_field_meta_can_be_set_by_name_on_models() -> None: assert chunk.field_meta == {"outer": "value"} assert chunk.content.field_meta == {"inner": "value"} + + +@pytest.mark.parametrize( + "original, expected", + [ + ("simple_test", "simpleTest"), + ("another_example_here", "anotherExampleHere"), + ("lowercase", "lowercase"), + ("alreadyCamelCase", "alreadyCamelCase"), + ], +) +def test_to_camel_case(original, expected) -> None: + from acp.utils import to_camel_case + + assert to_camel_case(original) == expected