From 370988e25b1b75e52885ddd654f375fcc37ab29f Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Wed, 19 Nov 2025 18:16:18 +0800 Subject: [PATCH 1/2] feat: Auto-generate agent/client methods based on the schema Signed-off-by: Frost Ming --- .pre-commit-config.yaml | 1 - Makefile | 2 + docs/quickstart.md | 25 +- examples/agent.py | 90 ++++-- examples/client.py | 92 ++++-- examples/duet.py | 6 +- examples/echo_agent.py | 53 +++- examples/gemini.py | 166 +++++----- pyproject.toml | 4 - scripts/gen_all.py | 17 +- scripts/gen_schema.py | 39 +-- scripts/gen_signature.py | 133 ++++++++ src/acp/agent/connection.py | 128 +++++--- src/acp/agent/router.py | 41 ++- src/acp/client/connection.py | 115 ++++--- src/acp/client/router.py | 49 ++- src/acp/interfaces.py | 175 ++++++++--- src/acp/meta.py | 23 +- src/acp/router.py | 231 ++++++-------- src/acp/utils.py | 14 + tests/real_user/test_cancel_prompt_flow.py | 42 ++- tests/real_user/test_permission_flow.py | 50 +-- tests/test_rpc.py | 338 +++++++++++++-------- 23 files changed, 1137 insertions(+), 697 deletions(-) create mode 100644 scripts/gen_signature.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c30c852..40b46dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,4 +21,3 @@ repos: args: [ --exit-non-zero-on-fix ] exclude: ^src/acp/(meta|schema)\.py$ - id: ruff-format - exclude: ^src/acp/(meta|schema)\.py$ diff --git a/Makefile b/Makefile index f995835..a6f88bf 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ install: ## Install the virtual environment and install the pre-commit hooks gen-all: ## Generate all code from schema @echo "šŸš€ Generating all code" @uv run scripts/gen_all.py + @uv run ruff check --fix + @uv run ruff format . .PHONY: check check: ## Run code quality tools. diff --git a/docs/quickstart.md b/docs/quickstart.md index 824315c..cd15df8 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -76,27 +76,26 @@ from pathlib import Path from acp import spawn_agent_process, text_block from acp.interfaces import Client -from acp.schema import InitializeRequest, NewSessionRequest, PromptRequest, SessionNotification class SimpleClient(Client): - async def requestPermission(self, params): # pragma: no cover - minimal stub + async def request_permission( + self, options, session_id, tool_call, **kwargs: Any + ) return {"outcome": {"outcome": "cancelled"}} - async def sessionUpdate(self, params: SessionNotification) -> None: - print("update:", params.session_id, params.update) + async def session_update(self, session_id, update, **kwargs): + print("update:", session_id, update) async def main() -> None: script = Path("examples/echo_agent.py") async with spawn_agent_process(lambda _agent: SimpleClient(), sys.executable, str(script)) as (conn, _proc): - await conn.initialize(InitializeRequest(protocol_version=1)) - session = await conn.newSession(NewSessionRequest(cwd=str(script.parent), mcp_servers=[])) + await conn.initialize(protocol_version=1) + session = await conn.new_session(cwd=str(script.parent), mcp_servers=[]) await conn.prompt( - PromptRequest( - session_id=session.session_id, - prompt=[text_block("Hello from spawn!")], - ) + session_id=session.session_id, + prompt=[text_block("Hello from spawn!")], ) asyncio.run(main()) @@ -111,12 +110,12 @@ _Swap the echo demo for your own `Agent` subclass._ Create your own agent by subclassing `acp.Agent`. The pattern mirrors the echo example: ```python -from acp import Agent, PromptRequest, PromptResponse +from acp import Agent, PromptResponse class MyAgent(Agent): - async def prompt(self, params: PromptRequest) -> PromptResponse: - # inspect params.prompt, stream updates, then finish the turn + async def prompt(self, prompt, session_id, **kwargs) -> PromptResponse: + # inspect prompt, stream updates, then finish the turn return PromptResponse(stop_reason="end_turn") ``` diff --git a/examples/agent.py b/examples/agent.py index ce09e6c..7f66216 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -5,26 +5,31 @@ from acp import ( Agent, AgentSideConnection, - AuthenticateRequest, AuthenticateResponse, - CancelNotification, - InitializeRequest, InitializeResponse, - LoadSessionRequest, LoadSessionResponse, - NewSessionRequest, NewSessionResponse, - PromptRequest, PromptResponse, - SetSessionModeRequest, SetSessionModeResponse, - session_notification, stdio_streams, text_block, update_agent_message, PROTOCOL_VERSION, ) -from acp.schema import AgentCapabilities, AgentMessageChunk, Implementation +from acp.schema import ( + AgentCapabilities, + AgentMessageChunk, + AudioContentBlock, + ClientCapabilities, + EmbeddedResourceContentBlock, + HttpMcpServer, + ImageContentBlock, + Implementation, + ResourceContentBlock, + SseMcpServer, + StdioMcpServer, + TextContentBlock, +) class ExampleAgent(Agent): @@ -35,9 +40,15 @@ def __init__(self, conn: AgentSideConnection) -> None: async def _send_agent_message(self, session_id: str, content: Any) -> None: update = content if isinstance(content, AgentMessageChunk) else update_agent_message(content) - await self._conn.sessionUpdate(session_notification(session_id, update)) - - async def initialize(self, params: InitializeRequest) -> InitializeResponse: # noqa: ARG002 + await self._conn.session_update(session_id, update) + + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: logging.info("Received initialize request") return InitializeResponse( protocol_version=PROTOCOL_VERSION, @@ -45,44 +56,59 @@ async def initialize(self, params: InitializeRequest) -> InitializeResponse: # agent_info=Implementation(name="example-agent", title="Example Agent", version="0.1.0"), ) - async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: # noqa: ARG002 - logging.info("Received authenticate request %s", params.method_id) + async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None: + logging.info("Received authenticate request %s", method_id) return AuthenticateResponse() - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # noqa: ARG002 + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any + ) -> NewSessionResponse: logging.info("Received new session request") session_id = str(self._next_session_id) self._next_session_id += 1 self._sessions.add(session_id) return NewSessionResponse(session_id=session_id, modes=None) - async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: # noqa: ARG002 - logging.info("Received load session request %s", params.session_id) - self._sessions.add(params.session_id) + async def load_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any + ) -> LoadSessionResponse | None: + logging.info("Received load session request %s", session_id) + self._sessions.add(session_id) return LoadSessionResponse() - async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # noqa: ARG002 - logging.info("Received set session mode request %s -> %s", params.session_id, params.mode_id) + async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None: + logging.info("Received set session mode request %s -> %s", session_id, mode_id) return SetSessionModeResponse() - async def prompt(self, params: PromptRequest) -> PromptResponse: - logging.info("Received prompt request for session %s", params.session_id) - if params.session_id not in self._sessions: - self._sessions.add(params.session_id) - - await self._send_agent_message(params.session_id, text_block("Client sent:")) - for block in params.prompt: - await self._send_agent_message(params.session_id, block) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + logging.info("Received prompt request for session %s", session_id) + if session_id not in self._sessions: + self._sessions.add(session_id) + + await self._send_agent_message(session_id, text_block("Client sent:")) + for block in prompt: + await self._send_agent_message(session_id, block) return PromptResponse(stop_reason="end_turn") - async def cancel(self, params: CancelNotification) -> None: # noqa: ARG002 - logging.info("Received cancel notification for session %s", params.session_id) + async def cancel(self, session_id: str, **kwargs: Any) -> None: + logging.info("Received cancel notification for session %s", session_id) - async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002 + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: logging.info("Received extension method call: %s", method) return {"example": "response"} - async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG002 + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: logging.info("Received extension notification: %s", method) diff --git a/examples/client.py b/examples/client.py index 69841b5..6da8121 100644 --- a/examples/client.py +++ b/examples/client.py @@ -5,6 +5,7 @@ import os import sys from pathlib import Path +from typing import Any from acp import ( Client, @@ -13,49 +14,98 @@ NewSessionRequest, PromptRequest, RequestError, - SessionNotification, text_block, PROTOCOL_VERSION, ) from acp.schema import ( AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, AudioContentBlock, + AvailableCommandsUpdate, ClientCapabilities, + CreateTerminalResponse, + CurrentModeUpdate, EmbeddedResourceContentBlock, + EnvVariable, ImageContentBlock, Implementation, + KillTerminalCommandResponse, + PermissionOption, + ReadTextFileResponse, + ReleaseTerminalResponse, + RequestPermissionResponse, ResourceContentBlock, + TerminalOutputResponse, TextContentBlock, + ToolCall, + ToolCallProgress, + ToolCallStart, + UserMessageChunk, + WaitForTerminalExitResponse, + WriteTextFileResponse, ) class ExampleClient(Client): - async def requestPermission(self, params): # type: ignore[override] + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any + ) -> RequestPermissionResponse: raise RequestError.method_not_found("session/request_permission") - async def writeTextFile(self, params): # type: ignore[override] + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: raise RequestError.method_not_found("fs/write_text_file") - async def readTextFile(self, params): # type: ignore[override] + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: raise RequestError.method_not_found("fs/read_text_file") - async def createTerminal(self, params): # type: ignore[override] + async def create_terminal( + self, + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> CreateTerminalResponse: raise RequestError.method_not_found("terminal/create") - async def terminalOutput(self, params): # type: ignore[override] + async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: raise RequestError.method_not_found("terminal/output") - async def releaseTerminal(self, params): # type: ignore[override] + async def release_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> ReleaseTerminalResponse | None: raise RequestError.method_not_found("terminal/release") - async def waitForTerminalExit(self, params): # type: ignore[override] + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> WaitForTerminalExitResponse: raise RequestError.method_not_found("terminal/wait_for_exit") - async def killTerminal(self, params): # type: ignore[override] + async def kill_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> KillTerminalCommandResponse | None: raise RequestError.method_not_found("terminal/kill") - async def sessionUpdate(self, params: SessionNotification) -> None: - update = params.update + async def session_update( + self, + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: if not isinstance(update, AgentMessageChunk): return @@ -76,10 +126,10 @@ async def sessionUpdate(self, params: SessionNotification) -> None: print(f"| Agent: {text}") - async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002 + async def ext_method(self, method: str, params: dict) -> dict: raise RequestError.method_not_found(method) - async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG002 + async def ext_notification(self, method: str, params: dict) -> None: raise RequestError.method_not_found(method) @@ -103,10 +153,8 @@ async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None: try: await conn.prompt( - PromptRequest( - session_id=session_id, - prompt=[text_block(line)], - ) + session_id=session_id, + prompt=[text_block(line)], ) except Exception as exc: # noqa: BLE001 logging.error("Prompt failed: %s", exc) @@ -145,13 +193,11 @@ async def main(argv: list[str]) -> int: conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout) await conn.initialize( - InitializeRequest( - protocol_version=PROTOCOL_VERSION, - client_capabilities=ClientCapabilities(), - client_info=Implementation(name="example-client", title="Example Client", version="0.1.0"), - ) + protocol_version=PROTOCOL_VERSION, + client_capabilities=ClientCapabilities(), + client_info=Implementation(name="example-client", title="Example Client", version="0.1.0"), ) - session = await conn.newSession(NewSessionRequest(mcp_servers=[], cwd=os.getcwd())) + session = await conn.new_session(mcp_servers=[], cwd=os.getcwd()) await interactive_loop(conn, session.session_id) diff --git a/examples/duet.py b/examples/duet.py index de8d9ca..e9c5e2f 100644 --- a/examples/duet.py +++ b/examples/duet.py @@ -34,9 +34,9 @@ async def main() -> int: conn, process, ): - await conn.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION, clientCapabilities=None)) - session = await conn.newSession(NewSessionRequest(mcpServers=[], cwd=str(root))) - await client_module.interactive_loop(conn, session.sessionId) + await conn.initialize(protocol_version=PROTOCOL_VERSION, client_capabilities=None) + session = await conn.new_session(mcp_servers=[], cwd=str(root)) + await client_module.interactive_loop(conn, session.session_id) return process.returncode or 0 diff --git a/examples/echo_agent.py b/examples/echo_agent.py index 6528f51..f99c539 100644 --- a/examples/echo_agent.py +++ b/examples/echo_agent.py @@ -1,43 +1,68 @@ import asyncio +from typing import Any from uuid import uuid4 from acp import ( Agent, AgentSideConnection, - InitializeRequest, InitializeResponse, - NewSessionRequest, NewSessionResponse, - PromptRequest, PromptResponse, - session_notification, stdio_streams, text_block, update_agent_message, ) +from acp.schema import ( + AudioContentBlock, + ClientCapabilities, + EmbeddedResourceContentBlock, + HttpMcpServer, + ImageContentBlock, + Implementation, + ResourceContentBlock, + SseMcpServer, + StdioMcpServer, + TextContentBlock, +) class EchoAgent(Agent): - def __init__(self, conn): + def __init__(self, conn: AgentSideConnection) -> None: self._conn = conn - async def initialize(self, params: InitializeRequest) -> InitializeResponse: - return InitializeResponse(protocol_version=params.protocol_version) + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: + return InitializeResponse(protocol_version=protocol_version) - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any + ) -> NewSessionResponse: return NewSessionResponse(session_id=uuid4().hex) - async def prompt(self, params: PromptRequest) -> PromptResponse: - for block in params.prompt: + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + for block in prompt: text = block.get("text", "") if isinstance(block, dict) else getattr(block, "text", "") chunk = update_agent_message(text_block(text)) chunk.field_meta = {"echo": True} chunk.content.field_meta = {"echo": True} - notification = session_notification(params.session_id, chunk) - notification.field_meta = {"source": "echo_agent"} - - await self._conn.sessionUpdate(notification) + await self._conn.session_update(session_id=session_id, update=chunk, source="echo_agent") return PromptResponse(stop_reason="end_turn") diff --git a/examples/gemini.py b/examples/gemini.py index adea53f..d1214fb 100644 --- a/examples/gemini.py +++ b/examples/gemini.py @@ -9,7 +9,7 @@ import shutil import sys from pathlib import Path -from typing import Iterable +from typing import Any, Iterable from acp import ( Client, @@ -23,38 +23,33 @@ AgentPlanUpdate, AgentThoughtChunk, AllowedOutcome, + AvailableCommandsUpdate, CancelNotification, ClientCapabilities, + CurrentModeUpdate, + EnvVariable, FileEditToolCallContent, FileSystemCapability, - CreateTerminalRequest, CreateTerminalResponse, DeniedOutcome, EmbeddedResourceContentBlock, - KillTerminalCommandRequest, KillTerminalCommandResponse, InitializeRequest, NewSessionRequest, PermissionOption, PromptRequest, - ReadTextFileRequest, ReadTextFileResponse, - RequestPermissionRequest, RequestPermissionResponse, ResourceContentBlock, - ReleaseTerminalRequest, ReleaseTerminalResponse, - SessionNotification, TerminalToolCallContent, - TerminalOutputRequest, TerminalOutputResponse, TextContentBlock, + ToolCall, ToolCallProgress, ToolCallStart, UserMessageChunk, - WaitForTerminalExitRequest, WaitForTerminalExitResponse, - WriteTextFileRequest, WriteTextFileResponse, ) @@ -65,22 +60,21 @@ class GeminiClient(Client): def __init__(self, auto_approve: bool) -> None: self._auto_approve = auto_approve - async def requestPermission( - self, - params: RequestPermissionRequest, - ) -> RequestPermissionResponse: # type: ignore[override] + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any + ) -> RequestPermissionResponse: if self._auto_approve: - option = _pick_preferred_option(params.options) + option = _pick_preferred_option(options) if option is None: return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) return RequestPermissionResponse(outcome=AllowedOutcome(option_id=option.option_id, outcome="selected")) - title = params.tool_call.title or "" - if not params.options: + title = tool_call.title or "" + if not options: print(f"\nšŸ” Permission requested: {title} (no options, cancelling)") return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) print(f"\nšŸ” Permission requested: {title}") - for idx, opt in enumerate(params.options, start=1): + for idx, opt in enumerate(options, start=1): print(f" {idx}. {opt.name} ({opt.kind})") loop = asyncio.get_running_loop() @@ -90,43 +84,49 @@ async def requestPermission( continue if choice.isdigit(): idx = int(choice) - 1 - if 0 <= idx < len(params.options): - opt = params.options[idx] + if 0 <= idx < len(options): + opt = options[idx] return RequestPermissionResponse( outcome=AllowedOutcome(option_id=opt.option_id, outcome="selected") ) print("Invalid selection, try again.") - async def writeTextFile( - self, - params: WriteTextFileRequest, - ) -> WriteTextFileResponse: # type: ignore[override] - path = Path(params.path) - if not path.is_absolute(): - raise RequestError.invalid_params({"path": params.path, "reason": "path must be absolute"}) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(params.content) - print(f"[Client] Wrote {path} ({len(params.content)} bytes)") + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: + pathlib_path = Path(path) + if not pathlib_path.is_absolute(): + raise RequestError.invalid_params({"path": pathlib_path, "reason": "path must be absolute"}) + pathlib_path.parent.mkdir(parents=True, exist_ok=True) + pathlib_path.write_text(content) + print(f"[Client] Wrote {pathlib_path} ({len(content)} bytes)") return WriteTextFileResponse() - async def readTextFile( - self, - params: ReadTextFileRequest, - ) -> ReadTextFileResponse: # type: ignore[override] - path = Path(params.path) - if not path.is_absolute(): - raise RequestError.invalid_params({"path": params.path, "reason": "path must be absolute"}) - text = path.read_text() - print(f"[Client] Read {path} ({len(text)} bytes)") - if params.line is not None or params.limit is not None: - text = _slice_text(text, params.line, params.limit) + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: + pathlib_path = Path(path) + if not pathlib_path.is_absolute(): + raise RequestError.invalid_params({"path": pathlib_path, "reason": "path must be absolute"}) + text = pathlib_path.read_text() + print(f"[Client] Read {pathlib_path} ({len(text)} bytes)") + if line is not None or limit is not None: + text = _slice_text(text, line, limit) return ReadTextFileResponse(content=text) - async def sessionUpdate( + async def session_update( self, - params: SessionNotification, - ) -> None: # type: ignore[override] - update = params.update + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: if isinstance(update, AgentMessageChunk): _print_text_content(update.content) elif isinstance(update, AgentThoughtChunk): @@ -156,39 +156,39 @@ async def sessionUpdate( print(f"\n[session update] {update}") # Optional / terminal-related methods --------------------------------- - async def createTerminal( + async def create_terminal( self, - params: CreateTerminalRequest, - ) -> CreateTerminalResponse: # type: ignore[override] - print(f"[Client] createTerminal: {params}") + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> CreateTerminalResponse: + print(f"[Client] createTerminal: {command} {args or []} (cwd={cwd})") return CreateTerminalResponse(terminal_id="term-1") - async def terminalOutput( - self, - params: TerminalOutputRequest, - ) -> TerminalOutputResponse: # type: ignore[override] - print(f"[Client] terminalOutput: {params}") + async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: + print(f"[Client] terminalOutput: {session_id} {terminal_id}") return TerminalOutputResponse(output="", truncated=False) - async def releaseTerminal( - self, - params: ReleaseTerminalRequest, - ) -> ReleaseTerminalResponse: # type: ignore[override] - print(f"[Client] releaseTerminal: {params}") + async def release_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> ReleaseTerminalResponse | None: + print(f"[Client] releaseTerminal: {session_id} {terminal_id}") return ReleaseTerminalResponse() - async def waitForTerminalExit( - self, - params: WaitForTerminalExitRequest, - ) -> WaitForTerminalExitResponse: # type: ignore[override] - print(f"[Client] waitForTerminalExit: {params}") + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> WaitForTerminalExitResponse: + print(f"[Client] waitForTerminalExit: {session_id} {terminal_id}") return WaitForTerminalExitResponse() - async def killTerminal( - self, - params: KillTerminalCommandRequest, - ) -> KillTerminalCommandResponse: # type: ignore[override] - print(f"[Client] killTerminal: {params}") + async def kill_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> KillTerminalCommandResponse | None: + print(f"[Client] killTerminal: {session_id} {terminal_id}") return KillTerminalCommandResponse() @@ -248,15 +248,13 @@ async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None: if line in {":exit", ":quit"}: break if line == ":cancel": - await conn.cancel(CancelNotification(session_id=session_id)) + await conn.cancel(session_id=session_id) continue try: await conn.prompt( - PromptRequest( - session_id=session_id, - prompt=[text_block(line)], - ) + session_id=session_id, + prompt=[text_block(line)], ) except RequestError as err: _print_request_error("prompt", err) @@ -322,13 +320,11 @@ async def run(argv: list[str]) -> int: try: init_resp = await conn.initialize( - InitializeRequest( - protocol_version=PROTOCOL_VERSION, - client_capabilities=ClientCapabilities( - fs=FileSystemCapability(read_text_file=True, write_text_file=True), - terminal=True, - ), - ) + protocol_version=PROTOCOL_VERSION, + client_capabilities=ClientCapabilities( + fs=FileSystemCapability(read_text_file=True, write_text_file=True), + terminal=True, + ), ) except RequestError as err: _print_request_error("initialize", err) @@ -342,11 +338,9 @@ async def run(argv: list[str]) -> int: print(f"āœ… Connected to Gemini (protocol v{init_resp.protocol_version})") try: - session = await conn.newSession( - NewSessionRequest( - cwd=os.getcwd(), - mcp_servers=[], - ) + session = await conn.new_session( + cwd=os.getcwd(), + mcp_servers=[], ) except RequestError as err: _print_request_error("new_session", err) diff --git a/pyproject.toml b/pyproject.toml index bed99a9..2fc88b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,10 +113,6 @@ ignore = [ [tool.ruff.format] preview = true -exclude = [ - "src/acp/meta.py", - "src/acp/schema.py", -] [tool.deptry.package_module_name_map] opentelemetry-sdk = "opentelemetry" diff --git a/scripts/gen_all.py b/scripts/gen_all.py index 63e8d8a..872a98a 100644 --- a/scripts/gen_all.py +++ b/scripts/gen_all.py @@ -14,7 +14,7 @@ if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) -from scripts import gen_meta, gen_schema # noqa: E402 pylint: disable=wrong-import-position +from scripts import gen_meta, gen_schema, gen_signature # noqa: E402 pylint: disable=wrong-import-position SCHEMA_DIR = ROOT / "schema" SCHEMA_JSON = SCHEMA_DIR / "schema.json" @@ -44,18 +44,6 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Skip downloading schema files even when a version is provided.", ) - parser.add_argument( - "--format", - dest="format_output", - action="store_true", - help="Format generated files with 'uv run ruff format'.", - ) - parser.add_argument( - "--no-format", - dest="format_output", - action="store_false", - help="Disable formatting with ruff.", - ) parser.set_defaults(format_output=True) parser.add_argument( "--force", @@ -82,8 +70,9 @@ def main() -> None: print("schema/schema.json or schema/meta.json missing; run with --version to fetch them.", file=sys.stderr) sys.exit(1) - gen_schema.generate_schema(format_output=args.format_output) + gen_schema.generate_schema() gen_meta.generate_meta() + gen_signature.gen_signature(ROOT / "src" / "acp") if ref: print(f"Generated schema using ref: {ref}") diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index 6e0c43b..033abe2 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -1,11 +1,9 @@ #!/usr/bin/env python3 from __future__ import annotations -import argparse import ast import json import re -import shutil import subprocess import sys from collections.abc import Callable @@ -126,30 +124,11 @@ class _ProcessingStep: apply: Callable[[str], str] -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Generate src/acp/schema.py from the ACP JSON schema.") - parser.add_argument( - "--format", - dest="format_output", - action="store_true", - help="Format generated files with 'uv run ruff format'.", - ) - parser.add_argument( - "--no-format", - dest="format_output", - action="store_false", - help="Disable formatting with ruff.", - ) - parser.set_defaults(format_output=True) - return parser.parse_args() - - def main() -> None: - args = parse_args() - generate_schema(format_output=args.format_output) + generate_schema() -def generate_schema(*, format_output: bool = True) -> None: +def generate_schema() -> None: if not SCHEMA_JSON.exists(): print( "Schema file missing. Ensure schema/schema.json exists (run gen_all.py --version to download).", @@ -181,9 +160,6 @@ def generate_schema(*, format_output: bool = True) -> None: for warning in warnings: print(f"Warning: {warning}", file=sys.stderr) - if format_output: - format_with_ruff(SCHEMA_OUT) - def postprocess_generated_schema(output_path: Path) -> list[str]: if not output_path.exists(): @@ -529,16 +505,5 @@ def _inject_enum_aliases(content: str) -> str: return content[:insertion_point] + block + content[insertion_point:] -def format_with_ruff(file_path: Path) -> None: - uv_executable = shutil.which("uv") - if uv_executable is None: - print("Warning: 'uv' executable not found; skipping formatting.", file=sys.stderr) - return - try: - subprocess.check_call([uv_executable, "run", "ruff", "format", str(file_path)]) # noqa: S603 - except (FileNotFoundError, subprocess.CalledProcessError) as exc: # pragma: no cover - best effort - print(f"Warning: failed to format {file_path}: {exc}", file=sys.stderr) - - if __name__ == "__main__": main() diff --git a/scripts/gen_signature.py b/scripts/gen_signature.py new file mode 100644 index 0000000..635a7a1 --- /dev/null +++ b/scripts/gen_signature.py @@ -0,0 +1,133 @@ +import ast +import inspect +import typing as t +from pathlib import Path + +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined + +from acp import schema + + +class NodeTransformer(ast.NodeTransformer): + def __init__(self) -> None: + self._type_import_node: ast.ImportFrom | None = None + self._schema_import_node: ast.ImportFrom | None = None + self._should_rewrite = False + self._literals = {name: value for name, value in schema.__dict__.items() if t.get_origin(value) is t.Literal} + + def _add_typing_import(self, name: str) -> None: + if not self._type_import_node: + return + if not any(alias.name == name for alias in self._type_import_node.names): + self._type_import_node.names.append(ast.alias(name=name)) + self._should_rewrite = True + + def _add_schema_import(self, name: str) -> None: + if not self._schema_import_node: + return + if not any(alias.name == name for alias in self._schema_import_node.names): + self._schema_import_node.names.append(ast.alias(name=name)) + self._should_rewrite = True + + def transform(self, source_file: Path) -> None: + with source_file.open("r", encoding="utf-8") as f: + source_code = f.read() + tree = ast.parse(source_code) + self.visit(tree) + if self._should_rewrite: + print("Rewriting signatures in", source_file) + new_code = ast.unparse(tree) + with source_file.open("w", encoding="utf-8") as f: + f.write(new_code) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST: + if node.module == "schema": + self._schema_import_node = node + elif node.module == "typing": + self._type_import_node = node + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + return self.visit_func(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: + return self.visit_func(node) + + def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST: + decorator = next( + ( + decorator + for decorator in node.decorator_list + if isinstance(decorator, ast.Call) + and isinstance(decorator.func, ast.Name) + and decorator.func.id == "param_model" + ), + None, + ) + if not decorator: + return self.generic_visit(node) + self._should_rewrite = True + model_name = t.cast(ast.Name, decorator.args[0]).id + model = t.cast(type[schema.BaseModel], getattr(schema, model_name)) + param_defaults = [ + self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta" + ] + param_defaults.sort(key=lambda x: x[1] is not None) + node.args.args[1:] = [param for param, _ in param_defaults] + node.args.defaults = [default for _, default in param_defaults if default is not None] + if "field_meta" in model.model_fields: + node.args.kwarg = ast.arg(arg="kwargs", annotation=ast.Name(id="Any")) + return self.generic_visit(node) + + def _to_param_def(self, name: str, field: FieldInfo) -> tuple[ast.arg, ast.expr | None]: + arg = ast.arg(arg=name) + ann = field.annotation + if field.default is PydanticUndefined: + default = None + elif isinstance(field.default, dict): + default = ast.Constant(None) + else: + default = ast.Constant(value=field.default) + if ann is not None: + arg.annotation = self._format_annotation(ann) + return arg, default + + def _format_annotation(self, annotation: t.Any) -> ast.expr: + if t.get_origin(annotation) is t.Literal and annotation in self._literals.values(): + name = next(name for name, value in self._literals.items() if value is annotation) + self._add_schema_import(name) + return ast.Name(id=name) + elif ( + inspect.isclass(annotation) and issubclass(annotation, BaseModel) and annotation.__module__ == "acp.schema" + ): + self._add_schema_import(annotation.__name__) + return ast.Name(id=annotation.__name__) + elif args := t.get_args(annotation): + origin = t.get_origin(annotation) + return ast.Subscript( + value=self._format_annotation(origin), + slice=ast.Tuple(elts=[self._format_annotation(arg) for arg in args], ctx=ast.Load()) + if len(args) > 1 + else self._format_annotation(args[0]), + ctx=ast.Load(), + ) + elif annotation.__module__ == "typing": + name = annotation.__name__ + self._add_typing_import(name) + return ast.Name(id=name) + elif annotation is None or annotation is type(None): + return ast.Constant(value=None) + elif annotation in __builtins__.values(): + return ast.Name(id=annotation.__name__) + else: + print(f"Warning: Unhandled annotation type: {annotation}") + self._add_typing_import("Any") + return ast.Name(id="Any") + + +def gen_signature(source_dir: Path) -> None: + for source_file in source_dir.rglob("*.py"): + transformer = NodeTransformer() + transformer.transform(source_file) diff --git a/src/acp/agent/connection.py b/src/acp/agent/connection.py index 5f992b1..0b4d542 100644 --- a/src/acp/agent/connection.py +++ b/src/acp/agent/connection.py @@ -4,14 +4,21 @@ from collections.abc import Callable from typing import Any -from ..connection import Connection, MethodHandler +from ..connection import Connection from ..interfaces import Agent from ..meta import CLIENT_METHODS from ..schema import ( + AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, + AvailableCommandsUpdate, CreateTerminalRequest, CreateTerminalResponse, + CurrentModeUpdate, + EnvVariable, KillTerminalCommandRequest, KillTerminalCommandResponse, + PermissionOption, ReadTextFileRequest, ReadTextFileResponse, ReleaseTerminalRequest, @@ -21,17 +28,20 @@ SessionNotification, TerminalOutputRequest, TerminalOutputResponse, + ToolCall, + ToolCallProgress, + ToolCallStart, + UserMessageChunk, WaitForTerminalExitRequest, WaitForTerminalExitResponse, WriteTextFileRequest, WriteTextFileResponse, ) from ..terminal import TerminalHandle -from ..utils import notify_model, request_model, request_optional_model +from ..utils import notify_model, param_model, request_model, request_optional_model from .router import build_agent_router __all__ = ["AgentSideConnection"] - _AGENT_CONNECTION_ERROR = "AgentSideConnection requires asyncio StreamWriter/StreamReader" @@ -46,95 +56,139 @@ def __init__( **connection_kwargs: Any, ) -> None: agent = to_agent(self) - handler = self._create_handler(agent) - + handler = build_agent_router(agent) if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_AGENT_CONNECTION_ERROR) self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) - def _create_handler(self, agent: Agent) -> MethodHandler: - router = build_agent_router(agent) - - async def handler(method: str, params: Any | None, is_notification: bool) -> Any: - if is_notification: - await router.dispatch_notification(method, params) - return None - return await router.dispatch_request(method, params) - - return handler - - async def sessionUpdate(self, params: SessionNotification) -> None: - await notify_model(self._conn, CLIENT_METHODS["session_update"], params) + @param_model(SessionNotification) + async def session_update( + self, + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: + await notify_model( + self._conn, + CLIENT_METHODS["session_update"], + SessionNotification(session_id=session_id, update=update, field_meta=kwargs or None), + ) - async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: + @param_model(RequestPermissionRequest) + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any + ) -> RequestPermissionResponse: return await request_model( self._conn, CLIENT_METHODS["session_request_permission"], - params, + RequestPermissionRequest( + options=options, session_id=session_id, tool_call=tool_call, field_meta=kwargs or None + ), RequestPermissionResponse, ) - async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse: + @param_model(ReadTextFileRequest) + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: return await request_model( self._conn, CLIENT_METHODS["fs_read_text_file"], - params, + ReadTextFileRequest(path=path, session_id=session_id, limit=limit, line=line, field_meta=kwargs or None), ReadTextFileResponse, ) - async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse | None: + @param_model(WriteTextFileRequest) + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: return await request_optional_model( self._conn, CLIENT_METHODS["fs_write_text_file"], - params, + WriteTextFileRequest(content=content, path=path, session_id=session_id, field_meta=kwargs or None), WriteTextFileResponse, ) - async def createTerminal(self, params: CreateTerminalRequest) -> TerminalHandle: + @param_model(CreateTerminalRequest) + async def create_terminal( + self, + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> TerminalHandle: create_response = await request_model( self._conn, CLIENT_METHODS["terminal_create"], - params, + CreateTerminalRequest( + command=command, + session_id=session_id, + args=args, + cwd=cwd, + env=env, + output_byte_limit=output_byte_limit, + field_meta=kwargs or None, + ), CreateTerminalResponse, ) - return TerminalHandle(create_response.terminal_id, params.session_id, self._conn) + return TerminalHandle(create_response.terminal_id, session_id, self._conn) - async def terminalOutput(self, params: TerminalOutputRequest) -> TerminalOutputResponse: + @param_model(TerminalOutputRequest) + async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: return await request_model( self._conn, CLIENT_METHODS["terminal_output"], - params, + TerminalOutputRequest(session_id=session_id, terminal_id=terminal_id, field_meta=kwargs or None), TerminalOutputResponse, ) - async def releaseTerminal(self, params: ReleaseTerminalRequest) -> ReleaseTerminalResponse | None: + @param_model(ReleaseTerminalRequest) + async def release_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> ReleaseTerminalResponse | None: return await request_optional_model( self._conn, CLIENT_METHODS["terminal_release"], - params, + ReleaseTerminalRequest(session_id=session_id, terminal_id=terminal_id, field_meta=kwargs or None), ReleaseTerminalResponse, ) - async def waitForTerminalExit(self, params: WaitForTerminalExitRequest) -> WaitForTerminalExitResponse: + @param_model(WaitForTerminalExitRequest) + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> WaitForTerminalExitResponse: return await request_model( self._conn, CLIENT_METHODS["terminal_wait_for_exit"], - params, + WaitForTerminalExitRequest(session_id=session_id, terminal_id=terminal_id, field_meta=kwargs or None), WaitForTerminalExitResponse, ) - async def killTerminal(self, params: KillTerminalCommandRequest) -> KillTerminalCommandResponse | None: + @param_model(KillTerminalCommandRequest) + async def kill_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> KillTerminalCommandResponse | None: return await request_optional_model( self._conn, CLIENT_METHODS["terminal_kill"], - params, + KillTerminalCommandRequest(session_id=session_id, terminal_id=terminal_id, field_meta=kwargs or None), KillTerminalCommandResponse, ) - async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]: + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: return await self._conn.send_request(f"_{method}", params) - async def extNotification(self, method: str, params: dict[str, Any]) -> None: + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: await self._conn.send_notification(f"_{method}", params) async def close(self) -> None: diff --git a/src/acp/agent/router.py b/src/acp/agent/router.py index 515d804..9934b40 100644 --- a/src/acp/agent/router.py +++ b/src/acp/agent/router.py @@ -5,7 +5,7 @@ from ..exceptions import RequestError from ..interfaces import Agent from ..meta import AGENT_METHODS -from ..router import MessageRouter, RouterBuilder +from ..router import MessageRouter from ..schema import ( AuthenticateRequest, CancelNotification, @@ -22,33 +22,33 @@ def build_agent_router(agent: Agent) -> MessageRouter: - builder = RouterBuilder() + router = MessageRouter() - builder.request_attr(AGENT_METHODS["initialize"], InitializeRequest, agent, "initialize") - builder.request_attr(AGENT_METHODS["session_new"], NewSessionRequest, agent, "newSession") - builder.request_attr( + router.route_request(AGENT_METHODS["initialize"], InitializeRequest, agent, "initialize") + router.route_request(AGENT_METHODS["session_new"], NewSessionRequest, agent, "new_session") + router.route_request( AGENT_METHODS["session_load"], LoadSessionRequest, agent, - "loadSession", + "load_session", adapt_result=normalize_result, ) - builder.request_attr( + router.route_request( AGENT_METHODS["session_set_mode"], SetSessionModeRequest, agent, - "setSessionMode", + "set_session_mode", adapt_result=normalize_result, ) - builder.request_attr(AGENT_METHODS["session_prompt"], PromptRequest, agent, "prompt") - builder.request_attr( + router.route_request(AGENT_METHODS["session_prompt"], PromptRequest, agent, "prompt") + router.route_request( AGENT_METHODS["session_set_model"], SetSessionModelRequest, agent, - "setSessionModel", + "set_session_model", adapt_result=normalize_result, ) - builder.request_attr( + router.route_request( AGENT_METHODS["authenticate"], AuthenticateRequest, agent, @@ -56,21 +56,20 @@ def build_agent_router(agent: Agent) -> MessageRouter: adapt_result=normalize_result, ) - builder.notification_attr(AGENT_METHODS["session_cancel"], CancelNotification, agent, "cancel") + router.route_notification(AGENT_METHODS["session_cancel"], CancelNotification, agent, "cancel") - async def handle_extension_request(name: str, payload: dict[str, Any]) -> Any: - ext = getattr(agent, "extMethod", None) + @router.handle_extension_request + async def _handle_extension_request(name: str, payload: dict[str, Any]) -> Any: + ext = getattr(agent, "ext_method", None) if ext is None: raise RequestError.method_not_found(f"_{name}") return await ext(name, payload) - async def handle_extension_notification(name: str, payload: dict[str, Any]) -> None: - ext = getattr(agent, "extNotification", None) + @router.handle_extension_notification + async def _handle_extension_notification(name: str, payload: dict[str, Any]) -> None: + ext = getattr(agent, "ext_notification", None) if ext is None: return await ext(name, payload) - return builder.build( - request_extensions=handle_extension_request, - notification_extensions=handle_extension_notification, - ) + return router diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index f97ff25..857e7d8 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -4,13 +4,19 @@ from collections.abc import Callable from typing import Any -from ..connection import Connection, MethodHandler +from ..connection import Connection from ..interfaces import Agent, Client from ..meta import AGENT_METHODS from ..schema import ( + AudioContentBlock, AuthenticateRequest, AuthenticateResponse, CancelNotification, + ClientCapabilities, + EmbeddedResourceContentBlock, + HttpMcpServer, + ImageContentBlock, + Implementation, InitializeRequest, InitializeResponse, LoadSessionRequest, @@ -19,20 +25,19 @@ NewSessionResponse, PromptRequest, PromptResponse, + ResourceContentBlock, SetSessionModelRequest, SetSessionModelResponse, SetSessionModeRequest, SetSessionModeResponse, + SseMcpServer, + StdioMcpServer, + TextContentBlock, ) -from ..utils import ( - notify_model, - request_model, - request_model_from_dict, -) +from ..utils import notify_model, param_model, request_model, request_model_from_dict from .router import build_client_router __all__ = ["ClientSideConnection"] - _CLIENT_CONNECTION_ERROR = "ClientSideConnection requires asyncio StreamWriter/StreamReader" @@ -40,93 +45,115 @@ class ClientSideConnection: """Client-side connection wrapper that dispatches JSON-RPC messages to an Agent implementation.""" def __init__( - self, - to_client: Callable[[Agent], Client], - input_stream: Any, - output_stream: Any, - **connection_kwargs: Any, + self, to_client: Callable[[Agent], Client], input_stream: Any, output_stream: Any, **connection_kwargs: Any ) -> None: if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_CLIENT_CONNECTION_ERROR) - - client = to_client(self) # type: ignore[arg-type] - handler = self._create_handler(client) + client = to_client(self) + handler = build_client_router(client) self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) - def _create_handler(self, client: Client) -> MethodHandler: - router = build_client_router(client) - - async def handler(method: str, params: Any | None, is_notification: bool) -> Any: - if is_notification: - await router.dispatch_notification(method, params) - return None - return await router.dispatch_request(method, params) - - return handler - - async def initialize(self, params: InitializeRequest) -> InitializeResponse: + @param_model(InitializeRequest) + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: return await request_model( self._conn, AGENT_METHODS["initialize"], - params, + InitializeRequest( + protocol_version=protocol_version, + client_capabilities=client_capabilities or ClientCapabilities(), + client_info=client_info, + field_meta=kwargs or None, + ), InitializeResponse, ) - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: + @param_model(NewSessionRequest) + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any + ) -> NewSessionResponse: return await request_model( self._conn, AGENT_METHODS["session_new"], - params, + NewSessionRequest(cwd=cwd, mcp_servers=mcp_servers, field_meta=kwargs or None), NewSessionResponse, ) - async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse: + @param_model(LoadSessionRequest) + async def load_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any + ) -> LoadSessionResponse: return await request_model_from_dict( self._conn, AGENT_METHODS["session_load"], - params, + LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None), LoadSessionResponse, ) - async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse: + @param_model(SetSessionModeRequest) + async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse: return await request_model_from_dict( self._conn, AGENT_METHODS["session_set_mode"], - params, + SetSessionModeRequest(mode_id=mode_id, session_id=session_id, field_meta=kwargs or None), SetSessionModeResponse, ) - async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse: + @param_model(SetSessionModelRequest) + async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any) -> SetSessionModelResponse: return await request_model_from_dict( self._conn, AGENT_METHODS["session_set_model"], - params, + SetSessionModelRequest(model_id=model_id, session_id=session_id, field_meta=kwargs or None), SetSessionModelResponse, ) - async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse: + @param_model(AuthenticateRequest) + async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse: return await request_model_from_dict( self._conn, AGENT_METHODS["authenticate"], - params, + AuthenticateRequest(method_id=method_id, field_meta=kwargs or None), AuthenticateResponse, ) - async def prompt(self, params: PromptRequest) -> PromptResponse: + @param_model(PromptRequest) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: return await request_model( self._conn, AGENT_METHODS["session_prompt"], - params, + PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None), PromptResponse, ) - async def cancel(self, params: CancelNotification) -> None: - await notify_model(self._conn, AGENT_METHODS["session_cancel"], params) + @param_model(CancelNotification) + async def cancel(self, session_id: str, **kwargs: Any) -> None: + await notify_model( + self._conn, + AGENT_METHODS["session_cancel"], + CancelNotification(session_id=session_id, field_meta=kwargs or None), + ) - async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]: + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: return await self._conn.send_request(f"_{method}", params) - async def extNotification(self, method: str, params: dict[str, Any]) -> None: + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: await self._conn.send_notification(f"_{method}", params) async def close(self) -> None: diff --git a/src/acp/client/router.py b/src/acp/client/router.py index 9f0b85f..488d42c 100644 --- a/src/acp/client/router.py +++ b/src/acp/client/router.py @@ -5,7 +5,7 @@ from ..exceptions import RequestError from ..interfaces import Client from ..meta import CLIENT_METHODS -from ..router import MessageRouter, RouterBuilder +from ..router import MessageRouter from ..schema import ( CreateTerminalRequest, KillTerminalCommandRequest, @@ -23,74 +23,73 @@ def build_client_router(client: Client) -> MessageRouter: - builder = RouterBuilder() + router = MessageRouter() - builder.request_attr(CLIENT_METHODS["fs_write_text_file"], WriteTextFileRequest, client, "writeTextFile") - builder.request_attr(CLIENT_METHODS["fs_read_text_file"], ReadTextFileRequest, client, "readTextFile") - builder.request_attr( + router.route_request(CLIENT_METHODS["fs_write_text_file"], WriteTextFileRequest, client, "write_text_file") + router.route_request(CLIENT_METHODS["fs_read_text_file"], ReadTextFileRequest, client, "read_text_file") + router.route_request( CLIENT_METHODS["session_request_permission"], RequestPermissionRequest, client, - "requestPermission", + "request_permission", ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_create"], CreateTerminalRequest, client, - "createTerminal", + "create_terminal", optional=True, default_result=None, ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_output"], TerminalOutputRequest, client, - "terminalOutput", + "terminal_output", optional=True, default_result=None, ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_release"], ReleaseTerminalRequest, client, - "releaseTerminal", + "release_terminal", optional=True, default_result={}, adapt_result=normalize_result, ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_wait_for_exit"], WaitForTerminalExitRequest, client, - "waitForTerminalExit", + "wait_for_terminal_exit", optional=True, default_result=None, ) - builder.request_attr( + router.route_request( CLIENT_METHODS["terminal_kill"], KillTerminalCommandRequest, client, - "killTerminal", + "kill_terminal", optional=True, default_result={}, adapt_result=normalize_result, ) - builder.notification_attr(CLIENT_METHODS["session_update"], SessionNotification, client, "sessionUpdate") + router.route_notification(CLIENT_METHODS["session_update"], SessionNotification, client, "session_update") - async def handle_extension_request(name: str, payload: dict[str, Any]) -> Any: - ext = getattr(client, "extMethod", None) + @router.handle_extension_request + async def _handle_extension_request(name: str, payload: dict[str, Any]) -> Any: + ext = getattr(client, "ext_method", None) if ext is None: raise RequestError.method_not_found(f"_{name}") return await ext(name, payload) - async def handle_extension_notification(name: str, payload: dict[str, Any]) -> None: - ext = getattr(client, "extNotification", None) + @router.handle_extension_notification + async def _handle_extension_notification(name: str, payload: dict[str, Any]) -> None: + ext = getattr(client, "ext_notification", None) if ext is None: return await ext(name, payload) - return builder.build( - request_extensions=handle_extension_request, - notification_extensions=handle_extension_notification, - ) + return router diff --git a/src/acp/interfaces.py b/src/acp/interfaces.py index 11d04d3..61c4f27 100644 --- a/src/acp/interfaces.py +++ b/src/acp/interfaces.py @@ -3,11 +3,23 @@ from typing import Any, Protocol from .schema import ( + AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, + AudioContentBlock, AuthenticateRequest, AuthenticateResponse, + AvailableCommandsUpdate, CancelNotification, + ClientCapabilities, CreateTerminalRequest, CreateTerminalResponse, + CurrentModeUpdate, + EmbeddedResourceContentBlock, + EnvVariable, + HttpMcpServer, + ImageContentBlock, + Implementation, InitializeRequest, InitializeResponse, KillTerminalCommandRequest, @@ -16,6 +28,7 @@ LoadSessionResponse, NewSessionRequest, NewSessionResponse, + PermissionOption, PromptRequest, PromptResponse, ReadTextFileRequest, @@ -24,63 +37,145 @@ ReleaseTerminalResponse, RequestPermissionRequest, RequestPermissionResponse, + ResourceContentBlock, SessionNotification, SetSessionModelRequest, SetSessionModelResponse, SetSessionModeRequest, SetSessionModeResponse, + SseMcpServer, + StdioMcpServer, TerminalOutputRequest, TerminalOutputResponse, + TextContentBlock, + ToolCall, + ToolCallProgress, + ToolCallStart, + UserMessageChunk, WaitForTerminalExitRequest, WaitForTerminalExitResponse, WriteTextFileRequest, WriteTextFileResponse, ) +from .utils import param_model __all__ = ["Agent", "Client"] class Client(Protocol): - async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: ... - - async def sessionUpdate(self, params: SessionNotification) -> None: ... - - async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse | None: ... - - async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse: ... - - async def createTerminal(self, params: CreateTerminalRequest) -> CreateTerminalResponse: ... - - async def terminalOutput(self, params: TerminalOutputRequest) -> TerminalOutputResponse: ... - - async def releaseTerminal(self, params: ReleaseTerminalRequest) -> ReleaseTerminalResponse | None: ... - - async def waitForTerminalExit(self, params: WaitForTerminalExitRequest) -> WaitForTerminalExitResponse: ... - - async def killTerminal(self, params: KillTerminalCommandRequest) -> KillTerminalCommandResponse | None: ... - - async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]: ... - - async def extNotification(self, method: str, params: dict[str, Any]) -> None: ... + @param_model(RequestPermissionRequest) + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any + ) -> RequestPermissionResponse: ... + + @param_model(SessionNotification) + async def session_update( + self, + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: ... + + @param_model(WriteTextFileRequest) + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: ... + + @param_model(ReadTextFileRequest) + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: ... + + @param_model(CreateTerminalRequest) + async def create_terminal( + self, + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> CreateTerminalResponse: ... + + @param_model(TerminalOutputRequest) + async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: ... + + @param_model(ReleaseTerminalRequest) + async def release_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> ReleaseTerminalResponse | None: ... + + @param_model(WaitForTerminalExitRequest) + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> WaitForTerminalExitResponse: ... + + @param_model(KillTerminalCommandRequest) + async def kill_terminal( + self, session_id: str, terminal_id: str, **kwargs: Any + ) -> KillTerminalCommandResponse | None: ... + + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: ... + + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: ... class Agent(Protocol): - async def initialize(self, params: InitializeRequest) -> InitializeResponse: ... - - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: ... - - async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: ... - - async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: ... - - async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse | None: ... - - async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: ... - - async def prompt(self, params: PromptRequest) -> PromptResponse: ... - - async def cancel(self, params: CancelNotification) -> None: ... - - async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]: ... - - async def extNotification(self, method: str, params: dict[str, Any]) -> None: ... + @param_model(InitializeRequest) + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: ... + + @param_model(NewSessionRequest) + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any + ) -> NewSessionResponse: ... + + @param_model(LoadSessionRequest) + async def load_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any + ) -> LoadSessionResponse | None: ... + + @param_model(SetSessionModeRequest) + async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None: ... + + @param_model(SetSessionModelRequest) + async def set_session_model( + self, model_id: str, session_id: str, **kwargs: Any + ) -> SetSessionModelResponse | None: ... + + @param_model(AuthenticateRequest) + async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None: ... + + @param_model(PromptRequest) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: ... + + @param_model(CancelNotification) + async def cancel(self, session_id: str, **kwargs: Any) -> None: ... + + async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: ... + + async def ext_notification(self, method: str, params: dict[str, Any]) -> None: ... diff --git a/src/acp/meta.py b/src/acp/meta.py index 2b67512..82deb0e 100644 --- a/src/acp/meta.py +++ b/src/acp/meta.py @@ -1,5 +1,24 @@ # Generated from schema/meta.json. Do not edit by hand. # Schema ref: refs/tags/v0.6.3 -AGENT_METHODS = {'authenticate': 'authenticate', 'initialize': 'initialize', 'session_cancel': 'session/cancel', 'session_load': 'session/load', 'session_new': 'session/new', 'session_prompt': 'session/prompt', 'session_set_mode': 'session/set_mode', 'session_set_model': 'session/set_model'} -CLIENT_METHODS = {'fs_read_text_file': 'fs/read_text_file', 'fs_write_text_file': 'fs/write_text_file', 'session_request_permission': 'session/request_permission', 'session_update': 'session/update', 'terminal_create': 'terminal/create', 'terminal_kill': 'terminal/kill', 'terminal_output': 'terminal/output', 'terminal_release': 'terminal/release', 'terminal_wait_for_exit': 'terminal/wait_for_exit'} +AGENT_METHODS = { + "authenticate": "authenticate", + "initialize": "initialize", + "session_cancel": "session/cancel", + "session_load": "session/load", + "session_new": "session/new", + "session_prompt": "session/prompt", + "session_set_mode": "session/set_mode", + "session_set_model": "session/set_model", +} +CLIENT_METHODS = { + "fs_read_text_file": "fs/read_text_file", + "fs_write_text_file": "fs/write_text_file", + "session_request_permission": "session/request_permission", + "session_update": "session/update", + "terminal_create": "terminal/create", + "terminal_kill": "terminal/kill", + "terminal_output": "terminal/output", + "terminal_release": "terminal/release", + "terminal_wait_for_exit": "terminal/wait_for_exit", +} PROTOCOL_VERSION = 1 diff --git a/src/acp/router.py b/src/acp/router.py index f50f9b7..1d53473 100644 --- a/src/acp/router.py +++ b/src/acp/router.py @@ -1,192 +1,135 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, Literal +from typing import Any, Literal, TypeVar from pydantic import BaseModel from .exceptions import RequestError -__all__ = [ - "MessageRouter", - "Route", - "RouterBuilder", - "attribute_handler", -] +__all__ = ["MessageRouter", "Route"] AsyncHandler = Callable[[Any], Awaitable[Any | None]] +RequestHandler = Callable[[str, dict[str, Any]], Awaitable[Any]] +HandlerT = TypeVar("HandlerT", bound=RequestHandler) @dataclass(slots=True) class Route: method: str - model: type[BaseModel] - handle: Callable[[], AsyncHandler | None] + func: AsyncHandler | None kind: Literal["request", "notification"] optional: bool = False default_result: Any = None adapt_result: Callable[[Any | None], Any] | None = None - -class MessageRouter: - def __init__( - self, - routes: Sequence[Route], - *, - request_extensions: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None, - notification_extensions: Callable[[str, dict[str, Any]], Awaitable[None]] | None = None, - ) -> None: - self._requests: Mapping[str, Route] = {route.method: route for route in routes if route.kind == "request"} - self._notifications: Mapping[str, Route] = { - route.method: route for route in routes if route.kind == "notification" - } - self._request_extensions = request_extensions - self._notification_extensions = notification_extensions - - async def dispatch_request(self, method: str, params: Any | None) -> Any: - if isinstance(method, str) and method.startswith("_"): - if self._request_extensions is None: - raise RequestError.method_not_found(method) - payload = params if isinstance(params, dict) else {} - return await self._request_extensions(method[1:], payload) - - route = self._requests.get(method) - if route is None: - raise RequestError.method_not_found(method) - model = route.model - parsed = model.model_validate(params) - - handler = route.handle() - if handler is None: - if route.optional: - return route.default_result - raise RequestError.method_not_found(method) - - result = await handler(parsed) - if route.adapt_result is not None: - return route.adapt_result(result) + async def handle(self, params: Any) -> Any: + if self.func is None: + if self.optional: + return self.default_result + raise RequestError.method_not_found(self.method) + result = await self.func(params) + if self.adapt_result is not None and self.kind == "request": + return self.adapt_result(result) return result - async def dispatch_notification(self, method: str, params: Any | None) -> None: - if isinstance(method, str) and method.startswith("_"): - if self._notification_extensions is None: - return - payload = params if isinstance(params, dict) else {} - await self._notification_extensions(method[1:], payload) - return - - route = self._notifications.get(method) - if route is None: - raise RequestError.method_not_found(method) - model = route.model - parsed = model.model_validate(params) - handler = route.handle() - if handler is None: - if route.optional: - return - raise RequestError.method_not_found(method) - await handler(parsed) - - -class RouterBuilder: +class MessageRouter: def __init__(self) -> None: - self._routes: list[Route] = [] + self._requests: dict[str, Route] = {} + self._notifications: dict[str, Route] = {} + self._request_extensions: RequestHandler | None = None + self._notification_extensions: RequestHandler | None = None + + def add_route(self, route: Route) -> None: + if route.kind == "request": + self._requests[route.method] = route + else: + self._notifications[route.method] = route + + def _make_func(self, model: type[BaseModel], obj: Any, attr: str) -> AsyncHandler | None: + func = getattr(obj, attr, None) + if func is None or not callable(func): + return None - def request( - self, - method: str, - model: type[BaseModel], - *, - optional: bool = False, - default_result: Any = None, - adapt_result: Callable[[Any | None], Any] | None = None, - ) -> Callable[[Callable[[], AsyncHandler | None]], Callable[[], AsyncHandler | None]]: - def decorator(factory: Callable[[], AsyncHandler | None]) -> Callable[[], AsyncHandler | None]: - self._routes.append( - Route( - method=method, - model=model, - handle=factory, - kind="request", - optional=optional, - default_result=default_result, - adapt_result=adapt_result, - ) - ) - return factory - - return decorator - - def notification( - self, - method: str, - model: type[BaseModel], - *, - optional: bool = False, - ) -> Callable[[Callable[[], AsyncHandler | None]], Callable[[], AsyncHandler | None]]: - def decorator(factory: Callable[[], AsyncHandler | None]) -> Callable[[], AsyncHandler | None]: - self._routes.append( - Route( - method=method, - model=model, - handle=factory, - kind="notification", - optional=optional, - ) - ) - return factory - - return decorator - - def build( - self, - *, - request_extensions: Callable[[str, dict[str, Any]], Awaitable[Any]] | None = None, - notification_extensions: Callable[[str, dict[str, Any]], Awaitable[None]] | None = None, - ) -> MessageRouter: - return MessageRouter( - routes=self._routes, - request_extensions=request_extensions, - notification_extensions=notification_extensions, - ) + async def wrapper(params: Any) -> Any: + model_obj = model.model_validate(params) + params = {k: getattr(model_obj, k) for k in model.model_fields if k != "field_meta"} + if meta := getattr(model_obj, "field_meta", None): + params.update(meta) + return await func(**params) # type: ignore[arg-type] + + return wrapper - def request_attr( + def route_request( self, method: str, model: type[BaseModel], obj: Any, attr: str, - *, optional: bool = False, default_result: Any = None, adapt_result: Callable[[Any | None], Any] | None = None, - ) -> None: - self.request( - method, - model, + ) -> Route: + """Register a request route with obj and attribute name.""" + route = Route( + method=method, + func=self._make_func(model, obj, attr), + kind="request", optional=optional, default_result=default_result, adapt_result=adapt_result, - )(attribute_handler(obj, attr)) + ) + self.add_route(route) + return route - def notification_attr( + def route_notification( self, method: str, model: type[BaseModel], obj: Any, attr: str, - *, optional: bool = False, - ) -> None: - self.notification(method, model, optional=optional)(attribute_handler(obj, attr)) + ) -> Route: + """Register a notification route with obj and attribute name.""" + route = Route( + method=method, + func=self._make_func(model, obj, attr), + kind="notification", + optional=optional, + ) + self.add_route(route) + return route + + def handle_extension_request(self, handler: HandlerT) -> HandlerT: + """Register a handler for extension requests.""" + self._request_extensions = handler + return handler + + def handle_extension_notification(self, handler: HandlerT) -> HandlerT: + """Register a handler for extension notifications.""" + self._notification_extensions = handler + return handler + + async def __call__(self, method: str, params: Any | None, is_notification: bool) -> Any: + """The main router call to handle a request or notification.""" + if is_notification: + ext_handler = self._notification_extensions + routes = self._notifications + else: + ext_handler = self._request_extensions + routes = self._requests + if isinstance(method, str) and method.startswith("_"): + if ext_handler is None: + raise RequestError.method_not_found(method) + payload = params if isinstance(params, dict) else {} + return await ext_handler(method[1:], payload) -def attribute_handler(obj: Any, attr: str) -> Callable[[], AsyncHandler | None]: - def factory() -> AsyncHandler | None: - func = getattr(obj, attr, None) - return func if callable(func) else None + route = routes.get(method) + if route is None: + raise RequestError.method_not_found(method) - return factory + return await route.handle(params) diff --git a/src/acp/utils.py b/src/acp/utils.py index e81d7ba..e32f467 100644 --- a/src/acp/utils.py +++ b/src/acp/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from typing import Any, TypeVar from pydantic import BaseModel @@ -20,6 +21,8 @@ ] ModelT = TypeVar("ModelT", bound=BaseModel) +MethodT = TypeVar("MethodT", bound=Callable) +ClassT = TypeVar("ClassT", bound=type) def serialize_params(params: BaseModel) -> dict[str, Any]: @@ -94,3 +97,14 @@ async def request_optional_model( async def notify_model(conn: Connection, method: str, params: BaseModel) -> None: """Send a notification with serialized params.""" await conn.send_notification(method, serialize_params(params)) + + +def param_model(param_cls: type[BaseModel]) -> Callable[[MethodT], MethodT]: + """Decorator to map the method parameters to a Pydantic model. + It is just a marker and does nothing at runtime. + """ + + def decorator(func: MethodT) -> MethodT: + return func + + return decorator diff --git a/tests/real_user/test_cancel_prompt_flow.py b/tests/real_user/test_cancel_prompt_flow.py index 1189e17..64b76e5 100644 --- a/tests/real_user/test_cancel_prompt_flow.py +++ b/tests/real_user/test_cancel_prompt_flow.py @@ -1,9 +1,17 @@ import asyncio +from typing import Any import pytest -from acp import AgentSideConnection, CancelNotification, ClientSideConnection, PromptRequest, PromptResponse -from acp.schema import TextContentBlock +from acp import AgentSideConnection, ClientSideConnection, PromptResponse +from acp.schema import ( + AudioContentBlock, + EmbeddedResourceContentBlock, + ImageContentBlock, + PromptRequest, + ResourceContentBlock, + TextContentBlock, +) from tests.test_rpc import TestAgent, TestClient, _Server # Regression from a real user session where cancel needed to interrupt a long-running prompt. @@ -17,8 +25,19 @@ def __init__(self) -> None: self.prompt_started = asyncio.Event() self.cancel_received = asyncio.Event() - async def prompt(self, params: PromptRequest) -> PromptResponse: - self.prompts.append(params) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + self.prompts.append(PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None)) self.prompt_started.set() try: await asyncio.wait_for(self.cancel_received.wait(), timeout=1.0) @@ -27,8 +46,8 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: raise AssertionError(msg) from exc return PromptResponse(stop_reason="cancelled") - async def cancel(self, params: CancelNotification) -> None: - await super().cancel(params) + async def cancel(self, session_id: str, **kwargs: Any) -> None: + await super().cancel(session_id, **kwargs) self.cancel_received.set() @@ -40,16 +59,17 @@ async def test_cancel_reaches_agent_during_prompt() -> None: agent_conn = ClientSideConnection(lambda _conn: client, server._client_writer, server._client_reader) _client_conn = AgentSideConnection(lambda _conn: agent, server._server_writer, server._server_reader) - prompt_request = PromptRequest( - session_id="sess-xyz", - prompt=[TextContentBlock(type="text", text="hello")], + prompt_task = asyncio.create_task( + agent_conn.prompt( + session_id="sess-xyz", + prompt=[TextContentBlock(type="text", text="hello")], + ) ) - prompt_task = asyncio.create_task(agent_conn.prompt(prompt_request)) await agent.prompt_started.wait() assert not prompt_task.done(), "Prompt finished before cancel was sent" - await agent_conn.cancel(CancelNotification(session_id="sess-xyz")) + await agent_conn.cancel(session_id="sess-xyz") await asyncio.wait_for(agent.cancel_received.wait(), timeout=1.0) diff --git a/tests/real_user/test_permission_flow.py b/tests/real_user/test_permission_flow.py index 3b051af..6c6d3ca 100644 --- a/tests/real_user/test_permission_flow.py +++ b/tests/real_user/test_permission_flow.py @@ -1,9 +1,18 @@ import asyncio +from typing import Any import pytest -from acp import AgentSideConnection, ClientSideConnection, PromptRequest, PromptResponse, RequestPermissionRequest -from acp.schema import PermissionOption, TextContentBlock, ToolCall +from acp import AgentSideConnection, ClientSideConnection, PromptResponse +from acp.schema import ( + AudioContentBlock, + EmbeddedResourceContentBlock, + ImageContentBlock, + PermissionOption, + ResourceContentBlock, + TextContentBlock, + ToolCall, +) from tests.test_rpc import TestAgent, TestClient, _Server # Regression from real-world runs where agents paused prompts to obtain user permission. @@ -17,19 +26,28 @@ def __init__(self, conn: AgentSideConnection) -> None: self._conn = conn self.permission_responses = [] - async def prompt(self, params: PromptRequest) -> PromptResponse: - permission = await self._conn.requestPermission( - RequestPermissionRequest( - session_id=params.session_id, - options=[ - PermissionOption(option_id="allow", name="Allow", kind="allow_once"), - PermissionOption(option_id="deny", name="Deny", kind="reject_once"), - ], - tool_call=ToolCall(tool_call_id="call-1", title="Write File"), - ) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + permission = await self._conn.request_permission( + session_id=session_id, + options=[ + PermissionOption(option_id="allow", name="Allow", kind="allow_once"), + PermissionOption(option_id="deny", name="Deny", kind="reject_once"), + ], + tool_call=ToolCall(tool_call_id="call-1", title="Write File"), ) self.permission_responses.append(permission) - return await super().prompt(params) + return await super().prompt(prompt, session_id, **kwargs) @pytest.mark.asyncio @@ -49,10 +67,8 @@ async def test_agent_request_permission_roundtrip() -> None: response = await asyncio.wait_for( agent_conn.prompt( - PromptRequest( - session_id="sess-perm", - prompt=[TextContentBlock(type="text", text="needs approval")], - ) + session_id="sess-perm", + prompt=[TextContentBlock(type="text", text="needs approval")], ), timeout=1.0, ) diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 7fe3653..112e0c3 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -3,38 +3,34 @@ import json import sys from pathlib import Path +from typing import Any import pytest from acp import ( Agent, AgentSideConnection, - AuthenticateRequest, AuthenticateResponse, - CancelNotification, Client, ClientSideConnection, - InitializeRequest, + CreateTerminalResponse, InitializeResponse, - LoadSessionRequest, + KillTerminalCommandResponse, LoadSessionResponse, - NewSessionRequest, NewSessionResponse, PromptRequest, PromptResponse, - ReadTextFileRequest, ReadTextFileResponse, + ReleaseTerminalResponse, RequestError, RequestPermissionRequest, RequestPermissionResponse, SessionNotification, - SetSessionModelRequest, SetSessionModelResponse, - SetSessionModeRequest, SetSessionModeResponse, - WriteTextFileRequest, + TerminalOutputResponse, + WaitForTerminalExitResponse, WriteTextFileResponse, - session_notification, spawn_agent_process, start_tool_call, update_agent_message_text, @@ -42,9 +38,23 @@ ) from acp.schema import ( AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, AllowedOutcome, + AudioContentBlock, + AvailableCommandsUpdate, + ClientCapabilities, + CurrentModeUpdate, DeniedOutcome, + EmbeddedResourceContentBlock, + EnvVariable, + HttpMcpServer, + ImageContentBlock, + Implementation, PermissionOption, + ResourceContentBlock, + SseMcpServer, + StdioMcpServer, TextContentBlock, ToolCall, ToolCallLocation, @@ -137,45 +147,80 @@ def queue_permission_selected(self, option_id: str) -> None: RequestPermissionResponse(outcome=AllowedOutcome(option_id=option_id, outcome="selected")) ) - async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: + async def request_permission( + self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any + ) -> RequestPermissionResponse: if self.permission_outcomes: return self.permission_outcomes.pop() return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) - async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse: - self.files[str(params.path)] = params.content + async def write_text_file( + self, content: str, path: str, session_id: str, **kwargs: Any + ) -> WriteTextFileResponse | None: + self.files[str(path)] = content return WriteTextFileResponse() - async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse: - content = self.files.get(str(params.path), "default content") + async def read_text_file( + self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any + ) -> ReadTextFileResponse: + content = self.files.get(str(path), "default content") return ReadTextFileResponse(content=content) - async def sessionUpdate(self, params: SessionNotification) -> None: - self.notifications.append(params) + async def session_update( + self, + session_id: str, + update: UserMessageChunk + | AgentMessageChunk + | AgentThoughtChunk + | ToolCallStart + | ToolCallProgress + | AgentPlanUpdate + | AvailableCommandsUpdate + | CurrentModeUpdate, + **kwargs: Any, + ) -> None: + self.notifications.append(SessionNotification(session_id=session_id, update=update, field_meta=kwargs or None)) # Optional terminal methods (not implemented in this test client) - async def createTerminal(self, params): # pragma: no cover - placeholder + async def create_terminal( + self, + command: str, + session_id: str, + args: list[str] | None = None, + cwd: str | None = None, + env: list[EnvVariable] | None = None, + output_byte_limit: int | None = None, + **kwargs: Any, + ) -> CreateTerminalResponse: raise NotImplementedError - async def terminalOutput(self, params): # pragma: no cover - placeholder + async def terminal_output( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> TerminalOutputResponse: # pragma: no cover - placeholder raise NotImplementedError - async def releaseTerminal(self, params): # pragma: no cover - placeholder + async def release_terminal( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> ReleaseTerminalResponse | None: raise NotImplementedError - async def waitForTerminalExit(self, params): # pragma: no cover - placeholder + async def wait_for_terminal_exit( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> WaitForTerminalExitResponse: raise NotImplementedError - async def killTerminal(self, params): # pragma: no cover - placeholder + async def kill_terminal( + self, session_id: str, terminal_id: str | None = None, **kwargs: Any + ) -> KillTerminalCommandResponse | None: raise NotImplementedError - async def extMethod(self, method: str, params: dict) -> dict: + async def ext_method(self, method: str, params: dict) -> dict: self.ext_calls.append((method, params)) if method == "example.com/ping": return {"response": "pong", "params": params} raise RequestError.method_not_found(method) - async def extNotification(self, method: str, params: dict) -> None: + async def ext_notification(self, method: str, params: dict) -> None: self.ext_notes.append((method, params)) @@ -188,39 +233,60 @@ def __init__(self) -> None: self.ext_calls: list[tuple[str, dict]] = [] self.ext_notes: list[tuple[str, dict]] = [] - async def initialize(self, params: InitializeRequest) -> InitializeResponse: + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: # Avoid serializer warnings by omitting defaults - return InitializeResponse(protocol_version=params.protocol_version, agent_capabilities=None, auth_methods=[]) + return InitializeResponse(protocol_version=protocol_version, agent_capabilities=None, auth_methods=[]) - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any + ) -> NewSessionResponse: return NewSessionResponse(session_id="test-session-123") - async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse: + async def load_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any + ) -> LoadSessionResponse | None: return LoadSessionResponse() - async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse: + async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None: return AuthenticateResponse() - async def prompt(self, params: PromptRequest) -> PromptResponse: - self.prompts.append(params) + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + self.prompts.append(PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None)) return PromptResponse(stop_reason="end_turn") - async def cancel(self, params: CancelNotification) -> None: - self.cancellations.append(params.session_id) + async def cancel(self, session_id: str, **kwargs: Any) -> None: + self.cancellations.append(session_id) - async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse: + async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None: return SetSessionModeResponse() - async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse: + async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any) -> SetSessionModelResponse | None: return SetSessionModelResponse() - async def extMethod(self, method: str, params: dict) -> dict: + async def ext_method(self, method: str, params: dict) -> dict: self.ext_calls.append((method, params)) if method == "example.com/echo": return {"echo": params} raise RequestError.method_not_found(method) - async def extNotification(self, method: str, params: dict) -> None: + async def ext_notification(self, method: str, params: dict) -> None: self.ext_notes.append((method, params)) @@ -234,31 +300,25 @@ async def test_initialize_and_new_session(): client = TestClient() # server side is agent; client side is client agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) - _client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) + AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) - resp = await agent_conn.initialize(InitializeRequest(protocol_version=1)) + resp = await agent_conn.initialize(protocol_version=1) assert isinstance(resp, InitializeResponse) assert resp.protocol_version == 1 - new_sess = await agent_conn.newSession(NewSessionRequest(mcp_servers=[], cwd="/test")) + new_sess = await agent_conn.new_session(mcp_servers=[], cwd="/test") assert new_sess.session_id == "test-session-123" - load_resp = await agent_conn.loadSession( - LoadSessionRequest(session_id=new_sess.session_id, cwd="/test", mcp_servers=[]) - ) + load_resp = await agent_conn.load_session(session_id=new_sess.session_id, cwd="/test", mcp_servers=[]) assert isinstance(load_resp, LoadSessionResponse) - auth_resp = await agent_conn.authenticate(AuthenticateRequest(method_id="password")) + auth_resp = await agent_conn.authenticate(method_id="password") assert isinstance(auth_resp, AuthenticateResponse) - mode_resp = await agent_conn.setSessionMode( - SetSessionModeRequest(session_id=new_sess.session_id, mode_id="ask") - ) + mode_resp = await agent_conn.set_session_mode(session_id=new_sess.session_id, mode_id="ask") assert isinstance(mode_resp, SetSessionModeResponse) - model_resp = await agent_conn.setSessionModel( - SetSessionModelRequest(session_id=new_sess.session_id, model_id="gpt-4o") - ) + model_resp = await agent_conn.set_session_model(session_id=new_sess.session_id, model_id="gpt-4o") assert isinstance(model_resp, SetSessionModelResponse) @@ -272,13 +332,11 @@ async def test_bidirectional_file_ops(): client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # Agent asks client to read - res = await client_conn.readTextFile(ReadTextFileRequest(session_id="sess", path="/test/file.txt")) + res = await client_conn.read_text_file(session_id="sess", path="/test/file.txt") assert res.content == "Hello, World!" # Agent asks client to write - write_result = await client_conn.writeTextFile( - WriteTextFileRequest(session_id="sess", path="/test/file.txt", content="Updated") - ) + write_result = await client_conn.write_text_file(session_id="sess", path="/test/file.txt", content="Updated") assert isinstance(write_result, WriteTextFileResponse) assert client.files["/test/file.txt"] == "Updated" @@ -293,7 +351,7 @@ async def test_cancel_notification_and_capture_wire(): _client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # Send cancel notification from client-side connection to agent - await agent_conn.cancel(CancelNotification(session_id="test-123")) + await agent_conn.cancel(session_id="test-123") # Read raw line from server peer (it will be consumed by agent receive loop quickly). # Instead, wait a brief moment and assert agent recorded it. @@ -313,23 +371,19 @@ async def test_session_notifications_flow(): client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # Agent -> Client notifications - await client_conn.sessionUpdate( - SessionNotification( - session_id="sess", - update=AgentMessageChunk( - session_update="agent_message_chunk", - content=TextContentBlock(type="text", text="Hello"), - ), - ) + await client_conn.session_update( + session_id="sess", + update=AgentMessageChunk( + session_update="agent_message_chunk", + content=TextContentBlock(type="text", text="Hello"), + ), ) - await client_conn.sessionUpdate( - SessionNotification( - session_id="sess", - update=UserMessageChunk( - session_update="user_message_chunk", - content=TextContentBlock(type="text", text="World"), - ), - ) + await client_conn.session_update( + session_id="sess", + update=UserMessageChunk( + session_update="user_message_chunk", + content=TextContentBlock(type="text", text="World"), + ), ) # Wait for async dispatch @@ -352,7 +406,7 @@ async def test_concurrent_reads(): client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) async def read_one(i: int): - return await client_conn.readTextFile(ReadTextFileRequest(session_id="sess", path=f"/test/file{i}.txt")) + return await client_conn.read_text_file(session_id="sess", path=f"/test/file{i}.txt") results = await asyncio.gather(*(read_one(i) for i in range(5))) for i, res in enumerate(results): @@ -404,24 +458,24 @@ async def test_set_session_mode_and_extensions(): client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # setSessionMode - resp = await agent_conn.setSessionMode(SetSessionModeRequest(session_id="sess", mode_id="yolo")) + resp = await agent_conn.set_session_mode(session_id="sess", mode_id="yolo") assert isinstance(resp, SetSessionModeResponse) - model_resp = await agent_conn.setSessionModel(SetSessionModelRequest(session_id="sess", model_id="gpt-4o-mini")) + model_resp = await agent_conn.set_session_model(session_id="sess", model_id="gpt-4o-mini") assert isinstance(model_resp, SetSessionModelResponse) # extMethod - echo = await agent_conn.extMethod("example.com/echo", {"x": 1}) + echo = await agent_conn.ext_method("example.com/echo", {"x": 1}) assert echo == {"echo": {"x": 1}} # extNotification - await agent_conn.extNotification("note", {"y": 2}) + await agent_conn.ext_notification("note", {"y": 2}) # allow dispatch await asyncio.sleep(0.05) assert agent.ext_notes and agent.ext_notes[-1][0] == "note" # client extension method - ping = await client_conn.extMethod("example.com/ping", {"k": 3}) + ping = await client_conn.ext_method("example.com/ping", {"k": 3}) assert ping == {"response": "pong", "params": {"k": 3}} assert client.ext_calls and client.ext_calls[-1] == ("example.com/ping", {"k": 3}) @@ -459,40 +513,55 @@ def bind(self, conn: AgentSideConnection) -> "_ExampleAgent": self._conn = conn return self - async def initialize(self, params: InitializeRequest) -> InitializeResponse: - return InitializeResponse(protocol_version=params.protocol_version) - - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: + return InitializeResponse(protocol_version=protocol_version) + + async def new_session( + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any + ) -> NewSessionResponse: return NewSessionResponse(session_id="sess_demo") - async def prompt(self, params: PromptRequest) -> PromptResponse: + async def prompt( + self, + prompt: list[ + TextContentBlock + | ImageContentBlock + | AudioContentBlock + | ResourceContentBlock + | EmbeddedResourceContentBlock + ], + session_id: str, + **kwargs: Any, + ) -> PromptResponse: assert self._conn is not None - self.prompt_requests.append(params) + self.prompt_requests.append(PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None)) - await self._conn.sessionUpdate( - session_notification( - params.session_id, - update_agent_message_text("I'll help you with that."), - ) + await self._conn.session_update( + session_id, + update_agent_message_text("I'll help you with that."), ) - await self._conn.sessionUpdate( - session_notification( - params.session_id, - start_tool_call( - "call_1", - "Modifying configuration", - kind="edit", - status="pending", - locations=[ToolCallLocation(path="/project/config.json")], - raw_input={"path": "/project/config.json"}, - ), - ) + await self._conn.session_update( + session_id, + start_tool_call( + "call_1", + "Modifying configuration", + kind="edit", + status="pending", + locations=[ToolCallLocation(path="/project/config.json")], + raw_input={"path": "/project/config.json"}, + ), ) - permission_request = RequestPermissionRequest( - session_id=params.session_id, - tool_call=ToolCall( + permission_request = { + "session_id": session_id, + "tool_call": ToolCall( tool_call_id="call_1", title="Modifying configuration", kind="edit", @@ -500,30 +569,26 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: locations=[ToolCallLocation(path="/project/config.json")], raw_input={"path": "/project/config.json"}, ), - options=[ + "options": [ PermissionOption(kind="allow_once", name="Allow", option_id="allow"), PermissionOption(kind="reject_once", name="Reject", option_id="reject"), ], - ) - response = await self._conn.requestPermission(permission_request) + } + response = await self._conn.request_permission(**permission_request) self.permission_response = response if isinstance(response.outcome, AllowedOutcome) and response.outcome.option_id == "allow": - await self._conn.sessionUpdate( - session_notification( - params.session_id, - update_tool_call( - "call_1", - status="completed", - raw_output={"success": True}, - ), - ) + await self._conn.session_update( + session_id, + update_tool_call( + "call_1", + status="completed", + raw_output={"success": True}, + ), ) - await self._conn.sessionUpdate( - session_notification( - params.session_id, - update_agent_message_text("Done."), - ) + await self._conn.session_update( + session_id, + update_agent_message_text("Done."), ) return PromptResponse(stop_reason="end_turn") @@ -536,7 +601,23 @@ def __init__(self) -> None: super().__init__() self.permission_requests: list[RequestPermissionRequest] = [] - async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: + async def request_permission( + self, + options: list[PermissionOption] | RequestPermissionRequest, + session_id: str | None = None, + tool_call: ToolCall | None = None, + **kwargs: Any, + ) -> RequestPermissionResponse: + if isinstance(options, RequestPermissionRequest): + params = options + else: + assert session_id is not None and tool_call is not None + params = RequestPermissionRequest( + options=options, + session_id=session_id, + tool_call=tool_call, + field_meta=kwargs or None, + ) self.permission_requests.append(params) if not params.options: return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) @@ -553,16 +634,16 @@ async def test_example_agent_permission_flow(): agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) AgentSideConnection(lambda conn: agent.bind(conn), s._server_writer, s._server_reader) - init = await agent_conn.initialize(InitializeRequest(protocol_version=1)) + init = await agent_conn.initialize(protocol_version=1) assert init.protocol_version == 1 - session = await agent_conn.newSession(NewSessionRequest(mcp_servers=[], cwd="/workspace")) + session = await agent_conn.new_session(mcp_servers=[], cwd="/workspace") assert session.session_id == "sess_demo" - prompt = PromptRequest( + + resp = await agent_conn.prompt( session_id=session.session_id, prompt=[TextContentBlock(type="text", text="Please edit config")], ) - resp = await agent_conn.prompt(prompt) assert resp.stop_reason == "end_turn" for _ in range(50): if len(client.notifications) >= 4: @@ -610,14 +691,13 @@ async def test_spawn_agent_process_roundtrip(tmp_path): test_client = TestClient() async with spawn_agent_process(lambda _agent: test_client, sys.executable, str(script)) as (client_conn, process): - init = await client_conn.initialize(InitializeRequest(protocol_version=1)) + init = await client_conn.initialize(protocol_version=1) assert isinstance(init, InitializeResponse) - session = await client_conn.newSession(NewSessionRequest(cwd=str(tmp_path), mcp_servers=[])) - prompt = PromptRequest( + session = await client_conn.new_session(mcp_servers=[], cwd=str(tmp_path)) + await client_conn.prompt( session_id=session.session_id, prompt=[TextContentBlock(type="text", text="hi spawn")], ) - await client_conn.prompt(prompt) # Wait for echo agent notification to arrive for _ in range(50): From 9a930c66ea464fc3a2168bff1c35767b9391bd44 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Thu, 20 Nov 2025 11:59:55 +0800 Subject: [PATCH 2/2] fix: overrides in gen_schema.py Signed-off-by: Frost Ming --- scripts/gen_schema.py | 18 +++++++++--------- scripts/gen_signature.py | 5 ++++- src/acp/schema.py | 20 ++++++-------------- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index 033abe2..700fad5 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -85,7 +85,7 @@ ("PermissionOption", "kind", "PermissionOptionKind", False), ("PlanEntry", "priority", "PlanEntryPriority", False), ("PlanEntry", "status", "PlanEntryStatus", False), - ("PromptResponse", "stopReason", "StopReason", False), + ("PromptResponse", "stop_reason", "StopReason", False), ("ToolCallProgress", "kind", "ToolKind", True), ("ToolCallProgress", "status", "ToolCallStatus", True), ("ToolCallStart", "kind", "ToolKind", True), @@ -95,23 +95,23 @@ ) DEFAULT_VALUE_OVERRIDES: tuple[tuple[str, str, str], ...] = ( - ("AgentCapabilities", "mcpCapabilities", "McpCapabilities(http=False, sse=False)"), + ("AgentCapabilities", "mcp_capabilities", "McpCapabilities()"), ( "AgentCapabilities", - "promptCapabilities", - "PromptCapabilities(audio=False, embeddedContext=False, image=False)", + "prompt_capabilities", + "PromptCapabilities()", ), - ("ClientCapabilities", "fs", "FileSystemCapability(readTextFile=False, writeTextFile=False)"), + ("ClientCapabilities", "fs", "FileSystemCapability()"), ("ClientCapabilities", "terminal", "False"), ( "InitializeRequest", - "clientCapabilities", - "ClientCapabilities(fs=FileSystemCapability(readTextFile=False, writeTextFile=False), terminal=False)", + "client_capabilities", + "ClientCapabilities()", ), ( "InitializeResponse", - "agentCapabilities", - "AgentCapabilities(loadSession=False, mcpCapabilities=McpCapabilities(http=False, sse=False), promptCapabilities=PromptCapabilities(audio=False, embeddedContext=False, image=False))", + "agent_capabilities", + "AgentCapabilities()", ), ) diff --git a/scripts/gen_signature.py b/scripts/gen_signature.py index 635a7a1..b3a7add 100644 --- a/scripts/gen_signature.py +++ b/scripts/gen_signature.py @@ -86,7 +86,7 @@ def _to_param_def(self, name: str, field: FieldInfo) -> tuple[ast.arg, ast.expr ann = field.annotation if field.default is PydanticUndefined: default = None - elif isinstance(field.default, dict): + elif isinstance(field.default, dict | BaseModel): default = ast.Constant(None) else: default = ast.Constant(value=field.default) @@ -128,6 +128,9 @@ def _format_annotation(self, annotation: t.Any) -> ast.expr: def gen_signature(source_dir: Path) -> None: + import importlib + + importlib.reload(schema) # Ensure schema is up to date for source_file in source_dir.rglob("*.py"): transformer = NodeTransformer() transformer.transform(source_file) diff --git a/src/acp/schema.py b/src/acp/schema.py index 0e9c9cb..1d8b91e 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -623,7 +623,7 @@ class AgentCapabilities(BaseModel): alias="mcpCapabilities", description="MCP capabilities supported by the agent.", ), - ] = {"http": False, "sse": False} + ] = McpCapabilities() # Prompt capabilities supported by the agent. prompt_capabilities: Annotated[ Optional[PromptCapabilities], @@ -631,7 +631,7 @@ class AgentCapabilities(BaseModel): alias="promptCapabilities", description="Prompt capabilities supported by the agent.", ), - ] = {"audio": False, "embeddedContext": False, "image": False} + ] = PromptCapabilities() class AgentErrorMessage(BaseModel): @@ -731,7 +731,7 @@ class ClientCapabilities(BaseModel): Field( description="File system capabilities supported by the client.\nDetermines which file operations the agent can request." ), - ] = FileSystemCapability(readTextFile=False, writeTextFile=False) + ] = FileSystemCapability() # Whether the Client support all `terminal/*` methods. terminal: Annotated[ Optional[bool], @@ -870,7 +870,7 @@ class InitializeRequest(BaseModel): alias="clientCapabilities", description="Capabilities supported by the client.", ), - ] = {"fs": {"readTextFile": False, "writeTextFile": False}, "terminal": False} + ] = ClientCapabilities() # Information about the Client name and version sent to the Agent. # # Note: in future versions of the protocol, this will be required. @@ -906,15 +906,7 @@ class InitializeResponse(BaseModel): alias="agentCapabilities", description="Capabilities supported by the agent.", ), - ] = { - "loadSession": False, - "mcpCapabilities": {"http": False, "sse": False}, - "promptCapabilities": { - "audio": False, - "embeddedContext": False, - "image": False, - }, - } + ] = AgentCapabilities() # Information about the Agent name and version sent to the Client. # # Note: in future versions of the protocol, this will be required. @@ -1031,7 +1023,7 @@ class PromptResponse(BaseModel): ] = None # Indicates why the agent stopped processing the turn. stop_reason: Annotated[ - str, + StopReason, Field( alias="stopReason", description="Indicates why the agent stopped processing the turn.",