diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c30c852..40b46dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,4 +21,3 @@ repos: args: [ --exit-non-zero-on-fix ] exclude: ^src/acp/(meta|schema)\.py$ - id: ruff-format - exclude: ^src/acp/(meta|schema)\.py$ diff --git a/Makefile b/Makefile index f995835..a6f88bf 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ install: ## Install the virtual environment and install the pre-commit hooks gen-all: ## Generate all code from schema @echo "🚀 Generating all code" @uv run scripts/gen_all.py + @uv run ruff check --fix + @uv run ruff format . .PHONY: check check: ## Run code quality tools. diff --git a/docs/migration-guide-0.7.md b/docs/migration-guide-0.7.md new file mode 100644 index 0000000..63aa906 --- /dev/null +++ b/docs/migration-guide-0.7.md @@ -0,0 +1,109 @@ +# Migrating to ACP Python SDK 0.7 + +ACP 0.7 reshapes the public surface so that Python-facing names, runtime helpers, and schema models line up with the evolving Agent Client Protocol schema. This guide covers the major changes in 0.7.0 and calls out the mechanical steps you need to apply in downstream agents, clients, and transports. + +## 1. `acp.schema` models now expose `snake_case` fields + +- Every generated model in `acp.schema` (see `src/acp/schema.py`) now uses Pythonic attribute names such as `session_id`, `stop_reason`, and `field_meta`. The JSON aliases (e.g., `alias="sessionId"`) stay intact so over-the-wire payloads remain camelCase. +- Instantiating a model or accessing response values must now use the `snake_case` form: + +```python +# Before (0.6 and earlier) +PromptResponse(stopReason="end_turn") +params.sessionId + +# After (0.7 and later) +PromptResponse(stop_reason="end_turn") +params.session_id +``` + +- If you relied on `model_dump()` to emit camelCase keys automatically, switch to `model_dump(by_alias=True)` (or use helpers such as `text_block`, `start_tool_call`, etc.) so responses continue to match the protocol. +- `field_meta` stays available for extension data. Any extra keys that were nested under `_meta` should now be provided via keyword arguments when constructing the schema models (see section 3). + +## 2. `acp.run_agent` and `acp.connect_to_agent` replace manual connection wiring + +`AgentSideConnection` and `ClientSideConnection` still exist internally, but the top-level entry points now prefer the helper functions implemented in `src/acp/core.py`. + +### Updating agents + +- Old pattern: + +```python +conn = AgentSideConnection(lambda conn: Agent(), writer, reader) +await asyncio.Event().wait() # keep running +``` + +- New pattern: + +```python +await run_agent(MyAgent(), input_stream=writer, output_stream=reader) +``` + +- When your agent just runs over stdio, call `await run_agent(MyAgent())` and the helper will acquire asyncio streams via `stdio_streams()` for you. + +### Updating clients and tests + +- Old pattern: + +```python +conn = ClientSideConnection(lambda conn: MyClient(), proc.stdin, proc.stdout) +``` + +- New pattern: + +```python +conn = connect_to_agent(MyClient(), proc.stdin, proc.stdout) +``` + +- `spawn_agent_process` / `spawn_client_process` now accept concrete `Agent`/`Client` instances instead of factories that received the connection. Instantiate your implementation first and pass it in. +- Importing the legacy connection classes via `acp.AgentSideConnection` / `acp.ClientSideConnection` issues a `DeprecationWarning` (see `src/acp/__init__.py:82-96`). Update your imports to `run_agent` and `connect_to_agent` to silence the warning. + +## 3. `Agent` and `Client` interface methods take explicit parameters + +Both interfaces in `src/acp/interfaces.py` now look like idiomatic Python protocols: methods use `snake_case` names and receive the individual schema fields rather than a single request model. + +### What changed + +- Method names follow `snake_case` (`request_permission`, `session_update`, `new_session`, `set_session_model`, etc.). +- Parameters represent the schema fields, so there is no need to unpack `params` manually. +- Each method is decorated with `@param_model(...)`. Combined with the `compatible_class` helper (see `src/acp/utils.py`), this keeps the camelCase wrappers alive for callers that still pass a full Pydantic request object—but those wrappers now emit `DeprecationWarning`s to encourage migration. + +### How to update your implementations + +1. Rename your method overrides to their `snake_case` equivalents. +2. Replace `params: Model` arguments with the concrete fields plus `**kwargs` to collect future `_meta` keys. +3. Access schema data directly via those parameters. + +Example migration for an agent: + +```python +# Before +class EchoAgent: + async def prompt(self, params: PromptRequest) -> PromptResponse: + text = params.prompt[0].text + return PromptResponse(stopReason="end_turn") + +# After +class EchoAgent: + async def prompt(self, prompt, session_id, **kwargs) -> PromptResponse: + text = prompt[0].text + return PromptResponse(stop_reason="end_turn") +``` + +Similarly, a client method such as `requestPermission` becomes: + +```python +class RecordingClient(Client): + async def request_permission(self, options, session_id, tool_call, **kwargs): + ... +``` + +### Additional notes + +- The connection layers automatically assemble the right request/response models using the `param_model` metadata, so callers do not need to build Pydantic objects manually anymore. +- For extension points (`field_meta`), pass keyword arguments from the connection into your handler signature: they arrive inside `**kwargs`. + +### Backward compatibility + +- The change should be 100% backward compatible as long as you update your method names and signatures. The `compatible_class` wrapper ensures that existing callers passing full request models continue to work. The old style API will remain functional before the next major release(1.0). +- Because camelCase wrappers remain for now, you can migrate file-by-file while still running against ACP 0.7. Just watch for the new deprecation warnings in your logs/tests. diff --git a/docs/quickstart.md b/docs/quickstart.md index 3a54147..ba5b80c 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -4,13 +4,13 @@ Spin up a working ACP agent/client loop in minutes. Keep this page beside the te ## Quick checklist -| Goal | Command / Link | -| --- | --- | -| Install the SDK | `pip install agent-client-protocol` or `uv add agent-client-protocol` | -| Run the echo agent | `python examples/echo_agent.py` | -| Point Zed (or another client) at it | Update `settings.json` as shown below | -| Programmatically drive an agent | Copy the `spawn_agent_process` example | -| Run tests before hacking further | `make check && make test` | +| Goal | Command / Link | +| ----------------------------------- | --------------------------------------------------------------------- | +| Install the SDK | `pip install agent-client-protocol` or `uv add agent-client-protocol` | +| Run the echo agent | `python examples/echo_agent.py` | +| Point Zed (or another client) at it | Update `settings.json` as shown below | +| Programmatically drive an agent | Copy the `spawn_agent_process` example | +| Run tests before hacking further | `make check && make test` | ## Before you begin @@ -76,27 +76,26 @@ from pathlib import Path from acp import spawn_agent_process, text_block from acp.interfaces import Client -from acp.schema import InitializeRequest, NewSessionRequest, PromptRequest, SessionNotification class SimpleClient(Client): - async def requestPermission(self, params): # pragma: no cover - minimal stub + async def request_permission( + self, options, session_id, tool_call, **kwargs: Any + ) return {"outcome": {"outcome": "cancelled"}} - async def sessionUpdate(self, params: SessionNotification) -> None: - print("update:", params.sessionId, params.update) + async def session_update(self, session_id, update, **kwargs): + print("update:", session_id, update) async def main() -> None: script = Path("examples/echo_agent.py") - async with spawn_agent_process(lambda _agent: SimpleClient(), sys.executable, str(script)) as (conn, _proc): - await conn.initialize(InitializeRequest(protocolVersion=1)) - session = await conn.newSession(NewSessionRequest(cwd=str(script.parent), mcpServers=[])) + async with spawn_agent_process(SimpleClient(), sys.executable, str(script)) as (conn, _proc): + await conn.initialize(protocol_version=1) + session = await conn.new_session(cwd=str(script.parent), mcp_servers=[]) await conn.prompt( - PromptRequest( - sessionId=session.sessionId, - prompt=[text_block("Hello from spawn!")], - ) + session_id=session.session_id, + prompt=[text_block("Hello from spawn!")], ) asyncio.run(main()) @@ -111,16 +110,16 @@ _Swap the echo demo for your own `Agent` subclass._ Create your own agent by subclassing `acp.Agent`. The pattern mirrors the echo example: ```python -from acp import Agent, PromptRequest, PromptResponse +from acp import Agent, PromptResponse class MyAgent(Agent): - async def prompt(self, params: PromptRequest) -> PromptResponse: - # inspect params.prompt, stream updates, then finish the turn - return PromptResponse(stopReason="end_turn") + async def prompt(self, prompt, session_id, **kwargs) -> PromptResponse: + # inspect prompt, stream updates, then finish the turn + return PromptResponse(stop_reason="end_turn") ``` -Hook it up with `AgentSideConnection` inside an async entrypoint and wire it to your client. Refer to: +Run it with `run_agent()` inside an async entrypoint and wire it to your client. Refer to: - [`examples/echo_agent.py`](https://github.com/agentclientprotocol/python-sdk/blob/main/examples/echo_agent.py) for the smallest streaming agent - [`examples/agent.py`](https://github.com/agentclientprotocol/python-sdk/blob/main/examples/agent.py) for an implementation that negotiates capabilities and streams richer updates diff --git a/examples/agent.py b/examples/agent.py index a75e1a5..c53ad56 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -5,93 +5,121 @@ from acp import ( Agent, AgentSideConnection, - AuthenticateRequest, AuthenticateResponse, - CancelNotification, - InitializeRequest, InitializeResponse, - LoadSessionRequest, LoadSessionResponse, - NewSessionRequest, NewSessionResponse, - PromptRequest, PromptResponse, - SetSessionModeRequest, SetSessionModeResponse, - session_notification, - stdio_streams, + run_agent, text_block, update_agent_message, PROTOCOL_VERSION, ) -from acp.schema import AgentCapabilities, AgentMessageChunk, Implementation +from acp.interfaces import Client +from acp.schema import ( + AgentCapabilities, + AgentMessageChunk, + AudioContentBlock, + ClientCapabilities, + EmbeddedResourceContentBlock, + HttpMcpServer, + ImageContentBlock, + Implementation, + ResourceContentBlock, + SseMcpServer, + McpServerStdio, + TextContentBlock, +) class ExampleAgent(Agent): - def __init__(self, conn: AgentSideConnection) -> None: - self._conn = conn + _conn: Client + + def __init__(self) -> None: self._next_session_id = 0 self._sessions: set[str] = set() + def on_connect(self, conn: Client) -> None: + self._conn = conn + async def _send_agent_message(self, session_id: str, content: Any) -> None: update = content if isinstance(content, AgentMessageChunk) else update_agent_message(content) - await self._conn.sessionUpdate(session_notification(session_id, update)) - - async def initialize(self, params: InitializeRequest) -> InitializeResponse: # noqa: ARG002 + await self._conn.session_update(session_id, update) + + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: logging.info("Received initialize request") return InitializeResponse( - protocolVersion=PROTOCOL_VERSION, - agentCapabilities=AgentCapabilities(), - agentInfo=Implementation(name="example-agent", title="Example Agent", version="0.1.0"), + protocol_version=PROTOCOL_VERSION, + agent_capabilities=AgentCapabilities(), + agent_info=Implementation(name="example-agent", title="Example Agent", version="0.1.0"), ) - async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: # noqa: ARG002 - logging.info("Received authenticate request %s", params.methodId) + async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None: + logging.info("Received authenticate request %s", method_id) return AuthenticateResponse() - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # noqa: ARG002 + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + ) -> NewSessionResponse: logging.info("Received new session request") session_id = str(self._next_session_id) self._next_session_id += 1 self._sessions.add(session_id) - return NewSessionResponse(sessionId=session_id, modes=None) + return NewSessionResponse(session_id=session_id, modes=None) - async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: # noqa: ARG002 - logging.info("Received load session request %s", params.sessionId) - self._sessions.add(params.sessionId) + async def load_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + ) -> LoadSessionResponse | None: + logging.info("Received load session request %s", session_id) + self._sessions.add(session_id) return LoadSessionResponse() - async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # noqa: ARG002 - logging.info("Received set session mode request %s -> %s", params.sessionId, params.modeId) + async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None: + logging.info("Received set session mode request %s -> %s", session_id, mode_id) return SetSessionModeResponse() - async def prompt(self, params: PromptRequest) -> PromptResponse: - logging.info("Received prompt request for session %s", params.sessionId) - if params.sessionId not in self._sessions: - self._sessions.add(params.sessionId) - - await self._send_agent_message(params.sessionId, text_block("Client sent:")) - for block in params.prompt: - await self._send_agent_message(params.sessionId, block) - - return PromptResponse(stopReason="end_turn") - - async def cancel(self, params: CancelNotification) -> None: # noqa: ARG002 - logging.info("Received cancel notification for session %s", params.sessionId) - - async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002 + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + logging.info("Received prompt request for session %s", session_id) + if session_id not in self._sessions: + self._sessions.add(session_id) + + await self._send_agent_message(session_id, text_block("Client sent:")) + for block in prompt: + await self._send_agent_message(session_id, block) + return PromptResponse(stop_reason="end_turn") + + async def cancel(self, session_id: str, **kwargs: Any) -> None: + logging.info("Received cancel notification for session %s", session_id) + + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: logging.info("Received extension method call: %s", method) return {"example": "response"} - async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG002 + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: logging.info("Received extension notification: %s", method) async def main() -> None: logging.basicConfig(level=logging.INFO) - reader, writer = await stdio_streams() - AgentSideConnection(ExampleAgent, writer, reader) - await asyncio.Event().wait() + await run_agent(ExampleAgent()) if __name__ == "__main__": diff --git a/examples/client.py b/examples/client.py index 8c62462..7a0cc27 100644 --- a/examples/client.py +++ b/examples/client.py @@ -5,57 +5,105 @@ import os import sys from pathlib import Path +from typing import Any from acp import ( Client, - ClientSideConnection, - InitializeRequest, - NewSessionRequest, - PromptRequest, + connect_to_agent, RequestError, - SessionNotification, text_block, PROTOCOL_VERSION, ) +from acp.core import ClientSideConnection from acp.schema import ( AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, AudioContentBlock, + AvailableCommandsUpdate, ClientCapabilities, + CreateTerminalResponse, + CurrentModeUpdate, EmbeddedResourceContentBlock, + EnvVariable, ImageContentBlock, Implementation, + KillTerminalCommandResponse, + PermissionOption, + ReadTextFileResponse, + ReleaseTerminalResponse, + RequestPermissionResponse, ResourceContentBlock, + TerminalOutputResponse, TextContentBlock, + ToolCall, + ToolCallProgress, + ToolCallStart, + UserMessageChunk, + WaitForTerminalExitResponse, + WriteTextFileResponse, ) class ExampleClient(Client): - async def requestPermission(self, params): # type: ignore[override] + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any + ) -> RequestPermissionResponse: raise RequestError.method_not_found("session/request_permission") - async def writeTextFile(self, params): # type: ignore[override] + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: raise RequestError.method_not_found("fs/write_text_file") - async def readTextFile(self, params): # type: ignore[override] + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: raise RequestError.method_not_found("fs/read_text_file") - async def createTerminal(self, params): # type: ignore[override] + async def create_terminal( + self, + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> CreateTerminalResponse: raise RequestError.method_not_found("terminal/create") - async def terminalOutput(self, params): # type: ignore[override] + async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: raise RequestError.method_not_found("terminal/output") - async def releaseTerminal(self, params): # type: ignore[override] + async def release_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> ReleaseTerminalResponse | None: raise RequestError.method_not_found("terminal/release") - async def waitForTerminalExit(self, params): # type: ignore[override] + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> WaitForTerminalExitResponse: raise RequestError.method_not_found("terminal/wait_for_exit") - async def killTerminal(self, params): # type: ignore[override] + async def kill_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> KillTerminalCommandResponse | None: raise RequestError.method_not_found("terminal/kill") - async def sessionUpdate(self, params: SessionNotification) -> None: - update = params.update + async def session_update( + self, + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: if not isinstance(update, AgentMessageChunk): return @@ -76,10 +124,10 @@ async def sessionUpdate(self, params: SessionNotification) -> None: print(f"| Agent: {text}") - async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002 + async def ext_method(self, method: str, params: dict) -> dict: raise RequestError.method_not_found(method) - async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG002 + async def ext_notification(self, method: str, params: dict) -> None: raise RequestError.method_not_found(method) @@ -103,10 +151,8 @@ async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None: try: await conn.prompt( - PromptRequest( - sessionId=session_id, - prompt=[text_block(line)], - ) + session_id=session_id, + prompt=[text_block(line)], ) except Exception as exc: # noqa: BLE001 logging.error("Prompt failed: %s", exc) @@ -142,18 +188,16 @@ async def main(argv: list[str]) -> int: return 1 client_impl = ExampleClient() - conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout) + conn = connect_to_agent(client_impl, proc.stdin, proc.stdout) await conn.initialize( - InitializeRequest( - protocolVersion=PROTOCOL_VERSION, - clientCapabilities=ClientCapabilities(), - clientInfo=Implementation(name="example-client", title="Example Client", version="0.1.0"), - ) + protocol_version=PROTOCOL_VERSION, + client_capabilities=ClientCapabilities(), + client_info=Implementation(name="example-client", title="Example Client", version="0.1.0"), ) - session = await conn.newSession(NewSessionRequest(mcpServers=[], cwd=os.getcwd())) + session = await conn.new_session(mcp_servers=[], cwd=os.getcwd()) - await interactive_loop(conn, session.sessionId) + await interactive_loop(conn, session.session_id) if proc.returncode is None: proc.terminate() diff --git a/examples/duet.py b/examples/duet.py index de8d9ca..f2c2871 100644 --- a/examples/duet.py +++ b/examples/duet.py @@ -30,13 +30,13 @@ async def main() -> int: client_module = _load_client_module(root / "client.py") client = client_module.ExampleClient() - async with spawn_agent_process(lambda _agent: client, sys.executable, str(agent_path), env=env) as ( + async with spawn_agent_process(client, sys.executable, str(agent_path), env=env) as ( conn, process, ): - await conn.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION, clientCapabilities=None)) - session = await conn.newSession(NewSessionRequest(mcpServers=[], cwd=str(root))) - await client_module.interactive_loop(conn, session.sessionId) + await conn.initialize(protocol_version=PROTOCOL_VERSION, client_capabilities=None) + session = await conn.new_session(mcp_servers=[], cwd=str(root)) + await client_module.interactive_loop(conn, session.session_id) return process.returncode or 0 diff --git a/examples/echo_agent.py b/examples/echo_agent.py index 657eb28..a096677 100644 --- a/examples/echo_agent.py +++ b/examples/echo_agent.py @@ -1,50 +1,76 @@ import asyncio +from typing import Any from uuid import uuid4 from acp import ( Agent, AgentSideConnection, - InitializeRequest, InitializeResponse, - NewSessionRequest, NewSessionResponse, - PromptRequest, PromptResponse, - session_notification, - stdio_streams, + run_agent, text_block, update_agent_message, ) +from acp.interfaces import Client +from acp.schema import ( + AudioContentBlock, + ClientCapabilities, + EmbeddedResourceContentBlock, + HttpMcpServer, + ImageContentBlock, + Implementation, + ResourceContentBlock, + SseMcpServer, + McpServerStdio, + TextContentBlock, +) class EchoAgent(Agent): - def __init__(self, conn): + _conn: Client + + def on_connect(self, conn: Client) -> None: self._conn = conn - async def initialize(self, params: InitializeRequest) -> InitializeResponse: - return InitializeResponse(protocolVersion=params.protocolVersion) + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: + return InitializeResponse(protocol_version=protocol_version) - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: - return NewSessionResponse(sessionId=uuid4().hex) + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + ) -> NewSessionResponse: + return NewSessionResponse(session_id=uuid4().hex) - async def prompt(self, params: PromptRequest) -> PromptResponse: - for block in params.prompt: + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + for block in prompt: text = block.get("text", "") if isinstance(block, dict) else getattr(block, "text", "") chunk = update_agent_message(text_block(text)) chunk.field_meta = {"echo": True} chunk.content.field_meta = {"echo": True} - notification = session_notification(params.sessionId, chunk) - notification.field_meta = {"source": "echo_agent"} - - await self._conn.sessionUpdate(notification) - return PromptResponse(stopReason="end_turn") + await self._conn.session_update(session_id=session_id, update=chunk, source="echo_agent") + return PromptResponse(stop_reason="end_turn") async def main() -> None: - reader, writer = await stdio_streams() - AgentSideConnection(lambda conn: EchoAgent(conn), writer, reader) - await asyncio.Event().wait() + await run_agent(EchoAgent()) if __name__ == "__main__": diff --git a/examples/gemini.py b/examples/gemini.py index f1fe9a9..e9ec79c 100644 --- a/examples/gemini.py +++ b/examples/gemini.py @@ -9,52 +9,48 @@ import shutil import sys from pathlib import Path -from typing import Iterable +from typing import Any, Iterable from acp import ( Client, - ClientSideConnection, + connect_to_agent, PROTOCOL_VERSION, RequestError, text_block, ) +from acp.core import ClientSideConnection from acp.schema import ( AgentMessageChunk, AgentPlanUpdate, AgentThoughtChunk, AllowedOutcome, + AvailableCommandsUpdate, CancelNotification, ClientCapabilities, + CurrentModeUpdate, + EnvVariable, FileEditToolCallContent, FileSystemCapability, - CreateTerminalRequest, CreateTerminalResponse, DeniedOutcome, EmbeddedResourceContentBlock, - KillTerminalCommandRequest, KillTerminalCommandResponse, InitializeRequest, NewSessionRequest, PermissionOption, PromptRequest, - ReadTextFileRequest, ReadTextFileResponse, - RequestPermissionRequest, RequestPermissionResponse, ResourceContentBlock, - ReleaseTerminalRequest, ReleaseTerminalResponse, - SessionNotification, TerminalToolCallContent, - TerminalOutputRequest, TerminalOutputResponse, TextContentBlock, + ToolCall, ToolCallProgress, ToolCallStart, UserMessageChunk, - WaitForTerminalExitRequest, WaitForTerminalExitResponse, - WriteTextFileRequest, WriteTextFileResponse, ) @@ -65,22 +61,21 @@ class GeminiClient(Client): def __init__(self, auto_approve: bool) -> None: self._auto_approve = auto_approve - async def requestPermission( - self, - params: RequestPermissionRequest, - ) -> RequestPermissionResponse: # type: ignore[override] + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any + ) -> RequestPermissionResponse: if self._auto_approve: - option = _pick_preferred_option(params.options) + option = _pick_preferred_option(options) if option is None: return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) - return RequestPermissionResponse(outcome=AllowedOutcome(optionId=option.optionId, outcome="selected")) + return RequestPermissionResponse(outcome=AllowedOutcome(option_id=option.option_id, outcome="selected")) - title = params.toolCall.title or "" - if not params.options: + title = tool_call.title or "" + if not options: print(f"\n🔐 Permission requested: {title} (no options, cancelling)") return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) print(f"\n🔐 Permission requested: {title}") - for idx, opt in enumerate(params.options, start=1): + for idx, opt in enumerate(options, start=1): print(f" {idx}. {opt.name} ({opt.kind})") loop = asyncio.get_running_loop() @@ -90,41 +85,49 @@ async def requestPermission( continue if choice.isdigit(): idx = int(choice) - 1 - if 0 <= idx < len(params.options): - opt = params.options[idx] - return RequestPermissionResponse(outcome=AllowedOutcome(optionId=opt.optionId, outcome="selected")) + if 0 <= idx < len(options): + opt = options[idx] + return RequestPermissionResponse( + outcome=AllowedOutcome(option_id=opt.option_id, outcome="selected") + ) print("Invalid selection, try again.") - async def writeTextFile( - self, - params: WriteTextFileRequest, - ) -> WriteTextFileResponse: # type: ignore[override] - path = Path(params.path) - if not path.is_absolute(): - raise RequestError.invalid_params({"path": params.path, "reason": "path must be absolute"}) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(params.content) - print(f"[Client] Wrote {path} ({len(params.content)} bytes)") + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: + pathlib_path = Path(path) + if not pathlib_path.is_absolute(): + raise RequestError.invalid_params({"path": pathlib_path, "reason": "path must be absolute"}) + pathlib_path.parent.mkdir(parents=True, exist_ok=True) + pathlib_path.write_text(content) + print(f"[Client] Wrote {pathlib_path} ({len(content)} bytes)") return WriteTextFileResponse() - async def readTextFile( - self, - params: ReadTextFileRequest, - ) -> ReadTextFileResponse: # type: ignore[override] - path = Path(params.path) - if not path.is_absolute(): - raise RequestError.invalid_params({"path": params.path, "reason": "path must be absolute"}) - text = path.read_text() - print(f"[Client] Read {path} ({len(text)} bytes)") - if params.line is not None or params.limit is not None: - text = _slice_text(text, params.line, params.limit) + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: + pathlib_path = Path(path) + if not pathlib_path.is_absolute(): + raise RequestError.invalid_params({"path": pathlib_path, "reason": "path must be absolute"}) + text = pathlib_path.read_text() + print(f"[Client] Read {pathlib_path} ({len(text)} bytes)") + if line is not None or limit is not None: + text = _slice_text(text, line, limit) return ReadTextFileResponse(content=text) - async def sessionUpdate( + async def session_update( self, - params: SessionNotification, - ) -> None: # type: ignore[override] - update = params.update + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: if isinstance(update, AgentMessageChunk): _print_text_content(update.content) elif isinstance(update, AgentThoughtChunk): @@ -141,52 +144,52 @@ async def sessionUpdate( print(f"\n🔧 {update.title} ({update.status or 'pending'})") elif isinstance(update, ToolCallProgress): status = update.status or "in_progress" - print(f"\n🔧 Tool call `{update.toolCallId}` → {status}") + print(f"\n🔧 Tool call `{update.tool_call_id}` → {status}") if update.content: for item in update.content: if isinstance(item, FileEditToolCallContent): print(f" diff: {item.path}") elif isinstance(item, TerminalToolCallContent): - print(f" terminal: {item.terminalId}") + print(f" terminal: {item.terminal_id}") elif isinstance(item, dict): print(f" content: {json.dumps(item, indent=2)}") else: print(f"\n[session update] {update}") # Optional / terminal-related methods --------------------------------- - async def createTerminal( - self, - params: CreateTerminalRequest, - ) -> CreateTerminalResponse: # type: ignore[override] - print(f"[Client] createTerminal: {params}") - return CreateTerminalResponse(terminalId="term-1") - - async def terminalOutput( + async def create_terminal( self, - params: TerminalOutputRequest, - ) -> TerminalOutputResponse: # type: ignore[override] - print(f"[Client] terminalOutput: {params}") + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> CreateTerminalResponse: + print(f"[Client] createTerminal: {command} {args or []} (cwd={cwd})") + return CreateTerminalResponse(terminal_id="term-1") + + async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: + print(f"[Client] terminalOutput: {session_id} {terminal_id}") return TerminalOutputResponse(output="", truncated=False) - async def releaseTerminal( - self, - params: ReleaseTerminalRequest, - ) -> ReleaseTerminalResponse: # type: ignore[override] - print(f"[Client] releaseTerminal: {params}") + async def release_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> ReleaseTerminalResponse | None: + print(f"[Client] releaseTerminal: {session_id} {terminal_id}") return ReleaseTerminalResponse() - async def waitForTerminalExit( - self, - params: WaitForTerminalExitRequest, - ) -> WaitForTerminalExitResponse: # type: ignore[override] - print(f"[Client] waitForTerminalExit: {params}") + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> WaitForTerminalExitResponse: + print(f"[Client] waitForTerminalExit: {session_id} {terminal_id}") return WaitForTerminalExitResponse() - async def killTerminal( - self, - params: KillTerminalCommandRequest, - ) -> KillTerminalCommandResponse: # type: ignore[override] - print(f"[Client] killTerminal: {params}") + async def kill_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> KillTerminalCommandResponse | None: + print(f"[Client] killTerminal: {session_id} {terminal_id}") return KillTerminalCommandResponse() @@ -246,15 +249,13 @@ async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None: if line in {":exit", ":quit"}: break if line == ":cancel": - await conn.cancel(CancelNotification(sessionId=session_id)) + await conn.cancel(session_id=session_id) continue try: await conn.prompt( - PromptRequest( - sessionId=session_id, - prompt=[text_block(line)], - ) + session_id=session_id, + prompt=[text_block(line)], ) except RequestError as err: _print_request_error("prompt", err) @@ -316,17 +317,15 @@ async def run(argv: list[str]) -> int: return 1 client_impl = GeminiClient(auto_approve=args.yolo) - conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout) + conn = connect_to_agent(client_impl, proc.stdin, proc.stdout) try: init_resp = await conn.initialize( - InitializeRequest( - protocolVersion=PROTOCOL_VERSION, - clientCapabilities=ClientCapabilities( - fs=FileSystemCapability(readTextFile=True, writeTextFile=True), - terminal=True, - ), - ) + protocol_version=PROTOCOL_VERSION, + client_capabilities=ClientCapabilities( + fs=FileSystemCapability(read_text_file=True, write_text_file=True), + terminal=True, + ), ) except RequestError as err: _print_request_error("initialize", err) @@ -337,14 +336,12 @@ async def run(argv: list[str]) -> int: await _shutdown(proc, conn) return 1 - print(f"✅ Connected to Gemini (protocol v{init_resp.protocolVersion})") + print(f"✅ Connected to Gemini (protocol v{init_resp.protocol_version})") try: - session = await conn.newSession( - NewSessionRequest( - cwd=os.getcwd(), - mcpServers=[], - ) + session = await conn.new_session( + cwd=os.getcwd(), + mcp_servers=[], ) except RequestError as err: _print_request_error("new_session", err) @@ -355,10 +352,10 @@ async def run(argv: list[str]) -> int: await _shutdown(proc, conn) return 1 - print(f"📝 Created session: {session.sessionId}") + print(f"📝 Created session: {session.session_id}") try: - await interactive_loop(conn, session.sessionId) + await interactive_loop(conn, session.session_id) finally: await _shutdown(proc, conn) diff --git a/mkdocs.yml b/mkdocs.yml index 0464ad0..e74cd49 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -13,6 +13,7 @@ nav: - Use Cases: use-cases.md - Experimental Contrib: contrib.md - Releasing: releasing.md + - 0.7 Migration Guide: migration-guide-0.7.md plugins: - search - mkdocstrings: diff --git a/pyproject.toml b/pyproject.toml index bed99a9..4ad791a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "A Python implement of Agent Client Protocol (ACP, by Zed Industri authors = [{ name = "Chojan Shang", email = "psiace@apache.org" }] readme = "README.md" keywords = ['python'] -requires-python = ">=3.10,<=3.14" +requires-python = ">=3.10,<3.15" classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", @@ -113,10 +113,6 @@ ignore = [ [tool.ruff.format] preview = true -exclude = [ - "src/acp/meta.py", - "src/acp/schema.py", -] [tool.deptry.package_module_name_map] opentelemetry-sdk = "opentelemetry" diff --git a/schema/VERSION b/schema/VERSION index 75451e3..4803c54 100644 --- a/schema/VERSION +++ b/schema/VERSION @@ -1 +1 @@ -refs/tags/v0.6.3 +refs/tags/v0.7.0 diff --git a/schema/meta.json b/schema/meta.json index 0f0c6c4..7fad892 100644 --- a/schema/meta.json +++ b/schema/meta.json @@ -3,6 +3,7 @@ "authenticate": "authenticate", "initialize": "initialize", "session_cancel": "session/cancel", + "session_list": "session/list", "session_load": "session/load", "session_new": "session/new", "session_prompt": "session/prompt", diff --git a/schema/schema.json b/schema/schema.json index 9b39020..88f286e 100644 --- a/schema/schema.json +++ b/schema/schema.json @@ -12,7 +12,11 @@ "type": "boolean" }, "mcpCapabilities": { - "$ref": "#/$defs/McpCapabilities", + "allOf": [ + { + "$ref": "#/$defs/McpCapabilities" + } + ], "default": { "http": false, "sse": false @@ -20,13 +24,25 @@ "description": "MCP capabilities supported by the agent." }, "promptCapabilities": { - "$ref": "#/$defs/PromptCapabilities", + "allOf": [ + { + "$ref": "#/$defs/PromptCapabilities" + } + ], "default": { "audio": false, "embeddedContext": false, "image": false }, "description": "Prompt capabilities supported by the agent." + }, + "sessionCapabilities": { + "allOf": [ + { + "$ref": "#/$defs/SessionCapabilities" + } + ], + "default": {} } }, "type": "object" @@ -34,13 +50,15 @@ "AgentNotification": { "anyOf": [ { - "$ref": "#/$defs/SessionNotification", - "description": "Handles session update notifications from the agent.\n\nThis is a notification endpoint (no response expected) that receives\nreal-time updates about session progress, including message chunks,\ntool calls, and execution plans.\n\nNote: Clients SHOULD continue accepting tool call updates even after\nsending a `session/cancel` notification, as the agent may send final\nupdates before responding with the cancelled stop reason.\n\nSee protocol docs: [Agent Reports Output](https://agentclientprotocol.com/protocol/prompt-turn#3-agent-reports-output)", - "title": "SessionNotification" + "allOf": [ + { + "$ref": "#/$defs/SessionNotification" + } + ], + "description": "Handles session update notifications from the agent.\n\nThis is a notification endpoint (no response expected) that receives\nreal-time updates about session progress, including message chunks,\ntool calls, and execution plans.\n\nNote: Clients SHOULD continue accepting tool call updates even after\nsending a `session/cancel` notification, as the agent may send final\nupdates before responding with the cancelled stop reason.\n\nSee protocol docs: [Agent Reports Output](https://agentclientprotocol.com/protocol/prompt-turn#3-agent-reports-output)" }, { - "description": "Handles extension notifications from the agent.\n\nAllows the Agent to send an arbitrary notification that is not part of the ACP spec.\nExtension notifications provide a way to send one-way messages for custom functionality\nwhile maintaining protocol compatibility.\n\nSee protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility)", - "title": "ExtNotification" + "description": "Handles extension notifications from the agent.\n\nAllows the Agent to send an arbitrary notification that is not part of the ACP spec.\nExtension notifications provide a way to send one-way messages for custom functionality\nwhile maintaining protocol compatibility.\n\nSee protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility)" } ], "description": "All possible notifications that an agent can send to a client.\n\nThis enum is used internally for routing RPC notifications. You typically won't need\nto use this directly - use the notification methods on the [`Client`] trait instead.\n\nNotifications do not expect a response.", @@ -53,16 +71,13 @@ "id": { "anyOf": [ { - "title": "null", "type": "null" }, { "format": "int64", - "title": "number", "type": "integer" }, { - "title": "string", "type": "string" } ], @@ -86,7 +101,6 @@ "id", "method" ], - "title": "Request", "type": "object" }, { @@ -118,16 +132,13 @@ "id": { "anyOf": [ { - "title": "null", "type": "null" }, { "format": "int64", - "title": "number", "type": "integer" }, { - "title": "string", "type": "string" } ], @@ -137,7 +148,6 @@ "required": [ "id" ], - "title": "Response", "type": "object" }, { @@ -159,7 +169,6 @@ "required": [ "method" ], - "title": "Notification", "type": "object" } ], @@ -181,48 +190,71 @@ "AgentRequest": { "anyOf": [ { - "$ref": "#/$defs/WriteTextFileRequest", - "description": "Writes content to a text file in the client's file system.\n\nOnly available if the client advertises the `fs.writeTextFile` capability.\nAllows the agent to create or modify files within the client's environment.\n\nSee protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client)", - "title": "WriteTextFileRequest" + "allOf": [ + { + "$ref": "#/$defs/WriteTextFileRequest" + } + ], + "description": "Writes content to a text file in the client's file system.\n\nOnly available if the client advertises the `fs.writeTextFile` capability.\nAllows the agent to create or modify files within the client's environment.\n\nSee protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client)" }, { - "$ref": "#/$defs/ReadTextFileRequest", - "description": "Reads content from a text file in the client's file system.\n\nOnly available if the client advertises the `fs.readTextFile` capability.\nAllows the agent to access file contents within the client's environment.\n\nSee protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client)", - "title": "ReadTextFileRequest" + "allOf": [ + { + "$ref": "#/$defs/ReadTextFileRequest" + } + ], + "description": "Reads content from a text file in the client's file system.\n\nOnly available if the client advertises the `fs.readTextFile` capability.\nAllows the agent to access file contents within the client's environment.\n\nSee protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client)" }, { - "$ref": "#/$defs/RequestPermissionRequest", - "description": "Requests permission from the user for a tool call operation.\n\nCalled by the agent when it needs user authorization before executing\na potentially sensitive operation. The client should present the options\nto the user and return their decision.\n\nIf the client cancels the prompt turn via `session/cancel`, it MUST\nrespond to this request with `RequestPermissionOutcome::Cancelled`.\n\nSee protocol docs: [Requesting Permission](https://agentclientprotocol.com/protocol/tool-calls#requesting-permission)", - "title": "RequestPermissionRequest" + "allOf": [ + { + "$ref": "#/$defs/RequestPermissionRequest" + } + ], + "description": "Requests permission from the user for a tool call operation.\n\nCalled by the agent when it needs user authorization before executing\na potentially sensitive operation. The client should present the options\nto the user and return their decision.\n\nIf the client cancels the prompt turn via `session/cancel`, it MUST\nrespond to this request with `RequestPermissionOutcome::Cancelled`.\n\nSee protocol docs: [Requesting Permission](https://agentclientprotocol.com/protocol/tool-calls#requesting-permission)" }, { - "$ref": "#/$defs/CreateTerminalRequest", - "description": "Executes a command in a new terminal\n\nOnly available if the `terminal` Client capability is set to `true`.\n\nReturns a `TerminalId` that can be used with other terminal methods\nto get the current output, wait for exit, and kill the command.\n\nThe `TerminalId` can also be used to embed the terminal in a tool call\nby using the `ToolCallContent::Terminal` variant.\n\nThe Agent is responsible for releasing the terminal by using the `terminal/release`\nmethod.\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)", - "title": "CreateTerminalRequest" + "allOf": [ + { + "$ref": "#/$defs/CreateTerminalRequest" + } + ], + "description": "Executes a command in a new terminal\n\nOnly available if the `terminal` Client capability is set to `true`.\n\nReturns a `TerminalId` that can be used with other terminal methods\nto get the current output, wait for exit, and kill the command.\n\nThe `TerminalId` can also be used to embed the terminal in a tool call\nby using the `ToolCallContent::Terminal` variant.\n\nThe Agent is responsible for releasing the terminal by using the `terminal/release`\nmethod.\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)" }, { - "$ref": "#/$defs/TerminalOutputRequest", - "description": "Gets the terminal output and exit status\n\nReturns the current content in the terminal without waiting for the command to exit.\nIf the command has already exited, the exit status is included.\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)", - "title": "TerminalOutputRequest" + "allOf": [ + { + "$ref": "#/$defs/TerminalOutputRequest" + } + ], + "description": "Gets the terminal output and exit status\n\nReturns the current content in the terminal without waiting for the command to exit.\nIf the command has already exited, the exit status is included.\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)" }, { - "$ref": "#/$defs/ReleaseTerminalRequest", - "description": "Releases a terminal\n\nThe command is killed if it hasn't exited yet. Use `terminal/wait_for_exit`\nto wait for the command to exit before releasing the terminal.\n\nAfter release, the `TerminalId` can no longer be used with other `terminal/*` methods,\nbut tool calls that already contain it, continue to display its output.\n\nThe `terminal/kill` method can be used to terminate the command without releasing\nthe terminal, allowing the Agent to call `terminal/output` and other methods.\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)", - "title": "ReleaseTerminalRequest" + "allOf": [ + { + "$ref": "#/$defs/ReleaseTerminalRequest" + } + ], + "description": "Releases a terminal\n\nThe command is killed if it hasn't exited yet. Use `terminal/wait_for_exit`\nto wait for the command to exit before releasing the terminal.\n\nAfter release, the `TerminalId` can no longer be used with other `terminal/*` methods,\nbut tool calls that already contain it, continue to display its output.\n\nThe `terminal/kill` method can be used to terminate the command without releasing\nthe terminal, allowing the Agent to call `terminal/output` and other methods.\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)" }, { - "$ref": "#/$defs/WaitForTerminalExitRequest", - "description": "Waits for the terminal command to exit and return its exit status\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)", - "title": "WaitForTerminalExitRequest" + "allOf": [ + { + "$ref": "#/$defs/WaitForTerminalExitRequest" + } + ], + "description": "Waits for the terminal command to exit and return its exit status\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)" }, { - "$ref": "#/$defs/KillTerminalCommandRequest", - "description": "Kills the terminal command without releasing the terminal\n\nWhile `terminal/release` will also kill the command, this method will keep\nthe `TerminalId` valid so it can be used with other methods.\n\nThis method can be helpful when implementing command timeouts which terminate\nthe command as soon as elapsed, and then get the final output so it can be sent\nto the model.\n\nNote: `terminal/release` when `TerminalId` is no longer needed.\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)", - "title": "KillTerminalCommandRequest" + "allOf": [ + { + "$ref": "#/$defs/KillTerminalCommandRequest" + } + ], + "description": "Kills the terminal command without releasing the terminal\n\nWhile `terminal/release` will also kill the command, this method will keep\nthe `TerminalId` valid so it can be used with other methods.\n\nThis method can be helpful when implementing command timeouts which terminate\nthe command as soon as elapsed, and then get the final output so it can be sent\nto the model.\n\nNote: `terminal/release` when `TerminalId` is no longer needed.\n\nSee protocol docs: [Terminals](https://agentclientprotocol.com/protocol/terminals)" }, { - "description": "Handles extension method requests from the agent.\n\nAllows the Agent to send an arbitrary request that is not part of the ACP spec.\nExtension methods provide a way to add custom functionality while maintaining\nprotocol compatibility.\n\nSee protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility)", - "title": "ExtMethodRequest" + "description": "Handles extension method requests from the agent.\n\nAllows the Agent to send an arbitrary request that is not part of the ACP spec.\nExtension methods provide a way to add custom functionality while maintaining\nprotocol compatibility.\n\nSee protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility)" } ], "description": "All possible requests that an agent can send to a client.\n\nThis enum is used internally for routing RPC requests. You typically won't need\nto use this directly - instead, use the methods on the [`Client`] trait.\n\nThis enum encompasses all method calls from agent to client.", @@ -231,36 +263,30 @@ "AgentResponse": { "anyOf": [ { - "$ref": "#/$defs/InitializeResponse", - "title": "InitializeResponse" + "$ref": "#/$defs/InitializeResponse" }, { - "$ref": "#/$defs/AuthenticateResponse", - "title": "AuthenticateResponse" + "$ref": "#/$defs/AuthenticateResponse" }, { - "$ref": "#/$defs/NewSessionResponse", - "title": "NewSessionResponse" + "$ref": "#/$defs/NewSessionResponse" }, { - "$ref": "#/$defs/LoadSessionResponse", - "title": "LoadSessionResponse" + "$ref": "#/$defs/LoadSessionResponse" }, { - "$ref": "#/$defs/SetSessionModeResponse", - "title": "SetSessionModeResponse" + "$ref": "#/$defs/ListSessionsResponse" }, { - "$ref": "#/$defs/PromptResponse", - "title": "PromptResponse" + "$ref": "#/$defs/SetSessionModeResponse" }, { - "$ref": "#/$defs/SetSessionModelResponse", - "title": "SetSessionModelResponse" + "$ref": "#/$defs/PromptResponse" }, { - "title": "ExtMethodResponse" - } + "$ref": "#/$defs/SetSessionModelResponse" + }, + {} ], "description": "All possible responses that an agent can send to a client.\n\nThis enum is used internally for routing RPC responses. You typically won't need\nto use this directly - the responses are handled automatically by the connection.\n\nThese are responses to the corresponding `ClientRequest` variants.", "x-docs-ignore": true @@ -296,6 +322,35 @@ }, "type": "object" }, + "AudioContent": { + "description": "Audio provided to or from an LLM.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "annotations": { + "anyOf": [ + { + "$ref": "#/$defs/Annotations" + }, + { + "type": "null" + } + ] + }, + "data": { + "type": "string" + }, + "mimeType": { + "type": "string" + } + }, + "required": [ + "data", + "mimeType" + ], + "type": "object" + }, "AuthMethod": { "description": "Describes an available authentication method.", "properties": { @@ -310,8 +365,8 @@ ] }, "id": { - "$ref": "#/$defs/AuthMethodId", - "description": "Unique identifier for this authentication method." + "description": "Unique identifier for this authentication method.", + "type": "string" }, "name": { "description": "Human-readable name of the authentication method.", @@ -324,10 +379,6 @@ ], "type": "object" }, - "AuthMethodId": { - "description": "Unique identifier for an authentication method.", - "type": "string" - }, "AuthenticateRequest": { "description": "Request parameters for the authenticate method.\n\nSpecifies which authentication method to use.", "properties": { @@ -335,8 +386,8 @@ "description": "Extension point for implementations" }, "methodId": { - "$ref": "#/$defs/AuthMethodId", - "description": "The ID of the authentication method to use.\nMust be one of the methods advertised in the initialize response." + "description": "The ID of the authentication method to use.\nMust be one of the methods advertised in the initialize response.", + "type": "string" } }, "required": [ @@ -347,7 +398,7 @@ "x-side": "agent" }, "AuthenticateResponse": { - "description": "Response to authenticate method", + "description": "Response to the `authenticate` method.", "properties": { "_meta": { "description": "Extension point for implementations" @@ -392,22 +443,35 @@ "AvailableCommandInput": { "anyOf": [ { - "description": "All text that was typed after the command name is provided as input.", - "properties": { - "hint": { - "description": "A hint to display when the input hasn't been provided yet", - "type": "string" + "allOf": [ + { + "$ref": "#/$defs/UnstructuredCommandInput" } - }, - "required": [ - "hint" ], - "title": "UnstructuredCommandInput", - "type": "object" + "description": "All text that was typed after the command name is provided as input." } ], "description": "The input specification for a command." }, + "AvailableCommandsUpdate": { + "description": "Available commands are ready or have changed", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "availableCommands": { + "description": "Commands the agent can execute", + "items": { + "$ref": "#/$defs/AvailableCommand" + }, + "type": "array" + } + }, + "required": [ + "availableCommands" + ], + "type": "object" + }, "BlobResourceContents": { "description": "Binary resource contents.", "properties": { @@ -440,7 +504,11 @@ "description": "Extension point for implementations" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The ID of the session to cancel operations for." } }, @@ -458,7 +526,11 @@ "description": "Extension point for implementations" }, "fs": { - "$ref": "#/$defs/FileSystemCapability", + "allOf": [ + { + "$ref": "#/$defs/FileSystemCapability" + } + ], "default": { "readTextFile": false, "writeTextFile": false @@ -476,13 +548,15 @@ "ClientNotification": { "anyOf": [ { - "$ref": "#/$defs/CancelNotification", - "description": "Cancels ongoing operations for a session.\n\nThis is a notification sent by the client to cancel an ongoing prompt turn.\n\nUpon receiving this notification, the Agent SHOULD:\n- Stop all language model requests as soon as possible\n- Abort all tool call invocations in progress\n- Send any pending `session/update` notifications\n- Respond to the original `session/prompt` request with `StopReason::Cancelled`\n\nSee protocol docs: [Cancellation](https://agentclientprotocol.com/protocol/prompt-turn#cancellation)", - "title": "CancelNotification" + "allOf": [ + { + "$ref": "#/$defs/CancelNotification" + } + ], + "description": "Cancels ongoing operations for a session.\n\nThis is a notification sent by the client to cancel an ongoing prompt turn.\n\nUpon receiving this notification, the Agent SHOULD:\n- Stop all language model requests as soon as possible\n- Abort all tool call invocations in progress\n- Send any pending `session/update` notifications\n- Respond to the original `session/prompt` request with `StopReason::Cancelled`\n\nSee protocol docs: [Cancellation](https://agentclientprotocol.com/protocol/prompt-turn#cancellation)" }, { - "description": "Handles extension notifications from the client.\n\nExtension notifications provide a way to send one-way messages for custom functionality\nwhile maintaining protocol compatibility.\n\nSee protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility)", - "title": "ExtNotification" + "description": "Handles extension notifications from the client.\n\nExtension notifications provide a way to send one-way messages for custom functionality\nwhile maintaining protocol compatibility.\n\nSee protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility)" } ], "description": "All possible notifications that a client can send to an agent.\n\nThis enum is used internally for routing RPC notifications. You typically won't need\nto use this directly - use the notification methods on the [`Agent`] trait instead.\n\nNotifications do not expect a response.", @@ -495,16 +569,13 @@ "id": { "anyOf": [ { - "title": "null", "type": "null" }, { "format": "int64", - "title": "number", "type": "integer" }, { - "title": "string", "type": "string" } ], @@ -528,7 +599,6 @@ "id", "method" ], - "title": "Request", "type": "object" }, { @@ -560,16 +630,13 @@ "id": { "anyOf": [ { - "title": "null", "type": "null" }, { "format": "int64", - "title": "number", "type": "integer" }, { - "title": "string", "type": "string" } ], @@ -579,7 +646,6 @@ "required": [ "id" ], - "title": "Response", "type": "object" }, { @@ -601,7 +667,6 @@ "required": [ "method" ], - "title": "Notification", "type": "object" } ], @@ -623,43 +688,71 @@ "ClientRequest": { "anyOf": [ { - "$ref": "#/$defs/InitializeRequest", - "description": "Establishes the connection with a client and negotiates protocol capabilities.\n\nThis method is called once at the beginning of the connection to:\n- Negotiate the protocol version to use\n- Exchange capability information between client and agent\n- Determine available authentication methods\n\nThe agent should respond with its supported protocol version and capabilities.\n\nSee protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization)", - "title": "InitializeRequest" + "allOf": [ + { + "$ref": "#/$defs/InitializeRequest" + } + ], + "description": "Establishes the connection with a client and negotiates protocol capabilities.\n\nThis method is called once at the beginning of the connection to:\n- Negotiate the protocol version to use\n- Exchange capability information between client and agent\n- Determine available authentication methods\n\nThe agent should respond with its supported protocol version and capabilities.\n\nSee protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization)" + }, + { + "allOf": [ + { + "$ref": "#/$defs/AuthenticateRequest" + } + ], + "description": "Authenticates the client using the specified authentication method.\n\nCalled when the agent requires authentication before allowing session creation.\nThe client provides the authentication method ID that was advertised during initialization.\n\nAfter successful authentication, the client can proceed to create sessions with\n`new_session` without receiving an `auth_required` error.\n\nSee protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization)" }, { - "$ref": "#/$defs/AuthenticateRequest", - "description": "Authenticates the client using the specified authentication method.\n\nCalled when the agent requires authentication before allowing session creation.\nThe client provides the authentication method ID that was advertised during initialization.\n\nAfter successful authentication, the client can proceed to create sessions with\n`new_session` without receiving an `auth_required` error.\n\nSee protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization)", - "title": "AuthenticateRequest" + "allOf": [ + { + "$ref": "#/$defs/NewSessionRequest" + } + ], + "description": "Creates a new conversation session with the agent.\n\nSessions represent independent conversation contexts with their own history and state.\n\nThe agent should:\n- Create a new session context\n- Connect to any specified MCP servers\n- Return a unique session ID for future requests\n\nMay return an `auth_required` error if the agent requires authentication.\n\nSee protocol docs: [Session Setup](https://agentclientprotocol.com/protocol/session-setup)" }, { - "$ref": "#/$defs/NewSessionRequest", - "description": "Creates a new conversation session with the agent.\n\nSessions represent independent conversation contexts with their own history and state.\n\nThe agent should:\n- Create a new session context\n- Connect to any specified MCP servers\n- Return a unique session ID for future requests\n\nMay return an `auth_required` error if the agent requires authentication.\n\nSee protocol docs: [Session Setup](https://agentclientprotocol.com/protocol/session-setup)", - "title": "NewSessionRequest" + "allOf": [ + { + "$ref": "#/$defs/LoadSessionRequest" + } + ], + "description": "Loads an existing session to resume a previous conversation.\n\nThis method is only available if the agent advertises the `loadSession` capability.\n\nThe agent should:\n- Restore the session context and conversation history\n- Connect to the specified MCP servers\n- Stream the entire conversation history back to the client via notifications\n\nSee protocol docs: [Loading Sessions](https://agentclientprotocol.com/protocol/session-setup#loading-sessions)" }, { - "$ref": "#/$defs/LoadSessionRequest", - "description": "Loads an existing session to resume a previous conversation.\n\nThis method is only available if the agent advertises the `loadSession` capability.\n\nThe agent should:\n- Restore the session context and conversation history\n- Connect to the specified MCP servers\n- Stream the entire conversation history back to the client via notifications\n\nSee protocol docs: [Loading Sessions](https://agentclientprotocol.com/protocol/session-setup#loading-sessions)", - "title": "LoadSessionRequest" + "allOf": [ + { + "$ref": "#/$defs/ListSessionsRequest" + } + ], + "description": "**UNSTABLE**\n\nThis capability is not part of the spec yet, and may be removed or changed at any point.\n\nLists existing sessions known to the agent.\n\nThis method is only available if the agent advertises the `listSessions` capability.\n\nThe agent should return metadata about sessions with optional filtering and pagination support." }, { - "$ref": "#/$defs/SetSessionModeRequest", - "description": "Sets the current mode for a session.\n\nAllows switching between different agent modes (e.g., \"ask\", \"architect\", \"code\")\nthat affect system prompts, tool availability, and permission behaviors.\n\nThe mode must be one of the modes advertised in `availableModes` during session\ncreation or loading. Agents may also change modes autonomously and notify the\nclient via `current_mode_update` notifications.\n\nThis method can be called at any time during a session, whether the Agent is\nidle or actively generating a response.\n\nSee protocol docs: [Session Modes](https://agentclientprotocol.com/protocol/session-modes)", - "title": "SetSessionModeRequest" + "allOf": [ + { + "$ref": "#/$defs/SetSessionModeRequest" + } + ], + "description": "Sets the current mode for a session.\n\nAllows switching between different agent modes (e.g., \"ask\", \"architect\", \"code\")\nthat affect system prompts, tool availability, and permission behaviors.\n\nThe mode must be one of the modes advertised in `availableModes` during session\ncreation or loading. Agents may also change modes autonomously and notify the\nclient via `current_mode_update` notifications.\n\nThis method can be called at any time during a session, whether the Agent is\nidle or actively generating a response.\n\nSee protocol docs: [Session Modes](https://agentclientprotocol.com/protocol/session-modes)" }, { - "$ref": "#/$defs/PromptRequest", - "description": "Processes a user prompt within a session.\n\nThis method handles the whole lifecycle of a prompt:\n- Receives user messages with optional context (files, images, etc.)\n- Processes the prompt using language models\n- Reports language model content and tool calls to the Clients\n- Requests permission to run tools\n- Executes any requested tool calls\n- Returns when the turn is complete with a stop reason\n\nSee protocol docs: [Prompt Turn](https://agentclientprotocol.com/protocol/prompt-turn)", - "title": "PromptRequest" + "allOf": [ + { + "$ref": "#/$defs/PromptRequest" + } + ], + "description": "Processes a user prompt within a session.\n\nThis method handles the whole lifecycle of a prompt:\n- Receives user messages with optional context (files, images, etc.)\n- Processes the prompt using language models\n- Reports language model content and tool calls to the Clients\n- Requests permission to run tools\n- Executes any requested tool calls\n- Returns when the turn is complete with a stop reason\n\nSee protocol docs: [Prompt Turn](https://agentclientprotocol.com/protocol/prompt-turn)" }, { - "$ref": "#/$defs/SetSessionModelRequest", - "description": "**UNSTABLE**\n\nThis capability is not part of the spec yet, and may be removed or changed at any point.\n\nSelect a model for a given session.", - "title": "SetSessionModelRequest" + "allOf": [ + { + "$ref": "#/$defs/SetSessionModelRequest" + } + ], + "description": "**UNSTABLE**\n\nThis capability is not part of the spec yet, and may be removed or changed at any point.\n\nSelect a model for a given session." }, { - "description": "Handles extension method requests from the client.\n\nExtension methods provide a way to add custom functionality while maintaining\nprotocol compatibility.\n\nSee protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility)", - "title": "ExtMethodRequest" + "description": "Handles extension method requests from the client.\n\nExtension methods provide a way to add custom functionality while maintaining\nprotocol compatibility.\n\nSee protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility)" } ], "description": "All possible requests that a client can send to an agent.\n\nThis enum is used internally for routing RPC requests. You typically won't need\nto use this directly - instead, use the methods on the [`Agent`] trait.\n\nThis enum encompasses all method calls from client to agent.", @@ -668,44 +761,54 @@ "ClientResponse": { "anyOf": [ { - "$ref": "#/$defs/WriteTextFileResponse", - "title": "WriteTextFileResponse" + "$ref": "#/$defs/WriteTextFileResponse" }, { - "$ref": "#/$defs/ReadTextFileResponse", - "title": "ReadTextFileResponse" + "$ref": "#/$defs/ReadTextFileResponse" }, { - "$ref": "#/$defs/RequestPermissionResponse", - "title": "RequestPermissionResponse" + "$ref": "#/$defs/RequestPermissionResponse" }, { - "$ref": "#/$defs/CreateTerminalResponse", - "title": "CreateTerminalResponse" + "$ref": "#/$defs/CreateTerminalResponse" }, { - "$ref": "#/$defs/TerminalOutputResponse", - "title": "TerminalOutputResponse" + "$ref": "#/$defs/TerminalOutputResponse" }, { - "$ref": "#/$defs/ReleaseTerminalResponse", - "title": "ReleaseTerminalResponse" + "$ref": "#/$defs/ReleaseTerminalResponse" }, { - "$ref": "#/$defs/WaitForTerminalExitResponse", - "title": "WaitForTerminalExitResponse" + "$ref": "#/$defs/WaitForTerminalExitResponse" }, { - "$ref": "#/$defs/KillTerminalCommandResponse", - "title": "KillTerminalResponse" + "$ref": "#/$defs/KillTerminalCommandResponse" }, - { - "title": "ExtMethodResponse" - } + {} ], "description": "All possible responses that a client can send to an agent.\n\nThis enum is used internally for routing RPC responses. You typically won't need\nto use this directly - the responses are handled automatically by the connection.\n\nThese are responses to the corresponding `AgentRequest` variants.", "x-docs-ignore": true }, + "Content": { + "description": "Standard content block (text, images, resources).", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "content": { + "allOf": [ + { + "$ref": "#/$defs/ContentBlock" + } + ], + "description": "The actual content block." + } + }, + "required": [ + "content" + ], + "type": "object" + }, "ContentBlock": { "description": "Content blocks represent displayable information in the Agent Client Protocol.\n\nThey provide a structured way to handle various types of user-facing content\u2014whether\nit's text from language models, images for analysis, or embedded resources for context.\n\nContent blocks appear in:\n- User prompts sent via `session/prompt`\n- Language model output streamed through `session/update` notifications\n- Progress updates and results from tool calls\n\nThis structure is compatible with the Model Context Protocol (MCP), enabling\nagents to seamlessly forward content from MCP tool outputs without transformation.\n\nSee protocol docs: [Content](https://agentclientprotocol.com/protocol/content)", "discriminator": { @@ -713,200 +816,117 @@ }, "oneOf": [ { + "allOf": [ + { + "$ref": "#/$defs/TextContent" + } + ], "description": "Text content. May be plain text or formatted with Markdown.\n\nAll agents MUST support text content blocks in prompts.\nClients SHOULD render this text as Markdown.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "annotations": { - "anyOf": [ - { - "$ref": "#/$defs/Annotations" - }, - { - "type": "null" - } - ] - }, - "text": { - "type": "string" - }, "type": { "const": "text", "type": "string" } }, "required": [ - "type", - "text" + "type" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/ImageContent" + } + ], "description": "Images for visual context or analysis.\n\nRequires the `image` prompt capability when included in prompts.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "annotations": { - "anyOf": [ - { - "$ref": "#/$defs/Annotations" - }, - { - "type": "null" - } - ] - }, - "data": { - "type": "string" - }, - "mimeType": { - "type": "string" - }, "type": { "const": "image", "type": "string" - }, - "uri": { - "type": [ - "string", - "null" - ] } }, "required": [ - "type", - "data", - "mimeType" + "type" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/AudioContent" + } + ], "description": "Audio data for transcription or analysis.\n\nRequires the `audio` prompt capability when included in prompts.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "annotations": { - "anyOf": [ - { - "$ref": "#/$defs/Annotations" - }, - { - "type": "null" - } - ] - }, - "data": { - "type": "string" - }, - "mimeType": { - "type": "string" - }, "type": { "const": "audio", "type": "string" } }, "required": [ - "type", - "data", - "mimeType" + "type" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/ResourceLink" + } + ], "description": "References to resources that the agent can access.\n\nAll agents MUST support resource links in prompts.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "annotations": { - "anyOf": [ - { - "$ref": "#/$defs/Annotations" - }, - { - "type": "null" - } - ] - }, - "description": { - "type": [ - "string", - "null" - ] - }, - "mimeType": { - "type": [ - "string", - "null" - ] - }, - "name": { - "type": "string" - }, - "size": { - "format": "int64", - "type": [ - "integer", - "null" - ] - }, - "title": { - "type": [ - "string", - "null" - ] - }, "type": { "const": "resource_link", "type": "string" - }, - "uri": { - "type": "string" } }, "required": [ - "type", - "name", - "uri" + "type" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/EmbeddedResource" + } + ], "description": "Complete resource contents embedded directly in the message.\n\nPreferred for including context as it avoids extra round-trips.\n\nRequires the `embeddedContext` prompt capability when included in prompts.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "annotations": { - "anyOf": [ - { - "$ref": "#/$defs/Annotations" - }, - { - "type": "null" - } - ] - }, - "resource": { - "$ref": "#/$defs/EmbeddedResourceResource" - }, "type": { "const": "resource", "type": "string" } }, "required": [ - "type", - "resource" + "type" ], "type": "object" } ] }, + "ContentChunk": { + "description": "A streamed item of content", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "content": { + "allOf": [ + { + "$ref": "#/$defs/ContentBlock" + } + ], + "description": "A single item of content" + } + }, + "required": [ + "content" + ], + "type": "object" + }, "CreateTerminalRequest": { "description": "Request to create a new terminal and execute a command.", "properties": { @@ -948,7 +968,11 @@ ] }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The session ID for this request." } }, @@ -978,30 +1002,101 @@ "x-method": "terminal/create", "x-side": "client" }, - "EmbeddedResourceResource": { - "anyOf": [ - { - "$ref": "#/$defs/TextResourceContents", - "title": "TextResourceContents" - }, - { - "$ref": "#/$defs/BlobResourceContents", - "title": "BlobResourceContents" - } - ], - "description": "Resource content that can be embedded in a message." - }, - "EnvVariable": { - "description": "An environment variable to set when launching an MCP server.", + "CurrentModeUpdate": { + "description": "The current mode of the session has changed\n\nSee protocol docs: [Session Modes](https://agentclientprotocol.com/protocol/session-modes)", "properties": { "_meta": { "description": "Extension point for implementations" }, - "name": { - "description": "The name of the environment variable.", - "type": "string" - }, - "value": { + "currentModeId": { + "allOf": [ + { + "$ref": "#/$defs/SessionModeId" + } + ], + "description": "The ID of the current mode" + } + }, + "required": [ + "currentModeId" + ], + "type": "object" + }, + "Diff": { + "description": "A diff representing file modifications.\n\nShows changes to files in a format suitable for display in the client UI.\n\nSee protocol docs: [Content](https://agentclientprotocol.com/protocol/tool-calls#content)", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "newText": { + "description": "The new content after modification.", + "type": "string" + }, + "oldText": { + "description": "The original content (None for new files).", + "type": [ + "string", + "null" + ] + }, + "path": { + "description": "The file path being modified.", + "type": "string" + } + }, + "required": [ + "path", + "newText" + ], + "type": "object" + }, + "EmbeddedResource": { + "description": "The contents of a resource, embedded into a prompt or tool call result.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "annotations": { + "anyOf": [ + { + "$ref": "#/$defs/Annotations" + }, + { + "type": "null" + } + ] + }, + "resource": { + "$ref": "#/$defs/EmbeddedResourceResource" + } + }, + "required": [ + "resource" + ], + "type": "object" + }, + "EmbeddedResourceResource": { + "anyOf": [ + { + "$ref": "#/$defs/TextResourceContents" + }, + { + "$ref": "#/$defs/BlobResourceContents" + } + ], + "description": "Resource content that can be embedded in a message." + }, + "EnvVariable": { + "description": "An environment variable to set when launching an MCP server.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "name": { + "description": "The name of the environment variable.", + "type": "string" + }, + "value": { "description": "The value to set for the environment variable.", "type": "string" } @@ -1035,7 +1130,7 @@ "type": "object" }, "FileSystemCapability": { - "description": "File system capabilities that a client may support.\n\nSee protocol docs: [FileSystem](https://agentclientprotocol.com/protocol/initialization#filesystem)", + "description": "Filesystem capabilities supported by the client.\nFile system capabilities that a client may support.\n\nSee protocol docs: [FileSystem](https://agentclientprotocol.com/protocol/initialization#filesystem)", "properties": { "_meta": { "description": "Extension point for implementations" @@ -1074,9 +1169,47 @@ ], "type": "object" }, + "ImageContent": { + "description": "An image provided to or from an LLM.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "annotations": { + "anyOf": [ + { + "$ref": "#/$defs/Annotations" + }, + { + "type": "null" + } + ] + }, + "data": { + "type": "string" + }, + "mimeType": { + "type": "string" + }, + "uri": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "data", + "mimeType" + ], + "type": "object" + }, "Implementation": { - "description": "Describes the name and version of an MCP implementation, with an optional\ntitle for UI representation.", + "description": "Metadata about the implementation of the client or agent.\nDescribes the name and version of an MCP implementation, with an optional\ntitle for UI representation.", "properties": { + "_meta": { + "description": "Extension point for implementations" + }, "name": { "description": "Intended for programmatic or logical use, but can be used as a display\nname fallback if title isn\u2019t present.", "type": "string" @@ -1089,7 +1222,7 @@ ] }, "version": { - "description": "Version of the implementation. Can be displayed to the user or used\nfor debugging or metrics purposes.", + "description": "Version of the implementation. Can be displayed to the user or used\nfor debugging or metrics purposes. (e.g. \"1.0.0\").", "type": "string" } }, @@ -1106,7 +1239,11 @@ "description": "Extension point for implementations" }, "clientCapabilities": { - "$ref": "#/$defs/ClientCapabilities", + "allOf": [ + { + "$ref": "#/$defs/ClientCapabilities" + } + ], "default": { "fs": { "readTextFile": false, @@ -1128,7 +1265,11 @@ "description": "Information about the Client name and version sent to the Agent.\n\nNote: in future versions of the protocol, this will be required." }, "protocolVersion": { - "$ref": "#/$defs/ProtocolVersion", + "allOf": [ + { + "$ref": "#/$defs/ProtocolVersion" + } + ], "description": "The latest protocol version supported by the client." } }, @@ -1140,13 +1281,17 @@ "x-side": "agent" }, "InitializeResponse": { - "description": "Response from the initialize method.\n\nContains the negotiated protocol version and agent capabilities.\n\nSee protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization)", + "description": "Response to the `initialize` method.\n\nContains the negotiated protocol version and agent capabilities.\n\nSee protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization)", "properties": { "_meta": { "description": "Extension point for implementations" }, "agentCapabilities": { - "$ref": "#/$defs/AgentCapabilities", + "allOf": [ + { + "$ref": "#/$defs/AgentCapabilities" + } + ], "default": { "loadSession": false, "mcpCapabilities": { @@ -1157,7 +1302,8 @@ "audio": false, "embeddedContext": false, "image": false - } + }, + "sessionCapabilities": {} }, "description": "Capabilities supported by the agent." }, @@ -1181,7 +1327,11 @@ "type": "array" }, "protocolVersion": { - "$ref": "#/$defs/ProtocolVersion", + "allOf": [ + { + "$ref": "#/$defs/ProtocolVersion" + } + ], "description": "The protocol version the client specified if supported by the agent,\nor the latest protocol version supported by the agent.\n\nThe client should disconnect, if it doesn't support this version." } }, @@ -1199,7 +1349,11 @@ "description": "Extension point for implementations" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The session ID for this request." }, "terminalId": { @@ -1226,6 +1380,59 @@ "x-method": "terminal/kill", "x-side": "client" }, + "ListSessionsRequest": { + "description": "**UNSTABLE**\n\nThis capability is not part of the spec yet, and may be removed or changed at any point.\n\nRequest parameters for listing existing sessions.\n\nOnly available if the Agent supports the `listSessions` capability.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "cursor": { + "description": "Opaque cursor token from a previous response's nextCursor field for cursor-based pagination", + "type": [ + "string", + "null" + ] + }, + "cwd": { + "description": "Filter sessions by working directory. Must be an absolute path.", + "type": [ + "string", + "null" + ] + } + }, + "type": "object", + "x-method": "session/list", + "x-side": "agent" + }, + "ListSessionsResponse": { + "description": "**UNSTABLE**\n\nThis capability is not part of the spec yet, and may be removed or changed at any point.\n\nResponse from listing sessions.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "nextCursor": { + "description": "Opaque cursor token. If present, pass this in the next request's cursor parameter\nto fetch the next page. If absent, there are no more results.", + "type": [ + "string", + "null" + ] + }, + "sessions": { + "description": "Array of session information objects", + "items": { + "$ref": "#/$defs/SessionInfo" + }, + "type": "array" + } + }, + "required": [ + "sessions" + ], + "type": "object", + "x-method": "session/list", + "x-side": "agent" + }, "LoadSessionRequest": { "description": "Request parameters for loading an existing session.\n\nOnly available if the Agent supports the `loadSession` capability.\n\nSee protocol docs: [Loading Sessions](https://agentclientprotocol.com/protocol/session-setup#loading-sessions)", "properties": { @@ -1244,7 +1451,11 @@ "type": "array" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The ID of the session to load." } }, @@ -1312,101 +1523,48 @@ "McpServer": { "anyOf": [ { + "allOf": [ + { + "$ref": "#/$defs/McpServerHttp" + } + ], "description": "HTTP transport configuration\n\nOnly available when the Agent capabilities indicate `mcp_capabilities.http` is `true`.", "properties": { - "headers": { - "description": "HTTP headers to set when making requests to the MCP server.", - "items": { - "$ref": "#/$defs/HttpHeader" - }, - "type": "array" - }, - "name": { - "description": "Human-readable name identifying this MCP server.", - "type": "string" - }, "type": { "const": "http", "type": "string" - }, - "url": { - "description": "URL to the MCP server.", - "type": "string" } }, "required": [ - "type", - "name", - "url", - "headers" + "type" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/McpServerSse" + } + ], "description": "SSE transport configuration\n\nOnly available when the Agent capabilities indicate `mcp_capabilities.sse` is `true`.", "properties": { - "headers": { - "description": "HTTP headers to set when making requests to the MCP server.", - "items": { - "$ref": "#/$defs/HttpHeader" - }, - "type": "array" - }, - "name": { - "description": "Human-readable name identifying this MCP server.", - "type": "string" - }, "type": { "const": "sse", "type": "string" - }, - "url": { - "description": "URL to the MCP server.", - "type": "string" } }, "required": [ - "type", - "name", - "url", - "headers" + "type" ], "type": "object" }, { - "description": "Stdio transport configuration\n\nAll Agents MUST support this transport.", - "properties": { - "args": { - "description": "Command-line arguments to pass to the MCP server.", - "items": { - "type": "string" - }, - "type": "array" - }, - "command": { - "description": "Path to the MCP server executable.", - "type": "string" - }, - "env": { - "description": "Environment variables to set when launching the MCP server.", - "items": { - "$ref": "#/$defs/EnvVariable" - }, - "type": "array" - }, - "name": { - "description": "Human-readable name identifying this MCP server.", - "type": "string" + "allOf": [ + { + "$ref": "#/$defs/McpServerStdio" } - }, - "required": [ - "name", - "command", - "args", - "env" ], - "title": "stdio", - "type": "object" + "description": "Stdio transport configuration\n\nAll Agents MUST support this transport." } ], "description": "Configuration for connecting to an MCP (Model Context Protocol) server.\n\nMCP servers provide tools and context that the agent can use when\nprocessing prompts.\n\nSee protocol docs: [MCP Servers](https://agentclientprotocol.com/protocol/session-setup#mcp-servers)", @@ -1414,6 +1572,101 @@ "propertyName": "type" } }, + "McpServerHttp": { + "description": "HTTP transport configuration for MCP.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "headers": { + "description": "HTTP headers to set when making requests to the MCP server.", + "items": { + "$ref": "#/$defs/HttpHeader" + }, + "type": "array" + }, + "name": { + "description": "Human-readable name identifying this MCP server.", + "type": "string" + }, + "url": { + "description": "URL to the MCP server.", + "type": "string" + } + }, + "required": [ + "name", + "url", + "headers" + ], + "type": "object" + }, + "McpServerSse": { + "description": "SSE transport configuration for MCP.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "headers": { + "description": "HTTP headers to set when making requests to the MCP server.", + "items": { + "$ref": "#/$defs/HttpHeader" + }, + "type": "array" + }, + "name": { + "description": "Human-readable name identifying this MCP server.", + "type": "string" + }, + "url": { + "description": "URL to the MCP server.", + "type": "string" + } + }, + "required": [ + "name", + "url", + "headers" + ], + "type": "object" + }, + "McpServerStdio": { + "description": "Stdio transport configuration for MCP.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "args": { + "description": "Command-line arguments to pass to the MCP server.", + "items": { + "type": "string" + }, + "type": "array" + }, + "command": { + "description": "Path to the MCP server executable.", + "type": "string" + }, + "env": { + "description": "Environment variables to set when launching the MCP server.", + "items": { + "$ref": "#/$defs/EnvVariable" + }, + "type": "array" + }, + "name": { + "description": "Human-readable name identifying this MCP server.", + "type": "string" + } + }, + "required": [ + "name", + "command", + "args", + "env" + ], + "type": "object" + }, "ModelId": { "description": "**UNSTABLE**\n\nThis capability is not part of the spec yet, and may be removed or changed at any point.\n\nA unique identifier for a model.", "type": "string" @@ -1432,7 +1685,11 @@ ] }, "modelId": { - "$ref": "#/$defs/ModelId", + "allOf": [ + { + "$ref": "#/$defs/ModelId" + } + ], "description": "Unique identifier for the model." }, "name": { @@ -1501,7 +1758,11 @@ "description": "Initial mode state if supported by the Agent\n\nSee protocol docs: [Session Modes](https://agentclientprotocol.com/protocol/session-modes)" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "Unique identifier for the created session.\n\nUsed in all subsequent requests for this conversation." } }, @@ -1519,7 +1780,11 @@ "description": "Extension point for implementations" }, "kind": { - "$ref": "#/$defs/PermissionOptionKind", + "allOf": [ + { + "$ref": "#/$defs/PermissionOptionKind" + } + ], "description": "Hint about the nature of this permission option." }, "name": { @@ -1527,7 +1792,11 @@ "type": "string" }, "optionId": { - "$ref": "#/$defs/PermissionOptionId", + "allOf": [ + { + "$ref": "#/$defs/PermissionOptionId" + } + ], "description": "Unique identifier for this permission option." } }, @@ -1567,6 +1836,25 @@ } ] }, + "Plan": { + "description": "An execution plan for accomplishing complex tasks.\n\nPlans consist of multiple entries representing individual tasks or goals.\nAgents report plans to clients to provide visibility into their execution strategy.\nPlans can evolve during execution as the agent discovers new requirements or completes tasks.\n\nSee protocol docs: [Agent Plan](https://agentclientprotocol.com/protocol/agent-plan)", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "entries": { + "description": "The list of tasks to be accomplished.\n\nWhen updating a plan, the agent must send a complete list of all entries\nwith their current status. The client replaces the entire plan with each update.", + "items": { + "$ref": "#/$defs/PlanEntry" + }, + "type": "array" + } + }, + "required": [ + "entries" + ], + "type": "object" + }, "PlanEntry": { "description": "A single entry in the execution plan.\n\nRepresents a task or goal that the assistant intends to accomplish\nas part of fulfilling the user's request.\nSee protocol docs: [Plan Entries](https://agentclientprotocol.com/protocol/agent-plan#plan-entries)", "properties": { @@ -1578,11 +1866,19 @@ "type": "string" }, "priority": { - "$ref": "#/$defs/PlanEntryPriority", + "allOf": [ + { + "$ref": "#/$defs/PlanEntryPriority" + } + ], "description": "The relative importance of this task.\nUsed to indicate which tasks are most critical to the overall goal." }, "status": { - "$ref": "#/$defs/PlanEntryStatus", + "allOf": [ + { + "$ref": "#/$defs/PlanEntryStatus" + } + ], "description": "Current execution status of this task." } }, @@ -1671,7 +1967,11 @@ "type": "array" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The ID of the session to send this user message to" } }, @@ -1690,7 +1990,11 @@ "description": "Extension point for implementations" }, "stopReason": { - "$ref": "#/$defs/StopReason", + "allOf": [ + { + "$ref": "#/$defs/StopReason" + } + ], "description": "Indicates why the agent stopped processing the turn." } }, @@ -1737,7 +2041,11 @@ "type": "string" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The session ID for this request." } }, @@ -1773,7 +2081,11 @@ "description": "Extension point for implementations" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The session ID for this request." }, "terminalId": { @@ -1820,20 +2132,20 @@ "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/SelectedPermissionOutcome" + } + ], "description": "The user selected one of the provided options.", "properties": { - "optionId": { - "$ref": "#/$defs/PermissionOptionId", - "description": "The ID of the option the user selected." - }, "outcome": { "const": "selected", "type": "string" } }, "required": [ - "outcome", - "optionId" + "outcome" ], "type": "object" } @@ -1853,79 +2165,20 @@ "type": "array" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The session ID for this request." }, "toolCall": { - "description": "Details about the tool call requiring permission.", - "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "content": { - "description": "Replace the content collection.", - "items": { - "$ref": "#/$defs/ToolCallContent" - }, - "type": [ - "array", - "null" - ] - }, - "kind": { - "anyOf": [ - { - "$ref": "#/$defs/ToolKind" - }, - { - "type": "null" - } - ], - "description": "Update the tool kind." - }, - "locations": { - "description": "Replace the locations collection.", - "items": { - "$ref": "#/$defs/ToolCallLocation" - }, - "type": [ - "array", - "null" - ] - }, - "rawInput": { - "description": "Update the raw input." - }, - "rawOutput": { - "description": "Update the raw output." - }, - "status": { - "anyOf": [ - { - "$ref": "#/$defs/ToolCallStatus" - }, - { - "type": "null" - } - ], - "description": "Update the execution status." - }, - "title": { - "description": "Update the human-readable title.", - "type": [ - "string", - "null" - ] - }, - "toolCallId": { - "$ref": "#/$defs/ToolCallId", - "description": "The ID of the tool call being updated." + "allOf": [ + { + "$ref": "#/$defs/ToolCallUpdate" } - }, - "required": [ - "toolCallId" ], - "type": "object" + "description": "Details about the tool call requiring permission." } }, "required": [ @@ -1944,7 +2197,11 @@ "description": "Extension point for implementations" }, "outcome": { - "$ref": "#/$defs/RequestPermissionOutcome", + "allOf": [ + { + "$ref": "#/$defs/RequestPermissionOutcome" + } + ], "description": "The user's decision on the permission request." } }, @@ -1955,6 +2212,60 @@ "x-method": "session/request_permission", "x-side": "client" }, + "ResourceLink": { + "description": "A resource that the server is capable of reading, included in a prompt or tool call result.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "annotations": { + "anyOf": [ + { + "$ref": "#/$defs/Annotations" + }, + { + "type": "null" + } + ] + }, + "description": { + "type": [ + "string", + "null" + ] + }, + "mimeType": { + "type": [ + "string", + "null" + ] + }, + "name": { + "type": "string" + }, + "size": { + "format": "int64", + "type": [ + "integer", + "null" + ] + }, + "title": { + "type": [ + "string", + "null" + ] + }, + "uri": { + "type": "string" + } + }, + "required": [ + "name", + "uri" + ], + "type": "object" + }, "Role": { "description": "The sender or recipient of messages and data in a conversation.", "enum": [ @@ -1963,10 +2274,98 @@ ], "type": "string" }, + "SelectedPermissionOutcome": { + "description": "The user selected one of the provided options.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "optionId": { + "allOf": [ + { + "$ref": "#/$defs/PermissionOptionId" + } + ], + "description": "The ID of the option the user selected." + } + }, + "required": [ + "optionId" + ], + "type": "object" + }, + "SessionCapabilities": { + "description": "Session capabilities supported by the agent.\n\nAs a baseline, all Agents **MUST** support `session/new`, `session/prompt`, `session/cancel`, and `session/update`.\n\nOptionally, they **MAY** support other session methods and notifications by specifying additional capabilities.\n\nNote: `session/load` is still handled by the top-level `load_session` capability. This will be unified in future versions of the protocol.\n\nSee protocol docs: [Session Capabilities](https://agentclientprotocol.com/protocol/initialization#session-capabilities)", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "list": { + "anyOf": [ + { + "$ref": "#/$defs/SessionListCapabilities" + }, + { + "type": "null" + } + ], + "description": "**UNSTABLE**\n\nThis capability is not part of the spec yet, and may be removed or changed at any point.\n\nWhether the agent supports `session/list`." + } + }, + "type": "object" + }, "SessionId": { "description": "A unique identifier for a conversation session between a client and agent.\n\nSessions maintain their own context, conversation history, and state,\nallowing multiple independent interactions with the same agent.\n\n# Example\n\n```\nuse agent_client_protocol::SessionId;\nuse std::sync::Arc;\n\nlet session_id = SessionId(Arc::from(\"sess_abc123def456\"));\n```\n\nSee protocol docs: [Session ID](https://agentclientprotocol.com/protocol/session-setup#session-id)", "type": "string" }, + "SessionInfo": { + "description": "**UNSTABLE**\n\nThis capability is not part of the spec yet, and may be removed or changed at any point.\n\nInformation about a session returned by session/list", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "cwd": { + "description": "The working directory for this session. Must be an absolute path.", + "type": "string" + }, + "sessionId": { + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], + "description": "Unique identifier for the session" + }, + "title": { + "description": "Human-readable title for the session", + "type": [ + "string", + "null" + ] + }, + "updatedAt": { + "description": "ISO 8601 timestamp of last activity", + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "sessionId", + "cwd" + ], + "type": "object" + }, + "SessionListCapabilities": { + "description": "Capabilities for the `session/list` method.\n\nBy supplying `{}` it means that the agent supports listing of sessions.\n\nFurther capabilities can be added in the future for other means of filtering or searching the list.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + } + }, + "type": "object" + }, "SessionMode": { "description": "A mode the agent can operate in.\n\nSee protocol docs: [Session Modes](https://agentclientprotocol.com/protocol/session-modes)", "properties": { @@ -2010,7 +2409,11 @@ "type": "array" }, "currentModeId": { - "$ref": "#/$defs/SessionModeId", + "allOf": [ + { + "$ref": "#/$defs/SessionModeId" + } + ], "description": "The current mode the Agent is in." } }, @@ -2034,7 +2437,11 @@ "type": "array" }, "currentModelId": { - "$ref": "#/$defs/ModelId", + "allOf": [ + { + "$ref": "#/$defs/ModelId" + } + ], "description": "The current model the Agent is in." } }, @@ -2051,11 +2458,19 @@ "description": "Extension point for implementations" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The ID of the session this update pertains to." }, "update": { - "$ref": "#/$defs/SessionUpdate", + "allOf": [ + { + "$ref": "#/$defs/SessionUpdate" + } + ], "description": "The actual update content." } }, @@ -2074,264 +2489,146 @@ }, "oneOf": [ { + "allOf": [ + { + "$ref": "#/$defs/ContentChunk" + } + ], "description": "A chunk of the user's message being streamed.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "content": { - "$ref": "#/$defs/ContentBlock", - "description": "A single item of content" - }, "sessionUpdate": { "const": "user_message_chunk", "type": "string" } }, "required": [ - "sessionUpdate", - "content" + "sessionUpdate" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/ContentChunk" + } + ], "description": "A chunk of the agent's response being streamed.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "content": { - "$ref": "#/$defs/ContentBlock", - "description": "A single item of content" - }, "sessionUpdate": { "const": "agent_message_chunk", "type": "string" } }, "required": [ - "sessionUpdate", - "content" + "sessionUpdate" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/ContentChunk" + } + ], "description": "A chunk of the agent's internal reasoning being streamed.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "content": { - "$ref": "#/$defs/ContentBlock", - "description": "A single item of content" - }, "sessionUpdate": { "const": "agent_thought_chunk", "type": "string" } }, "required": [ - "sessionUpdate", - "content" + "sessionUpdate" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/ToolCall" + } + ], "description": "Notification that a new tool call has been initiated.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "content": { - "description": "Content produced by the tool call.", - "items": { - "$ref": "#/$defs/ToolCallContent" - }, - "type": "array" - }, - "kind": { - "$ref": "#/$defs/ToolKind", - "description": "The category of tool being invoked.\nHelps clients choose appropriate icons and UI treatment." - }, - "locations": { - "description": "File locations affected by this tool call.\nEnables \"follow-along\" features in clients.", - "items": { - "$ref": "#/$defs/ToolCallLocation" - }, - "type": "array" - }, - "rawInput": { - "description": "Raw input parameters sent to the tool." - }, - "rawOutput": { - "description": "Raw output returned by the tool." - }, "sessionUpdate": { "const": "tool_call", "type": "string" - }, - "status": { - "$ref": "#/$defs/ToolCallStatus", - "description": "Current execution status of the tool call." - }, - "title": { - "description": "Human-readable title describing what the tool is doing.", - "type": "string" - }, - "toolCallId": { - "$ref": "#/$defs/ToolCallId", - "description": "Unique identifier for this tool call within the session." } }, "required": [ - "sessionUpdate", - "toolCallId", - "title" + "sessionUpdate" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/ToolCallUpdate" + } + ], "description": "Update on the status or results of a tool call.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "content": { - "description": "Replace the content collection.", - "items": { - "$ref": "#/$defs/ToolCallContent" - }, - "type": [ - "array", - "null" - ] - }, - "kind": { - "anyOf": [ - { - "$ref": "#/$defs/ToolKind" - }, - { - "type": "null" - } - ], - "description": "Update the tool kind." - }, - "locations": { - "description": "Replace the locations collection.", - "items": { - "$ref": "#/$defs/ToolCallLocation" - }, - "type": [ - "array", - "null" - ] - }, - "rawInput": { - "description": "Update the raw input." - }, - "rawOutput": { - "description": "Update the raw output." - }, "sessionUpdate": { "const": "tool_call_update", "type": "string" - }, - "status": { - "anyOf": [ - { - "$ref": "#/$defs/ToolCallStatus" - }, - { - "type": "null" - } - ], - "description": "Update the execution status." - }, - "title": { - "description": "Update the human-readable title.", - "type": [ - "string", - "null" - ] - }, - "toolCallId": { - "$ref": "#/$defs/ToolCallId", - "description": "The ID of the tool call being updated." } }, "required": [ - "sessionUpdate", - "toolCallId" + "sessionUpdate" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/Plan" + } + ], "description": "The agent's execution plan for complex tasks.\nSee protocol docs: [Agent Plan](https://agentclientprotocol.com/protocol/agent-plan)", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "entries": { - "description": "The list of tasks to be accomplished.\n\nWhen updating a plan, the agent must send a complete list of all entries\nwith their current status. The client replaces the entire plan with each update.", - "items": { - "$ref": "#/$defs/PlanEntry" - }, - "type": "array" - }, "sessionUpdate": { "const": "plan", "type": "string" } }, "required": [ - "sessionUpdate", - "entries" + "sessionUpdate" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/AvailableCommandsUpdate" + } + ], "description": "Available commands are ready or have changed", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "availableCommands": { - "description": "Commands the agent can execute", - "items": { - "$ref": "#/$defs/AvailableCommand" - }, - "type": "array" - }, "sessionUpdate": { "const": "available_commands_update", "type": "string" } }, "required": [ - "sessionUpdate", - "availableCommands" + "sessionUpdate" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/CurrentModeUpdate" + } + ], "description": "The current mode of the session has changed\n\nSee protocol docs: [Session Modes](https://agentclientprotocol.com/protocol/session-modes)", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "currentModeId": { - "$ref": "#/$defs/SessionModeId", - "description": "The ID of the current mode" - }, "sessionUpdate": { "const": "current_mode_update", "type": "string" } }, "required": [ - "sessionUpdate", - "currentModeId" + "sessionUpdate" ], "type": "object" } @@ -2344,11 +2641,19 @@ "description": "Extension point for implementations" }, "modeId": { - "$ref": "#/$defs/SessionModeId", + "allOf": [ + { + "$ref": "#/$defs/SessionModeId" + } + ], "description": "The ID of the mode to set." }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The ID of the session to set the mode for." } }, @@ -2363,7 +2668,7 @@ "SetSessionModeResponse": { "description": "Response to `session/set_mode` method.", "properties": { - "_meta": true + "_meta": {} }, "type": "object", "x-method": "session/set_mode", @@ -2376,11 +2681,19 @@ "description": "Extension point for implementations" }, "modelId": { - "$ref": "#/$defs/ModelId", + "allOf": [ + { + "$ref": "#/$defs/ModelId" + } + ], "description": "The ID of the model to set." }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The ID of the session to set the model for." } }, @@ -2433,6 +2746,21 @@ } ] }, + "Terminal": { + "description": "Embed a terminal created with `terminal/create` by its id.\n\nThe terminal must be added before calling `terminal/release`.\n\nSee protocol docs: [Terminal](https://agentclientprotocol.com/protocol/terminals)", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "terminalId": { + "type": "string" + } + }, + "required": [ + "terminalId" + ], + "type": "object" + }, "TerminalExitStatus": { "description": "Exit status of a terminal command.", "properties": { @@ -2465,7 +2793,11 @@ "description": "Extension point for implementations" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The session ID for this request." }, "terminalId": { @@ -2515,6 +2847,31 @@ "x-method": "terminal/output", "x-side": "client" }, + "TextContent": { + "description": "Text provided to or from an LLM.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "annotations": { + "anyOf": [ + { + "$ref": "#/$defs/Annotations" + }, + { + "type": "null" + } + ] + }, + "text": { + "type": "string" + } + }, + "required": [ + "text" + ], + "type": "object" + }, "TextResourceContents": { "description": "Text-based resource contents.", "properties": { @@ -2540,6 +2897,67 @@ ], "type": "object" }, + "ToolCall": { + "description": "Represents a tool call that the language model has requested.\n\nTool calls are actions that the agent executes on behalf of the language model,\nsuch as reading files, executing code, or fetching data from external sources.\n\nSee protocol docs: [Tool Calls](https://agentclientprotocol.com/protocol/tool-calls)", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "content": { + "description": "Content produced by the tool call.", + "items": { + "$ref": "#/$defs/ToolCallContent" + }, + "type": "array" + }, + "kind": { + "allOf": [ + { + "$ref": "#/$defs/ToolKind" + } + ], + "description": "The category of tool being invoked.\nHelps clients choose appropriate icons and UI treatment." + }, + "locations": { + "description": "File locations affected by this tool call.\nEnables \"follow-along\" features in clients.", + "items": { + "$ref": "#/$defs/ToolCallLocation" + }, + "type": "array" + }, + "rawInput": { + "description": "Raw input parameters sent to the tool." + }, + "rawOutput": { + "description": "Raw output returned by the tool." + }, + "status": { + "allOf": [ + { + "$ref": "#/$defs/ToolCallStatus" + } + ], + "description": "Current execution status of the tool call." + }, + "title": { + "description": "Human-readable title describing what the tool is doing.", + "type": "string" + }, + "toolCallId": { + "allOf": [ + { + "$ref": "#/$defs/ToolCallId" + } + ], + "description": "Unique identifier for this tool call within the session." + } + }, + "required": [ + "toolCallId", + "title" + ], + "type": "object" + }, "ToolCallContent": { "description": "Content produced by a tool call.\n\nTool calls can produce different types of content including\nstandard content blocks (text, images) or file diffs.\n\nSee protocol docs: [Content](https://agentclientprotocol.com/protocol/tool-calls#content)", "discriminator": { @@ -2547,70 +2965,56 @@ }, "oneOf": [ { + "allOf": [ + { + "$ref": "#/$defs/Content" + } + ], "description": "Standard content block (text, images, resources).", "properties": { - "content": { - "$ref": "#/$defs/ContentBlock", - "description": "The actual content block." - }, "type": { "const": "content", "type": "string" } }, "required": [ - "type", - "content" + "type" ], "type": "object" }, { + "allOf": [ + { + "$ref": "#/$defs/Diff" + } + ], "description": "File modification shown as a diff.", "properties": { - "_meta": { - "description": "Extension point for implementations" - }, - "newText": { - "description": "The new content after modification.", - "type": "string" - }, - "oldText": { - "description": "The original content (None for new files).", - "type": [ - "string", - "null" - ] - }, - "path": { - "description": "The file path being modified.", - "type": "string" - }, "type": { "const": "diff", "type": "string" } }, "required": [ - "type", - "path", - "newText" + "type" ], "type": "object" }, { - "description": "Embed a terminal created with `terminal/create` by its id.\n\nThe terminal must be added before calling `terminal/release`.\n\nSee protocol docs: [Terminal](https://agentclientprotocol.com/protocol/terminal)", + "allOf": [ + { + "$ref": "#/$defs/Terminal" + } + ], + "description": "Embed a terminal created with `terminal/create` by its id.\n\nThe terminal must be added before calling `terminal/release`.\n\nSee protocol docs: [Terminal](https://agentclientprotocol.com/protocol/terminals)", "properties": { - "terminalId": { - "type": "string" - }, "type": { "const": "terminal", "type": "string" } }, "required": [ - "type", - "terminalId" + "type" ], "type": "object" } @@ -2670,6 +3074,81 @@ } ] }, + "ToolCallUpdate": { + "description": "An update to an existing tool call.\n\nUsed to report progress and results as tools execute. All fields except\nthe tool call ID are optional - only changed fields need to be included.\n\nSee protocol docs: [Updating](https://agentclientprotocol.com/protocol/tool-calls#updating)", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "content": { + "description": "Replace the content collection.", + "items": { + "$ref": "#/$defs/ToolCallContent" + }, + "type": [ + "array", + "null" + ] + }, + "kind": { + "anyOf": [ + { + "$ref": "#/$defs/ToolKind" + }, + { + "type": "null" + } + ], + "description": "Update the tool kind." + }, + "locations": { + "description": "Replace the locations collection.", + "items": { + "$ref": "#/$defs/ToolCallLocation" + }, + "type": [ + "array", + "null" + ] + }, + "rawInput": { + "description": "Update the raw input." + }, + "rawOutput": { + "description": "Update the raw output." + }, + "status": { + "anyOf": [ + { + "$ref": "#/$defs/ToolCallStatus" + }, + { + "type": "null" + } + ], + "description": "Update the execution status." + }, + "title": { + "description": "Update the human-readable title.", + "type": [ + "string", + "null" + ] + }, + "toolCallId": { + "allOf": [ + { + "$ref": "#/$defs/ToolCallId" + } + ], + "description": "The ID of the tool call being updated." + } + }, + "required": [ + "toolCallId" + ], + "type": "object" + }, "ToolKind": { "description": "Categories of tools that can be invoked.\n\nTool kinds help clients choose appropriate icons and optimize how they\ndisplay tool execution progress.\n\nSee protocol docs: [Creating](https://agentclientprotocol.com/protocol/tool-calls#creating)", "oneOf": [ @@ -2725,6 +3204,22 @@ } ] }, + "UnstructuredCommandInput": { + "description": "All text that was typed after the command name is provided as input.", + "properties": { + "_meta": { + "description": "Extension point for implementations" + }, + "hint": { + "description": "A hint to display when the input hasn't been provided yet", + "type": "string" + } + }, + "required": [ + "hint" + ], + "type": "object" + }, "WaitForTerminalExitRequest": { "description": "Request to wait for a terminal command to exit.", "properties": { @@ -2732,7 +3227,11 @@ "description": "Extension point for implementations" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The session ID for this request." }, "terminalId": { @@ -2790,7 +3289,11 @@ "type": "string" }, "sessionId": { - "$ref": "#/$defs/SessionId", + "allOf": [ + { + "$ref": "#/$defs/SessionId" + } + ], "description": "The session ID for this request." } }, @@ -2818,12 +3321,10 @@ "$schema": "https://json-schema.org/draft/2020-12/schema", "anyOf": [ { - "$ref": "#/$defs/AgentOutgoingMessage", - "title": "AgentOutgoingMessage" + "$ref": "#/$defs/AgentOutgoingMessage" }, { - "$ref": "#/$defs/ClientOutgoingMessage", - "title": "ClientOutgoingMessage" + "$ref": "#/$defs/ClientOutgoingMessage" } ] } diff --git a/scripts/gen_all.py b/scripts/gen_all.py index 63e8d8a..3414e12 100644 --- a/scripts/gen_all.py +++ b/scripts/gen_all.py @@ -14,7 +14,7 @@ if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) -from scripts import gen_meta, gen_schema # noqa: E402 pylint: disable=wrong-import-position +from scripts import gen_meta, gen_schema, gen_signature # noqa: E402 pylint: disable=wrong-import-position SCHEMA_DIR = ROOT / "schema" SCHEMA_JSON = SCHEMA_DIR / "schema.json" @@ -44,18 +44,6 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Skip downloading schema files even when a version is provided.", ) - parser.add_argument( - "--format", - dest="format_output", - action="store_true", - help="Format generated files with 'uv run ruff format'.", - ) - parser.add_argument( - "--no-format", - dest="format_output", - action="store_false", - help="Disable formatting with ruff.", - ) parser.set_defaults(format_output=True) parser.add_argument( "--force", @@ -82,8 +70,9 @@ def main() -> None: print("schema/schema.json or schema/meta.json missing; run with --version to fetch them.", file=sys.stderr) sys.exit(1) - gen_schema.generate_schema(format_output=args.format_output) + gen_schema.generate_schema() gen_meta.generate_meta() + gen_signature.gen_signature(ROOT / "src" / "acp") if ref: print(f"Generated schema using ref: {ref}") @@ -120,8 +109,8 @@ def resolve_ref(version: str | None) -> str: def download_schema(repo: str, ref: str) -> None: SCHEMA_DIR.mkdir(parents=True, exist_ok=True) - schema_url = f"https://raw.githubusercontent.com/{repo}/{ref}/schema/schema.json" - meta_url = f"https://raw.githubusercontent.com/{repo}/{ref}/schema/meta.json" + schema_url = f"https://raw.githubusercontent.com/{repo}/{ref}/schema/schema.unstable.json" + meta_url = f"https://raw.githubusercontent.com/{repo}/{ref}/schema/meta.unstable.json" try: schema_data = fetch_json(schema_url) meta_data = fetch_json(meta_url) diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index 0badf30..b8f9e9a 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 from __future__ import annotations -import argparse import ast import json import re -import shutil import subprocess import sys +import textwrap from collections.abc import Callable from dataclasses import dataclass from pathlib import Path @@ -29,7 +28,7 @@ STDIO_TYPE_LITERAL = 'Literal["2#-datamodel-code-generator-#-object-#-special-#"]' STDIO_TYPE_PATTERN = re.compile( - r"^ type:\s*Literal\[['\"]2#-datamodel-code-generator-#-object-#-special-#['\"]\]" + r"^ type:\s*Literal\[['\"]McpServerStdio['\"]\]" r"(?:\s*=\s*['\"][^'\"]+['\"])?\s*$", re.MULTILINE, ) @@ -41,7 +40,6 @@ "AgentOutgoingMessage2": "AgentResponseMessage", "AgentOutgoingMessage3": "AgentErrorMessage", "AgentOutgoingMessage4": "AgentNotificationMessage", - "AvailableCommandInput1": "CommandInputHint", "ClientOutgoingMessage1": "ClientRequestMessage", "ClientOutgoingMessage2": "ClientResponseMessage", "ClientOutgoingMessage3": "ClientErrorMessage", @@ -53,7 +51,6 @@ "ContentBlock5": "EmbeddedResourceContentBlock", "McpServer1": "HttpMcpServer", "McpServer2": "SseMcpServer", - "McpServer3": "StdioMcpServer", "RequestPermissionOutcome1": "DeniedOutcome", "RequestPermissionOutcome2": "AllowedOutcome", "SessionUpdate1": "UserMessageChunk", @@ -69,6 +66,10 @@ "ToolCallContent3": "TerminalToolCallContent", } +ALIASES_MAP = { + "StdioMcpServer": "McpServerStdio", +} + ENUM_LITERAL_MAP: dict[str, tuple[str, ...]] = { "PermissionOptionKind": ( "allow_once", @@ -87,33 +88,32 @@ ("PermissionOption", "kind", "PermissionOptionKind", False), ("PlanEntry", "priority", "PlanEntryPriority", False), ("PlanEntry", "status", "PlanEntryStatus", False), - ("PromptResponse", "stopReason", "StopReason", False), - ("ToolCallProgress", "kind", "ToolKind", True), - ("ToolCallProgress", "status", "ToolCallStatus", True), - ("ToolCallStart", "kind", "ToolKind", True), - ("ToolCallStart", "status", "ToolCallStatus", True), + ("PromptResponse", "stop_reason", "StopReason", False), ("ToolCall", "kind", "ToolKind", True), ("ToolCall", "status", "ToolCallStatus", True), + ("ToolCallUpdate", "kind", "ToolKind", True), + ("ToolCallUpdate", "status", "ToolCallStatus", True), ) DEFAULT_VALUE_OVERRIDES: tuple[tuple[str, str, str], ...] = ( - ("AgentCapabilities", "mcpCapabilities", "McpCapabilities(http=False, sse=False)"), + ("AgentCapabilities", "mcp_capabilities", "McpCapabilities()"), + ("AgentCapabilities", "session_capabilities", "SessionCapabilities()"), ( "AgentCapabilities", - "promptCapabilities", - "PromptCapabilities(audio=False, embeddedContext=False, image=False)", + "prompt_capabilities", + "PromptCapabilities()", ), - ("ClientCapabilities", "fs", "FileSystemCapability(readTextFile=False, writeTextFile=False)"), + ("ClientCapabilities", "fs", "FileSystemCapability()"), ("ClientCapabilities", "terminal", "False"), ( "InitializeRequest", - "clientCapabilities", - "ClientCapabilities(fs=FileSystemCapability(readTextFile=False, writeTextFile=False), terminal=False)", + "client_capabilities", + "ClientCapabilities()", ), ( "InitializeResponse", - "agentCapabilities", - "AgentCapabilities(loadSession=False, mcpCapabilities=McpCapabilities(http=False, sse=False), promptCapabilities=PromptCapabilities(audio=False, embeddedContext=False, image=False))", + "agent_capabilities", + "AgentCapabilities()", ), ) @@ -126,30 +126,11 @@ class _ProcessingStep: apply: Callable[[str], str] -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Generate src/acp/schema.py from the ACP JSON schema.") - parser.add_argument( - "--format", - dest="format_output", - action="store_true", - help="Format generated files with 'uv run ruff format'.", - ) - parser.add_argument( - "--no-format", - dest="format_output", - action="store_false", - help="Disable formatting with ruff.", - ) - parser.set_defaults(format_output=True) - return parser.parse_args() - - def main() -> None: - args = parse_args() - generate_schema(format_output=args.format_output) + generate_schema() -def generate_schema(*, format_output: bool = True) -> None: +def generate_schema() -> None: if not SCHEMA_JSON.exists(): print( "Schema file missing. Ensure schema/schema.json exists (run gen_all.py --version to download).", @@ -173,6 +154,7 @@ def generate_schema(*, format_output: bool = True) -> None: "--output-model-type", "pydantic_v2.BaseModel", "--use-annotated", + "--snake-case-field", ] subprocess.check_call(cmd) # noqa: S603 @@ -180,9 +162,6 @@ def generate_schema(*, format_output: bool = True) -> None: for warning in warnings: print(f"Warning: {warning}", file=sys.stderr) - if format_output: - format_with_ruff(SCHEMA_OUT) - def postprocess_generated_schema(output_path: Path) -> list[str]: if not output_path.exists(): @@ -244,6 +223,7 @@ def _build_header_block() -> str: def _build_alias_block() -> str: alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())] + alias_lines += [f"{old} = {new}" for old, new in sorted(ALIASES_MAP.items())] return BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n" @@ -350,11 +330,20 @@ def _ensure_custom_base_model(content: str) -> str: if not has_config: new_imports.append("ConfigDict") lines[idx] = "from pydantic import " + ", ".join(new_imports) + to_insert = textwrap.dedent("""\ + class BaseModel(_BaseModel): + model_config = ConfigDict(populate_by_name=True) + + def __getattr__(self, item: str) -> Any: + if item.lower() != item: + snake_cased = "".join("_" + c.lower() if c.isupper() and i > 0 else c.lower() for i, c in enumerate(item)) + return getattr(self, snake_cased) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + """) insert_idx = idx + 1 lines.insert(insert_idx, "") - lines.insert(insert_idx + 1, "class BaseModel(_BaseModel):") - lines.insert(insert_idx + 2, " model_config = ConfigDict(populate_by_name=True)") - lines.insert(insert_idx + 3, "") + for offset, line in enumerate(to_insert.splitlines(), 1): + lines.insert(insert_idx + offset, line) break return "\n".join(lines) + "\n" @@ -434,6 +423,7 @@ def _normalize_stdio_model(content: str) -> str: replacement_line = ' type: Literal["stdio"] = "stdio"' new_content, count = STDIO_TYPE_PATTERN.subn(replacement_line, content) if count == 0: + print("Warning: stdio type placeholder not found; no replacements made.", file=sys.stderr) return content if count > 1: print( @@ -528,16 +518,5 @@ def _inject_enum_aliases(content: str) -> str: return content[:insertion_point] + block + content[insertion_point:] -def format_with_ruff(file_path: Path) -> None: - uv_executable = shutil.which("uv") - if uv_executable is None: - print("Warning: 'uv' executable not found; skipping formatting.", file=sys.stderr) - return - try: - subprocess.check_call([uv_executable, "run", "ruff", "format", str(file_path)]) # noqa: S603 - except (FileNotFoundError, subprocess.CalledProcessError) as exc: # pragma: no cover - best effort - print(f"Warning: failed to format {file_path}: {exc}", file=sys.stderr) - - if __name__ == "__main__": main() diff --git a/scripts/gen_signature.py b/scripts/gen_signature.py new file mode 100644 index 0000000..b3a7add --- /dev/null +++ b/scripts/gen_signature.py @@ -0,0 +1,136 @@ +import ast +import inspect +import typing as t +from pathlib import Path + +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined + +from acp import schema + + +class NodeTransformer(ast.NodeTransformer): + def __init__(self) -> None: + self._type_import_node: ast.ImportFrom | None = None + self._schema_import_node: ast.ImportFrom | None = None + self._should_rewrite = False + self._literals = {name: value for name, value in schema.__dict__.items() if t.get_origin(value) is t.Literal} + + def _add_typing_import(self, name: str) -> None: + if not self._type_import_node: + return + if not any(alias.name == name for alias in self._type_import_node.names): + self._type_import_node.names.append(ast.alias(name=name)) + self._should_rewrite = True + + def _add_schema_import(self, name: str) -> None: + if not self._schema_import_node: + return + if not any(alias.name == name for alias in self._schema_import_node.names): + self._schema_import_node.names.append(ast.alias(name=name)) + self._should_rewrite = True + + def transform(self, source_file: Path) -> None: + with source_file.open("r", encoding="utf-8") as f: + source_code = f.read() + tree = ast.parse(source_code) + self.visit(tree) + if self._should_rewrite: + print("Rewriting signatures in", source_file) + new_code = ast.unparse(tree) + with source_file.open("w", encoding="utf-8") as f: + f.write(new_code) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST: + if node.module == "schema": + self._schema_import_node = node + elif node.module == "typing": + self._type_import_node = node + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + return self.visit_func(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: + return self.visit_func(node) + + def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST: + decorator = next( + ( + decorator + for decorator in node.decorator_list + if isinstance(decorator, ast.Call) + and isinstance(decorator.func, ast.Name) + and decorator.func.id == "param_model" + ), + None, + ) + if not decorator: + return self.generic_visit(node) + self._should_rewrite = True + model_name = t.cast(ast.Name, decorator.args[0]).id + model = t.cast(type[schema.BaseModel], getattr(schema, model_name)) + param_defaults = [ + self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta" + ] + param_defaults.sort(key=lambda x: x[1] is not None) + node.args.args[1:] = [param for param, _ in param_defaults] + node.args.defaults = [default for _, default in param_defaults if default is not None] + if "field_meta" in model.model_fields: + node.args.kwarg = ast.arg(arg="kwargs", annotation=ast.Name(id="Any")) + return self.generic_visit(node) + + def _to_param_def(self, name: str, field: FieldInfo) -> tuple[ast.arg, ast.expr | None]: + arg = ast.arg(arg=name) + ann = field.annotation + if field.default is PydanticUndefined: + default = None + elif isinstance(field.default, dict | BaseModel): + default = ast.Constant(None) + else: + default = ast.Constant(value=field.default) + if ann is not None: + arg.annotation = self._format_annotation(ann) + return arg, default + + def _format_annotation(self, annotation: t.Any) -> ast.expr: + if t.get_origin(annotation) is t.Literal and annotation in self._literals.values(): + name = next(name for name, value in self._literals.items() if value is annotation) + self._add_schema_import(name) + return ast.Name(id=name) + elif ( + inspect.isclass(annotation) and issubclass(annotation, BaseModel) and annotation.__module__ == "acp.schema" + ): + self._add_schema_import(annotation.__name__) + return ast.Name(id=annotation.__name__) + elif args := t.get_args(annotation): + origin = t.get_origin(annotation) + return ast.Subscript( + value=self._format_annotation(origin), + slice=ast.Tuple(elts=[self._format_annotation(arg) for arg in args], ctx=ast.Load()) + if len(args) > 1 + else self._format_annotation(args[0]), + ctx=ast.Load(), + ) + elif annotation.__module__ == "typing": + name = annotation.__name__ + self._add_typing_import(name) + return ast.Name(id=name) + elif annotation is None or annotation is type(None): + return ast.Constant(value=None) + elif annotation in __builtins__.values(): + return ast.Name(id=annotation.__name__) + else: + print(f"Warning: Unhandled annotation type: {annotation}") + self._add_typing_import("Any") + return ast.Name(id="Any") + + +def gen_signature(source_dir: Path) -> None: + import importlib + + importlib.reload(schema) # Ensure schema is up to date + for source_file in source_dir.rglob("*.py"): + transformer = NodeTransformer() + transformer.transform(source_file) diff --git a/src/acp/__init__.py b/src/acp/__init__.py index 3f5e72f..a25e664 100644 --- a/src/acp/__init__.py +++ b/src/acp/__init__.py @@ -1,10 +1,12 @@ +from typing import Any + from .core import ( Agent, - AgentSideConnection, Client, - ClientSideConnection, RequestError, TerminalHandle, + connect_to_agent, + run_agent, ) from .helpers import ( audio_block, @@ -73,6 +75,19 @@ from .stdio import spawn_agent_process, spawn_client_process, spawn_stdio_connection, stdio_streams from .transports import default_environment, spawn_stdio_transport +_DEPRECATED_NAMES = [ + ( + "AgentSideConnection", + "acp.core:AgentSideConnection", + "Using `AgentSideConnection` directly is deprecated, please use `acp.run_agent` instead.", + ), + ( + "ClientSideConnection", + "acp.core:ClientSideConnection", + "Using `ClientSideConnection` directly is deprecated, please use `acp.connect_to_agent` instead.", + ), +] + __all__ = [ # noqa: RUF022 # constants "PROTOCOL_VERSION", @@ -113,8 +128,8 @@ "ReleaseTerminalRequest", "ReleaseTerminalResponse", # core - "AgentSideConnection", - "ClientSideConnection", + "run_agent", + "connect_to_agent", "RequestError", "Agent", "Client", @@ -151,3 +166,16 @@ "start_edit_tool_call", "update_tool_call", ] + + +def __getattr__(name: str) -> Any: + import warnings + from importlib import import_module + + for deprecated_name, new_path, warning in _DEPRECATED_NAMES: + if name == deprecated_name: + warnings.warn(warning, DeprecationWarning, stacklevel=2) + module_name, attr_name = new_path.split(":") + module = import_module(module_name) + return getattr(module, attr_name) + raise AttributeError(f"module {__name__} has no attribute {name}") # noqa: TRY003 diff --git a/src/acp/agent/connection.py b/src/acp/agent/connection.py index eab6766..5a96af5 100644 --- a/src/acp/agent/connection.py +++ b/src/acp/agent/connection.py @@ -2,16 +2,23 @@ import asyncio from collections.abc import Callable -from typing import Any +from typing import Any, cast, final -from ..connection import Connection, MethodHandler -from ..interfaces import Agent +from ..connection import Connection +from ..interfaces import Agent, Client from ..meta import CLIENT_METHODS from ..schema import ( + AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, + AvailableCommandsUpdate, CreateTerminalRequest, CreateTerminalResponse, + CurrentModeUpdate, + EnvVariable, KillTerminalCommandRequest, KillTerminalCommandResponse, + PermissionOption, ReadTextFileRequest, ReadTextFileResponse, ReleaseTerminalRequest, @@ -21,120 +28,178 @@ SessionNotification, TerminalOutputRequest, TerminalOutputResponse, + ToolCallProgress, + ToolCallStart, + ToolCallUpdate, + UserMessageChunk, WaitForTerminalExitRequest, WaitForTerminalExitResponse, WriteTextFileRequest, WriteTextFileResponse, ) from ..terminal import TerminalHandle -from ..utils import notify_model, request_model, request_optional_model +from ..utils import compatible_class, notify_model, param_model, request_model, request_optional_model from .router import build_agent_router __all__ = ["AgentSideConnection"] - _AGENT_CONNECTION_ERROR = "AgentSideConnection requires asyncio StreamWriter/StreamReader" +@final +@compatible_class class AgentSideConnection: """Agent-side connection wrapper that dispatches JSON-RPC messages to a Client implementation.""" def __init__( self, - to_agent: Callable[[AgentSideConnection], Agent], + to_agent: Callable[[Client], Agent] | Agent, input_stream: Any, output_stream: Any, + listening: bool = True, + *, + use_unstable_protocol: bool = False, **connection_kwargs: Any, ) -> None: - agent = to_agent(self) - handler = self._create_handler(agent) - + agent = to_agent(cast(Client, self)) if callable(to_agent) else to_agent if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_AGENT_CONNECTION_ERROR) - self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) - - def _create_handler(self, agent: Agent) -> MethodHandler: - router = build_agent_router(agent) + handler = build_agent_router(cast(Agent, agent), use_unstable_protocol=use_unstable_protocol) + self._conn = Connection(handler, input_stream, output_stream, listening=listening, **connection_kwargs) + if on_connect := getattr(agent, "on_connect", None): + on_connect(self) - async def handler(method: str, params: Any | None, is_notification: bool) -> Any: - if is_notification: - await router.dispatch_notification(method, params) - return None - return await router.dispatch_request(method, params) + async def listen(self) -> None: + """Start listening for incoming messages.""" + await self._conn.main_loop() - return handler - - async def sessionUpdate(self, params: SessionNotification) -> None: - await notify_model(self._conn, CLIENT_METHODS["session_update"], params) + @param_model(SessionNotification) + async def session_update( + self, + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: + await notify_model( + self._conn, + CLIENT_METHODS["session_update"], + SessionNotification(session_id=session_id, update=update, field_meta=kwargs or None), + ) - async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: + @param_model(RequestPermissionRequest) + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCallUpdate, **kwargs: Any + ) -> RequestPermissionResponse: return await request_model( self._conn, CLIENT_METHODS["session_request_permission"], - params, + RequestPermissionRequest( + options=options, session_id=session_id, tool_call=tool_call, field_meta=kwargs or None + ), RequestPermissionResponse, ) - async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse: + @param_model(ReadTextFileRequest) + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: return await request_model( self._conn, CLIENT_METHODS["fs_read_text_file"], - params, + ReadTextFileRequest(path=path, session_id=session_id, limit=limit, line=line, field_meta=kwargs or None), ReadTextFileResponse, ) - async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse | None: + @param_model(WriteTextFileRequest) + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: return await request_optional_model( self._conn, CLIENT_METHODS["fs_write_text_file"], - params, + WriteTextFileRequest(content=content, path=path, session_id=session_id, field_meta=kwargs or None), WriteTextFileResponse, ) - async def createTerminal(self, params: CreateTerminalRequest) -> TerminalHandle: + @param_model(CreateTerminalRequest) + async def create_terminal( + self, + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> TerminalHandle: create_response = await request_model( self._conn, CLIENT_METHODS["terminal_create"], - params, + CreateTerminalRequest( + command=command, + session_id=session_id, + args=args, + cwd=cwd, + env=env, + output_byte_limit=output_byte_limit, + field_meta=kwargs or None, + ), CreateTerminalResponse, ) - return TerminalHandle(create_response.terminalId, params.sessionId, self._conn) + return TerminalHandle(create_response.terminal_id, session_id, self._conn) - async def terminalOutput(self, params: TerminalOutputRequest) -> TerminalOutputResponse: + @param_model(TerminalOutputRequest) + async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: return await request_model( self._conn, CLIENT_METHODS["terminal_output"], - params, + TerminalOutputRequest(session_id=session_id, terminal_id=terminal_id, field_meta=kwargs or None), TerminalOutputResponse, ) - async def releaseTerminal(self, params: ReleaseTerminalRequest) -> ReleaseTerminalResponse | None: + @param_model(ReleaseTerminalRequest) + async def release_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> ReleaseTerminalResponse | None: return await request_optional_model( self._conn, CLIENT_METHODS["terminal_release"], - params, + ReleaseTerminalRequest(session_id=session_id, terminal_id=terminal_id, field_meta=kwargs or None), ReleaseTerminalResponse, ) - async def waitForTerminalExit(self, params: WaitForTerminalExitRequest) -> WaitForTerminalExitResponse: + @param_model(WaitForTerminalExitRequest) + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> WaitForTerminalExitResponse: return await request_model( self._conn, CLIENT_METHODS["terminal_wait_for_exit"], - params, + WaitForTerminalExitRequest(session_id=session_id, terminal_id=terminal_id, field_meta=kwargs or None), WaitForTerminalExitResponse, ) - async def killTerminal(self, params: KillTerminalCommandRequest) -> KillTerminalCommandResponse | None: + @param_model(KillTerminalCommandRequest) + async def kill_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> KillTerminalCommandResponse | None: return await request_optional_model( self._conn, CLIENT_METHODS["terminal_kill"], - params, + KillTerminalCommandRequest(session_id=session_id, terminal_id=terminal_id, field_meta=kwargs or None), KillTerminalCommandResponse, ) - async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]: + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: return await self._conn.send_request(f"_{method}", params) - async def extNotification(self, method: str, params: dict[str, Any]) -> None: + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: await self._conn.send_notification(f"_{method}", params) async def close(self) -> None: diff --git a/src/acp/agent/router.py b/src/acp/agent/router.py index 515d804..c80b42c 100644 --- a/src/acp/agent/router.py +++ b/src/acp/agent/router.py @@ -5,11 +5,12 @@ from ..exceptions import RequestError from ..interfaces import Agent from ..meta import AGENT_METHODS -from ..router import MessageRouter, RouterBuilder +from ..router import MessageRouter from ..schema import ( AuthenticateRequest, CancelNotification, InitializeRequest, + ListSessionsRequest, LoadSessionRequest, NewSessionRequest, PromptRequest, @@ -21,34 +22,36 @@ __all__ = ["build_agent_router"] -def build_agent_router(agent: Agent) -> MessageRouter: - builder = RouterBuilder() +def build_agent_router(agent: Agent, use_unstable_protocol: bool = False) -> MessageRouter: + router = MessageRouter(use_unstable_protocol=use_unstable_protocol) - builder.request_attr(AGENT_METHODS["initialize"], InitializeRequest, agent, "initialize") - builder.request_attr(AGENT_METHODS["session_new"], NewSessionRequest, agent, "newSession") - builder.request_attr( + router.route_request(AGENT_METHODS["initialize"], InitializeRequest, agent, "initialize") + router.route_request(AGENT_METHODS["session_new"], NewSessionRequest, agent, "new_session") + router.route_request( AGENT_METHODS["session_load"], LoadSessionRequest, agent, - "loadSession", + "load_session", adapt_result=normalize_result, ) - builder.request_attr( + router.route_request(AGENT_METHODS["session_list"], ListSessionsRequest, agent, "list_sessions", unstable=True) + router.route_request( AGENT_METHODS["session_set_mode"], SetSessionModeRequest, agent, - "setSessionMode", + "set_session_mode", adapt_result=normalize_result, ) - builder.request_attr(AGENT_METHODS["session_prompt"], PromptRequest, agent, "prompt") - builder.request_attr( + router.route_request(AGENT_METHODS["session_prompt"], PromptRequest, agent, "prompt") + router.route_request( AGENT_METHODS["session_set_model"], SetSessionModelRequest, agent, - "setSessionModel", + "set_session_model", adapt_result=normalize_result, + unstable=True, ) - builder.request_attr( + router.route_request( AGENT_METHODS["authenticate"], AuthenticateRequest, agent, @@ -56,21 +59,20 @@ def build_agent_router(agent: Agent) -> MessageRouter: adapt_result=normalize_result, ) - builder.notification_attr(AGENT_METHODS["session_cancel"], CancelNotification, agent, "cancel") + router.route_notification(AGENT_METHODS["session_cancel"], CancelNotification, agent, "cancel") - async def handle_extension_request(name: str, payload: dict[str, Any]) -> Any: - ext = getattr(agent, "extMethod", None) + @router.handle_extension_request + async def _handle_extension_request(name: str, payload: dict[str, Any]) -> Any: + ext = getattr(agent, "ext_method", None) if ext is None: raise RequestError.method_not_found(f"_{name}") return await ext(name, payload) - async def handle_extension_notification(name: str, payload: dict[str, Any]) -> None: - ext = getattr(agent, "extNotification", None) + @router.handle_extension_notification + async def _handle_extension_notification(name: str, payload: dict[str, Any]) -> None: + ext = getattr(agent, "ext_notification", None) if ext is None: return await ext(name, payload) - return builder.build( - request_extensions=handle_extension_request, - notification_extensions=handle_extension_notification, - ) + return router diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index f97ff25..88cf2ac 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -2,131 +2,181 @@ import asyncio from collections.abc import Callable -from typing import Any +from typing import Any, cast, final -from ..connection import Connection, MethodHandler +from ..connection import Connection from ..interfaces import Agent, Client from ..meta import AGENT_METHODS from ..schema import ( + AudioContentBlock, AuthenticateRequest, AuthenticateResponse, CancelNotification, + ClientCapabilities, + EmbeddedResourceContentBlock, + HttpMcpServer, + ImageContentBlock, + Implementation, InitializeRequest, InitializeResponse, + ListSessionsRequest, + ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, + McpServerStdio, NewSessionRequest, NewSessionResponse, PromptRequest, PromptResponse, + ResourceContentBlock, SetSessionModelRequest, SetSessionModelResponse, SetSessionModeRequest, SetSessionModeResponse, + SseMcpServer, + TextContentBlock, ) -from ..utils import ( - notify_model, - request_model, - request_model_from_dict, -) +from ..utils import compatible_class, notify_model, param_model, request_model, request_model_from_dict from .router import build_client_router __all__ = ["ClientSideConnection"] - _CLIENT_CONNECTION_ERROR = "ClientSideConnection requires asyncio StreamWriter/StreamReader" +@final +@compatible_class class ClientSideConnection: """Client-side connection wrapper that dispatches JSON-RPC messages to an Agent implementation.""" def __init__( self, - to_client: Callable[[Agent], Client], + to_client: Callable[[Agent], Client] | Client, input_stream: Any, output_stream: Any, + *, + use_unstable_protocol: bool = False, **connection_kwargs: Any, ) -> None: if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_CLIENT_CONNECTION_ERROR) - - client = to_client(self) # type: ignore[arg-type] - handler = self._create_handler(client) + client = to_client(cast(Agent, self)) if callable(to_client) else to_client + handler = build_client_router(cast(Client, client), use_unstable_protocol=use_unstable_protocol) self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) + if on_connect := getattr(client, "on_connect", None): + on_connect(self) - def _create_handler(self, client: Client) -> MethodHandler: - router = build_client_router(client) - - async def handler(method: str, params: Any | None, is_notification: bool) -> Any: - if is_notification: - await router.dispatch_notification(method, params) - return None - return await router.dispatch_request(method, params) - - return handler - - async def initialize(self, params: InitializeRequest) -> InitializeResponse: + @param_model(InitializeRequest) + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: return await request_model( self._conn, AGENT_METHODS["initialize"], - params, + InitializeRequest( + protocol_version=protocol_version, + client_capabilities=client_capabilities or ClientCapabilities(), + client_info=client_info, + field_meta=kwargs or None, + ), InitializeResponse, ) - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: + @param_model(NewSessionRequest) + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + ) -> NewSessionResponse: return await request_model( self._conn, AGENT_METHODS["session_new"], - params, + NewSessionRequest(cwd=cwd, mcp_servers=mcp_servers, field_meta=kwargs or None), NewSessionResponse, ) - async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse: + @param_model(LoadSessionRequest) + async def load_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + ) -> LoadSessionResponse: return await request_model_from_dict( self._conn, AGENT_METHODS["session_load"], - params, + LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None), LoadSessionResponse, ) - async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse: + @param_model(ListSessionsRequest) + async def list_sessions( + self, cursor: str | None = None, cwd: str | None = None, **kwargs: Any + ) -> ListSessionsResponse: + return await request_model_from_dict( + self._conn, + AGENT_METHODS["session_list"], + ListSessionsRequest(cursor=cursor, cwd=cwd, field_meta=kwargs or None), + ListSessionsResponse, + ) + + @param_model(SetSessionModeRequest) + async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse: return await request_model_from_dict( self._conn, AGENT_METHODS["session_set_mode"], - params, + SetSessionModeRequest(mode_id=mode_id, session_id=session_id, field_meta=kwargs or None), SetSessionModeResponse, ) - async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse: + @param_model(SetSessionModelRequest) + async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any) -> SetSessionModelResponse: return await request_model_from_dict( self._conn, AGENT_METHODS["session_set_model"], - params, + SetSessionModelRequest(model_id=model_id, session_id=session_id, field_meta=kwargs or None), SetSessionModelResponse, ) - async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse: + @param_model(AuthenticateRequest) + async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse: return await request_model_from_dict( self._conn, AGENT_METHODS["authenticate"], - params, + AuthenticateRequest(method_id=method_id, field_meta=kwargs or None), AuthenticateResponse, ) - async def prompt(self, params: PromptRequest) -> PromptResponse: + @param_model(PromptRequest) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: return await request_model( self._conn, AGENT_METHODS["session_prompt"], - params, + PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None), PromptResponse, ) - async def cancel(self, params: CancelNotification) -> None: - await notify_model(self._conn, AGENT_METHODS["session_cancel"], params) + @param_model(CancelNotification) + async def cancel(self, session_id: str, **kwargs: Any) -> None: + await notify_model( + self._conn, + AGENT_METHODS["session_cancel"], + CancelNotification(session_id=session_id, field_meta=kwargs or None), + ) - async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]: + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: return await self._conn.send_request(f"_{method}", params) - async def extNotification(self, method: str, params: dict[str, Any]) -> None: + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: await self._conn.send_notification(f"_{method}", params) async def close(self) -> None: diff --git a/src/acp/client/router.py b/src/acp/client/router.py index 9f0b85f..4bab2d9 100644 --- a/src/acp/client/router.py +++ b/src/acp/client/router.py @@ -5,7 +5,7 @@ from ..exceptions import RequestError from ..interfaces import Client from ..meta import CLIENT_METHODS -from ..router import MessageRouter, RouterBuilder +from ..router import MessageRouter from ..schema import ( CreateTerminalRequest, KillTerminalCommandRequest, @@ -22,75 +22,74 @@ __all__ = ["build_client_router"] -def build_client_router(client: Client) -> MessageRouter: - builder = RouterBuilder() +def build_client_router(client: Client, use_unstable_protocol: bool = False) -> MessageRouter: + router = MessageRouter(use_unstable_protocol=use_unstable_protocol) - builder.request_attr(CLIENT_METHODS["fs_write_text_file"], WriteTextFileRequest, client, "writeTextFile") - builder.request_attr(CLIENT_METHODS["fs_read_text_file"], ReadTextFileRequest, client, "readTextFile") - builder.request_attr( + router.route_request(CLIENT_METHODS["fs_write_text_file"], WriteTextFileRequest, client, "write_text_file") + router.route_request(CLIENT_METHODS["fs_read_text_file"], ReadTextFileRequest, client, "read_text_file") + router.route_request( CLIENT_METHODS["session_request_permission"], RequestPermissionRequest, client, - "requestPermission", + "request_permission", ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_create"], CreateTerminalRequest, client, - "createTerminal", + "create_terminal", optional=True, default_result=None, ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_output"], TerminalOutputRequest, client, - "terminalOutput", + "terminal_output", optional=True, default_result=None, ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_release"], ReleaseTerminalRequest, client, - "releaseTerminal", + "release_terminal", optional=True, default_result={}, adapt_result=normalize_result, ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_wait_for_exit"], WaitForTerminalExitRequest, client, - "waitForTerminalExit", + "wait_for_terminal_exit", optional=True, default_result=None, ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_kill"], KillTerminalCommandRequest, client, - "killTerminal", + "kill_terminal", optional=True, default_result={}, adapt_result=normalize_result, ) - builder.notification_attr(CLIENT_METHODS["session_update"], SessionNotification, client, "sessionUpdate") + router.route_notification(CLIENT_METHODS["session_update"], SessionNotification, client, "session_update") - async def handle_extension_request(name: str, payload: dict[str, Any]) -> Any: - ext = getattr(client, "extMethod", None) + @router.handle_extension_request + async def _handle_extension_request(name: str, payload: dict[str, Any]) -> Any: + ext = getattr(client, "ext_method", None) if ext is None: raise RequestError.method_not_found(f"_{name}") return await ext(name, payload) - async def handle_extension_notification(name: str, payload: dict[str, Any]) -> None: - ext = getattr(client, "extNotification", None) + @router.handle_extension_notification + async def _handle_extension_notification(name: str, payload: dict[str, Any]) -> None: + ext = getattr(client, "ext_notification", None) if ext is None: return await ext(name, payload) - return builder.build( - request_extensions=handle_extension_request, - notification_extensions=handle_extension_notification, - ) + return router diff --git a/src/acp/connection.py b/src/acp/connection.py index 34142d7..aca1c19 100644 --- a/src/acp/connection.py +++ b/src/acp/connection.py @@ -72,6 +72,7 @@ def __init__( dispatcher_factory: DispatcherFactory | None = None, sender_factory: SenderFactory | None = None, observers: list[StreamObserver] | None = None, + listening: bool = True, ) -> None: self._handler = handler self._writer = writer @@ -83,11 +84,14 @@ def __init__( self._queue = queue or InMemoryMessageQueue() self._closed = False self._sender = (sender_factory or self._default_sender_factory)(self._writer, self._tasks) - self._recv_task = self._tasks.create( - self._receive_loop(), - name="acp.Connection.receive", - on_error=self._on_receive_error, - ) + if listening: + self._recv_task = self._tasks.create( + self._receive_loop(), + name="acp.Connection.receive", + on_error=self._on_receive_error, + ) + else: + self._recv_task = None dispatcher_factory = dispatcher_factory or self._default_dispatcher_factory self._dispatcher = dispatcher_factory( self._queue, @@ -109,6 +113,14 @@ async def close(self) -> None: await self._tasks.shutdown() self._state.reject_all_outgoing(ConnectionError("Connection closed")) + async def main_loop(self) -> None: + try: + await self._receive_loop() + except Exception as exc: + logging.exception("Connection main loop failed", exc_info=exc) + self._on_receive_error(None, exc) # type: ignore[arg-type] + raise + async def __aenter__(self) -> Connection: return self diff --git a/src/acp/contrib/permissions.py b/src/acp/contrib/permissions.py index 5008092..7dfc8e8 100644 --- a/src/acp/contrib/permissions.py +++ b/src/acp/contrib/permissions.py @@ -4,7 +4,7 @@ from typing import Any from ..helpers import text_block, tool_content -from ..schema import PermissionOption, RequestPermissionRequest, RequestPermissionResponse, ToolCall +from ..schema import PermissionOption, RequestPermissionRequest, RequestPermissionResponse, ToolCallUpdate from .tool_calls import ToolCallTracker, _copy_model_list @@ -29,9 +29,9 @@ def __init__(self) -> None: def default_permission_options() -> tuple[PermissionOption, PermissionOption, PermissionOption]: """Return a standard approval/reject option set.""" return ( - PermissionOption(optionId="approve", name="Approve", kind="allow_once"), - PermissionOption(optionId="approve_for_session", name="Approve for session", kind="allow_always"), - PermissionOption(optionId="reject", name="Reject", kind="reject_once"), + PermissionOption(option_id="approve", name="Approve", kind="allow_once"), + PermissionOption(option_id="approve_for_session", name="Approve for session", kind="allow_always"), + PermissionOption(option_id="reject", name="Reject", kind="reject_once"), ) @@ -60,7 +60,7 @@ async def request_for( description: str | None = None, options: Sequence[PermissionOption] | None = None, content: Sequence[Any] | None = None, - tool_call: ToolCall | None = None, + tool_call: ToolCallUpdate | None = None, ) -> RequestPermissionResponse: """Request user approval for a tool call.""" if tool_call is None: @@ -83,8 +83,8 @@ async def request_for( raise MissingPermissionOptionsError() request = RequestPermissionRequest( - sessionId=self._session_id, - toolCall=tool_call, + session_id=self._session_id, + tool_call=tool_call, options=list(option_set), ) return await self._requester(request) diff --git a/src/acp/contrib/session_state.py b/src/acp/contrib/session_state.py index 7933be6..ee56125 100644 --- a/src/acp/contrib/session_state.py +++ b/src/acp/contrib/session_state.py @@ -62,8 +62,8 @@ def apply_start(self, update: ToolCallStart) -> None: self.status = update.status self.content = _copy_model_list(update.content) self.locations = _copy_model_list(update.locations) - self.raw_input = update.rawInput - self.raw_output = update.rawOutput + self.raw_input = update.raw_input + self.raw_output = update.raw_output def apply_progress(self, update: ToolCallProgress) -> None: if update.title is not None: @@ -76,10 +76,10 @@ def apply_progress(self, update: ToolCallProgress) -> None: self.content = _copy_model_list(update.content) if update.locations is not None: self.locations = _copy_model_list(update.locations) - if update.rawInput is not None: - self.raw_input = update.rawInput - if update.rawOutput is not None: - self.raw_output = update.rawOutput + if update.raw_input is not None: + self.raw_input = update.raw_input + if update.raw_output is not None: + self.raw_output = update.raw_output def snapshot(self) -> ToolCallView: return ToolCallView( @@ -185,11 +185,11 @@ def apply(self, notification: SessionNotification) -> SessionSnapshot: def _ensure_session(self, notification: SessionNotification) -> None: if self.session_id is None: - self.session_id = notification.sessionId + self.session_id = notification.session_id return - if notification.sessionId != self.session_id: - self._handle_session_change(notification.sessionId) + if notification.session_id != self.session_id: + self._handle_session_change(notification.session_id) def _handle_session_change(self, session_id: str) -> None: expected = self.session_id @@ -206,14 +206,14 @@ def _handle_session_change(self, session_id: str) -> None: def _apply_update(self, update: Any) -> None: if isinstance(update, ToolCallStart): state = self._tool_calls.setdefault( - update.toolCallId, _MutableToolCallState(tool_call_id=update.toolCallId) + update.tool_call_id, _MutableToolCallState(tool_call_id=update.tool_call_id) ) state.apply_start(update) return if isinstance(update, ToolCallProgress): state = self._tool_calls.setdefault( - update.toolCallId, _MutableToolCallState(tool_call_id=update.toolCallId) + update.tool_call_id, _MutableToolCallState(tool_call_id=update.tool_call_id) ) state.apply_progress(update) return @@ -223,11 +223,11 @@ def _apply_update(self, update: Any) -> None: return if isinstance(update, CurrentModeUpdate): - self._current_mode_id = update.currentModeId + self._current_mode_id = update.current_mode_id return if isinstance(update, AvailableCommandsUpdate): - self._available_commands = _copy_model_list(update.availableCommands) or [] + self._available_commands = _copy_model_list(update.available_commands) or [] return if isinstance(update, UserMessageChunk): diff --git a/src/acp/contrib/tool_calls.py b/src/acp/contrib/tool_calls.py index 5907485..107d134 100644 --- a/src/acp/contrib/tool_calls.py +++ b/src/acp/contrib/tool_calls.py @@ -7,7 +7,14 @@ from pydantic import BaseModel, ConfigDict from ..helpers import text_block, tool_content -from ..schema import ToolCall, ToolCallLocation, ToolCallProgress, ToolCallStart, ToolCallStatus, ToolKind +from ..schema import ( + ToolCallLocation, + ToolCallProgress, + ToolCallStart, + ToolCallStatus, + ToolCallUpdate, + ToolKind, +) class _MissingToolCallTitleError(ValueError): @@ -91,31 +98,31 @@ def to_view(self) -> TrackedToolCallView: raw_output=self.raw_output, ) - def to_tool_call_model(self) -> ToolCall: - return ToolCall( - toolCallId=self.tool_call_id, + def to_tool_call_model(self) -> ToolCallUpdate: + return ToolCallUpdate( + tool_call_id=self.tool_call_id, title=self.title, kind=self.kind, status=self.status, content=_copy_model_list(self.content), locations=_copy_model_list(self.locations), - rawInput=self.raw_input, - rawOutput=self.raw_output, + raw_input=self.raw_input, + raw_output=self.raw_output, ) def to_start_model(self) -> ToolCallStart: if self.title is None: raise _MissingToolCallTitleError() return ToolCallStart( - sessionUpdate="tool_call", - toolCallId=self.tool_call_id, + session_update="tool_call", + tool_call_id=self.tool_call_id, title=self.title, kind=self.kind, status=self.status, content=_copy_model_list(self.content), locations=_copy_model_list(self.locations), - rawInput=self.raw_input, - rawOutput=self.raw_output, + raw_input=self.raw_input, + raw_output=self.raw_output, ) def update( @@ -155,8 +162,8 @@ def update( kwargs["rawInput"] = raw_input if raw_output is not UNSET: self.raw_output = raw_output - kwargs["rawOutput"] = raw_output - return ToolCallProgress(sessionUpdate="tool_call_update", toolCallId=self.tool_call_id, **kwargs) + kwargs["raw_output"] = raw_output + return ToolCallProgress(session_update="tool_call_update", tool_call_id=self.tool_call_id, **kwargs) def append_stream_text( self, @@ -249,7 +256,7 @@ def view(self, external_id: str) -> TrackedToolCallView: state = self._require_call(external_id) return state.to_view() - def tool_call_model(self, external_id: str) -> ToolCall: + def tool_call_model(self, external_id: str) -> ToolCallUpdate: """Return a deep copy of the tool call suitable for permission requests.""" state = self._require_call(external_id) return state.to_tool_call_model() diff --git a/src/acp/core.py b/src/acp/core.py index 8afa468..1d440de 100644 --- a/src/acp/core.py +++ b/src/acp/core.py @@ -7,6 +7,8 @@ from __future__ import annotations +from typing import Any + from .agent.connection import AgentSideConnection from .client.connection import ClientSideConnection from .connection import Connection, JsonValue, MethodHandler @@ -24,4 +26,68 @@ "MethodHandler", "RequestError", "TerminalHandle", + "connect_to_agent", + "run_agent", ] + + +async def run_agent( + agent: Agent, + input_stream: Any = None, + output_stream: Any = None, + *, + use_unstable_protocol: bool = False, + **connection_kwargs: Any, +) -> None: + """Run an ACP agent over the given input/output streams. + + This is a convenience function that creates an :class:`AgentSideConnection` + and starts listening for incoming messages. + + Args: + agent: The agent implementation to run. + input_stream: The (client) input stream to write to (defaults: ``sys.stdin``). + output_stream: The (client) output stream to read from (defaults: ``sys.stdout``). + use_unstable_protocol: Whether to enable unstable protocol features. + **connection_kwargs: Additional keyword arguments to pass to the + :class:`AgentSideConnection` constructor. + """ + from .stdio import stdio_streams + + if input_stream is None and output_stream is None: + output_stream, input_stream = await stdio_streams() + conn = AgentSideConnection( + agent, + input_stream, + output_stream, + listening=False, + use_unstable_protocol=use_unstable_protocol, + **connection_kwargs, + ) + await conn.listen() + + +def connect_to_agent( + client: Client, + input_stream: Any, + output_stream: Any, + *, + use_unstable_protocol: bool = False, + **connection_kwargs: Any, +) -> ClientSideConnection: + """Create a ClientSideConnection to an ACP agent over the given input/output streams. + + Args: + client: The client implementation to use. + input_stream: The (agent) input stream to write to (default: ``sys.stdin``). + output_stream: The (agent) output stream to read from (default: ``sys.stdout``). + use_unstable_protocol: Whether to enable unstable protocol features. + **connection_kwargs: Additional keyword arguments to pass to the + :class:`ClientSideConnection` constructor. + + Returns: + A :class:`ClientSideConnection` instance connected to the agent. + """ + return ClientSideConnection( + client, input_stream, output_stream, use_unstable_protocol=use_unstable_protocol, **connection_kwargs + ) diff --git a/src/acp/exceptions.py b/src/acp/exceptions.py index 898ca40..06098dd 100644 --- a/src/acp/exceptions.py +++ b/src/acp/exceptions.py @@ -13,34 +13,34 @@ def __init__(self, code: int, message: str, data: Any | None = None) -> None: self.code = code self.data = data - @staticmethod - def parse_error(data: dict[str, Any] | None = None) -> RequestError: - return RequestError(-32700, "Parse error", data) + @classmethod + def parse_error(cls, data: dict[str, Any] | None = None) -> RequestError: + return cls(-32700, "Parse error", data) - @staticmethod - def invalid_request(data: dict[str, Any] | None = None) -> RequestError: - return RequestError(-32600, "Invalid request", data) + @classmethod + def invalid_request(cls, data: dict[str, Any] | None = None) -> RequestError: + return cls(-32600, "Invalid request", data) - @staticmethod - def method_not_found(method: str) -> RequestError: - return RequestError(-32601, "Method not found", {"method": method}) + @classmethod + def method_not_found(cls, method: str) -> RequestError: + return cls(-32601, "Method not found", {"method": method}) - @staticmethod - def invalid_params(data: dict[str, Any] | None = None) -> RequestError: - return RequestError(-32602, "Invalid params", data) + @classmethod + def invalid_params(cls, data: dict[str, Any] | None = None) -> RequestError: + return cls(-32602, "Invalid params", data) - @staticmethod - def internal_error(data: dict[str, Any] | None = None) -> RequestError: - return RequestError(-32603, "Internal error", data) + @classmethod + def internal_error(cls, data: dict[str, Any] | None = None) -> RequestError: + return cls(-32603, "Internal error", data) - @staticmethod - def auth_required(data: dict[str, Any] | None = None) -> RequestError: - return RequestError(-32000, "Authentication required", data) + @classmethod + def auth_required(cls, data: dict[str, Any] | None = None) -> RequestError: + return cls(-32000, "Authentication required", data) - @staticmethod - def resource_not_found(uri: str | None = None) -> RequestError: + @classmethod + def resource_not_found(cls, uri: str | None = None) -> RequestError: data = {"uri": uri} if uri is not None else None - return RequestError(-32002, "Resource not found", data) + return cls(-32002, "Resource not found", data) def to_error_obj(self) -> dict[str, Any]: return {"code": self.code, "message": str(self), "data": self.data} diff --git a/src/acp/helpers.py b/src/acp/helpers.py index d5a473f..701cda7 100644 --- a/src/acp/helpers.py +++ b/src/acp/helpers.py @@ -83,11 +83,11 @@ def text_block(text: str) -> TextContentBlock: def image_block(data: str, mime_type: str, *, uri: str | None = None) -> ImageContentBlock: - return ImageContentBlock(type="image", data=data, mimeType=mime_type, uri=uri) + return ImageContentBlock(type="image", data=data, mime_type=mime_type, uri=uri) def audio_block(data: str, mime_type: str) -> AudioContentBlock: - return AudioContentBlock(type="audio", data=data, mimeType=mime_type) + return AudioContentBlock(type="audio", data=data, mime_type=mime_type) def resource_link_block( @@ -103,7 +103,7 @@ def resource_link_block( type="resource_link", name=name, uri=uri, - mimeType=mime_type, + mime_type=mime_type, size=size, description=description, title=title, @@ -111,11 +111,11 @@ def resource_link_block( def embedded_text_resource(uri: str, text: str, *, mime_type: str | None = None) -> TextResourceContents: - return TextResourceContents(uri=uri, text=text, mimeType=mime_type) + return TextResourceContents(uri=uri, text=text, mime_type=mime_type) def embedded_blob_resource(uri: str, blob: str, *, mime_type: str | None = None) -> BlobResourceContents: - return BlobResourceContents(uri=uri, blob=blob, mimeType=mime_type) + return BlobResourceContents(uri=uri, blob=blob, mime_type=mime_type) def resource_block( @@ -129,11 +129,11 @@ def tool_content(block: ContentBlock) -> ContentToolCallContent: def tool_diff_content(path: str, new_text: str, old_text: str | None = None) -> FileEditToolCallContent: - return FileEditToolCallContent(type="diff", path=path, newText=new_text, oldText=old_text) + return FileEditToolCallContent(type="diff", path=path, new_text=new_text, old_text=old_text) def tool_terminal_ref(terminal_id: str) -> TerminalToolCallContent: - return TerminalToolCallContent(type="terminal", terminalId=terminal_id) + return TerminalToolCallContent(type="terminal", terminal_id=terminal_id) def plan_entry( @@ -146,11 +146,11 @@ def plan_entry( def update_plan(entries: Iterable[PlanEntry]) -> AgentPlanUpdate: - return AgentPlanUpdate(sessionUpdate="plan", entries=list(entries)) + return AgentPlanUpdate(session_update="plan", entries=list(entries)) def update_user_message(content: ContentBlock) -> UserMessageChunk: - return UserMessageChunk(sessionUpdate="user_message_chunk", content=content) + return UserMessageChunk(session_update="user_message_chunk", content=content) def update_user_message_text(text: str) -> UserMessageChunk: @@ -158,7 +158,7 @@ def update_user_message_text(text: str) -> UserMessageChunk: def update_agent_message(content: ContentBlock) -> AgentMessageChunk: - return AgentMessageChunk(sessionUpdate="agent_message_chunk", content=content) + return AgentMessageChunk(session_update="agent_message_chunk", content=content) def update_agent_message_text(text: str) -> AgentMessageChunk: @@ -166,7 +166,7 @@ def update_agent_message_text(text: str) -> AgentMessageChunk: def update_agent_thought(content: ContentBlock) -> AgentThoughtChunk: - return AgentThoughtChunk(sessionUpdate="agent_thought_chunk", content=content) + return AgentThoughtChunk(session_update="agent_thought_chunk", content=content) def update_agent_thought_text(text: str) -> AgentThoughtChunk: @@ -175,17 +175,17 @@ def update_agent_thought_text(text: str) -> AgentThoughtChunk: def update_available_commands(commands: Iterable[AvailableCommand]) -> AvailableCommandsUpdate: return AvailableCommandsUpdate( - sessionUpdate="available_commands_update", - availableCommands=list(commands), + session_update="available_commands_update", + available_commands=list(commands), ) def update_current_mode(current_mode_id: str) -> CurrentModeUpdate: - return CurrentModeUpdate(sessionUpdate="current_mode_update", currentModeId=current_mode_id) + return CurrentModeUpdate(session_update="current_mode_update", current_mode_id=current_mode_id) def session_notification(session_id: str, update: SessionUpdate) -> SessionNotification: - return SessionNotification(sessionId=session_id, update=update) + return SessionNotification(session_id=session_id, update=update) def start_tool_call( @@ -200,15 +200,15 @@ def start_tool_call( raw_output: Any | None = None, ) -> ToolCallStart: return ToolCallStart( - sessionUpdate="tool_call", - toolCallId=tool_call_id, + session_update="tool_call", + tool_call_id=tool_call_id, title=title, kind=kind, status=status, content=list(content) if content is not None else None, locations=list(locations) if locations is not None else None, - rawInput=raw_input, - rawOutput=raw_output, + raw_input=raw_input, + raw_output=raw_output, ) @@ -266,13 +266,13 @@ def update_tool_call( raw_output: Any | None = None, ) -> ToolCallProgress: return ToolCallProgress( - sessionUpdate="tool_call_update", - toolCallId=tool_call_id, + session_update="tool_call_update", + tool_call_id=tool_call_id, title=title, kind=kind, status=status, content=list(content) if content is not None else None, locations=list(locations) if locations is not None else None, - rawInput=raw_input, - rawOutput=raw_output, + raw_input=raw_input, + raw_output=raw_output, ) diff --git a/src/acp/interfaces.py b/src/acp/interfaces.py index 11d04d3..f0049e0 100644 --- a/src/acp/interfaces.py +++ b/src/acp/interfaces.py @@ -3,19 +3,35 @@ from typing import Any, Protocol from .schema import ( + AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, + AudioContentBlock, AuthenticateRequest, AuthenticateResponse, + AvailableCommandsUpdate, CancelNotification, + ClientCapabilities, CreateTerminalRequest, CreateTerminalResponse, + CurrentModeUpdate, + EmbeddedResourceContentBlock, + EnvVariable, + HttpMcpServer, + ImageContentBlock, + Implementation, InitializeRequest, InitializeResponse, KillTerminalCommandRequest, KillTerminalCommandResponse, + ListSessionsRequest, + ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, + McpServerStdio, NewSessionRequest, NewSessionResponse, + PermissionOption, PromptRequest, PromptResponse, ReadTextFileRequest, @@ -24,63 +40,153 @@ ReleaseTerminalResponse, RequestPermissionRequest, RequestPermissionResponse, + ResourceContentBlock, SessionNotification, SetSessionModelRequest, SetSessionModelResponse, SetSessionModeRequest, SetSessionModeResponse, + SseMcpServer, TerminalOutputRequest, TerminalOutputResponse, + TextContentBlock, + ToolCallProgress, + ToolCallStart, + ToolCallUpdate, + UserMessageChunk, WaitForTerminalExitRequest, WaitForTerminalExitResponse, WriteTextFileRequest, WriteTextFileResponse, ) +from .utils import param_model __all__ = ["Agent", "Client"] class Client(Protocol): - async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: ... - - async def sessionUpdate(self, params: SessionNotification) -> None: ... - - async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse | None: ... - - async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse: ... - - async def createTerminal(self, params: CreateTerminalRequest) -> CreateTerminalResponse: ... - - async def terminalOutput(self, params: TerminalOutputRequest) -> TerminalOutputResponse: ... - - async def releaseTerminal(self, params: ReleaseTerminalRequest) -> ReleaseTerminalResponse | None: ... - - async def waitForTerminalExit(self, params: WaitForTerminalExitRequest) -> WaitForTerminalExitResponse: ... - - async def killTerminal(self, params: KillTerminalCommandRequest) -> KillTerminalCommandResponse | None: ... - - async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]: ... - - async def extNotification(self, method: str, params: dict[str, Any]) -> None: ... + @param_model(RequestPermissionRequest) + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCallUpdate, **kwargs: Any + ) -> RequestPermissionResponse: ... + + @param_model(SessionNotification) + async def session_update( + self, + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: ... + + @param_model(WriteTextFileRequest) + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: ... + + @param_model(ReadTextFileRequest) + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: ... + + @param_model(CreateTerminalRequest) + async def create_terminal( + self, + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> CreateTerminalResponse: ... + + @param_model(TerminalOutputRequest) + async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: ... + + @param_model(ReleaseTerminalRequest) + async def release_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> ReleaseTerminalResponse | None: ... + + @param_model(WaitForTerminalExitRequest) + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> WaitForTerminalExitResponse: ... + + @param_model(KillTerminalCommandRequest) + async def kill_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> KillTerminalCommandResponse | None: ... + + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: ... + + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: ... + + def on_connect(self, conn: Agent) -> None: ... class Agent(Protocol): - async def initialize(self, params: InitializeRequest) -> InitializeResponse: ... - - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: ... - - async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: ... - - async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: ... - - async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse | None: ... - - async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: ... - - async def prompt(self, params: PromptRequest) -> PromptResponse: ... - - async def cancel(self, params: CancelNotification) -> None: ... - - async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]: ... - - async def extNotification(self, method: str, params: dict[str, Any]) -> None: ... + @param_model(InitializeRequest) + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: ... + + @param_model(NewSessionRequest) + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + ) -> NewSessionResponse: ... + + @param_model(LoadSessionRequest) + async def load_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + ) -> LoadSessionResponse | None: ... + + @param_model(ListSessionsRequest) + async def list_sessions( + self, cursor: str | None = None, cwd: str | None = None, **kwargs: Any + ) -> ListSessionsResponse: ... + + @param_model(SetSessionModeRequest) + async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None: ... + + @param_model(SetSessionModelRequest) + async def set_session_model( + self, model_id: str, session_id: str, **kwargs: Any + ) -> SetSessionModelResponse | None: ... + + @param_model(AuthenticateRequest) + async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None: ... + + @param_model(PromptRequest) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: ... + + @param_model(CancelNotification) + async def cancel(self, session_id: str, **kwargs: Any) -> None: ... + + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: ... + + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: ... + + def on_connect(self, conn: Client) -> None: ... diff --git a/src/acp/meta.py b/src/acp/meta.py index 2b67512..900e195 100644 --- a/src/acp/meta.py +++ b/src/acp/meta.py @@ -1,5 +1,25 @@ # Generated from schema/meta.json. Do not edit by hand. -# Schema ref: refs/tags/v0.6.3 -AGENT_METHODS = {'authenticate': 'authenticate', 'initialize': 'initialize', 'session_cancel': 'session/cancel', 'session_load': 'session/load', 'session_new': 'session/new', 'session_prompt': 'session/prompt', 'session_set_mode': 'session/set_mode', 'session_set_model': 'session/set_model'} -CLIENT_METHODS = {'fs_read_text_file': 'fs/read_text_file', 'fs_write_text_file': 'fs/write_text_file', 'session_request_permission': 'session/request_permission', 'session_update': 'session/update', 'terminal_create': 'terminal/create', 'terminal_kill': 'terminal/kill', 'terminal_output': 'terminal/output', 'terminal_release': 'terminal/release', 'terminal_wait_for_exit': 'terminal/wait_for_exit'} +# Schema ref: refs/tags/v0.7.0 +AGENT_METHODS = { + "authenticate": "authenticate", + "initialize": "initialize", + "session_cancel": "session/cancel", + "session_list": "session/list", + "session_load": "session/load", + "session_new": "session/new", + "session_prompt": "session/prompt", + "session_set_mode": "session/set_mode", + "session_set_model": "session/set_model", +} +CLIENT_METHODS = { + "fs_read_text_file": "fs/read_text_file", + "fs_write_text_file": "fs/write_text_file", + "session_request_permission": "session/request_permission", + "session_update": "session/update", + "terminal_create": "terminal/create", + "terminal_kill": "terminal/kill", + "terminal_output": "terminal/output", + "terminal_release": "terminal/release", + "terminal_wait_for_exit": "terminal/wait_for_exit", +} PROTOCOL_VERSION = 1 diff --git a/src/acp/router.py b/src/acp/router.py index f50f9b7..2aa3c24 100644 --- a/src/acp/router.py +++ b/src/acp/router.py @@ -1,192 +1,174 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, Mapping, Sequence +import inspect +import warnings +from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, Literal +from typing import Any, Literal, TypeVar from pydantic import BaseModel +from acp.utils import to_camel_case + from .exceptions import RequestError -__all__ = [ - "MessageRouter", - "Route", - "RouterBuilder", - "attribute_handler", -] +__all__ = ["MessageRouter", "Route"] AsyncHandler = Callable[[Any], Awaitable[Any | None]] +RequestHandler = Callable[[str, dict[str, Any]], Awaitable[Any]] +HandlerT = TypeVar("HandlerT", bound=RequestHandler) @dataclass(slots=True) class Route: method: str - model: type[BaseModel] - handle: Callable[[], AsyncHandler | None] + func: AsyncHandler | None kind: Literal["request", "notification"] optional: bool = False default_result: Any = None adapt_result: Callable[[Any | None], Any] | None = None - - -class MessageRouter: - def __init__( - self, - routes: Sequence[Route], - *, - request_extensions: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None, - notification_extensions: Callable[[str, dict[str, Any]], Awaitable[None]] | None = None, - ) -> None: - self._requests: Mapping[str, Route] = {route.method: route for route in routes if route.kind == "request"} - self._notifications: Mapping[str, Route] = { - route.method: route for route in routes if route.kind == "notification" - } - self._request_extensions = request_extensions - self._notification_extensions = notification_extensions - - async def dispatch_request(self, method: str, params: Any | None) -> Any: - if isinstance(method, str) and method.startswith("_"): - if self._request_extensions is None: - raise RequestError.method_not_found(method) - payload = params if isinstance(params, dict) else {} - return await self._request_extensions(method[1:], payload) - - route = self._requests.get(method) - if route is None: - raise RequestError.method_not_found(method) - model = route.model - parsed = model.model_validate(params) - - handler = route.handle() - if handler is None: - if route.optional: - return route.default_result - raise RequestError.method_not_found(method) - - result = await handler(parsed) - if route.adapt_result is not None: - return route.adapt_result(result) - return result - - async def dispatch_notification(self, method: str, params: Any | None) -> None: - if isinstance(method, str) and method.startswith("_"): - if self._notification_extensions is None: - return - payload = params if isinstance(params, dict) else {} - await self._notification_extensions(method[1:], payload) - return - - route = self._notifications.get(method) - if route is None: - raise RequestError.method_not_found(method) - model = route.model - parsed = model.model_validate(params) - - handler = route.handle() - if handler is None: - if route.optional: - return - raise RequestError.method_not_found(method) - await handler(parsed) - - -class RouterBuilder: - def __init__(self) -> None: - self._routes: list[Route] = [] - - def request( - self, - method: str, - model: type[BaseModel], - *, - optional: bool = False, - default_result: Any = None, - adapt_result: Callable[[Any | None], Any] | None = None, - ) -> Callable[[Callable[[], AsyncHandler | None]], Callable[[], AsyncHandler | None]]: - def decorator(factory: Callable[[], AsyncHandler | None]) -> Callable[[], AsyncHandler | None]: - self._routes.append( - Route( - method=method, - model=model, - handle=factory, - kind="request", - optional=optional, - default_result=default_result, - adapt_result=adapt_result, - ) + warn_unstable: bool = False + + async def handle(self, params: Any) -> Any: + if self.func is None: + if self.optional: + return self.default_result + raise RequestError.method_not_found(self.method) + if self.warn_unstable: + warnings.warn( + f"The method {self.method} is part of the unstable protocol, please enable `use_unstable_protocol` flag to use it.", + UserWarning, + stacklevel=3, ) - return factory + raise RequestError.method_not_found(self.method) + result = await self.func(params) + if self.adapt_result is not None and self.kind == "request": + return self.adapt_result(result) + return result - return decorator - def notification( - self, - method: str, - model: type[BaseModel], - *, - optional: bool = False, - ) -> Callable[[Callable[[], AsyncHandler | None]], Callable[[], AsyncHandler | None]]: - def decorator(factory: Callable[[], AsyncHandler | None]) -> Callable[[], AsyncHandler | None]: - self._routes.append( - Route( - method=method, - model=model, - handle=factory, - kind="notification", - optional=optional, +class MessageRouter: + def __init__(self, use_unstable_protocol: bool = False) -> None: + self._requests: dict[str, Route] = {} + self._notifications: dict[str, Route] = {} + self._request_extensions: RequestHandler | None = None + self._notification_extensions: RequestHandler | None = None + self._use_unstable_protocol = use_unstable_protocol + + def add_route(self, route: Route) -> None: + if route.kind == "request": + self._requests[route.method] = route + else: + self._notifications[route.method] = route + + def _make_func(self, model: type[BaseModel], obj: Any, attr: str) -> AsyncHandler | None: + legacy_api = False + func = getattr(obj, attr, None) + if func is None and "_" in attr: + attr = to_camel_case(attr) + func = getattr(obj, attr, None) + legacy_api = True + elif callable(func) and "_" not in attr: + original_func = func + if hasattr(func, "__func__"): + original_func = func.__func__ + parameters = inspect.signature(original_func).parameters + if len(parameters) == 2 and "params" in parameters: + legacy_api = True + + if func is None or not callable(func): + return None + + async def wrapper(params: Any) -> Any: + if legacy_api: + warnings.warn( + f"The old style method {type(obj).__name__}.{attr} is deprecated, " + "please update to the snake-cased form.", + DeprecationWarning, + stacklevel=3, ) - ) - return factory - - return decorator + model_obj = model.model_validate(params) + if legacy_api: + return await func(model_obj) # type: ignore[arg-type] + params = {k: getattr(model_obj, k) for k in model.model_fields if k != "field_meta"} + if meta := getattr(model_obj, "field_meta", None): + params.update(meta) + return await func(**params) # type: ignore[arg-type] - def build( - self, - *, - request_extensions: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None, - notification_extensions: Callable[[str, dict[str, Any]], Awaitable[None]] | None = None, - ) -> MessageRouter: - return MessageRouter( - routes=self._routes, - request_extensions=request_extensions, - notification_extensions=notification_extensions, - ) + return wrapper - def request_attr( + def route_request( self, method: str, model: type[BaseModel], obj: Any, attr: str, - *, optional: bool = False, default_result: Any = None, adapt_result: Callable[[Any | None], Any] | None = None, - ) -> None: - self.request( - method, - model, + unstable: bool = False, + ) -> Route: + """Register a request route with obj and attribute name.""" + route = Route( + method=method, + func=self._make_func(model, obj, attr), + kind="request", optional=optional, default_result=default_result, adapt_result=adapt_result, - )(attribute_handler(obj, attr)) + warn_unstable=unstable and not self._use_unstable_protocol, + ) + self.add_route(route) + return route - def notification_attr( + def route_notification( self, method: str, model: type[BaseModel], obj: Any, attr: str, - *, optional: bool = False, - ) -> None: - self.notification(method, model, optional=optional)(attribute_handler(obj, attr)) + unstable: bool = False, + ) -> Route: + """Register a notification route with obj and attribute name.""" + route = Route( + method=method, + func=self._make_func(model, obj, attr), + kind="notification", + optional=optional, + warn_unstable=unstable and not self._use_unstable_protocol, + ) + self.add_route(route) + return route + + def handle_extension_request(self, handler: HandlerT) -> HandlerT: + """Register a handler for extension requests.""" + self._request_extensions = handler + return handler + + def handle_extension_notification(self, handler: HandlerT) -> HandlerT: + """Register a handler for extension notifications.""" + self._notification_extensions = handler + return handler + + async def __call__(self, method: str, params: Any | None, is_notification: bool) -> Any: + """The main router call to handle a request or notification.""" + if is_notification: + ext_handler = self._notification_extensions + routes = self._notifications + else: + ext_handler = self._request_extensions + routes = self._requests + if isinstance(method, str) and method.startswith("_"): + if ext_handler is None: + raise RequestError.method_not_found(method) + payload = params if isinstance(params, dict) else {} + return await ext_handler(method[1:], payload) -def attribute_handler(obj: Any, attr: str) -> Callable[[], AsyncHandler | None]: - def factory() -> AsyncHandler | None: - func = getattr(obj, attr, None) - return func if callable(func) else None + route = routes.get(method) + if route is None: + raise RequestError.method_not_found(method) - return factory + return await route.handle(params) diff --git a/src/acp/schema.py b/src/acp/schema.py index 4814f08..2576f48 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -1,5 +1,5 @@ # Generated from schema/schema.json. Do not edit by hand. -# Schema ref: refs/tags/v0.6.3 +# Schema ref: refs/tags/v0.7.0 from __future__ import annotations @@ -19,11 +19,34 @@ class BaseModel(_BaseModel): model_config = ConfigDict(populate_by_name=True) + def __getattr__(self, item: str) -> Any: + if item.lower() != item: + snake_cased = "".join("_" + c.lower() if c.isupper() and i > 0 else c.lower() for i, c in enumerate(item)) + return getattr(self, snake_cased) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + class Jsonrpc(Enum): field_2_0 = "2.0" +class AuthMethod(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # Optional description providing more details about this authentication method. + description: Annotated[ + Optional[str], + Field(description="Optional description providing more details about this authentication method."), + ] = None + # Unique identifier for this authentication method. + id: Annotated[str, Field(description="Unique identifier for this authentication method.")] + # Human-readable name of the authentication method. + name: Annotated[str, Field(description="Human-readable name of the authentication method.")] + + class AuthenticateRequest(BaseModel): # Extension point for implementations field_meta: Annotated[ @@ -32,10 +55,11 @@ class AuthenticateRequest(BaseModel): ] = None # The ID of the authentication method to use. # Must be one of the methods advertised in the initialize response. - methodId: Annotated[ + method_id: Annotated[ str, Field( - description="The ID of the authentication method to use.\nMust be one of the methods advertised in the initialize response." + alias="methodId", + description="The ID of the authentication method to use.\nMust be one of the methods advertised in the initialize response.", ), ] @@ -48,22 +72,6 @@ class AuthenticateResponse(BaseModel): ] = None -class CommandInputHint(BaseModel): - # A hint to display when the input hasn't been provided yet - hint: Annotated[ - str, - Field(description="A hint to display when the input hasn't been provided yet"), - ] - - -class AvailableCommandInput(RootModel[CommandInputHint]): - # The input specification for a command. - root: Annotated[ - CommandInputHint, - Field(description="The input specification for a command."), - ] - - class BlobResourceContents(BaseModel): # Extension point for implementations field_meta: Annotated[ @@ -71,7 +79,7 @@ class BlobResourceContents(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None blob: str - mimeType: Optional[str] = None + mime_type: Annotated[Optional[str], Field(alias="mimeType")] = None uri: str @@ -82,7 +90,30 @@ class CreateTerminalResponse(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The unique identifier for the created terminal. - terminalId: Annotated[str, Field(description="The unique identifier for the created terminal.")] + terminal_id: Annotated[ + str, + Field( + alias="terminalId", + description="The unique identifier for the created terminal.", + ), + ] + + +class Diff(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # The new content after modification. + new_text: Annotated[str, Field(alias="newText", description="The new content after modification.")] + # The original content (None for new files). + old_text: Annotated[ + Optional[str], + Field(alias="oldText", description="The original content (None for new files)."), + ] = None + # The file path being modified. + path: Annotated[str, Field(description="The file path being modified.")] class EnvVariable(BaseModel): @@ -131,14 +162,20 @@ class FileSystemCapability(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Whether the Client supports `fs/read_text_file` requests. - readTextFile: Annotated[ + read_text_file: Annotated[ Optional[bool], - Field(description="Whether the Client supports `fs/read_text_file` requests."), + Field( + alias="readTextFile", + description="Whether the Client supports `fs/read_text_file` requests.", + ), ] = False # Whether the Client supports `fs/write_text_file` requests. - writeTextFile: Annotated[ + write_text_file: Annotated[ Optional[bool], - Field(description="Whether the Client supports `fs/write_text_file` requests."), + Field( + alias="writeTextFile", + description="Whether the Client supports `fs/write_text_file` requests.", + ), ] = False @@ -155,6 +192,11 @@ class HttpHeader(BaseModel): class Implementation(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None # Intended for programmatic or logical use, but can be used as a display # name fallback if title isn’t present. name: Annotated[ @@ -174,11 +216,11 @@ class Implementation(BaseModel): ), ] = None # Version of the implementation. Can be displayed to the user or used - # for debugging or metrics purposes. + # for debugging or metrics purposes. (e.g. "1.0.0"). version: Annotated[ str, Field( - description="Version of the implementation. Can be displayed to the user or used\nfor debugging or metrics purposes." + description='Version of the implementation. Can be displayed to the user or used\nfor debugging or metrics purposes. (e.g. "1.0.0").' ), ] @@ -191,6 +233,26 @@ class KillTerminalCommandResponse(BaseModel): ] = None +class ListSessionsRequest(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # Opaque cursor token from a previous response's nextCursor field for cursor-based pagination + cursor: Annotated[ + Optional[str], + Field( + description="Opaque cursor token from a previous response's nextCursor field for cursor-based pagination" + ), + ] = None + # Filter sessions by working directory. Must be an absolute path. + cwd: Annotated[ + Optional[str], + Field(description="Filter sessions by working directory. Must be an absolute path."), + ] = None + + class McpCapabilities(BaseModel): # Extension point for implementations field_meta: Annotated[ @@ -203,7 +265,12 @@ class McpCapabilities(BaseModel): sse: Annotated[Optional[bool], Field(description="Agent supports [`McpServer::Sse`].")] = False -class HttpMcpServer(BaseModel): +class McpServerHttp(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None # HTTP headers to set when making requests to the MCP server. headers: Annotated[ List[HttpHeader], @@ -211,12 +278,16 @@ class HttpMcpServer(BaseModel): ] # Human-readable name identifying this MCP server. name: Annotated[str, Field(description="Human-readable name identifying this MCP server.")] - type: Literal["http"] # URL to the MCP server. url: Annotated[str, Field(description="URL to the MCP server.")] -class SseMcpServer(BaseModel): +class McpServerSse(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None # HTTP headers to set when making requests to the MCP server. headers: Annotated[ List[HttpHeader], @@ -224,12 +295,16 @@ class SseMcpServer(BaseModel): ] # Human-readable name identifying this MCP server. name: Annotated[str, Field(description="Human-readable name identifying this MCP server.")] - type: Literal["sse"] # URL to the MCP server. url: Annotated[str, Field(description="URL to the MCP server.")] -class StdioMcpServer(BaseModel): +class McpServerStdio(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None # Command-line arguments to pass to the MCP server. args: Annotated[ List[str], @@ -256,29 +331,11 @@ class ModelInfo(BaseModel): # Optional description of the model. description: Annotated[Optional[str], Field(description="Optional description of the model.")] = None # Unique identifier for the model. - modelId: Annotated[str, Field(description="Unique identifier for the model.")] + model_id: Annotated[str, Field(alias="modelId", description="Unique identifier for the model.")] # Human-readable name of the model. name: Annotated[str, Field(description="Human-readable name of the model.")] -class NewSessionRequest(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - # The working directory for this session. Must be an absolute path. - cwd: Annotated[ - str, - Field(description="The working directory for this session. Must be an absolute path."), - ] - # List of MCP (Model Context Protocol) servers the agent should connect to. - mcpServers: Annotated[ - List[Union[HttpMcpServer, SseMcpServer, StdioMcpServer]], - Field(description="List of MCP (Model Context Protocol) servers the agent should connect to."), - ] - - class PromptCapabilities(BaseModel): # Extension point for implementations field_meta: Annotated[ @@ -291,10 +348,11 @@ class PromptCapabilities(BaseModel): # # When enabled, the Client is allowed to include [`ContentBlock::Resource`] # in prompt requests for pieces of context that are referenced in the message. - embeddedContext: Annotated[ + embedded_context: Annotated[ Optional[bool], Field( - description="Agent supports embedded context in `session/prompt` requests.\n\nWhen enabled, the Client is allowed to include [`ContentBlock::Resource`]\nin prompt requests for pieces of context that are referenced in the message." + alias="embeddedContext", + description="Agent supports embedded context in `session/prompt` requests.\n\nWhen enabled, the Client is allowed to include [`ContentBlock::Resource`]\nin prompt requests for pieces of context that are referenced in the message.", ), ] = False # Agent supports [`ContentBlock::Image`]. @@ -322,54 +380,73 @@ class DeniedOutcome(BaseModel): outcome: Literal["cancelled"] -class AllowedOutcome(BaseModel): - # The ID of the option the user selected. - optionId: Annotated[str, Field(description="The ID of the option the user selected.")] - outcome: Literal["selected"] +class Role(Enum): + assistant = "assistant" + user = "user" -class RequestPermissionResponse(BaseModel): +class SelectedPermissionOutcome(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # The user's decision on the permission request. - outcome: Annotated[ - Union[DeniedOutcome, AllowedOutcome], - Field( - description="The user's decision on the permission request.", - discriminator="outcome", - ), + # The ID of the option the user selected. + option_id: Annotated[ + str, + Field(alias="optionId", description="The ID of the option the user selected."), ] -class Role(Enum): - assistant = "assistant" - user = "user" +class SessionInfo(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # The working directory for this session. Must be an absolute path. + cwd: Annotated[ + str, + Field(description="The working directory for this session. Must be an absolute path."), + ] + # Unique identifier for the session + session_id: Annotated[str, Field(alias="sessionId", description="Unique identifier for the session")] + # Human-readable title for the session + title: Annotated[Optional[str], Field(description="Human-readable title for the session")] = None + # ISO 8601 timestamp of last activity + updated_at: Annotated[ + Optional[str], + Field(alias="updatedAt", description="ISO 8601 timestamp of last activity"), + ] = None -class SessionModelState(BaseModel): +class SessionListCapabilities(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # The set of models that the Agent can use - availableModels: Annotated[List[ModelInfo], Field(description="The set of models that the Agent can use")] - # The current model the Agent is in. - currentModelId: Annotated[str, Field(description="The current model the Agent is in.")] -class CurrentModeUpdate(BaseModel): +class SessionModelState(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # The ID of the current mode - currentModeId: Annotated[str, Field(description="The ID of the current mode")] - sessionUpdate: Literal["current_mode_update"] + # The set of models that the Agent can use + available_models: Annotated[ + List[ModelInfo], + Field( + alias="availableModels", + description="The set of models that the Agent can use", + ), + ] + # The current model the Agent is in. + current_model_id: Annotated[ + str, + Field(alias="currentModelId", description="The current model the Agent is in."), + ] class SetSessionModeRequest(BaseModel): @@ -379,9 +456,12 @@ class SetSessionModeRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The ID of the mode to set. - modeId: Annotated[str, Field(description="The ID of the mode to set.")] + mode_id: Annotated[str, Field(alias="modeId", description="The ID of the mode to set.")] # The ID of the session to set the mode for. - sessionId: Annotated[str, Field(description="The ID of the session to set the mode for.")] + session_id: Annotated[ + str, + Field(alias="sessionId", description="The ID of the session to set the mode for."), + ] class SetSessionModeResponse(BaseModel): @@ -395,9 +475,12 @@ class SetSessionModelRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The ID of the model to set. - modelId: Annotated[str, Field(description="The ID of the model to set.")] + model_id: Annotated[str, Field(alias="modelId", description="The ID of the model to set.")] # The ID of the session to set the model for. - sessionId: Annotated[str, Field(description="The ID of the session to set the model for.")] + session_id: Annotated[ + str, + Field(alias="sessionId", description="The ID of the session to set the model for."), + ] class SetSessionModelResponse(BaseModel): @@ -408,6 +491,15 @@ class SetSessionModelResponse(BaseModel): ] = None +class Terminal(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + terminal_id: Annotated[str, Field(alias="terminalId")] + + class TerminalExitStatus(BaseModel): # Extension point for implementations field_meta: Annotated[ @@ -415,9 +507,10 @@ class TerminalExitStatus(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The process exit code (may be null if terminated by signal). - exitCode: Annotated[ + exit_code: Annotated[ Optional[int], Field( + alias="exitCode", description="The process exit code (may be null if terminated by signal).", ge=0, ), @@ -436,9 +529,12 @@ class TerminalOutputRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] # The ID of the terminal to get output from. - terminalId: Annotated[str, Field(description="The ID of the terminal to get output from.")] + terminal_id: Annotated[ + str, + Field(alias="terminalId", description="The ID of the terminal to get output from."), + ] class TerminalOutputResponse(BaseModel): @@ -448,9 +544,9 @@ class TerminalOutputResponse(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Exit status if the command has completed. - exitStatus: Annotated[ + exit_status: Annotated[ Optional[TerminalExitStatus], - Field(description="Exit status if the command has completed."), + Field(alias="exitStatus", description="Exit status if the command has completed."), ] = None # The terminal output captured so far. output: Annotated[str, Field(description="The terminal output captured so far.")] @@ -464,28 +560,16 @@ class TextResourceContents(BaseModel): Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - mimeType: Optional[str] = None + mime_type: Annotated[Optional[str], Field(alias="mimeType")] = None text: str uri: str -class FileEditToolCallContent(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - # The new content after modification. - newText: Annotated[str, Field(description="The new content after modification.")] - # The original content (None for new files). - oldText: Annotated[Optional[str], Field(description="The original content (None for new files).")] = None - # The file path being modified. - path: Annotated[str, Field(description="The file path being modified.")] +class FileEditToolCallContent(Diff): type: Literal["diff"] -class TerminalToolCallContent(BaseModel): - terminalId: str +class TerminalToolCallContent(Terminal): type: Literal["terminal"] @@ -501,6 +585,19 @@ class ToolCallLocation(BaseModel): path: Annotated[str, Field(description="The file path being accessed or modified.")] +class UnstructuredCommandInput(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # A hint to display when the input hasn't been provided yet + hint: Annotated[ + str, + Field(description="A hint to display when the input hasn't been provided yet"), + ] + + class WaitForTerminalExitRequest(BaseModel): # Extension point for implementations field_meta: Annotated[ @@ -508,9 +605,12 @@ class WaitForTerminalExitRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] # The ID of the terminal to wait for. - terminalId: Annotated[str, Field(description="The ID of the terminal to wait for.")] + terminal_id: Annotated[ + str, + Field(alias="terminalId", description="The ID of the terminal to wait for."), + ] class WaitForTerminalExitResponse(BaseModel): @@ -520,9 +620,10 @@ class WaitForTerminalExitResponse(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The process exit code (may be null if terminated by signal). - exitCode: Annotated[ + exit_code: Annotated[ Optional[int], Field( + alias="exitCode", description="The process exit code (may be null if terminated by signal).", ge=0, ), @@ -545,7 +646,7 @@ class WriteTextFileRequest(BaseModel): # Absolute path to the file to write. path: Annotated[str, Field(description="Absolute path to the file to write.")] # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] class WriteTextFileResponse(BaseModel): @@ -556,26 +657,6 @@ class WriteTextFileResponse(BaseModel): ] = None -class AgentCapabilities(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - # Whether the agent supports `session/load`. - loadSession: Annotated[Optional[bool], Field(description="Whether the agent supports `session/load`.")] = False - # MCP capabilities supported by the agent. - mcpCapabilities: Annotated[ - Optional[McpCapabilities], - Field(description="MCP capabilities supported by the agent."), - ] = McpCapabilities(http=False, sse=False) - # Prompt capabilities supported by the agent. - promptCapabilities: Annotated[ - Optional[PromptCapabilities], - Field(description="Prompt capabilities supported by the agent."), - ] = PromptCapabilities(audio=False, embeddedContext=False, image=False) - - class AgentErrorMessage(BaseModel): jsonrpc: Jsonrpc # JSON RPC Request Id @@ -603,44 +684,26 @@ class Annotations(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None audience: Optional[List[Role]] = None - lastModified: Optional[str] = None + last_modified: Annotated[Optional[str], Field(alias="lastModified")] = None priority: Optional[float] = None -class AuthMethod(BaseModel): +class AudioContent(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # Optional description providing more details about this authentication method. - description: Annotated[ - Optional[str], - Field(description="Optional description providing more details about this authentication method."), - ] = None - # Unique identifier for this authentication method. - id: Annotated[str, Field(description="Unique identifier for this authentication method.")] - # Human-readable name of the authentication method. - name: Annotated[str, Field(description="Human-readable name of the authentication method.")] + annotations: Optional[Annotations] = None + data: str + mime_type: Annotated[str, Field(alias="mimeType")] -class AvailableCommand(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - # Human-readable description of what the command does. - description: Annotated[str, Field(description="Human-readable description of what the command does.")] - # Input for the command if required - input: Annotated[ - Optional[AvailableCommandInput], - Field(description="Input for the command if required"), - ] = None - # Command name (e.g., `create_plan`, `research_codebase`). - name: Annotated[ - str, - Field(description="Command name (e.g., `create_plan`, `research_codebase`)."), +class AvailableCommandInput(RootModel[UnstructuredCommandInput]): + # The input specification for a command. + root: Annotated[ + UnstructuredCommandInput, + Field(description="The input specification for a command."), ] @@ -651,7 +714,13 @@ class CancelNotification(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The ID of the session to cancel operations for. - sessionId: Annotated[str, Field(description="The ID of the session to cancel operations for.")] + session_id: Annotated[ + str, + Field( + alias="sessionId", + description="The ID of the session to cancel operations for.", + ), + ] class ClientCapabilities(BaseModel): @@ -667,7 +736,7 @@ class ClientCapabilities(BaseModel): Field( description="File system capabilities supported by the client.\nDetermines which file operations the agent can request." ), - ] = FileSystemCapability(readTextFile=False, writeTextFile=False) + ] = FileSystemCapability() # Whether the Client support all `terminal/*` methods. terminal: Annotated[ Optional[bool], @@ -701,59 +770,11 @@ class ClientNotificationMessage(BaseModel): params: Optional[Union[CancelNotification, Any]] = None -class TextContentBlock(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - annotations: Optional[Annotations] = None - text: str - type: Literal["text"] - - -class ImageContentBlock(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - annotations: Optional[Annotations] = None - data: str - mimeType: str - type: Literal["image"] - uri: Optional[str] = None +class AudioContentBlock(AudioContent): + type: Literal["audio"] -class AudioContentBlock(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - annotations: Optional[Annotations] = None - data: str - mimeType: str - type: Literal["audio"] - - -class ResourceContentBlock(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - annotations: Optional[Annotations] = None - description: Optional[str] = None - mimeType: Optional[str] = None - name: str - size: Optional[int] = None - title: Optional[str] = None - type: Literal["resource_link"] - uri: str - - -class CreateTerminalRequest(BaseModel): +class CreateTerminalRequest(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], @@ -781,15 +802,38 @@ class CreateTerminalRequest(BaseModel): # The Client MUST ensure truncation happens at a character boundary to maintain valid # string output, even if this means the retained output is slightly less than the # specified limit. - outputByteLimit: Annotated[ + output_byte_limit: Annotated[ Optional[int], Field( + alias="outputByteLimit", description="Maximum number of output bytes to retain.\n\nWhen the limit is exceeded, the Client truncates from the beginning of the output\nto stay within the limit.\n\nThe Client MUST ensure truncation happens at a character boundary to maintain valid\nstring output, even if this means the retained output is slightly less than the\nspecified limit.", ge=0, ), ] = None # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] + + +class CurrentModeUpdate(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # The ID of the current mode + current_mode_id: Annotated[str, Field(alias="currentModeId", description="The ID of the current mode")] + + +class ImageContent(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + annotations: Optional[Annotations] = None + data: str + mime_type: Annotated[str, Field(alias="mimeType")] + uri: Optional[str] = None class InitializeRequest(BaseModel): @@ -799,23 +843,28 @@ class InitializeRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Capabilities supported by the client. - clientCapabilities: Annotated[ + client_capabilities: Annotated[ Optional[ClientCapabilities], - Field(description="Capabilities supported by the client."), - ] = ClientCapabilities(fs=FileSystemCapability(readTextFile=False, writeTextFile=False), terminal=False) + Field( + alias="clientCapabilities", + description="Capabilities supported by the client.", + ), + ] = ClientCapabilities() # Information about the Client name and version sent to the Agent. # # Note: in future versions of the protocol, this will be required. - clientInfo: Annotated[ + client_info: Annotated[ Optional[Implementation], Field( - description="Information about the Client name and version sent to the Agent.\n\nNote: in future versions of the protocol, this will be required." + alias="clientInfo", + description="Information about the Client name and version sent to the Agent.\n\nNote: in future versions of the protocol, this will be required.", ), ] = None # The latest protocol version supported by the client. - protocolVersion: Annotated[ + protocol_version: Annotated[ int, Field( + alias="protocolVersion", description="The latest protocol version supported by the client.", ge=0, le=65535, @@ -823,76 +872,64 @@ class InitializeRequest(BaseModel): ] -class InitializeResponse(BaseModel): +class KillTerminalCommandRequest(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # Capabilities supported by the agent. - agentCapabilities: Annotated[ - Optional[AgentCapabilities], - Field(description="Capabilities supported by the agent."), - ] = AgentCapabilities( - loadSession=False, - mcpCapabilities=McpCapabilities(http=False, sse=False), - promptCapabilities=PromptCapabilities(audio=False, embeddedContext=False, image=False), - ) - # Information about the Agent name and version sent to the Client. - # - # Note: in future versions of the protocol, this will be required. - agentInfo: Annotated[ - Optional[Implementation], - Field( - description="Information about the Agent name and version sent to the Client.\n\nNote: in future versions of the protocol, this will be required." - ), - ] = None - # Authentication methods supported by the agent. - authMethods: Annotated[ - Optional[List[AuthMethod]], - Field(description="Authentication methods supported by the agent."), - ] = [] - # The protocol version the client specified if supported by the agent, - # or the latest protocol version supported by the agent. - # - # The client should disconnect, if it doesn't support this version. - protocolVersion: Annotated[ - int, - Field( - description="The protocol version the client specified if supported by the agent,\nor the latest protocol version supported by the agent.\n\nThe client should disconnect, if it doesn't support this version.", - ge=0, - le=65535, - ), - ] + # The session ID for this request. + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] + # The ID of the terminal to kill. + terminal_id: Annotated[str, Field(alias="terminalId", description="The ID of the terminal to kill.")] -class KillTerminalCommandRequest(BaseModel): +class ListSessionsResponse(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] - # The ID of the terminal to kill. - terminalId: Annotated[str, Field(description="The ID of the terminal to kill.")] + # Opaque cursor token. If present, pass this in the next request's cursor parameter + # to fetch the next page. If absent, there are no more results. + next_cursor: Annotated[ + Optional[str], + Field( + alias="nextCursor", + description="Opaque cursor token. If present, pass this in the next request's cursor parameter\nto fetch the next page. If absent, there are no more results.", + ), + ] = None + # Array of session information objects + sessions: Annotated[List[SessionInfo], Field(description="Array of session information objects")] -class LoadSessionRequest(BaseModel): +class HttpMcpServer(McpServerHttp): + type: Literal["http"] + + +class SseMcpServer(McpServerSse): + type: Literal["sse"] + + +class NewSessionRequest(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # The working directory for this session. - cwd: Annotated[str, Field(description="The working directory for this session.")] - # List of MCP servers to connect to for this session. - mcpServers: Annotated[ - List[Union[HttpMcpServer, SseMcpServer, StdioMcpServer]], - Field(description="List of MCP servers to connect to for this session."), + # The working directory for this session. Must be an absolute path. + cwd: Annotated[ + str, + Field(description="The working directory for this session. Must be an absolute path."), + ] + # List of MCP (Model Context Protocol) servers the agent should connect to. + mcp_servers: Annotated[ + List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]], + Field( + alias="mcpServers", + description="List of MCP (Model Context Protocol) servers the agent should connect to.", + ), ] - # The ID of the session to load. - sessionId: Annotated[str, Field(description="The ID of the session to load.")] class PermissionOption(BaseModel): @@ -906,7 +943,13 @@ class PermissionOption(BaseModel): # Human-readable label to display to the user. name: Annotated[str, Field(description="Human-readable label to display to the user.")] # Unique identifier for this permission option. - optionId: Annotated[str, Field(description="Unique identifier for this permission option.")] + option_id: Annotated[ + str, + Field( + alias="optionId", + description="Unique identifier for this permission option.", + ), + ] class PlanEntry(BaseModel): @@ -939,7 +982,13 @@ class PromptResponse(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Indicates why the agent stopped processing the turn. - stopReason: Annotated[StopReason, Field(description="Indicates why the agent stopped processing the turn.")] + stop_reason: Annotated[ + StopReason, + Field( + alias="stopReason", + description="Indicates why the agent stopped processing the turn.", + ), + ] class ReadTextFileRequest(BaseModel): @@ -958,7 +1007,7 @@ class ReadTextFileRequest(BaseModel): # Absolute path to the file to read. path: Annotated[str, Field(description="Absolute path to the file to read.")] # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] class ReleaseTerminalRequest(BaseModel): @@ -968,9 +1017,63 @@ class ReleaseTerminalRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] # The ID of the terminal to release. - terminalId: Annotated[str, Field(description="The ID of the terminal to release.")] + terminal_id: Annotated[str, Field(alias="terminalId", description="The ID of the terminal to release.")] + + +class AllowedOutcome(SelectedPermissionOutcome): + outcome: Literal["selected"] + + +class RequestPermissionResponse(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # The user's decision on the permission request. + outcome: Annotated[ + Union[DeniedOutcome, AllowedOutcome], + Field( + description="The user's decision on the permission request.", + discriminator="outcome", + ), + ] + + +class ResourceLink(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + annotations: Optional[Annotations] = None + description: Optional[str] = None + mime_type: Annotated[Optional[str], Field(alias="mimeType")] = None + name: str + size: Optional[int] = None + title: Optional[str] = None + uri: str + + +class SessionCapabilities(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # **UNSTABLE** + # + # This capability is not part of the spec yet, and may be removed or changed at any point. + # + # Whether the agent supports `session/list`. + list: Annotated[ + Optional[SessionListCapabilities], + Field( + description="**UNSTABLE**\n\nThis capability is not part of the spec yet, and may be removed or changed at any point.\n\nWhether the agent supports `session/list`." + ), + ] = None class SessionMode(BaseModel): @@ -992,31 +1095,87 @@ class SessionModeState(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The set of modes that the Agent can operate in - availableModes: Annotated[ + available_modes: Annotated[ List[SessionMode], - Field(description="The set of modes that the Agent can operate in"), + Field( + alias="availableModes", + description="The set of modes that the Agent can operate in", + ), ] # The current mode the Agent is in. - currentModeId: Annotated[str, Field(description="The current mode the Agent is in.")] + current_mode_id: Annotated[ + str, + Field(alias="currentModeId", description="The current mode the Agent is in."), + ] + + +class CurrentModeUpdate(CurrentModeUpdate): + session_update: Annotated[Literal["current_mode_update"], Field(alias="sessionUpdate")] -class AgentPlanUpdate(BaseModel): +class TextContent(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # The list of tasks to be accomplished. - # - # When updating a plan, the agent must send a complete list of all entries - # with their current status. The client replaces the entire plan with each update. - entries: Annotated[ - List[PlanEntry], + annotations: Optional[Annotations] = None + text: str + + +class AgentCapabilities(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # Whether the agent supports `session/load`. + load_session: Annotated[ + Optional[bool], Field( - description="The list of tasks to be accomplished.\n\nWhen updating a plan, the agent must send a complete list of all entries\nwith their current status. The client replaces the entire plan with each update." + alias="loadSession", + description="Whether the agent supports `session/load`.", + ), + ] = False + # MCP capabilities supported by the agent. + mcp_capabilities: Annotated[ + Optional[McpCapabilities], + Field( + alias="mcpCapabilities", + description="MCP capabilities supported by the agent.", + ), + ] = McpCapabilities() + # Prompt capabilities supported by the agent. + prompt_capabilities: Annotated[ + Optional[PromptCapabilities], + Field( + alias="promptCapabilities", + description="Prompt capabilities supported by the agent.", ), + ] = PromptCapabilities() + session_capabilities: Annotated[Optional[SessionCapabilities], Field(alias="sessionCapabilities")] = ( + SessionCapabilities() + ) + + +class AvailableCommand(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # Human-readable description of what the command does. + description: Annotated[str, Field(description="Human-readable description of what the command does.")] + # Input for the command if required + input: Annotated[ + Optional[AvailableCommandInput], + Field(description="Input for the command if required"), + ] = None + # Command name (e.g., `create_plan`, `research_codebase`). + name: Annotated[ + str, + Field(description="Command name (e.g., `create_plan`, `research_codebase`)."), ] - sessionUpdate: Literal["plan"] class AvailableCommandsUpdate(BaseModel): @@ -1026,52 +1185,25 @@ class AvailableCommandsUpdate(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Commands the agent can execute - availableCommands: Annotated[List[AvailableCommand], Field(description="Commands the agent can execute")] - sessionUpdate: Literal["available_commands_update"] + available_commands: Annotated[ + List[AvailableCommand], + Field(alias="availableCommands", description="Commands the agent can execute"), + ] -class ClientResponseMessage(BaseModel): - jsonrpc: Jsonrpc - # JSON RPC Request Id - # - # An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2] - # - # The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects. - # - # [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling. - # - # [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions. - id: Annotated[ - Optional[Union[int, str]], - Field( - description="JSON RPC Request Id\n\nAn identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]\n\nThe Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.\n\n[1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.\n\n[2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions." - ), - ] = None - # All possible responses that a client can send to an agent. - # - # This enum is used internally for routing RPC responses. You typically won't need - # to use this directly - the responses are handled automatically by the connection. - # - # These are responses to the corresponding `AgentRequest` variants. - result: Annotated[ - Union[ - WriteTextFileResponse, - ReadTextFileResponse, - RequestPermissionResponse, - CreateTerminalResponse, - TerminalOutputResponse, - ReleaseTerminalResponse, - WaitForTerminalExitResponse, - KillTerminalCommandResponse, - Any, - ], - Field( - description="All possible responses that a client can send to an agent.\n\nThis enum is used internally for routing RPC responses. You typically won't need\nto use this directly - the responses are handled automatically by the connection.\n\nThese are responses to the corresponding `AgentRequest` variants." - ), - ] +class TextContentBlock(TextContent): + type: Literal["text"] + + +class ImageContentBlock(ImageContent): + type: Literal["image"] -class EmbeddedResourceContentBlock(BaseModel): +class ResourceContentBlock(ResourceLink): + type: Literal["resource_link"] + + +class EmbeddedResource(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], @@ -1083,17 +1215,83 @@ class EmbeddedResourceContentBlock(BaseModel): Union[TextResourceContents, BlobResourceContents], Field(description="Resource content that can be embedded in a message."), ] - type: Literal["resource"] -class LoadSessionResponse(BaseModel): +class InitializeResponse(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # **UNSTABLE** - # + # Capabilities supported by the agent. + agent_capabilities: Annotated[ + Optional[AgentCapabilities], + Field( + alias="agentCapabilities", + description="Capabilities supported by the agent.", + ), + ] = AgentCapabilities() + # Information about the Agent name and version sent to the Client. + # + # Note: in future versions of the protocol, this will be required. + agent_info: Annotated[ + Optional[Implementation], + Field( + alias="agentInfo", + description="Information about the Agent name and version sent to the Client.\n\nNote: in future versions of the protocol, this will be required.", + ), + ] = None + # Authentication methods supported by the agent. + auth_methods: Annotated[ + Optional[List[AuthMethod]], + Field( + alias="authMethods", + description="Authentication methods supported by the agent.", + ), + ] = [] + # The protocol version the client specified if supported by the agent, + # or the latest protocol version supported by the agent. + # + # The client should disconnect, if it doesn't support this version. + protocol_version: Annotated[ + int, + Field( + alias="protocolVersion", + description="The protocol version the client specified if supported by the agent,\nor the latest protocol version supported by the agent.\n\nThe client should disconnect, if it doesn't support this version.", + ge=0, + le=65535, + ), + ] + + +class LoadSessionRequest(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # The working directory for this session. + cwd: Annotated[str, Field(description="The working directory for this session.")] + # List of MCP servers to connect to for this session. + mcp_servers: Annotated[ + List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]], + Field( + alias="mcpServers", + description="List of MCP servers to connect to for this session.", + ), + ] + # The ID of the session to load. + session_id: Annotated[str, Field(alias="sessionId", description="The ID of the session to load.")] + + +class LoadSessionResponse(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # **UNSTABLE** + # # This capability is not part of the spec yet, and may be removed or changed at any point. # # Initial model state if supported by the Agent @@ -1143,14 +1341,101 @@ class NewSessionResponse(BaseModel): # Unique identifier for the created session. # # Used in all subsequent requests for this conversation. - sessionId: Annotated[ + session_id: Annotated[ str, Field( - description="Unique identifier for the created session.\n\nUsed in all subsequent requests for this conversation." + alias="sessionId", + description="Unique identifier for the created session.\n\nUsed in all subsequent requests for this conversation.", ), ] +class Plan(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # The list of tasks to be accomplished. + # + # When updating a plan, the agent must send a complete list of all entries + # with their current status. The client replaces the entire plan with each update. + entries: Annotated[ + List[PlanEntry], + Field( + description="The list of tasks to be accomplished.\n\nWhen updating a plan, the agent must send a complete list of all entries\nwith their current status. The client replaces the entire plan with each update." + ), + ] + + +class AgentPlanUpdate(Plan): + session_update: Annotated[Literal["plan"], Field(alias="sessionUpdate")] + + +class AvailableCommandsUpdate(AvailableCommandsUpdate): + session_update: Annotated[Literal["available_commands_update"], Field(alias="sessionUpdate")] + + +class ClientResponseMessage(BaseModel): + jsonrpc: Jsonrpc + # JSON RPC Request Id + # + # An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2] + # + # The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects. + # + # [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling. + # + # [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions. + id: Annotated[ + Optional[Union[int, str]], + Field( + description="JSON RPC Request Id\n\nAn identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]\n\nThe Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.\n\n[1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.\n\n[2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions." + ), + ] = None + # All possible responses that a client can send to an agent. + # + # This enum is used internally for routing RPC responses. You typically won't need + # to use this directly - the responses are handled automatically by the connection. + # + # These are responses to the corresponding `AgentRequest` variants. + result: Annotated[ + Union[ + WriteTextFileResponse, + ReadTextFileResponse, + RequestPermissionResponse, + CreateTerminalResponse, + TerminalOutputResponse, + ReleaseTerminalResponse, + WaitForTerminalExitResponse, + KillTerminalCommandResponse, + Any, + ], + Field( + description="All possible responses that a client can send to an agent.\n\nThis enum is used internally for routing RPC responses. You typically won't need\nto use this directly - the responses are handled automatically by the connection.\n\nThese are responses to the corresponding `AgentRequest` variants." + ), + ] + + +class EmbeddedResourceContentBlock(EmbeddedResource): + type: Literal["resource"] + + +class ContentChunk(BaseModel): + # Extension point for implementations + field_meta: Annotated[ + Optional[Any], + Field(alias="_meta", description="Extension point for implementations"), + ] = None + # A single item of content + content: Annotated[ + Union[ + TextContentBlock, ImageContentBlock, AudioContentBlock, ResourceContentBlock, EmbeddedResourceContentBlock + ], + Field(description="A single item of content", discriminator="type"), + ] + + class PromptRequest(BaseModel): # Extension point for implementations field_meta: Annotated[ @@ -1185,58 +1470,74 @@ class PromptRequest(BaseModel): ), ] # The ID of the session to send this user message to - sessionId: Annotated[str, Field(description="The ID of the session to send this user message to")] + session_id: Annotated[ + str, + Field( + alias="sessionId", + description="The ID of the session to send this user message to", + ), + ] -class UserMessageChunk(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - # A single item of content - content: Annotated[ - Union[ - TextContentBlock, ImageContentBlock, AudioContentBlock, ResourceContentBlock, EmbeddedResourceContentBlock - ], - Field(description="A single item of content", discriminator="type"), - ] - sessionUpdate: Literal["user_message_chunk"] +class UserMessageChunk(ContentChunk): + session_update: Annotated[Literal["user_message_chunk"], Field(alias="sessionUpdate")] -class AgentMessageChunk(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), +class AgentMessageChunk(ContentChunk): + session_update: Annotated[Literal["agent_message_chunk"], Field(alias="sessionUpdate")] + + +class AgentThoughtChunk(ContentChunk): + session_update: Annotated[Literal["agent_thought_chunk"], Field(alias="sessionUpdate")] + + +class AgentResponseMessage(BaseModel): + jsonrpc: Jsonrpc + # JSON RPC Request Id + # + # An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2] + # + # The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects. + # + # [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling. + # + # [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions. + id: Annotated[ + Optional[Union[int, str]], + Field( + description="JSON RPC Request Id\n\nAn identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]\n\nThe Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.\n\n[1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.\n\n[2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions." + ), ] = None - # A single item of content - content: Annotated[ + # All possible responses that an agent can send to a client. + # + # This enum is used internally for routing RPC responses. You typically won't need + # to use this directly - the responses are handled automatically by the connection. + # + # These are responses to the corresponding `ClientRequest` variants. + result: Annotated[ Union[ - TextContentBlock, ImageContentBlock, AudioContentBlock, ResourceContentBlock, EmbeddedResourceContentBlock + InitializeResponse, + AuthenticateResponse, + NewSessionResponse, + LoadSessionResponse, + ListSessionsResponse, + SetSessionModeResponse, + PromptResponse, + SetSessionModelResponse, + Any, ], - Field(description="A single item of content", discriminator="type"), + Field( + description="All possible responses that an agent can send to a client.\n\nThis enum is used internally for routing RPC responses. You typically won't need\nto use this directly - the responses are handled automatically by the connection.\n\nThese are responses to the corresponding `ClientRequest` variants." + ), ] - sessionUpdate: Literal["agent_message_chunk"] -class AgentThoughtChunk(BaseModel): +class Content(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - # A single item of content - content: Annotated[ - Union[ - TextContentBlock, ImageContentBlock, AudioContentBlock, ResourceContentBlock, EmbeddedResourceContentBlock - ], - Field(description="A single item of content", discriminator="type"), - ] - sessionUpdate: Literal["agent_thought_chunk"] - - -class ContentToolCallContent(BaseModel): # The actual content block. content: Annotated[ Union[ @@ -1244,10 +1545,13 @@ class ContentToolCallContent(BaseModel): ], Field(description="The actual content block.", discriminator="type"), ] + + +class ContentToolCallContent(Content): type: Literal["content"] -class ToolCall(BaseModel): +class ToolCallUpdate(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], @@ -1266,15 +1570,51 @@ class ToolCall(BaseModel): Field(description="Replace the locations collection."), ] = None # Update the raw input. - rawInput: Annotated[Optional[Any], Field(description="Update the raw input.")] = None + raw_input: Annotated[Optional[Any], Field(alias="rawInput", description="Update the raw input.")] = None # Update the raw output. - rawOutput: Annotated[Optional[Any], Field(description="Update the raw output.")] = None + raw_output: Annotated[Optional[Any], Field(alias="rawOutput", description="Update the raw output.")] = None # Update the execution status. status: Annotated[Optional[ToolCallStatus], Field(description="Update the execution status.")] = None # Update the human-readable title. title: Annotated[Optional[str], Field(description="Update the human-readable title.")] = None # The ID of the tool call being updated. - toolCallId: Annotated[str, Field(description="The ID of the tool call being updated.")] + tool_call_id: Annotated[ + str, + Field(alias="toolCallId", description="The ID of the tool call being updated."), + ] + + +class ClientRequestMessage(BaseModel): + jsonrpc: Jsonrpc + # JSON RPC Request Id + # + # An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2] + # + # The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects. + # + # [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling. + # + # [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions. + id: Annotated[ + Optional[Union[int, str]], + Field( + description="JSON RPC Request Id\n\nAn identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]\n\nThe Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.\n\n[1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.\n\n[2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions." + ), + ] = None + method: str + params: Optional[ + Union[ + InitializeRequest, + AuthenticateRequest, + NewSessionRequest, + LoadSessionRequest, + ListSessionsRequest, + SetSessionModeRequest, + PromptRequest, + SetSessionModelRequest, + Any, + ] + ] = None class RequestPermissionRequest(BaseModel): @@ -1289,12 +1629,22 @@ class RequestPermissionRequest(BaseModel): Field(description="Available permission options for the user to choose from."), ] # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] # Details about the tool call requiring permission. - toolCall: Annotated[ToolCall, Field(description="Details about the tool call requiring permission.")] + tool_call: Annotated[ + ToolCallUpdate, + Field( + alias="toolCall", + description="Details about the tool call requiring permission.", + ), + ] + +class ToolCallProgress(ToolCallUpdate): + session_update: Annotated[Literal["tool_call_update"], Field(alias="sessionUpdate")] -class ToolCallStart(BaseModel): + +class ToolCall(BaseModel): # Extension point for implementations field_meta: Annotated[ Optional[Any], @@ -1320,10 +1670,15 @@ class ToolCallStart(BaseModel): Field(description='File locations affected by this tool call.\nEnables "follow-along" features in clients.'), ] = None # Raw input parameters sent to the tool. - rawInput: Annotated[Optional[Any], Field(description="Raw input parameters sent to the tool.")] = None + raw_input: Annotated[ + Optional[Any], + Field(alias="rawInput", description="Raw input parameters sent to the tool."), + ] = None # Raw output returned by the tool. - rawOutput: Annotated[Optional[Any], Field(description="Raw output returned by the tool.")] = None - sessionUpdate: Literal["tool_call"] + raw_output: Annotated[ + Optional[Any], + Field(alias="rawOutput", description="Raw output returned by the tool."), + ] = None # Current execution status of the tool call. status: Annotated[Optional[ToolCallStatus], Field(description="Current execution status of the tool call.")] = None # Human-readable title describing what the tool is doing. @@ -1332,84 +1687,20 @@ class ToolCallStart(BaseModel): Field(description="Human-readable title describing what the tool is doing."), ] # Unique identifier for this tool call within the session. - toolCallId: Annotated[ + tool_call_id: Annotated[ str, - Field(description="Unique identifier for this tool call within the session."), - ] - - -class ToolCallProgress(BaseModel): - # Extension point for implementations - field_meta: Annotated[ - Optional[Any], - Field(alias="_meta", description="Extension point for implementations"), - ] = None - # Replace the content collection. - content: Annotated[ - Optional[List[Union[ContentToolCallContent, FileEditToolCallContent, TerminalToolCallContent]]], - Field(description="Replace the content collection."), - ] = None - # Update the tool kind. - kind: Annotated[Optional[ToolKind], Field(description="Update the tool kind.")] = None - # Replace the locations collection. - locations: Annotated[ - Optional[List[ToolCallLocation]], - Field(description="Replace the locations collection."), - ] = None - # Update the raw input. - rawInput: Annotated[Optional[Any], Field(description="Update the raw input.")] = None - # Update the raw output. - rawOutput: Annotated[Optional[Any], Field(description="Update the raw output.")] = None - sessionUpdate: Literal["tool_call_update"] - # Update the execution status. - status: Annotated[Optional[ToolCallStatus], Field(description="Update the execution status.")] = None - # Update the human-readable title. - title: Annotated[Optional[str], Field(description="Update the human-readable title.")] = None - # The ID of the tool call being updated. - toolCallId: Annotated[str, Field(description="The ID of the tool call being updated.")] - - -class AgentResponseMessage(BaseModel): - jsonrpc: Jsonrpc - # JSON RPC Request Id - # - # An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2] - # - # The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects. - # - # [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling. - # - # [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions. - id: Annotated[ - Optional[Union[int, str]], - Field( - description="JSON RPC Request Id\n\nAn identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]\n\nThe Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.\n\n[1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.\n\n[2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions." - ), - ] = None - # All possible responses that an agent can send to a client. - # - # This enum is used internally for routing RPC responses. You typically won't need - # to use this directly - the responses are handled automatically by the connection. - # - # These are responses to the corresponding `ClientRequest` variants. - result: Annotated[ - Union[ - InitializeResponse, - AuthenticateResponse, - NewSessionResponse, - LoadSessionResponse, - SetSessionModeResponse, - PromptResponse, - SetSessionModelResponse, - Any, - ], Field( - description="All possible responses that an agent can send to a client.\n\nThis enum is used internally for routing RPC responses. You typically won't need\nto use this directly - the responses are handled automatically by the connection.\n\nThese are responses to the corresponding `ClientRequest` variants." + alias="toolCallId", + description="Unique identifier for this tool call within the session.", ), ] -class ClientRequestMessage(BaseModel): +class ToolCallStart(ToolCall): + session_update: Annotated[Literal["tool_call"], Field(alias="sessionUpdate")] + + +class AgentRequestMessage(BaseModel): jsonrpc: Jsonrpc # JSON RPC Request Id # @@ -1429,13 +1720,14 @@ class ClientRequestMessage(BaseModel): method: str params: Optional[ Union[ - InitializeRequest, - AuthenticateRequest, - NewSessionRequest, - LoadSessionRequest, - SetSessionModeRequest, - PromptRequest, - SetSessionModelRequest, + WriteTextFileRequest, + ReadTextFileRequest, + RequestPermissionRequest, + CreateTerminalRequest, + TerminalOutputRequest, + ReleaseTerminalRequest, + WaitForTerminalExitRequest, + KillTerminalCommandRequest, Any, ] ] = None @@ -1448,7 +1740,13 @@ class SessionNotification(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The ID of the session this update pertains to. - sessionId: Annotated[str, Field(description="The ID of the session this update pertains to.")] + session_id: Annotated[ + str, + Field( + alias="sessionId", + description="The ID of the session this update pertains to.", + ), + ] # The actual update content. update: Annotated[ Union[ @@ -1461,43 +1759,10 @@ class SessionNotification(BaseModel): AvailableCommandsUpdate, CurrentModeUpdate, ], - Field(description="The actual update content.", discriminator="sessionUpdate"), + Field(description="The actual update content.", discriminator="session_update"), ] -class AgentRequestMessage(BaseModel): - jsonrpc: Jsonrpc - # JSON RPC Request Id - # - # An identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2] - # - # The Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects. - # - # [1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling. - # - # [2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions. - id: Annotated[ - Optional[Union[int, str]], - Field( - description="JSON RPC Request Id\n\nAn identifier established by the Client that MUST contain a String, Number, or NULL value if included. If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers SHOULD NOT contain fractional parts [2]\n\nThe Server MUST reply with the same value in the Response object if included. This member is used to correlate the context between the two objects.\n\n[1] The use of Null as a value for the id member in a Request object is discouraged, because this specification uses a value of Null for Responses with an unknown id. Also, because JSON-RPC 1.0 uses an id value of Null for Notifications this could cause confusion in handling.\n\n[2] Fractional parts may be problematic, since many decimal fractions cannot be represented exactly as binary fractions." - ), - ] = None - method: str - params: Optional[ - Union[ - WriteTextFileRequest, - ReadTextFileRequest, - RequestPermissionRequest, - CreateTerminalRequest, - TerminalOutputRequest, - ReleaseTerminalRequest, - WaitForTerminalExitRequest, - KillTerminalCommandRequest, - Any, - ] - ] = None - - class AgentNotificationMessage(BaseModel): jsonrpc: Jsonrpc method: str @@ -1539,7 +1804,6 @@ class Model( AgentOutgoingMessage2 = AgentResponseMessage AgentOutgoingMessage3 = AgentErrorMessage AgentOutgoingMessage4 = AgentNotificationMessage -AvailableCommandInput1 = CommandInputHint ClientOutgoingMessage1 = ClientRequestMessage ClientOutgoingMessage2 = ClientResponseMessage ClientOutgoingMessage3 = ClientErrorMessage @@ -1551,7 +1815,6 @@ class Model( ContentBlock5 = EmbeddedResourceContentBlock McpServer1 = HttpMcpServer McpServer2 = SseMcpServer -McpServer3 = StdioMcpServer RequestPermissionOutcome1 = DeniedOutcome RequestPermissionOutcome2 = AllowedOutcome SessionUpdate1 = UserMessageChunk @@ -1565,3 +1828,4 @@ class Model( ToolCallContent1 = ContentToolCallContent ToolCallContent2 = FileEditToolCallContent ToolCallContent3 = TerminalToolCallContent +StdioMcpServer = McpServerStdio diff --git a/src/acp/stdio.py b/src/acp/stdio.py index 88917c9..d58644e 100644 --- a/src/acp/stdio.py +++ b/src/acp/stdio.py @@ -160,7 +160,7 @@ async def spawn_stdio_connection( @asynccontextmanager async def spawn_agent_process( - to_client: Callable[[Agent], Client], + to_client: Callable[[Agent], Client] | Client, command: str, *args: str, env: Mapping[str, str] | None = None, @@ -185,7 +185,7 @@ async def spawn_agent_process( @asynccontextmanager async def spawn_client_process( - to_agent: Callable[[AgentSideConnection], Agent], + to_agent: Callable[[Client], Agent] | Agent, command: str, *args: str, env: Mapping[str, str] | None = None, diff --git a/src/acp/utils.py b/src/acp/utils.py index e81d7ba..1be9c19 100644 --- a/src/acp/utils.py +++ b/src/acp/utils.py @@ -1,5 +1,8 @@ from __future__ import annotations +import functools +import warnings +from collections.abc import Callable from typing import Any, TypeVar from pydantic import BaseModel @@ -20,6 +23,9 @@ ] ModelT = TypeVar("ModelT", bound=BaseModel) +MethodT = TypeVar("MethodT", bound=Callable) +ClassT = TypeVar("ClassT", bound=type) +T = TypeVar("T") def serialize_params(params: BaseModel) -> dict[str, Any]: @@ -94,3 +100,75 @@ async def request_optional_model( async def notify_model(conn: Connection, method: str, params: BaseModel) -> None: """Send a notification with serialized params.""" await conn.send_notification(method, serialize_params(params)) + + +def param_model(param_cls: type[BaseModel]) -> Callable[[MethodT], MethodT]: + """Decorator to map the method parameters to a Pydantic model. + It is just a marker and does nothing at runtime. + """ + + def decorator(func: MethodT) -> MethodT: + func.__param_model__ = param_cls # type: ignore[attr-defined] + return func + + return decorator + + +def to_camel_case(snake_str: str) -> str: + """Convert snake_case strings to camelCase.""" + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def _make_legacy_func(func: Callable[..., T], model: type[BaseModel]) -> Callable[[Any, BaseModel], T]: + @functools.wraps(func) + def wrapped(self, params: BaseModel) -> T: + warnings.warn( + f"Calling {func.__name__} with {model.__name__} parameter is " # type: ignore[attr-defined] + "deprecated, please update to the new API style.", + DeprecationWarning, + stacklevel=3, + ) + kwargs = {k: getattr(params, k) for k in model.model_fields if k != "field_meta"} + if meta := getattr(params, "field_meta", None): + kwargs.update(meta) + return func(self, **kwargs) # type: ignore[arg-type] + + return wrapped + + +def _make_compatible_func(func: Callable[..., T], model: type[BaseModel]) -> Callable[..., T]: + @functools.wraps(func) + def wrapped(self, *args: Any, **kwargs: Any) -> T: + param = None + if not kwargs and len(args) == 1: + param = args[0] + elif not args and len(kwargs) == 1: + param = kwargs.get("params") + if isinstance(param, model): + warnings.warn( + f"Calling {func.__name__} with {model.__name__} parameter " # type: ignore[attr-defined] + "is deprecated, please update to the new API style.", + DeprecationWarning, + stacklevel=3, + ) + kwargs = {k: getattr(param, k) for k in model.model_fields if k != "field_meta"} + if meta := getattr(param, "field_meta", None): + kwargs.update(meta) + return func(self, **kwargs) # type: ignore[arg-type] + return func(self, *args, **kwargs) + + return wrapped + + +def compatible_class(cls: ClassT) -> ClassT: + """Mark a class as backward compatible with old API style.""" + for attr in dir(cls): + func = getattr(cls, attr) + if not callable(func) or (model := getattr(func, "__param_model__", None)) is None: + continue + if "_" in attr: + setattr(cls, to_camel_case(attr), _make_legacy_func(func, model)) + else: + setattr(cls, attr, _make_compatible_func(func, model)) + return cls diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..72fb232 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,322 @@ +import asyncio +import contextlib +from collections.abc import AsyncGenerator, Callable +from typing import Any + +import pytest +import pytest_asyncio + +from acp import ( + AuthenticateResponse, + CreateTerminalResponse, + InitializeResponse, + KillTerminalCommandResponse, + LoadSessionResponse, + NewSessionResponse, + PromptRequest, + PromptResponse, + ReadTextFileResponse, + ReleaseTerminalResponse, + RequestError, + RequestPermissionResponse, + SessionNotification, + SetSessionModelResponse, + SetSessionModeResponse, + TerminalOutputResponse, + WaitForTerminalExitResponse, + WriteTextFileResponse, +) +from acp.core import AgentSideConnection, ClientSideConnection +from acp.schema import ( + AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, + AllowedOutcome, + AudioContentBlock, + AvailableCommandsUpdate, + ClientCapabilities, + CurrentModeUpdate, + DeniedOutcome, + EmbeddedResourceContentBlock, + EnvVariable, + HttpMcpServer, + ImageContentBlock, + Implementation, + ListSessionsResponse, + McpServerStdio, + PermissionOption, + ResourceContentBlock, + SseMcpServer, + TextContentBlock, + ToolCallProgress, + ToolCallStart, + ToolCallUpdate, + UserMessageChunk, +) + + +class _Server: + def __init__(self) -> None: + self._server: asyncio.AbstractServer | None = None + self._server_reader: asyncio.StreamReader | None = None + self._server_writer: asyncio.StreamWriter | None = None + self._client_reader: asyncio.StreamReader | None = None + self._client_writer: asyncio.StreamWriter | None = None + + async def __aenter__(self): + async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + self._server_reader = reader + self._server_writer = writer + + self._server = await asyncio.start_server(handle, host="127.0.0.1", port=0) + host, port = self._server.sockets[0].getsockname()[:2] + self._client_reader, self._client_writer = await asyncio.open_connection(host, port) + + # wait until server side is set + for _ in range(100): + if self._server_reader and self._server_writer: + break + await asyncio.sleep(0.01) + assert self._server_reader and self._server_writer + assert self._client_reader and self._client_writer + return self + + async def __aexit__(self, exc_type, exc, tb): + if self._client_writer: + self._client_writer.close() + with contextlib.suppress(Exception): + await self._client_writer.wait_closed() + if self._server_writer: + self._server_writer.close() + with contextlib.suppress(Exception): + await self._server_writer.wait_closed() + if self._server: + self._server.close() + await self._server.wait_closed() + + @property + def server_writer(self) -> asyncio.StreamWriter: + assert self._server_writer is not None + return self._server_writer + + @property + def server_reader(self) -> asyncio.StreamReader: + assert self._server_reader is not None + return self._server_reader + + @property + def client_writer(self) -> asyncio.StreamWriter: + assert self._client_writer is not None + return self._client_writer + + @property + def client_reader(self) -> asyncio.StreamReader: + assert self._client_reader is not None + return self._client_reader + + +@pytest_asyncio.fixture +async def server() -> AsyncGenerator[_Server, None]: + """Provides a server-client connection pair for testing.""" + async with _Server() as server_instance: + yield server_instance + + +class TestClient: + __test__ = False # prevent pytest from collecting this class + + def __init__(self) -> None: + self.permission_outcomes: list[RequestPermissionResponse] = [] + self.files: dict[str, str] = {} + self.notifications: list[SessionNotification] = [] + self.ext_calls: list[tuple[str, dict]] = [] + self.ext_notes: list[tuple[str, dict]] = [] + + def queue_permission_cancelled(self) -> None: + self.permission_outcomes.append(RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled"))) + + def queue_permission_selected(self, option_id: str) -> None: + self.permission_outcomes.append( + RequestPermissionResponse(outcome=AllowedOutcome(option_id=option_id, outcome="selected")) + ) + + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCallUpdate, **kwargs: Any + ) -> RequestPermissionResponse: + if self.permission_outcomes: + return self.permission_outcomes.pop() + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) + + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: + self.files[str(path)] = content + return WriteTextFileResponse() + + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: + content = self.files.get(str(path), "default content") + return ReadTextFileResponse(content=content) + + async def session_update( + self, + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: + self.notifications.append(SessionNotification(session_id=session_id, update=update, field_meta=kwargs or None)) + + # Optional terminal methods (not implemented in this test client) + async def create_terminal( + self, + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> CreateTerminalResponse: + raise NotImplementedError + + async def terminal_output( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> TerminalOutputResponse: # pragma: no cover - placeholder + raise NotImplementedError + + async def release_terminal( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> ReleaseTerminalResponse | None: + raise NotImplementedError + + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> WaitForTerminalExitResponse: + raise NotImplementedError + + async def kill_terminal( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> KillTerminalCommandResponse | None: + raise NotImplementedError + + async def ext_method(self, method: str, params: dict) -> dict: + self.ext_calls.append((method, params)) + if method == "example.com/ping": + return {"response": "pong", "params": params} + raise RequestError.method_not_found(method) + + async def ext_notification(self, method: str, params: dict) -> None: + self.ext_notes.append((method, params)) + + +class TestAgent: + __test__ = False # prevent pytest from collecting this class + + def __init__(self) -> None: + self.prompts: list[PromptRequest] = [] + self.cancellations: list[str] = [] + self.ext_calls: list[tuple[str, dict]] = [] + self.ext_notes: list[tuple[str, dict]] = [] + + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: + # Avoid serializer warnings by omitting defaults + return InitializeResponse(protocol_version=protocol_version, agent_capabilities=None, auth_methods=[]) + + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + ) -> NewSessionResponse: + return NewSessionResponse(session_id="test-session-123") + + async def load_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + ) -> LoadSessionResponse | None: + return LoadSessionResponse() + + async def list_sessions( + self, cursor: str | None = None, cwd: str | None = None, **kwargs: Any + ) -> ListSessionsResponse: + return ListSessionsResponse(sessions=[], next_cursor=None) + + async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None: + return AuthenticateResponse() + + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + self.prompts.append(PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None)) + return PromptResponse(stop_reason="end_turn") + + async def cancel(self, session_id: str, **kwargs: Any) -> None: + self.cancellations.append(session_id) + + async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None: + return SetSessionModeResponse() + + async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any) -> SetSessionModelResponse | None: + return SetSessionModelResponse() + + async def ext_method(self, method: str, params: dict) -> dict: + self.ext_calls.append((method, params)) + if method == "example.com/echo": + return {"echo": params} + raise RequestError.method_not_found(method) + + async def ext_notification(self, method: str, params: dict) -> None: + self.ext_notes.append((method, params)) + + +@pytest.fixture(name="agent") +def agent_fixture() -> TestAgent: + return TestAgent() + + +@pytest.fixture(name="client") +def client_fixture() -> TestClient: + return TestClient() + + +@pytest.fixture(name="connect") +def connect_func(server, agent, client) -> Callable[[bool, bool], tuple[AgentSideConnection, ClientSideConnection]]: + def _connect( + connect_agent: bool = True, connect_client: bool = True, use_unstable_protocol: bool = False + ) -> tuple[AgentSideConnection, ClientSideConnection]: + agent_conn = None + client_conn = None + if connect_agent: + agent_conn = AgentSideConnection( + agent, + server.server_writer, + server.server_reader, + listening=True, + use_unstable_protocol=use_unstable_protocol, + ) + if connect_client: + client_conn = ClientSideConnection( + client, server.client_writer, server.client_reader, use_unstable_protocol=use_unstable_protocol + ) + return agent_conn, client_conn # type: ignore[return-value] + + return _connect diff --git a/tests/contrib/test_contrib_permissions.py b/tests/contrib/test_contrib_permissions.py index 6ed32b7..4ad9105 100644 --- a/tests/contrib/test_contrib_permissions.py +++ b/tests/contrib/test_contrib_permissions.py @@ -21,7 +21,7 @@ async def test_permission_broker_uses_tracker_state(): async def fake_requester(request: RequestPermissionRequest): captured["request"] = request return RequestPermissionResponse( - outcome=AllowedOutcome(optionId=request.options[0].optionId, outcome="selected") + outcome=AllowedOutcome(option_id=request.options[0].option_id, outcome="selected") ) tracker = ToolCallTracker(id_factory=lambda: "perm-id") @@ -30,9 +30,9 @@ async def fake_requester(request: RequestPermissionRequest): result = await broker.request_for("external", description="Perform sensitive action") assert isinstance(result.outcome, AllowedOutcome) - assert result.outcome.optionId == captured["request"].options[0].optionId - assert captured["request"].toolCall.content is not None - last_content = captured["request"].toolCall.content[-1] + assert result.outcome.option_id == captured["request"].options[0].option_id + assert captured["request"].tool_call.content is not None + last_content = captured["request"].tool_call.content[-1] assert isinstance(last_content, ContentToolCallContent) assert isinstance(last_content.content, TextContentBlock) assert last_content.content.text.startswith("Perform sensitive action") @@ -43,14 +43,14 @@ async def test_permission_broker_accepts_custom_options(): tracker = ToolCallTracker(id_factory=lambda: "custom") tracker.start("external", title="Custom options") options = [ - PermissionOption(optionId="allow", name="Allow once", kind="allow_once"), + PermissionOption(option_id="allow", name="Allow once", kind="allow_once"), ] recorded: list[str] = [] async def requester(request: RequestPermissionRequest): - recorded.append(request.options[0].optionId) + recorded.append(request.options[0].option_id) return RequestPermissionResponse( - outcome=AllowedOutcome(optionId=request.options[0].optionId, outcome="selected") + outcome=AllowedOutcome(option_id=request.options[0].option_id, outcome="selected") ) broker = PermissionBroker("session", requester, tracker=tracker) @@ -61,4 +61,4 @@ async def requester(request: RequestPermissionRequest): def test_default_permission_options_shape(): options = default_permission_options() assert len(options) == 3 - assert {opt.optionId for opt in options} == {"approve", "approve_for_session", "reject"} + assert {opt.option_id for opt in options} == {"approve", "approve_for_session", "reject"} diff --git a/tests/contrib/test_contrib_session_state.py b/tests/contrib/test_contrib_session_state.py index c2339f6..deaa467 100644 --- a/tests/contrib/test_contrib_session_state.py +++ b/tests/contrib/test_contrib_session_state.py @@ -19,21 +19,21 @@ def notification(session_id: str, update): - return SessionNotification(sessionId=session_id, update=update) + return SessionNotification(session_id=session_id, update=update) def test_session_accumulator_merges_tool_calls(): acc = SessionAccumulator() start = ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-1", + session_update="tool_call", + tool_call_id="call-1", title="Read file", status="in_progress", ) acc.apply(notification("s", start)) progress = ToolCallProgress( - sessionUpdate="tool_call_update", - toolCallId="call-1", + session_update="tool_call_update", + tool_call_id="call-1", status="completed", content=[ ContentToolCallContent( @@ -55,7 +55,7 @@ def test_session_accumulator_records_plan_and_mode(): notification( "s", AgentPlanUpdate( - sessionUpdate="plan", + session_update="plan", entries=[ PlanEntry(content="Step 1", priority="medium", status="pending"), ], @@ -63,7 +63,7 @@ def test_session_accumulator_records_plan_and_mode(): ) ) snapshot = acc.apply( - notification("s", CurrentModeUpdate(sessionUpdate="current_mode_update", currentModeId="coding")) + notification("s", CurrentModeUpdate(session_update="current_mode_update", current_mode_id="coding")) ) assert snapshot.plan_entries[0].content == "Step 1" assert snapshot.current_mode_id == "coding" @@ -75,8 +75,8 @@ def test_session_accumulator_tracks_messages_and_commands(): notification( "s", AvailableCommandsUpdate( - sessionUpdate="available_commands_update", - availableCommands=[], + session_update="available_commands_update", + available_commands=[], ), ) ) @@ -84,7 +84,7 @@ def test_session_accumulator_tracks_messages_and_commands(): notification( "s", UserMessageChunk( - sessionUpdate="user_message_chunk", + session_update="user_message_chunk", content=TextContentBlock(type="text", text="Hello"), ), ) @@ -93,7 +93,7 @@ def test_session_accumulator_tracks_messages_and_commands(): notification( "s", AgentMessageChunk( - sessionUpdate="agent_message_chunk", + session_update="agent_message_chunk", content=TextContentBlock(type="text", text="Hi!"), ), ) @@ -113,8 +113,8 @@ def test_session_accumulator_auto_resets_on_new_session(): notification( "s1", ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-1", + session_update="tool_call", + tool_call_id="call-1", title="First", ), ) @@ -123,8 +123,8 @@ def test_session_accumulator_auto_resets_on_new_session(): notification( "s2", ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-2", + session_update="tool_call", + tool_call_id="call-2", title="Second", ), ) @@ -142,8 +142,8 @@ def test_session_accumulator_rejects_cross_session_when_auto_reset_disabled(): notification( "s1", ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-1", + session_update="tool_call", + tool_call_id="call-1", title="First", ), ) @@ -153,8 +153,8 @@ def test_session_accumulator_rejects_cross_session_when_auto_reset_disabled(): notification( "s2", ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-2", + session_update="tool_call", + tool_call_id="call-2", title="Second", ), ) diff --git a/tests/contrib/test_contrib_tool_calls.py b/tests/contrib/test_contrib_tool_calls.py index a3fb290..1cfc1f0 100644 --- a/tests/contrib/test_contrib_tool_calls.py +++ b/tests/contrib/test_contrib_tool_calls.py @@ -7,10 +7,10 @@ def test_tool_call_tracker_generates_ids_and_updates(): tracker = ToolCallTracker(id_factory=lambda: "generated-id") start = tracker.start("external", title="Run command") - assert start.toolCallId == "generated-id" + assert start.tool_call_id == "generated-id" progress = tracker.progress("external", status="completed") assert isinstance(progress, ToolCallProgress) - assert progress.toolCallId == "generated-id" + assert progress.tool_call_id == "generated-id" view = tracker.view("external") assert view.status == "completed" diff --git a/tests/real_user/test_cancel_prompt_flow.py b/tests/real_user/test_cancel_prompt_flow.py index 45d7798..bdd20eb 100644 --- a/tests/real_user/test_cancel_prompt_flow.py +++ b/tests/real_user/test_cancel_prompt_flow.py @@ -1,10 +1,18 @@ import asyncio +from typing import Any import pytest -from acp import AgentSideConnection, CancelNotification, ClientSideConnection, PromptRequest, PromptResponse -from acp.schema import TextContentBlock -from tests.test_rpc import TestAgent, TestClient, _Server +from acp.schema import ( + AudioContentBlock, + EmbeddedResourceContentBlock, + ImageContentBlock, + PromptRequest, + PromptResponse, + ResourceContentBlock, + TextContentBlock, +) +from tests.conftest import TestAgent # Regression from a real user session where cancel needed to interrupt a long-running prompt. @@ -17,42 +25,51 @@ def __init__(self) -> None: self.prompt_started = asyncio.Event() self.cancel_received = asyncio.Event() - async def prompt(self, params: PromptRequest) -> PromptResponse: - self.prompts.append(params) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + self.prompts.append(PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None)) self.prompt_started.set() try: await asyncio.wait_for(self.cancel_received.wait(), timeout=1.0) except asyncio.TimeoutError as exc: msg = "Cancel notification did not arrive while prompt pending" raise AssertionError(msg) from exc - return PromptResponse(stopReason="cancelled") + return PromptResponse(stop_reason="cancelled") - async def cancel(self, params: CancelNotification) -> None: - await super().cancel(params) + async def cancel(self, session_id: str, **kwargs: Any) -> None: + await super().cancel(session_id, **kwargs) self.cancel_received.set() @pytest.mark.asyncio -async def test_cancel_reaches_agent_during_prompt() -> None: - async with _Server() as server: - agent = LongRunningAgent() - client = TestClient() - agent_conn = ClientSideConnection(lambda _conn: client, server.client_writer, server.client_reader) - _client_conn = AgentSideConnection(lambda _conn: agent, server.server_writer, server.server_reader) - - prompt_request = PromptRequest( - sessionId="sess-xyz", +@pytest.mark.parametrize("agent", [LongRunningAgent()]) +async def test_cancel_reaches_agent_during_prompt(connect, agent) -> None: + _, agent_conn = connect() + + prompt_task = asyncio.create_task( + agent_conn.prompt( + session_id="sess-xyz", prompt=[TextContentBlock(type="text", text="hello")], ) - prompt_task = asyncio.create_task(agent_conn.prompt(prompt_request)) + ) - await agent.prompt_started.wait() - assert not prompt_task.done(), "Prompt finished before cancel was sent" + await agent.prompt_started.wait() + assert not prompt_task.done(), "Prompt finished before cancel was sent" - await agent_conn.cancel(CancelNotification(sessionId="sess-xyz")) + await agent_conn.cancel(session_id="sess-xyz") - await asyncio.wait_for(agent.cancel_received.wait(), timeout=1.0) + await asyncio.wait_for(agent.cancel_received.wait(), timeout=1.0) - response = await asyncio.wait_for(prompt_task, timeout=1.0) - assert response.stopReason == "cancelled" - assert agent.cancellations == ["sess-xyz"] + response = await asyncio.wait_for(prompt_task, timeout=1.0) + assert response.stop_reason == "cancelled" + assert agent.cancellations == ["sess-xyz"] diff --git a/tests/real_user/test_permission_flow.py b/tests/real_user/test_permission_flow.py index b07817c..95b10ce 100644 --- a/tests/real_user/test_permission_flow.py +++ b/tests/real_user/test_permission_flow.py @@ -1,10 +1,20 @@ import asyncio +from typing import Any import pytest -from acp import AgentSideConnection, ClientSideConnection, PromptRequest, PromptResponse, RequestPermissionRequest -from acp.schema import PermissionOption, TextContentBlock, ToolCall -from tests.test_rpc import TestAgent, TestClient, _Server +from acp import PromptResponse +from acp.core import AgentSideConnection, ClientSideConnection +from acp.schema import ( + AudioContentBlock, + EmbeddedResourceContentBlock, + ImageContentBlock, + PermissionOption, + ResourceContentBlock, + TextContentBlock, + ToolCallUpdate, +) +from tests.conftest import TestAgent, TestClient # Regression from real-world runs where agents paused prompts to obtain user permission. @@ -17,50 +27,57 @@ def __init__(self, conn: AgentSideConnection) -> None: self._conn = conn self.permission_responses = [] - async def prompt(self, params: PromptRequest) -> PromptResponse: - permission = await self._conn.requestPermission( - RequestPermissionRequest( - sessionId=params.sessionId, - options=[ - PermissionOption(optionId="allow", name="Allow", kind="allow_once"), - PermissionOption(optionId="deny", name="Deny", kind="reject_once"), - ], - toolCall=ToolCall(toolCallId="call-1", title="Write File"), - ) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + permission = await self._conn.request_permission( + session_id=session_id, + options=[ + PermissionOption(option_id="allow", name="Allow", kind="allow_once"), + PermissionOption(option_id="deny", name="Deny", kind="reject_once"), + ], + tool_call=ToolCallUpdate(tool_call_id="call-1", title="Write File"), ) self.permission_responses.append(permission) - return await super().prompt(params) + return await super().prompt(prompt, session_id, **kwargs) @pytest.mark.asyncio -async def test_agent_request_permission_roundtrip() -> None: - async with _Server() as server: - client = TestClient() - client.queue_permission_selected("allow") +async def test_agent_request_permission_roundtrip(server) -> None: + client = TestClient() + client.queue_permission_selected("allow") - captured_agent = [] + captured_agent = [] - agent_conn = ClientSideConnection(lambda _conn: client, server.client_writer, server.client_reader) - _agent_conn = AgentSideConnection( - lambda conn: captured_agent.append(PermissionRequestAgent(conn)) or captured_agent[-1], - server.server_writer, - server.server_reader, - ) + agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type] + _agent_conn = AgentSideConnection( + lambda conn: captured_agent.append(PermissionRequestAgent(conn)) or captured_agent[-1], + server._server_writer, + server._server_reader, + listening=True, + ) - response = await asyncio.wait_for( - agent_conn.prompt( - PromptRequest( - sessionId="sess-perm", - prompt=[TextContentBlock(type="text", text="needs approval")], - ) - ), - timeout=1.0, - ) - assert response.stopReason == "end_turn" + response = await asyncio.wait_for( + agent_conn.prompt( + session_id="sess-perm", + prompt=[TextContentBlock(type="text", text="needs approval")], + ), + timeout=1.0, + ) + assert response.stop_reason == "end_turn" - assert captured_agent, "Agent was not constructed" - [agent] = captured_agent - assert agent.permission_responses, "Agent did not receive permission response" - permission_response = agent.permission_responses[0] - assert permission_response.outcome.outcome == "selected" - assert permission_response.outcome.optionId == "allow" + assert captured_agent, "Agent was not constructed" + [agent] = captured_agent + assert agent.permission_responses, "Agent did not receive permission response" + permission_response = agent.permission_responses[0] + assert permission_response.outcome.outcome == "selected" + assert permission_response.outcome.option_id == "allow" diff --git a/tests/real_user/test_stdio_limits.py b/tests/real_user/test_stdio_limits.py index 3de0ef9..eb9be15 100644 --- a/tests/real_user/test_stdio_limits.py +++ b/tests/real_user/test_stdio_limits.py @@ -22,7 +22,7 @@ def _large_line_script(size: int = LARGE_LINE_SIZE) -> str: @pytest.mark.asyncio async def test_spawn_stdio_transport_hits_default_limit() -> None: script = _large_line_script() - async with spawn_stdio_transport(sys.executable, "-c", script) as (reader, writer, _process): + async with spawn_stdio_transport(sys.executable, "-c", script) as (reader, _writer, _process): # readline() re-raises LimitOverrunError as ValueError on CPython 3.12+. with pytest.raises(ValueError): await reader.readline() @@ -36,6 +36,6 @@ async def test_spawn_stdio_transport_custom_limit_handles_large_line() -> None: "-c", script, limit=LARGE_LINE_SIZE * 2, - ) as (reader, writer, _process): + ) as (reader, _writer, _process): line = await reader.readline() assert len(line) == LARGE_LINE_SIZE + 1 diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py new file mode 100644 index 0000000..013427e --- /dev/null +++ b/tests/test_compatibility.py @@ -0,0 +1,169 @@ +import pytest + +from acp import ( + AuthenticateResponse, + InitializeResponse, + LoadSessionResponse, + NewSessionResponse, + PromptRequest, + PromptResponse, + ReadTextFileResponse, + RequestError, + RequestPermissionResponse, + SessionNotification, + SetSessionModelResponse, + SetSessionModeResponse, + WriteTextFileResponse, +) +from acp.schema import ( + AllowedOutcome, + AuthenticateRequest, + CancelNotification, + DeniedOutcome, + InitializeRequest, + LoadSessionRequest, + NewSessionRequest, + ReadTextFileRequest, + RequestPermissionRequest, + SetSessionModelRequest, + SetSessionModeRequest, + WriteTextFileRequest, +) + + +class LegacyAgent: + def __init__(self) -> None: + self.prompts: list[PromptRequest] = [] + self.cancellations: list[str] = [] + self.ext_calls: list[tuple[str, dict]] = [] + self.ext_notes: list[tuple[str, dict]] = [] + + async def initialize(self, params: InitializeRequest) -> InitializeResponse: + # Avoid serializer warnings by omitting defaults + return InitializeResponse(protocol_version=params.protocol_version, agent_capabilities=None, auth_methods=[]) + + async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: + return NewSessionResponse(session_id="test-session-123") + + async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: + return LoadSessionResponse() + + async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: + return AuthenticateResponse() + + async def prompt(self, params: PromptRequest) -> PromptResponse: + self.prompts.append(params) + return PromptResponse(stop_reason="end_turn") + + async def cancel(self, params: CancelNotification) -> None: + self.cancellations.append(params.session_id) + + async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: + return SetSessionModeResponse() + + async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse | None: + return SetSessionModelResponse() + + async def extMethod(self, method: str, params: dict) -> dict: + self.ext_calls.append((method, params)) + if method == "example.com/echo": + return {"echo": params} + raise RequestError.method_not_found(method) + + async def extNotification(self, method: str, params: dict) -> None: + self.ext_notes.append((method, params)) + + +class LegacyClient: + __test__ = False # prevent pytest from collecting this class + + def __init__(self) -> None: + self.permission_outcomes: list[RequestPermissionResponse] = [] + self.files: dict[str, str] = {} + self.notifications: list[SessionNotification] = [] + self.ext_calls: list[tuple[str, dict]] = [] + self.ext_notes: list[tuple[str, dict]] = [] + + def queue_permission_cancelled(self) -> None: + self.permission_outcomes.append(RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled"))) + + def queue_permission_selected(self, option_id: str) -> None: + self.permission_outcomes.append( + RequestPermissionResponse(outcome=AllowedOutcome(option_id=option_id, outcome="selected")) + ) + + async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: + if self.permission_outcomes: + return self.permission_outcomes.pop() + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) + + async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse: + self.files[str(params.path)] = params.content + return WriteTextFileResponse() + + async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse: + content = self.files.get(str(params.path), "default content") + return ReadTextFileResponse(content=content) + + async def sessionUpdate(self, params: SessionNotification) -> None: + self.notifications.append(params) + + # Optional terminal methods (not implemented in this test client) + async def createTerminal(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def terminalOutput(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def releaseTerminal(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def waitForTerminalExit(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def killTerminal(self, params): # pragma: no cover - placeholder + raise NotImplementedError + + async def extMethod(self, method: str, params: dict) -> dict: + self.ext_calls.append((method, params)) + if method == "example.com/ping": + return {"response": "pong", "params": params} + raise RequestError.method_not_found(method) + + async def extNotification(self, method: str, params: dict) -> None: + self.ext_notes.append((method, params)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("agent,client", [(LegacyAgent(), LegacyClient())]) +async def test_initialize_and_new_session_compat(connect, client): + client_conn, agent_conn = connect() + + with pytest.warns(DeprecationWarning) as record: + resp = await agent_conn.newSession(NewSessionRequest(cwd="/home/tmp", mcp_servers=[])) + + assert len(record) == 2 + assert "Calling new_session with NewSessionRequest parameter is deprecated" in str(record[0].message) + assert "The old style method LegacyAgent.newSession is deprecated" in str(record[1].message) + + assert isinstance(resp, NewSessionResponse) + assert resp.session_id == "test-session-123" + + with pytest.warns(DeprecationWarning) as record: + resp = await agent_conn.new_session(cwd="/home/tmp", mcp_servers=[]) + assert len(record) == 1 + assert "The old style method LegacyAgent.newSession is deprecated" in str(record[0].message) + + with pytest.warns(DeprecationWarning) as record: + await client_conn.writeTextFile( + WriteTextFileRequest(path="test.txt", content="Hello, World!", session_id="test-session-123") + ) + + assert len(record) == 2 + assert client.files["test.txt"] == "Hello, World!" + + with pytest.warns(DeprecationWarning) as record: + resp = await client_conn.read_text_file(path="test.txt", session_id="test-session-123") + + assert len(record) == 1 + assert resp.content == "Hello, World!" diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 5373045..25b5457 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -1,510 +1,325 @@ import asyncio -import contextlib import json import sys from pathlib import Path +from typing import Any import pytest from acp import ( Agent, - AgentSideConnection, - AuthenticateRequest, AuthenticateResponse, - CancelNotification, Client, - ClientSideConnection, - InitializeRequest, InitializeResponse, - LoadSessionRequest, LoadSessionResponse, - NewSessionRequest, NewSessionResponse, PromptRequest, PromptResponse, - ReadTextFileRequest, - ReadTextFileResponse, - RequestError, RequestPermissionRequest, RequestPermissionResponse, - SessionNotification, - SetSessionModelRequest, - SetSessionModelResponse, - SetSessionModeRequest, SetSessionModeResponse, - WriteTextFileRequest, WriteTextFileResponse, - session_notification, spawn_agent_process, start_tool_call, update_agent_message_text, update_tool_call, ) +from acp.exceptions import RequestError from acp.schema import ( AgentMessageChunk, AllowedOutcome, + AudioContentBlock, + ClientCapabilities, DeniedOutcome, + EmbeddedResourceContentBlock, + HttpMcpServer, + ImageContentBlock, + Implementation, + ListSessionsResponse, + McpServerStdio, PermissionOption, + ResourceContentBlock, + SetSessionModelResponse, + SseMcpServer, TextContentBlock, - ToolCall, ToolCallLocation, ToolCallProgress, ToolCallStart, + ToolCallUpdate, UserMessageChunk, ) +from tests.conftest import TestClient -# --------------------- Test Utilities --------------------- - - -class _Server: - def __init__(self) -> None: - self._server: asyncio.AbstractServer | None = None - self.server_reader: asyncio.StreamReader | None = None - self.server_writer: asyncio.StreamWriter | None = None - self.client_reader: asyncio.StreamReader | None = None - self.client_writer: asyncio.StreamWriter | None = None - - async def __aenter__(self): - async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - self.server_reader = reader - self.server_writer = writer - - self._server = await asyncio.start_server(handle, host="127.0.0.1", port=0) - host, port = self._server.sockets[0].getsockname()[:2] - self.client_reader, self.client_writer = await asyncio.open_connection(host, port) - - # wait until server side is set - for _ in range(100): - if self.server_reader and self.server_writer: - break - await asyncio.sleep(0.01) - assert self.server_reader and self.server_writer - assert self.client_reader and self.client_writer - return self - - async def __aexit__(self, exc_type, exc, tb): - if self.client_writer: - self.client_writer.close() - with contextlib.suppress(Exception): - await self.client_writer.wait_closed() - if self.server_writer: - self.server_writer.close() - with contextlib.suppress(Exception): - await self.server_writer.wait_closed() - if self._server: - self._server.close() - await self._server.wait_closed() - - -# --------------------- Test Doubles ----------------------- - - -class TestClient(Client): - __test__ = False # prevent pytest from collecting this class - - def __init__(self) -> None: - self.permission_outcomes: list[RequestPermissionResponse] = [] - self.files: dict[str, str] = {} - self.notifications: list[SessionNotification] = [] - self.ext_calls: list[tuple[str, dict]] = [] - self.ext_notes: list[tuple[str, dict]] = [] - - def queue_permission_cancelled(self) -> None: - self.permission_outcomes.append(RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled"))) - - def queue_permission_selected(self, option_id: str) -> None: - self.permission_outcomes.append( - RequestPermissionResponse(outcome=AllowedOutcome(optionId=option_id, outcome="selected")) - ) - - async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: - if self.permission_outcomes: - return self.permission_outcomes.pop() - return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) - - async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse: - self.files[str(params.path)] = params.content - return WriteTextFileResponse() - - async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse: - content = self.files.get(str(params.path), "default content") - return ReadTextFileResponse(content=content) - - async def sessionUpdate(self, params: SessionNotification) -> None: - self.notifications.append(params) - - # Optional terminal methods (not implemented in this test client) - async def createTerminal(self, params): # pragma: no cover - placeholder - raise NotImplementedError - - async def terminalOutput(self, params): # pragma: no cover - placeholder - raise NotImplementedError - - async def releaseTerminal(self, params): # pragma: no cover - placeholder - raise NotImplementedError - - async def waitForTerminalExit(self, params): # pragma: no cover - placeholder - raise NotImplementedError - - async def killTerminal(self, params): # pragma: no cover - placeholder - raise NotImplementedError - - async def extMethod(self, method: str, params: dict) -> dict: - self.ext_calls.append((method, params)) - if method == "example.com/ping": - return {"response": "pong", "params": params} - raise RequestError.method_not_found(method) - - async def extNotification(self, method: str, params: dict) -> None: - self.ext_notes.append((method, params)) - - -class TestAgent(Agent): - __test__ = False # prevent pytest from collecting this class - - def __init__(self) -> None: - self.prompts: list[PromptRequest] = [] - self.cancellations: list[str] = [] - self.ext_calls: list[tuple[str, dict]] = [] - self.ext_notes: list[tuple[str, dict]] = [] - - async def initialize(self, params: InitializeRequest) -> InitializeResponse: - # Avoid serializer warnings by omitting defaults - return InitializeResponse(protocolVersion=params.protocolVersion, agentCapabilities=None, authMethods=[]) - - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: - return NewSessionResponse(sessionId="test-session-123") - - async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse: - return LoadSessionResponse() - - async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse: - return AuthenticateResponse() +# ------------------------ Tests -------------------------- - async def prompt(self, params: PromptRequest) -> PromptResponse: - self.prompts.append(params) - return PromptResponse(stopReason="end_turn") - async def cancel(self, params: CancelNotification) -> None: - self.cancellations.append(params.sessionId) +@pytest.mark.asyncio +async def test_initialize_and_new_session(connect): + _, agent_conn = connect() - async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse: - return SetSessionModeResponse() + resp = await agent_conn.initialize(protocol_version=1) + assert isinstance(resp, InitializeResponse) + assert resp.protocol_version == 1 - async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse: - return SetSessionModelResponse() + new_sess = await agent_conn.new_session(mcp_servers=[], cwd="/test") + assert new_sess.session_id == "test-session-123" - async def extMethod(self, method: str, params: dict) -> dict: - self.ext_calls.append((method, params)) - if method == "example.com/echo": - return {"echo": params} - raise RequestError.method_not_found(method) + load_resp = await agent_conn.load_session(session_id=new_sess.session_id, cwd="/test", mcp_servers=[]) + assert isinstance(load_resp, LoadSessionResponse) - async def extNotification(self, method: str, params: dict) -> None: - self.ext_notes.append((method, params)) + auth_resp = await agent_conn.authenticate(method_id="password") + assert isinstance(auth_resp, AuthenticateResponse) + mode_resp = await agent_conn.set_session_mode(session_id=new_sess.session_id, mode_id="ask") + assert isinstance(mode_resp, SetSessionModeResponse) -# ------------------------ Tests -------------------------- + with pytest.raises(RequestError), pytest.warns(UserWarning) as record: + await agent_conn.set_session_model(session_id=new_sess.session_id, model_id="gpt-4o") + assert len(record) == 1 @pytest.mark.asyncio -async def test_initialize_and_new_session(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - # server side is agent; client side is client - agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - _client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) - - resp = await agent_conn.initialize(InitializeRequest(protocolVersion=1)) - assert isinstance(resp, InitializeResponse) - assert resp.protocolVersion == 1 - - new_sess = await agent_conn.newSession(NewSessionRequest(mcpServers=[], cwd="/test")) - assert new_sess.sessionId == "test-session-123" - - load_resp = await agent_conn.loadSession( - LoadSessionRequest(sessionId=new_sess.sessionId, cwd="/test", mcpServers=[]) - ) - assert isinstance(load_resp, LoadSessionResponse) - - auth_resp = await agent_conn.authenticate(AuthenticateRequest(methodId="password")) - assert isinstance(auth_resp, AuthenticateResponse) +async def test_bidirectional_file_ops(client, connect): + client.files["/test/file.txt"] = "Hello, World!" + client_conn, _ = connect() - mode_resp = await agent_conn.setSessionMode(SetSessionModeRequest(sessionId=new_sess.sessionId, modeId="ask")) - assert isinstance(mode_resp, SetSessionModeResponse) + # Agent asks client to read + res = await client_conn.read_text_file(session_id="sess", path="/test/file.txt") + assert res.content == "Hello, World!" - model_resp = await agent_conn.setSessionModel( - SetSessionModelRequest(sessionId=new_sess.sessionId, modelId="gpt-4o") - ) - assert isinstance(model_resp, SetSessionModelResponse) + # Agent asks client to write + write_result = await client_conn.write_text_file(session_id="sess", path="/test/file.txt", content="Updated") + assert isinstance(write_result, WriteTextFileResponse) + assert client.files["/test/file.txt"] == "Updated" @pytest.mark.asyncio -async def test_bidirectional_file_ops(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - client.files["/test/file.txt"] = "Hello, World!" - _agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) - - # Agent asks client to read - res = await client_conn.readTextFile(ReadTextFileRequest(sessionId="sess", path="/test/file.txt")) - assert res.content == "Hello, World!" - - # Agent asks client to write - write_result = await client_conn.writeTextFile( - WriteTextFileRequest(sessionId="sess", path="/test/file.txt", content="Updated") - ) - assert isinstance(write_result, WriteTextFileResponse) - assert client.files["/test/file.txt"] == "Updated" - +async def test_cancel_notification_and_capture_wire(connect, agent): + _, agent_conn = connect() + # Send cancel notification from client-side connection to agent + await agent_conn.cancel(session_id="test-123") -@pytest.mark.asyncio -async def test_cancel_notification_and_capture_wire(): - async with _Server() as s: - # Build only agent-side (server) connection. Client side: raw reader to inspect wire - agent = TestAgent() - client = TestClient() - agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - _client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) - - # Send cancel notification from client-side connection to agent - await agent_conn.cancel(CancelNotification(sessionId="test-123")) - - # Read raw line from server peer (it will be consumed by agent receive loop quickly). - # Instead, wait a brief moment and assert agent recorded it. - for _ in range(50): - if agent.cancellations: - break - await asyncio.sleep(0.01) - assert agent.cancellations == ["test-123"] + # Read raw line from server peer (it will be consumed by agent receive loop quickly). + # Instead, wait a brief moment and assert agent recorded it. + for _ in range(50): + if agent.cancellations: + break + await asyncio.sleep(0.01) + assert agent.cancellations == ["test-123"] @pytest.mark.asyncio -async def test_session_notifications_flow(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - _agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) - - # Agent -> Client notifications - await client_conn.sessionUpdate( - SessionNotification( - sessionId="sess", - update=AgentMessageChunk( - sessionUpdate="agent_message_chunk", - content=TextContentBlock(type="text", text="Hello"), - ), - ) - ) - await client_conn.sessionUpdate( - SessionNotification( - sessionId="sess", - update=UserMessageChunk( - sessionUpdate="user_message_chunk", - content=TextContentBlock(type="text", text="World"), - ), - ) - ) - - # Wait for async dispatch - for _ in range(50): - if len(client.notifications) >= 2: - break - await asyncio.sleep(0.01) - assert len(client.notifications) >= 2 - assert client.notifications[0].sessionId == "sess" +async def test_session_notifications_flow(connect, client): + client_conn, _ = connect() + + # Agent -> Client notifications + await client_conn.session_update( + session_id="sess", + update=AgentMessageChunk( + session_update="agent_message_chunk", + content=TextContentBlock(type="text", text="Hello"), + ), + ) + await client_conn.session_update( + session_id="sess", + update=UserMessageChunk( + session_update="user_message_chunk", + content=TextContentBlock(type="text", text="World"), + ), + ) + + # Wait for async dispatch + for _ in range(50): + if len(client.notifications) >= 2: + break + await asyncio.sleep(0.01) + assert len(client.notifications) >= 2 + assert client.notifications[0].session_id == "sess" @pytest.mark.asyncio -async def test_concurrent_reads(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - for i in range(5): - client.files[f"/test/file{i}.txt"] = f"Content {i}" - _agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) +async def test_concurrent_reads(connect, client): + for i in range(5): + client.files[f"/test/file{i}.txt"] = f"Content {i}" + client_conn, _ = connect() - async def read_one(i: int): - return await client_conn.readTextFile(ReadTextFileRequest(sessionId="sess", path=f"/test/file{i}.txt")) + async def read_one(i: int): + return await client_conn.read_text_file(session_id="sess", path=f"/test/file{i}.txt") - results = await asyncio.gather(*(read_one(i) for i in range(5))) - for i, res in enumerate(results): - assert res.content == f"Content {i}" + results = await asyncio.gather(*(read_one(i) for i in range(5))) + for i, res in enumerate(results): + assert res.content == f"Content {i}" @pytest.mark.asyncio -async def test_invalid_params_results_in_error_response(): - async with _Server() as s: - # Only start agent-side (server) so we can inject raw request from client socket - agent = TestAgent() - _server_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) +async def test_invalid_params_results_in_error_response(connect, server): + # Only start agent-side (server) so we can inject raw request from client socket + connect(connect_agent=True, connect_client=False) - # Send initialize with wrong param type (protocolVersion should be int) - req = {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "oops"}} - s.client_writer.write((json.dumps(req) + "\n").encode()) - await s.client_writer.drain() + # Send initialize with wrong param type (protocolVersion should be int) + req = {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "oops"}} + server.client_writer.write((json.dumps(req) + "\n").encode()) + await server.client_writer.drain() - # Read response - line = await asyncio.wait_for(s.client_reader.readline(), timeout=1) - resp = json.loads(line) - assert resp["id"] == 1 - assert "error" in resp - assert resp["error"]["code"] == -32602 # invalid params + # Read response + line = await asyncio.wait_for(server.client_reader.readline(), timeout=1) + resp = json.loads(line) + assert resp["id"] == 1 + assert "error" in resp + assert resp["error"]["code"] == -32602 # invalid params @pytest.mark.asyncio -async def test_method_not_found_results_in_error_response(): - async with _Server() as s: - agent = TestAgent() - _server_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) +async def test_method_not_found_results_in_error_response(connect, server): + connect(connect_agent=True, connect_client=False) - req = {"jsonrpc": "2.0", "id": 2, "method": "unknown/method", "params": {}} - s.client_writer.write((json.dumps(req) + "\n").encode()) - await s.client_writer.drain() + req = {"jsonrpc": "2.0", "id": 2, "method": "unknown/method", "params": {}} + server.client_writer.write((json.dumps(req) + "\n").encode()) + await server.client_writer.drain() - line = await asyncio.wait_for(s.client_reader.readline(), timeout=1) - resp = json.loads(line) - assert resp["id"] == 2 - assert resp["error"]["code"] == -32601 # method not found + line = await asyncio.wait_for(server.client_reader.readline(), timeout=1) + resp = json.loads(line) + assert resp["id"] == 2 + assert resp["error"]["code"] == -32601 # method not found @pytest.mark.asyncio -async def test_set_session_mode_and_extensions(): - async with _Server() as s: - agent = TestAgent() - client = TestClient() - agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) +async def test_set_session_mode_and_extensions(connect, agent, client): + client_conn, agent_conn = connect() - # setSessionMode - resp = await agent_conn.setSessionMode(SetSessionModeRequest(sessionId="sess", modeId="yolo")) - assert isinstance(resp, SetSessionModeResponse) + # setSessionMode + resp = await agent_conn.set_session_mode(session_id="sess", mode_id="yolo") + assert isinstance(resp, SetSessionModeResponse) - model_resp = await agent_conn.setSessionModel(SetSessionModelRequest(sessionId="sess", modelId="gpt-4o-mini")) - assert isinstance(model_resp, SetSessionModelResponse) + with pytest.raises(RequestError), pytest.warns(UserWarning) as record: + await agent_conn.set_session_model(session_id="sess", model_id="gpt-4o-mini") + assert len(record) == 1 - # extMethod - echo = await agent_conn.extMethod("example.com/echo", {"x": 1}) - assert echo == {"echo": {"x": 1}} + # extMethod + echo = await agent_conn.ext_method("example.com/echo", {"x": 1}) + assert echo == {"echo": {"x": 1}} - # extNotification - await agent_conn.extNotification("note", {"y": 2}) - # allow dispatch - await asyncio.sleep(0.05) - assert agent.ext_notes and agent.ext_notes[-1][0] == "note" + # extNotification + await agent_conn.ext_notification("note", {"y": 2}) + # allow dispatch + await asyncio.sleep(0.05) + assert agent.ext_notes and agent.ext_notes[-1][0] == "note" - # client extension method - ping = await client_conn.extMethod("example.com/ping", {"k": 3}) - assert ping == {"response": "pong", "params": {"k": 3}} - assert client.ext_calls and client.ext_calls[-1] == ("example.com/ping", {"k": 3}) + # client extension method + ping = await client_conn.ext_method("example.com/ping", {"k": 3}) + assert ping == {"response": "pong", "params": {"k": 3}} + assert client.ext_calls and client.ext_calls[-1] == ("example.com/ping", {"k": 3}) @pytest.mark.asyncio -async def test_ignore_invalid_messages(): - async with _Server() as s: - agent = TestAgent() - _server_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) +async def test_ignore_invalid_messages(connect, server): + connect(connect_agent=True, connect_client=False) - # Message without id and method - msg1 = {"jsonrpc": "2.0"} - s.client_writer.write((json.dumps(msg1) + "\n").encode()) - await s.client_writer.drain() + # Message without id and method + msg1 = {"jsonrpc": "2.0"} + server.client_writer.write((json.dumps(msg1) + "\n").encode()) + await server.client_writer.drain() - # Message without jsonrpc and without id/method - msg2 = {"foo": "bar"} - s.client_writer.write((json.dumps(msg2) + "\n").encode()) - await s.client_writer.drain() + # Message without jsonrpc and without id/method + msg2 = {"foo": "bar"} + server.client_writer.write((json.dumps(msg2) + "\n").encode()) + await server.client_writer.drain() - # Should not receive any response lines - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(s.client_reader.readline(), timeout=0.1) + # Should not receive any response lines + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(server.client_reader.readline(), timeout=0.1) class _ExampleAgent(Agent): __test__ = False def __init__(self) -> None: - self._conn: AgentSideConnection | None = None + self._conn: Client | None = None self.permission_response: RequestPermissionResponse | None = None self.prompt_requests: list[PromptRequest] = [] - def bind(self, conn: AgentSideConnection) -> "_ExampleAgent": + def on_connect(self, conn: Client) -> None: self._conn = conn - return self - - async def initialize(self, params: InitializeRequest) -> InitializeResponse: - return InitializeResponse(protocolVersion=params.protocolVersion) - - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: - return NewSessionResponse(sessionId="sess_demo") - async def prompt(self, params: PromptRequest) -> PromptResponse: + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: + return InitializeResponse(protocol_version=protocol_version) + + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + ) -> NewSessionResponse: + return NewSessionResponse(session_id="sess_demo") + + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: assert self._conn is not None - self.prompt_requests.append(params) + self.prompt_requests.append(PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None)) - await self._conn.sessionUpdate( - session_notification( - params.sessionId, - update_agent_message_text("I'll help you with that."), - ) + await self._conn.session_update( + session_id, + update_agent_message_text("I'll help you with that."), ) - await self._conn.sessionUpdate( - session_notification( - params.sessionId, - start_tool_call( - "call_1", - "Modifying configuration", - kind="edit", - status="pending", - locations=[ToolCallLocation(path="/project/config.json")], - raw_input={"path": "/project/config.json"}, - ), - ) + await self._conn.session_update( + session_id, + start_tool_call( + "call_1", + "Modifying configuration", + kind="edit", + status="pending", + locations=[ToolCallLocation(path="/project/config.json")], + raw_input={"path": "/project/config.json"}, + ), ) - permission_request = RequestPermissionRequest( - sessionId=params.sessionId, - toolCall=ToolCall( - toolCallId="call_1", + permission_request = { + "session_id": session_id, + "tool_call": ToolCallUpdate( + tool_call_id="call_1", title="Modifying configuration", kind="edit", status="pending", locations=[ToolCallLocation(path="/project/config.json")], - rawInput={"path": "/project/config.json"}, + raw_input={"path": "/project/config.json"}, ), - options=[ - PermissionOption(kind="allow_once", name="Allow", optionId="allow"), - PermissionOption(kind="reject_once", name="Reject", optionId="reject"), + "options": [ + PermissionOption(kind="allow_once", name="Allow", option_id="allow"), + PermissionOption(kind="reject_once", name="Reject", option_id="reject"), ], - ) - response = await self._conn.requestPermission(permission_request) + } + response = await self._conn.request_permission(**permission_request) self.permission_response = response - if isinstance(response.outcome, AllowedOutcome) and response.outcome.optionId == "allow": - await self._conn.sessionUpdate( - session_notification( - params.sessionId, - update_tool_call( - "call_1", - status="completed", - raw_output={"success": True}, - ), - ) + if isinstance(response.outcome, AllowedOutcome) and response.outcome.option_id == "allow": + await self._conn.session_update( + session_id, + update_tool_call( + "call_1", + status="completed", + raw_output={"success": True}, + ), ) - await self._conn.sessionUpdate( - session_notification( - params.sessionId, - update_agent_message_text("Done."), - ) + await self._conn.session_update( + session_id, + update_agent_message_text("Done."), ) - return PromptResponse(stopReason="end_turn") + return PromptResponse(stop_reason="end_turn") class _ExampleClient(TestClient): @@ -514,72 +329,82 @@ def __init__(self) -> None: super().__init__() self.permission_requests: list[RequestPermissionRequest] = [] - async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: + async def request_permission( + self, + options: list[PermissionOption] | RequestPermissionRequest, + session_id: str | None = None, + tool_call: ToolCallUpdate | None = None, + **kwargs: Any, + ) -> RequestPermissionResponse: + if isinstance(options, RequestPermissionRequest): + params = options + else: + assert session_id is not None and tool_call is not None + params = RequestPermissionRequest( + options=options, + session_id=session_id, + tool_call=tool_call, + field_meta=kwargs or None, + ) self.permission_requests.append(params) if not params.options: return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) option = params.options[0] - return RequestPermissionResponse(outcome=AllowedOutcome(optionId=option.optionId, outcome="selected")) + return RequestPermissionResponse(outcome=AllowedOutcome(option_id=option.option_id, outcome="selected")) @pytest.mark.asyncio -async def test_example_agent_permission_flow(): - async with _Server() as s: - agent = _ExampleAgent() - client = _ExampleClient() - - agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - AgentSideConnection(lambda conn: agent.bind(conn), s.server_writer, s.server_reader) - - init = await agent_conn.initialize(InitializeRequest(protocolVersion=1)) - assert init.protocolVersion == 1 - - session = await agent_conn.newSession(NewSessionRequest(mcpServers=[], cwd="/workspace")) - assert session.sessionId == "sess_demo" - - prompt = PromptRequest( - sessionId=session.sessionId, - prompt=[TextContentBlock(type="text", text="Please edit config")], - ) - resp = await agent_conn.prompt(prompt) - assert resp.stopReason == "end_turn" - - for _ in range(50): - if len(client.notifications) >= 4: - break - await asyncio.sleep(0.02) - - assert len(client.notifications) >= 4 - session_updates = [getattr(note.update, "sessionUpdate", None) for note in client.notifications] - assert session_updates[:4] == ["agent_message_chunk", "tool_call", "tool_call_update", "agent_message_chunk"] - - first_message = client.notifications[0].update - assert isinstance(first_message, AgentMessageChunk) - assert isinstance(first_message.content, TextContentBlock) - assert first_message.content.text == "I'll help you with that." - - tool_call = client.notifications[1].update - assert isinstance(tool_call, ToolCallStart) - assert tool_call.title == "Modifying configuration" - assert tool_call.status == "pending" - - tool_update = client.notifications[2].update - assert isinstance(tool_update, ToolCallProgress) - assert tool_update.status == "completed" - assert tool_update.rawOutput == {"success": True} - - final_message = client.notifications[3].update - assert isinstance(final_message, AgentMessageChunk) - assert isinstance(final_message.content, TextContentBlock) - assert final_message.content.text == "Done." - - assert len(client.permission_requests) == 1 - options = client.permission_requests[0].options - assert [opt.optionId for opt in options] == ["allow", "reject"] - - assert agent.permission_response is not None - assert isinstance(agent.permission_response.outcome, AllowedOutcome) - assert agent.permission_response.outcome.optionId == "allow" +@pytest.mark.parametrize("agent,client", [(_ExampleAgent(), _ExampleClient())]) +async def test_example_agent_permission_flow(connect, client, agent): + _, agent_conn = connect() + + init = await agent_conn.initialize(protocol_version=1) + assert init.protocol_version == 1 + + session = await agent_conn.new_session(mcp_servers=[], cwd="/workspace") + assert session.session_id == "sess_demo" + + resp = await agent_conn.prompt( + session_id=session.session_id, + prompt=[TextContentBlock(type="text", text="Please edit config")], + ) + assert resp.stop_reason == "end_turn" + for _ in range(50): + if len(client.notifications) >= 4: + break + await asyncio.sleep(0.02) + + assert len(client.notifications) >= 4 + session_updates = [getattr(note.update, "session_update", None) for note in client.notifications] + assert session_updates[:4] == ["agent_message_chunk", "tool_call", "tool_call_update", "agent_message_chunk"] + + first_message = client.notifications[0].update + assert isinstance(first_message, AgentMessageChunk) + assert isinstance(first_message.content, TextContentBlock) + assert first_message.content.text == "I'll help you with that." + + tool_call = client.notifications[1].update + assert isinstance(tool_call, ToolCallStart) + assert tool_call.title == "Modifying configuration" + assert tool_call.status == "pending" + + tool_update = client.notifications[2].update + assert isinstance(tool_update, ToolCallProgress) + assert tool_update.status == "completed" + assert tool_update.raw_output == {"success": True} + + final_message = client.notifications[3].update + assert isinstance(final_message, AgentMessageChunk) + assert isinstance(final_message.content, TextContentBlock) + assert final_message.content.text == "Done." + + assert len(client.permission_requests) == 1 + options = client.permission_requests[0].options + assert [opt.option_id for opt in options] == ["allow", "reject"] + + assert agent.permission_response is not None + assert isinstance(agent.permission_response.outcome, AllowedOutcome) + assert agent.permission_response.outcome.option_id == "allow" @pytest.mark.asyncio @@ -589,15 +414,14 @@ async def test_spawn_agent_process_roundtrip(tmp_path): test_client = TestClient() - async with spawn_agent_process(lambda _agent: test_client, sys.executable, str(script)) as (client_conn, process): - init = await client_conn.initialize(InitializeRequest(protocolVersion=1)) + async with spawn_agent_process(test_client, sys.executable, str(script)) as (client_conn, process): + init = await client_conn.initialize(protocol_version=1) assert isinstance(init, InitializeResponse) - session = await client_conn.newSession(NewSessionRequest(cwd=str(tmp_path), mcpServers=[])) - prompt = PromptRequest( - sessionId=session.sessionId, + session = await client_conn.new_session(mcp_servers=[], cwd=str(tmp_path)) + await client_conn.prompt( + session_id=session.session_id, prompt=[TextContentBlock(type="text", text="hi spawn")], ) - await client_conn.prompt(prompt) # Wait for echo agent notification to arrive for _ in range(50): @@ -608,3 +432,14 @@ async def test_spawn_agent_process_roundtrip(tmp_path): assert test_client.notifications assert process.returncode is not None + + +@pytest.mark.asyncio +async def test_call_unstable_protocol(connect): + _, agent_conn = connect(use_unstable_protocol=True) + + resp = await agent_conn.list_sessions() + assert isinstance(resp, ListSessionsResponse) + + resp = await agent_conn.set_session_model(session_id="sess", model_id="gpt-4o-mini") + assert isinstance(resp, SetSessionModelResponse) diff --git a/tests/test_utils.py b/tests/test_utils.py index fbbf08e..47706d9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,12 @@ +import pytest + from acp.schema import AgentMessageChunk, TextContentBlock from acp.utils import serialize_params def test_serialize_params_uses_meta_aliases() -> None: chunk = AgentMessageChunk( - sessionUpdate="agent_message_chunk", + session_update="agent_message_chunk", content=TextContentBlock(type="text", text="demo", field_meta={"inner": "value"}), field_meta={"outer": "value"}, ) @@ -17,7 +19,7 @@ def test_serialize_params_uses_meta_aliases() -> None: def test_serialize_params_omits_meta_when_absent() -> None: chunk = AgentMessageChunk( - sessionUpdate="agent_message_chunk", + session_update="agent_message_chunk", content=TextContentBlock(type="text", text="demo"), ) @@ -29,10 +31,25 @@ def test_serialize_params_omits_meta_when_absent() -> None: def test_field_meta_can_be_set_by_name_on_models() -> None: chunk = AgentMessageChunk( - sessionUpdate="agent_message_chunk", + session_update="agent_message_chunk", content=TextContentBlock(type="text", text="demo", field_meta={"inner": "value"}), field_meta={"outer": "value"}, ) assert chunk.field_meta == {"outer": "value"} assert chunk.content.field_meta == {"inner": "value"} + + +@pytest.mark.parametrize( + "original, expected", + [ + ("simple_test", "simpleTest"), + ("another_example_here", "anotherExampleHere"), + ("lowercase", "lowercase"), + ("alreadyCamelCase", "alreadyCamelCase"), + ], +) +def test_to_camel_case(original, expected) -> None: + from acp.utils import to_camel_case + + assert to_camel_case(original) == expected diff --git a/uv.lock b/uv.lock index 4c6b491..10967c7 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,6 @@ version = 1 revision = 3 -requires-python = ">=3.10, <=3.14" +requires-python = ">=3.10, <3.15" [[package]] name = "agent-client-protocol"