From 9edf3bec4b57c8491fc4bd63ac5f0d81e194b3e3 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Wed, 19 Nov 2025 11:02:54 +0800 Subject: [PATCH 1/2] feat: change field names to snake case Signed-off-by: Frost Ming --- docs/quickstart.md | 24 +- examples/agent.py | 31 +- examples/client.py | 12 +- examples/echo_agent.py | 10 +- examples/gemini.py | 32 +- scripts/gen_schema.py | 1 + src/acp/contrib/session_state.py | 26 +- src/acp/schema.py | 350 ++++++++++++++------ tests/contrib/test_contrib_permissions.py | 16 +- tests/contrib/test_contrib_session_state.py | 38 +-- tests/contrib/test_contrib_tool_calls.py | 4 +- tests/real_user/test_cancel_prompt_flow.py | 12 +- tests/real_user/test_permission_flow.py | 20 +- tests/real_user/test_stdio_limits.py | 4 +- tests/test_rpc.py | 194 ++++++----- 15 files changed, 468 insertions(+), 306 deletions(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index 3a54147..824315c 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -4,13 +4,13 @@ Spin up a working ACP agent/client loop in minutes. Keep this page beside the te ## Quick checklist -| Goal | Command / Link | -| --- | --- | -| Install the SDK | `pip install agent-client-protocol` or `uv add agent-client-protocol` | -| Run the echo agent | `python examples/echo_agent.py` | -| Point Zed (or another client) at it | Update `settings.json` as shown below | -| Programmatically drive an agent | Copy the `spawn_agent_process` example | -| Run tests before hacking further | `make check && make test` | +| Goal | Command / Link | +| ----------------------------------- | --------------------------------------------------------------------- | +| Install the SDK | `pip install agent-client-protocol` or `uv add agent-client-protocol` | +| Run the echo agent | `python examples/echo_agent.py` | +| Point Zed (or another client) at it | Update `settings.json` as shown below | +| Programmatically drive an agent | Copy the `spawn_agent_process` example | +| Run tests before hacking further | `make check && make test` | ## Before you begin @@ -84,17 +84,17 @@ class SimpleClient(Client): return {"outcome": {"outcome": "cancelled"}} async def sessionUpdate(self, params: SessionNotification) -> None: - print("update:", params.sessionId, params.update) + print("update:", params.session_id, params.update) async def main() -> None: script = Path("examples/echo_agent.py") async with spawn_agent_process(lambda _agent: SimpleClient(), sys.executable, str(script)) as (conn, _proc): - await conn.initialize(InitializeRequest(protocolVersion=1)) - session = await conn.newSession(NewSessionRequest(cwd=str(script.parent), mcpServers=[])) + await conn.initialize(InitializeRequest(protocol_version=1)) + session = await conn.newSession(NewSessionRequest(cwd=str(script.parent), mcp_servers=[])) await conn.prompt( PromptRequest( - sessionId=session.sessionId, + session_id=session.session_id, prompt=[text_block("Hello from spawn!")], ) ) @@ -117,7 +117,7 @@ from acp import Agent, PromptRequest, PromptResponse class MyAgent(Agent): async def prompt(self, params: PromptRequest) -> PromptResponse: # inspect params.prompt, stream updates, then finish the turn - return PromptResponse(stopReason="end_turn") + return PromptResponse(stop_reason="end_turn") ``` Hook it up with `AgentSideConnection` inside an async entrypoint and wire it to your client. Refer to: diff --git a/examples/agent.py b/examples/agent.py index a75e1a5..ce09e6c 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -40,13 +40,13 @@ async def _send_agent_message(self, session_id: str, content: Any) -> None: async def initialize(self, params: InitializeRequest) -> InitializeResponse: # noqa: ARG002 logging.info("Received initialize request") return InitializeResponse( - protocolVersion=PROTOCOL_VERSION, - agentCapabilities=AgentCapabilities(), - agentInfo=Implementation(name="example-agent", title="Example Agent", version="0.1.0"), + protocol_version=PROTOCOL_VERSION, + agent_capabilities=AgentCapabilities(), + agent_info=Implementation(name="example-agent", title="Example Agent", version="0.1.0"), ) async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: # noqa: ARG002 - logging.info("Received authenticate request %s", params.methodId) + logging.info("Received authenticate request %s", params.method_id) return AuthenticateResponse() async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # noqa: ARG002 @@ -54,30 +54,29 @@ async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # session_id = str(self._next_session_id) self._next_session_id += 1 self._sessions.add(session_id) - return NewSessionResponse(sessionId=session_id, modes=None) + return NewSessionResponse(session_id=session_id, modes=None) async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: # noqa: ARG002 - logging.info("Received load session request %s", params.sessionId) - self._sessions.add(params.sessionId) + logging.info("Received load session request %s", params.session_id) + self._sessions.add(params.session_id) return LoadSessionResponse() async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # noqa: ARG002 - logging.info("Received set session mode request %s -> %s", params.sessionId, params.modeId) + logging.info("Received set session mode request %s -> %s", params.session_id, params.mode_id) return SetSessionModeResponse() async def prompt(self, params: PromptRequest) -> PromptResponse: - logging.info("Received prompt request for session %s", params.sessionId) - if params.sessionId not in self._sessions: - self._sessions.add(params.sessionId) + logging.info("Received prompt request for session %s", params.session_id) + if params.session_id not in self._sessions: + self._sessions.add(params.session_id) - await self._send_agent_message(params.sessionId, text_block("Client sent:")) + await self._send_agent_message(params.session_id, text_block("Client sent:")) for block in params.prompt: - await self._send_agent_message(params.sessionId, block) - - return PromptResponse(stopReason="end_turn") + await self._send_agent_message(params.session_id, block) + return PromptResponse(stop_reason="end_turn") async def cancel(self, params: CancelNotification) -> None: # noqa: ARG002 - logging.info("Received cancel notification for session %s", params.sessionId) + logging.info("Received cancel notification for session %s", params.session_id) async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002 logging.info("Received extension method call: %s", method) diff --git a/examples/client.py b/examples/client.py index 8c62462..69841b5 100644 --- a/examples/client.py +++ b/examples/client.py @@ -104,7 +104,7 @@ async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None: try: await conn.prompt( PromptRequest( - sessionId=session_id, + session_id=session_id, prompt=[text_block(line)], ) ) @@ -146,14 +146,14 @@ async def main(argv: list[str]) -> int: await conn.initialize( InitializeRequest( - protocolVersion=PROTOCOL_VERSION, - clientCapabilities=ClientCapabilities(), - clientInfo=Implementation(name="example-client", title="Example Client", version="0.1.0"), + protocol_version=PROTOCOL_VERSION, + client_capabilities=ClientCapabilities(), + client_info=Implementation(name="example-client", title="Example Client", version="0.1.0"), ) ) - session = await conn.newSession(NewSessionRequest(mcpServers=[], cwd=os.getcwd())) + session = await conn.newSession(NewSessionRequest(mcp_servers=[], cwd=os.getcwd())) - await interactive_loop(conn, session.sessionId) + await interactive_loop(conn, session.session_id) if proc.returncode is None: proc.terminate() diff --git a/examples/echo_agent.py b/examples/echo_agent.py index 657eb28..6528f51 100644 --- a/examples/echo_agent.py +++ b/examples/echo_agent.py @@ -22,10 +22,10 @@ def __init__(self, conn): self._conn = conn async def initialize(self, params: InitializeRequest) -> InitializeResponse: - return InitializeResponse(protocolVersion=params.protocolVersion) + return InitializeResponse(protocol_version=params.protocol_version) async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: - return NewSessionResponse(sessionId=uuid4().hex) + return NewSessionResponse(session_id=uuid4().hex) async def prompt(self, params: PromptRequest) -> PromptResponse: for block in params.prompt: @@ -34,16 +34,16 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: chunk.field_meta = {"echo": True} chunk.content.field_meta = {"echo": True} - notification = session_notification(params.sessionId, chunk) + notification = session_notification(params.session_id, chunk) notification.field_meta = {"source": "echo_agent"} await self._conn.sessionUpdate(notification) - return PromptResponse(stopReason="end_turn") + return PromptResponse(stop_reason="end_turn") async def main() -> None: reader, writer = await stdio_streams() - AgentSideConnection(lambda conn: EchoAgent(conn), writer, reader) + AgentSideConnection(EchoAgent, writer, reader) await asyncio.Event().wait() diff --git a/examples/gemini.py b/examples/gemini.py index f1fe9a9..adea53f 100644 --- a/examples/gemini.py +++ b/examples/gemini.py @@ -73,9 +73,9 @@ async def requestPermission( option = _pick_preferred_option(params.options) if option is None: return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) - return RequestPermissionResponse(outcome=AllowedOutcome(optionId=option.optionId, outcome="selected")) + return RequestPermissionResponse(outcome=AllowedOutcome(option_id=option.option_id, outcome="selected")) - title = params.toolCall.title or "" + title = params.tool_call.title or "" if not params.options: print(f"\nšŸ” Permission requested: {title} (no options, cancelling)") return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) @@ -92,7 +92,9 @@ async def requestPermission( idx = int(choice) - 1 if 0 <= idx < len(params.options): opt = params.options[idx] - return RequestPermissionResponse(outcome=AllowedOutcome(optionId=opt.optionId, outcome="selected")) + return RequestPermissionResponse( + outcome=AllowedOutcome(option_id=opt.option_id, outcome="selected") + ) print("Invalid selection, try again.") async def writeTextFile( @@ -141,13 +143,13 @@ async def sessionUpdate( print(f"\nšŸ”§ {update.title} ({update.status or 'pending'})") elif isinstance(update, ToolCallProgress): status = update.status or "in_progress" - print(f"\nšŸ”§ Tool call `{update.toolCallId}` → {status}") + print(f"\nšŸ”§ Tool call `{update.tool_call_id}` → {status}") if update.content: for item in update.content: if isinstance(item, FileEditToolCallContent): print(f" diff: {item.path}") elif isinstance(item, TerminalToolCallContent): - print(f" terminal: {item.terminalId}") + print(f" terminal: {item.terminal_id}") elif isinstance(item, dict): print(f" content: {json.dumps(item, indent=2)}") else: @@ -159,7 +161,7 @@ async def createTerminal( params: CreateTerminalRequest, ) -> CreateTerminalResponse: # type: ignore[override] print(f"[Client] createTerminal: {params}") - return CreateTerminalResponse(terminalId="term-1") + return CreateTerminalResponse(terminal_id="term-1") async def terminalOutput( self, @@ -246,13 +248,13 @@ async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None: if line in {":exit", ":quit"}: break if line == ":cancel": - await conn.cancel(CancelNotification(sessionId=session_id)) + await conn.cancel(CancelNotification(session_id=session_id)) continue try: await conn.prompt( PromptRequest( - sessionId=session_id, + session_id=session_id, prompt=[text_block(line)], ) ) @@ -321,9 +323,9 @@ async def run(argv: list[str]) -> int: try: init_resp = await conn.initialize( InitializeRequest( - protocolVersion=PROTOCOL_VERSION, - clientCapabilities=ClientCapabilities( - fs=FileSystemCapability(readTextFile=True, writeTextFile=True), + protocol_version=PROTOCOL_VERSION, + client_capabilities=ClientCapabilities( + fs=FileSystemCapability(read_text_file=True, write_text_file=True), terminal=True, ), ) @@ -337,13 +339,13 @@ async def run(argv: list[str]) -> int: await _shutdown(proc, conn) return 1 - print(f"āœ… Connected to Gemini (protocol v{init_resp.protocolVersion})") + print(f"āœ… Connected to Gemini (protocol v{init_resp.protocol_version})") try: session = await conn.newSession( NewSessionRequest( cwd=os.getcwd(), - mcpServers=[], + mcp_servers=[], ) ) except RequestError as err: @@ -355,10 +357,10 @@ async def run(argv: list[str]) -> int: await _shutdown(proc, conn) return 1 - print(f"šŸ“ Created session: {session.sessionId}") + print(f"šŸ“ Created session: {session.session_id}") try: - await interactive_loop(conn, session.sessionId) + await interactive_loop(conn, session.session_id) finally: await _shutdown(proc, conn) diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index 0badf30..6e0c43b 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -173,6 +173,7 @@ def generate_schema(*, format_output: bool = True) -> None: "--output-model-type", "pydantic_v2.BaseModel", "--use-annotated", + "--snake-case-field", ] subprocess.check_call(cmd) # noqa: S603 diff --git a/src/acp/contrib/session_state.py b/src/acp/contrib/session_state.py index 7933be6..ee56125 100644 --- a/src/acp/contrib/session_state.py +++ b/src/acp/contrib/session_state.py @@ -62,8 +62,8 @@ def apply_start(self, update: ToolCallStart) -> None: self.status = update.status self.content = _copy_model_list(update.content) self.locations = _copy_model_list(update.locations) - self.raw_input = update.rawInput - self.raw_output = update.rawOutput + self.raw_input = update.raw_input + self.raw_output = update.raw_output def apply_progress(self, update: ToolCallProgress) -> None: if update.title is not None: @@ -76,10 +76,10 @@ def apply_progress(self, update: ToolCallProgress) -> None: self.content = _copy_model_list(update.content) if update.locations is not None: self.locations = _copy_model_list(update.locations) - if update.rawInput is not None: - self.raw_input = update.rawInput - if update.rawOutput is not None: - self.raw_output = update.rawOutput + if update.raw_input is not None: + self.raw_input = update.raw_input + if update.raw_output is not None: + self.raw_output = update.raw_output def snapshot(self) -> ToolCallView: return ToolCallView( @@ -185,11 +185,11 @@ def apply(self, notification: SessionNotification) -> SessionSnapshot: def _ensure_session(self, notification: SessionNotification) -> None: if self.session_id is None: - self.session_id = notification.sessionId + self.session_id = notification.session_id return - if notification.sessionId != self.session_id: - self._handle_session_change(notification.sessionId) + if notification.session_id != self.session_id: + self._handle_session_change(notification.session_id) def _handle_session_change(self, session_id: str) -> None: expected = self.session_id @@ -206,14 +206,14 @@ def _handle_session_change(self, session_id: str) -> None: def _apply_update(self, update: Any) -> None: if isinstance(update, ToolCallStart): state = self._tool_calls.setdefault( - update.toolCallId, _MutableToolCallState(tool_call_id=update.toolCallId) + update.tool_call_id, _MutableToolCallState(tool_call_id=update.tool_call_id) ) state.apply_start(update) return if isinstance(update, ToolCallProgress): state = self._tool_calls.setdefault( - update.toolCallId, _MutableToolCallState(tool_call_id=update.toolCallId) + update.tool_call_id, _MutableToolCallState(tool_call_id=update.tool_call_id) ) state.apply_progress(update) return @@ -223,11 +223,11 @@ def _apply_update(self, update: Any) -> None: return if isinstance(update, CurrentModeUpdate): - self._current_mode_id = update.currentModeId + self._current_mode_id = update.current_mode_id return if isinstance(update, AvailableCommandsUpdate): - self._available_commands = _copy_model_list(update.availableCommands) or [] + self._available_commands = _copy_model_list(update.available_commands) or [] return if isinstance(update, UserMessageChunk): diff --git a/src/acp/schema.py b/src/acp/schema.py index 4814f08..0e9c9cb 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -32,10 +32,11 @@ class AuthenticateRequest(BaseModel): ] = None # The ID of the authentication method to use. # Must be one of the methods advertised in the initialize response. - methodId: Annotated[ + method_id: Annotated[ str, Field( - description="The ID of the authentication method to use.\nMust be one of the methods advertised in the initialize response." + alias="methodId", + description="The ID of the authentication method to use.\nMust be one of the methods advertised in the initialize response.", ), ] @@ -71,7 +72,7 @@ class BlobResourceContents(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None blob: str - mimeType: Optional[str] = None + mime_type: Annotated[Optional[str], Field(alias="mimeType")] = None uri: str @@ -82,7 +83,13 @@ class CreateTerminalResponse(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The unique identifier for the created terminal. - terminalId: Annotated[str, Field(description="The unique identifier for the created terminal.")] + terminal_id: Annotated[ + str, + Field( + alias="terminalId", + description="The unique identifier for the created terminal.", + ), + ] class EnvVariable(BaseModel): @@ -131,14 +138,20 @@ class FileSystemCapability(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Whether the Client supports `fs/read_text_file` requests. - readTextFile: Annotated[ + read_text_file: Annotated[ Optional[bool], - Field(description="Whether the Client supports `fs/read_text_file` requests."), + Field( + alias="readTextFile", + description="Whether the Client supports `fs/read_text_file` requests.", + ), ] = False # Whether the Client supports `fs/write_text_file` requests. - writeTextFile: Annotated[ + write_text_file: Annotated[ Optional[bool], - Field(description="Whether the Client supports `fs/write_text_file` requests."), + Field( + alias="writeTextFile", + description="Whether the Client supports `fs/write_text_file` requests.", + ), ] = False @@ -256,7 +269,7 @@ class ModelInfo(BaseModel): # Optional description of the model. description: Annotated[Optional[str], Field(description="Optional description of the model.")] = None # Unique identifier for the model. - modelId: Annotated[str, Field(description="Unique identifier for the model.")] + model_id: Annotated[str, Field(alias="modelId", description="Unique identifier for the model.")] # Human-readable name of the model. name: Annotated[str, Field(description="Human-readable name of the model.")] @@ -273,9 +286,12 @@ class NewSessionRequest(BaseModel): Field(description="The working directory for this session. Must be an absolute path."), ] # List of MCP (Model Context Protocol) servers the agent should connect to. - mcpServers: Annotated[ + mcp_servers: Annotated[ List[Union[HttpMcpServer, SseMcpServer, StdioMcpServer]], - Field(description="List of MCP (Model Context Protocol) servers the agent should connect to."), + Field( + alias="mcpServers", + description="List of MCP (Model Context Protocol) servers the agent should connect to.", + ), ] @@ -291,10 +307,11 @@ class PromptCapabilities(BaseModel): # # When enabled, the Client is allowed to include [`ContentBlock::Resource`] # in prompt requests for pieces of context that are referenced in the message. - embeddedContext: Annotated[ + embedded_context: Annotated[ Optional[bool], Field( - description="Agent supports embedded context in `session/prompt` requests.\n\nWhen enabled, the Client is allowed to include [`ContentBlock::Resource`]\nin prompt requests for pieces of context that are referenced in the message." + alias="embeddedContext", + description="Agent supports embedded context in `session/prompt` requests.\n\nWhen enabled, the Client is allowed to include [`ContentBlock::Resource`]\nin prompt requests for pieces of context that are referenced in the message.", ), ] = False # Agent supports [`ContentBlock::Image`]. @@ -324,7 +341,10 @@ class DeniedOutcome(BaseModel): class AllowedOutcome(BaseModel): # The ID of the option the user selected. - optionId: Annotated[str, Field(description="The ID of the option the user selected.")] + option_id: Annotated[ + str, + Field(alias="optionId", description="The ID of the option the user selected."), + ] outcome: Literal["selected"] @@ -356,9 +376,18 @@ class SessionModelState(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The set of models that the Agent can use - availableModels: Annotated[List[ModelInfo], Field(description="The set of models that the Agent can use")] + available_models: Annotated[ + List[ModelInfo], + Field( + alias="availableModels", + description="The set of models that the Agent can use", + ), + ] # The current model the Agent is in. - currentModelId: Annotated[str, Field(description="The current model the Agent is in.")] + current_model_id: Annotated[ + str, + Field(alias="currentModelId", description="The current model the Agent is in."), + ] class CurrentModeUpdate(BaseModel): @@ -368,8 +397,8 @@ class CurrentModeUpdate(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The ID of the current mode - currentModeId: Annotated[str, Field(description="The ID of the current mode")] - sessionUpdate: Literal["current_mode_update"] + current_mode_id: Annotated[str, Field(alias="currentModeId", description="The ID of the current mode")] + session_update: Annotated[Literal["current_mode_update"], Field(alias="sessionUpdate")] class SetSessionModeRequest(BaseModel): @@ -379,9 +408,12 @@ class SetSessionModeRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The ID of the mode to set. - modeId: Annotated[str, Field(description="The ID of the mode to set.")] + mode_id: Annotated[str, Field(alias="modeId", description="The ID of the mode to set.")] # The ID of the session to set the mode for. - sessionId: Annotated[str, Field(description="The ID of the session to set the mode for.")] + session_id: Annotated[ + str, + Field(alias="sessionId", description="The ID of the session to set the mode for."), + ] class SetSessionModeResponse(BaseModel): @@ -395,9 +427,12 @@ class SetSessionModelRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The ID of the model to set. - modelId: Annotated[str, Field(description="The ID of the model to set.")] + model_id: Annotated[str, Field(alias="modelId", description="The ID of the model to set.")] # The ID of the session to set the model for. - sessionId: Annotated[str, Field(description="The ID of the session to set the model for.")] + session_id: Annotated[ + str, + Field(alias="sessionId", description="The ID of the session to set the model for."), + ] class SetSessionModelResponse(BaseModel): @@ -415,9 +450,10 @@ class TerminalExitStatus(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The process exit code (may be null if terminated by signal). - exitCode: Annotated[ + exit_code: Annotated[ Optional[int], Field( + alias="exitCode", description="The process exit code (may be null if terminated by signal).", ge=0, ), @@ -436,9 +472,12 @@ class TerminalOutputRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] # The ID of the terminal to get output from. - terminalId: Annotated[str, Field(description="The ID of the terminal to get output from.")] + terminal_id: Annotated[ + str, + Field(alias="terminalId", description="The ID of the terminal to get output from."), + ] class TerminalOutputResponse(BaseModel): @@ -448,9 +487,9 @@ class TerminalOutputResponse(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Exit status if the command has completed. - exitStatus: Annotated[ + exit_status: Annotated[ Optional[TerminalExitStatus], - Field(description="Exit status if the command has completed."), + Field(alias="exitStatus", description="Exit status if the command has completed."), ] = None # The terminal output captured so far. output: Annotated[str, Field(description="The terminal output captured so far.")] @@ -464,7 +503,7 @@ class TextResourceContents(BaseModel): Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - mimeType: Optional[str] = None + mime_type: Annotated[Optional[str], Field(alias="mimeType")] = None text: str uri: str @@ -476,16 +515,19 @@ class FileEditToolCallContent(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The new content after modification. - newText: Annotated[str, Field(description="The new content after modification.")] + new_text: Annotated[str, Field(alias="newText", description="The new content after modification.")] # The original content (None for new files). - oldText: Annotated[Optional[str], Field(description="The original content (None for new files).")] = None + old_text: Annotated[ + Optional[str], + Field(alias="oldText", description="The original content (None for new files)."), + ] = None # The file path being modified. path: Annotated[str, Field(description="The file path being modified.")] type: Literal["diff"] class TerminalToolCallContent(BaseModel): - terminalId: str + terminal_id: Annotated[str, Field(alias="terminalId")] type: Literal["terminal"] @@ -508,9 +550,12 @@ class WaitForTerminalExitRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] # The ID of the terminal to wait for. - terminalId: Annotated[str, Field(description="The ID of the terminal to wait for.")] + terminal_id: Annotated[ + str, + Field(alias="terminalId", description="The ID of the terminal to wait for."), + ] class WaitForTerminalExitResponse(BaseModel): @@ -520,9 +565,10 @@ class WaitForTerminalExitResponse(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The process exit code (may be null if terminated by signal). - exitCode: Annotated[ + exit_code: Annotated[ Optional[int], Field( + alias="exitCode", description="The process exit code (may be null if terminated by signal).", ge=0, ), @@ -545,7 +591,7 @@ class WriteTextFileRequest(BaseModel): # Absolute path to the file to write. path: Annotated[str, Field(description="Absolute path to the file to write.")] # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] class WriteTextFileResponse(BaseModel): @@ -563,17 +609,29 @@ class AgentCapabilities(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Whether the agent supports `session/load`. - loadSession: Annotated[Optional[bool], Field(description="Whether the agent supports `session/load`.")] = False + load_session: Annotated[ + Optional[bool], + Field( + alias="loadSession", + description="Whether the agent supports `session/load`.", + ), + ] = False # MCP capabilities supported by the agent. - mcpCapabilities: Annotated[ + mcp_capabilities: Annotated[ Optional[McpCapabilities], - Field(description="MCP capabilities supported by the agent."), - ] = McpCapabilities(http=False, sse=False) + Field( + alias="mcpCapabilities", + description="MCP capabilities supported by the agent.", + ), + ] = {"http": False, "sse": False} # Prompt capabilities supported by the agent. - promptCapabilities: Annotated[ + prompt_capabilities: Annotated[ Optional[PromptCapabilities], - Field(description="Prompt capabilities supported by the agent."), - ] = PromptCapabilities(audio=False, embeddedContext=False, image=False) + Field( + alias="promptCapabilities", + description="Prompt capabilities supported by the agent.", + ), + ] = {"audio": False, "embeddedContext": False, "image": False} class AgentErrorMessage(BaseModel): @@ -603,7 +661,7 @@ class Annotations(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None audience: Optional[List[Role]] = None - lastModified: Optional[str] = None + last_modified: Annotated[Optional[str], Field(alias="lastModified")] = None priority: Optional[float] = None @@ -651,7 +709,13 @@ class CancelNotification(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The ID of the session to cancel operations for. - sessionId: Annotated[str, Field(description="The ID of the session to cancel operations for.")] + session_id: Annotated[ + str, + Field( + alias="sessionId", + description="The ID of the session to cancel operations for.", + ), + ] class ClientCapabilities(BaseModel): @@ -720,7 +784,7 @@ class ImageContentBlock(BaseModel): ] = None annotations: Optional[Annotations] = None data: str - mimeType: str + mime_type: Annotated[str, Field(alias="mimeType")] type: Literal["image"] uri: Optional[str] = None @@ -733,7 +797,7 @@ class AudioContentBlock(BaseModel): ] = None annotations: Optional[Annotations] = None data: str - mimeType: str + mime_type: Annotated[str, Field(alias="mimeType")] type: Literal["audio"] @@ -745,7 +809,7 @@ class ResourceContentBlock(BaseModel): ] = None annotations: Optional[Annotations] = None description: Optional[str] = None - mimeType: Optional[str] = None + mime_type: Annotated[Optional[str], Field(alias="mimeType")] = None name: str size: Optional[int] = None title: Optional[str] = None @@ -781,15 +845,16 @@ class CreateTerminalRequest(BaseModel): # The Client MUST ensure truncation happens at a character boundary to maintain valid # string output, even if this means the retained output is slightly less than the # specified limit. - outputByteLimit: Annotated[ + output_byte_limit: Annotated[ Optional[int], Field( + alias="outputByteLimit", description="Maximum number of output bytes to retain.\n\nWhen the limit is exceeded, the Client truncates from the beginning of the output\nto stay within the limit.\n\nThe Client MUST ensure truncation happens at a character boundary to maintain valid\nstring output, even if this means the retained output is slightly less than the\nspecified limit.", ge=0, ), ] = None # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] class InitializeRequest(BaseModel): @@ -799,23 +864,28 @@ class InitializeRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Capabilities supported by the client. - clientCapabilities: Annotated[ + client_capabilities: Annotated[ Optional[ClientCapabilities], - Field(description="Capabilities supported by the client."), - ] = ClientCapabilities(fs=FileSystemCapability(readTextFile=False, writeTextFile=False), terminal=False) + Field( + alias="clientCapabilities", + description="Capabilities supported by the client.", + ), + ] = {"fs": {"readTextFile": False, "writeTextFile": False}, "terminal": False} # Information about the Client name and version sent to the Agent. # # Note: in future versions of the protocol, this will be required. - clientInfo: Annotated[ + client_info: Annotated[ Optional[Implementation], Field( - description="Information about the Client name and version sent to the Agent.\n\nNote: in future versions of the protocol, this will be required." + alias="clientInfo", + description="Information about the Client name and version sent to the Agent.\n\nNote: in future versions of the protocol, this will be required.", ), ] = None # The latest protocol version supported by the client. - protocolVersion: Annotated[ + protocol_version: Annotated[ int, Field( + alias="protocolVersion", description="The latest protocol version supported by the client.", ge=0, le=65535, @@ -830,35 +900,47 @@ class InitializeResponse(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Capabilities supported by the agent. - agentCapabilities: Annotated[ + agent_capabilities: Annotated[ Optional[AgentCapabilities], - Field(description="Capabilities supported by the agent."), - ] = AgentCapabilities( - loadSession=False, - mcpCapabilities=McpCapabilities(http=False, sse=False), - promptCapabilities=PromptCapabilities(audio=False, embeddedContext=False, image=False), - ) + Field( + alias="agentCapabilities", + description="Capabilities supported by the agent.", + ), + ] = { + "loadSession": False, + "mcpCapabilities": {"http": False, "sse": False}, + "promptCapabilities": { + "audio": False, + "embeddedContext": False, + "image": False, + }, + } # Information about the Agent name and version sent to the Client. # # Note: in future versions of the protocol, this will be required. - agentInfo: Annotated[ + agent_info: Annotated[ Optional[Implementation], Field( - description="Information about the Agent name and version sent to the Client.\n\nNote: in future versions of the protocol, this will be required." + alias="agentInfo", + description="Information about the Agent name and version sent to the Client.\n\nNote: in future versions of the protocol, this will be required.", ), ] = None # Authentication methods supported by the agent. - authMethods: Annotated[ + auth_methods: Annotated[ Optional[List[AuthMethod]], - Field(description="Authentication methods supported by the agent."), + Field( + alias="authMethods", + description="Authentication methods supported by the agent.", + ), ] = [] # The protocol version the client specified if supported by the agent, # or the latest protocol version supported by the agent. # # The client should disconnect, if it doesn't support this version. - protocolVersion: Annotated[ + protocol_version: Annotated[ int, Field( + alias="protocolVersion", description="The protocol version the client specified if supported by the agent,\nor the latest protocol version supported by the agent.\n\nThe client should disconnect, if it doesn't support this version.", ge=0, le=65535, @@ -873,9 +955,9 @@ class KillTerminalCommandRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] # The ID of the terminal to kill. - terminalId: Annotated[str, Field(description="The ID of the terminal to kill.")] + terminal_id: Annotated[str, Field(alias="terminalId", description="The ID of the terminal to kill.")] class LoadSessionRequest(BaseModel): @@ -887,12 +969,15 @@ class LoadSessionRequest(BaseModel): # The working directory for this session. cwd: Annotated[str, Field(description="The working directory for this session.")] # List of MCP servers to connect to for this session. - mcpServers: Annotated[ + mcp_servers: Annotated[ List[Union[HttpMcpServer, SseMcpServer, StdioMcpServer]], - Field(description="List of MCP servers to connect to for this session."), + Field( + alias="mcpServers", + description="List of MCP servers to connect to for this session.", + ), ] # The ID of the session to load. - sessionId: Annotated[str, Field(description="The ID of the session to load.")] + session_id: Annotated[str, Field(alias="sessionId", description="The ID of the session to load.")] class PermissionOption(BaseModel): @@ -906,7 +991,13 @@ class PermissionOption(BaseModel): # Human-readable label to display to the user. name: Annotated[str, Field(description="Human-readable label to display to the user.")] # Unique identifier for this permission option. - optionId: Annotated[str, Field(description="Unique identifier for this permission option.")] + option_id: Annotated[ + str, + Field( + alias="optionId", + description="Unique identifier for this permission option.", + ), + ] class PlanEntry(BaseModel): @@ -939,7 +1030,13 @@ class PromptResponse(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Indicates why the agent stopped processing the turn. - stopReason: Annotated[StopReason, Field(description="Indicates why the agent stopped processing the turn.")] + stop_reason: Annotated[ + str, + Field( + alias="stopReason", + description="Indicates why the agent stopped processing the turn.", + ), + ] class ReadTextFileRequest(BaseModel): @@ -958,7 +1055,7 @@ class ReadTextFileRequest(BaseModel): # Absolute path to the file to read. path: Annotated[str, Field(description="Absolute path to the file to read.")] # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] class ReleaseTerminalRequest(BaseModel): @@ -968,9 +1065,9 @@ class ReleaseTerminalRequest(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] # The ID of the terminal to release. - terminalId: Annotated[str, Field(description="The ID of the terminal to release.")] + terminal_id: Annotated[str, Field(alias="terminalId", description="The ID of the terminal to release.")] class SessionMode(BaseModel): @@ -992,12 +1089,18 @@ class SessionModeState(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The set of modes that the Agent can operate in - availableModes: Annotated[ + available_modes: Annotated[ List[SessionMode], - Field(description="The set of modes that the Agent can operate in"), + Field( + alias="availableModes", + description="The set of modes that the Agent can operate in", + ), ] # The current mode the Agent is in. - currentModeId: Annotated[str, Field(description="The current mode the Agent is in.")] + current_mode_id: Annotated[ + str, + Field(alias="currentModeId", description="The current mode the Agent is in."), + ] class AgentPlanUpdate(BaseModel): @@ -1016,7 +1119,7 @@ class AgentPlanUpdate(BaseModel): description="The list of tasks to be accomplished.\n\nWhen updating a plan, the agent must send a complete list of all entries\nwith their current status. The client replaces the entire plan with each update." ), ] - sessionUpdate: Literal["plan"] + session_update: Annotated[Literal["plan"], Field(alias="sessionUpdate")] class AvailableCommandsUpdate(BaseModel): @@ -1026,8 +1129,11 @@ class AvailableCommandsUpdate(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # Commands the agent can execute - availableCommands: Annotated[List[AvailableCommand], Field(description="Commands the agent can execute")] - sessionUpdate: Literal["available_commands_update"] + available_commands: Annotated[ + List[AvailableCommand], + Field(alias="availableCommands", description="Commands the agent can execute"), + ] + session_update: Annotated[Literal["available_commands_update"], Field(alias="sessionUpdate")] class ClientResponseMessage(BaseModel): @@ -1143,10 +1249,11 @@ class NewSessionResponse(BaseModel): # Unique identifier for the created session. # # Used in all subsequent requests for this conversation. - sessionId: Annotated[ + session_id: Annotated[ str, Field( - description="Unique identifier for the created session.\n\nUsed in all subsequent requests for this conversation." + alias="sessionId", + description="Unique identifier for the created session.\n\nUsed in all subsequent requests for this conversation.", ), ] @@ -1185,7 +1292,13 @@ class PromptRequest(BaseModel): ), ] # The ID of the session to send this user message to - sessionId: Annotated[str, Field(description="The ID of the session to send this user message to")] + session_id: Annotated[ + str, + Field( + alias="sessionId", + description="The ID of the session to send this user message to", + ), + ] class UserMessageChunk(BaseModel): @@ -1201,7 +1314,7 @@ class UserMessageChunk(BaseModel): ], Field(description="A single item of content", discriminator="type"), ] - sessionUpdate: Literal["user_message_chunk"] + session_update: Annotated[Literal["user_message_chunk"], Field(alias="sessionUpdate")] class AgentMessageChunk(BaseModel): @@ -1217,7 +1330,7 @@ class AgentMessageChunk(BaseModel): ], Field(description="A single item of content", discriminator="type"), ] - sessionUpdate: Literal["agent_message_chunk"] + session_update: Annotated[Literal["agent_message_chunk"], Field(alias="sessionUpdate")] class AgentThoughtChunk(BaseModel): @@ -1233,7 +1346,7 @@ class AgentThoughtChunk(BaseModel): ], Field(description="A single item of content", discriminator="type"), ] - sessionUpdate: Literal["agent_thought_chunk"] + session_update: Annotated[Literal["agent_thought_chunk"], Field(alias="sessionUpdate")] class ContentToolCallContent(BaseModel): @@ -1266,15 +1379,18 @@ class ToolCall(BaseModel): Field(description="Replace the locations collection."), ] = None # Update the raw input. - rawInput: Annotated[Optional[Any], Field(description="Update the raw input.")] = None + raw_input: Annotated[Optional[Any], Field(alias="rawInput", description="Update the raw input.")] = None # Update the raw output. - rawOutput: Annotated[Optional[Any], Field(description="Update the raw output.")] = None + raw_output: Annotated[Optional[Any], Field(alias="rawOutput", description="Update the raw output.")] = None # Update the execution status. status: Annotated[Optional[ToolCallStatus], Field(description="Update the execution status.")] = None # Update the human-readable title. title: Annotated[Optional[str], Field(description="Update the human-readable title.")] = None # The ID of the tool call being updated. - toolCallId: Annotated[str, Field(description="The ID of the tool call being updated.")] + tool_call_id: Annotated[ + str, + Field(alias="toolCallId", description="The ID of the tool call being updated."), + ] class RequestPermissionRequest(BaseModel): @@ -1289,9 +1405,15 @@ class RequestPermissionRequest(BaseModel): Field(description="Available permission options for the user to choose from."), ] # The session ID for this request. - sessionId: Annotated[str, Field(description="The session ID for this request.")] + session_id: Annotated[str, Field(alias="sessionId", description="The session ID for this request.")] # Details about the tool call requiring permission. - toolCall: Annotated[ToolCall, Field(description="Details about the tool call requiring permission.")] + tool_call: Annotated[ + ToolCall, + Field( + alias="toolCall", + description="Details about the tool call requiring permission.", + ), + ] class ToolCallStart(BaseModel): @@ -1320,10 +1442,16 @@ class ToolCallStart(BaseModel): Field(description='File locations affected by this tool call.\nEnables "follow-along" features in clients.'), ] = None # Raw input parameters sent to the tool. - rawInput: Annotated[Optional[Any], Field(description="Raw input parameters sent to the tool.")] = None + raw_input: Annotated[ + Optional[Any], + Field(alias="rawInput", description="Raw input parameters sent to the tool."), + ] = None # Raw output returned by the tool. - rawOutput: Annotated[Optional[Any], Field(description="Raw output returned by the tool.")] = None - sessionUpdate: Literal["tool_call"] + raw_output: Annotated[ + Optional[Any], + Field(alias="rawOutput", description="Raw output returned by the tool."), + ] = None + session_update: Annotated[Literal["tool_call"], Field(alias="sessionUpdate")] # Current execution status of the tool call. status: Annotated[Optional[ToolCallStatus], Field(description="Current execution status of the tool call.")] = None # Human-readable title describing what the tool is doing. @@ -1332,9 +1460,12 @@ class ToolCallStart(BaseModel): Field(description="Human-readable title describing what the tool is doing."), ] # Unique identifier for this tool call within the session. - toolCallId: Annotated[ + tool_call_id: Annotated[ str, - Field(description="Unique identifier for this tool call within the session."), + Field( + alias="toolCallId", + description="Unique identifier for this tool call within the session.", + ), ] @@ -1357,16 +1488,19 @@ class ToolCallProgress(BaseModel): Field(description="Replace the locations collection."), ] = None # Update the raw input. - rawInput: Annotated[Optional[Any], Field(description="Update the raw input.")] = None + raw_input: Annotated[Optional[Any], Field(alias="rawInput", description="Update the raw input.")] = None # Update the raw output. - rawOutput: Annotated[Optional[Any], Field(description="Update the raw output.")] = None - sessionUpdate: Literal["tool_call_update"] + raw_output: Annotated[Optional[Any], Field(alias="rawOutput", description="Update the raw output.")] = None + session_update: Annotated[Literal["tool_call_update"], Field(alias="sessionUpdate")] # Update the execution status. status: Annotated[Optional[ToolCallStatus], Field(description="Update the execution status.")] = None # Update the human-readable title. title: Annotated[Optional[str], Field(description="Update the human-readable title.")] = None # The ID of the tool call being updated. - toolCallId: Annotated[str, Field(description="The ID of the tool call being updated.")] + tool_call_id: Annotated[ + str, + Field(alias="toolCallId", description="The ID of the tool call being updated."), + ] class AgentResponseMessage(BaseModel): @@ -1448,7 +1582,13 @@ class SessionNotification(BaseModel): Field(alias="_meta", description="Extension point for implementations"), ] = None # The ID of the session this update pertains to. - sessionId: Annotated[str, Field(description="The ID of the session this update pertains to.")] + session_id: Annotated[ + str, + Field( + alias="sessionId", + description="The ID of the session this update pertains to.", + ), + ] # The actual update content. update: Annotated[ Union[ @@ -1461,7 +1601,7 @@ class SessionNotification(BaseModel): AvailableCommandsUpdate, CurrentModeUpdate, ], - Field(description="The actual update content.", discriminator="sessionUpdate"), + Field(description="The actual update content.", discriminator="session_update"), ] diff --git a/tests/contrib/test_contrib_permissions.py b/tests/contrib/test_contrib_permissions.py index 6ed32b7..4ad9105 100644 --- a/tests/contrib/test_contrib_permissions.py +++ b/tests/contrib/test_contrib_permissions.py @@ -21,7 +21,7 @@ async def test_permission_broker_uses_tracker_state(): async def fake_requester(request: RequestPermissionRequest): captured["request"] = request return RequestPermissionResponse( - outcome=AllowedOutcome(optionId=request.options[0].optionId, outcome="selected") + outcome=AllowedOutcome(option_id=request.options[0].option_id, outcome="selected") ) tracker = ToolCallTracker(id_factory=lambda: "perm-id") @@ -30,9 +30,9 @@ async def fake_requester(request: RequestPermissionRequest): result = await broker.request_for("external", description="Perform sensitive action") assert isinstance(result.outcome, AllowedOutcome) - assert result.outcome.optionId == captured["request"].options[0].optionId - assert captured["request"].toolCall.content is not None - last_content = captured["request"].toolCall.content[-1] + assert result.outcome.option_id == captured["request"].options[0].option_id + assert captured["request"].tool_call.content is not None + last_content = captured["request"].tool_call.content[-1] assert isinstance(last_content, ContentToolCallContent) assert isinstance(last_content.content, TextContentBlock) assert last_content.content.text.startswith("Perform sensitive action") @@ -43,14 +43,14 @@ async def test_permission_broker_accepts_custom_options(): tracker = ToolCallTracker(id_factory=lambda: "custom") tracker.start("external", title="Custom options") options = [ - PermissionOption(optionId="allow", name="Allow once", kind="allow_once"), + PermissionOption(option_id="allow", name="Allow once", kind="allow_once"), ] recorded: list[str] = [] async def requester(request: RequestPermissionRequest): - recorded.append(request.options[0].optionId) + recorded.append(request.options[0].option_id) return RequestPermissionResponse( - outcome=AllowedOutcome(optionId=request.options[0].optionId, outcome="selected") + outcome=AllowedOutcome(option_id=request.options[0].option_id, outcome="selected") ) broker = PermissionBroker("session", requester, tracker=tracker) @@ -61,4 +61,4 @@ async def requester(request: RequestPermissionRequest): def test_default_permission_options_shape(): options = default_permission_options() assert len(options) == 3 - assert {opt.optionId for opt in options} == {"approve", "approve_for_session", "reject"} + assert {opt.option_id for opt in options} == {"approve", "approve_for_session", "reject"} diff --git a/tests/contrib/test_contrib_session_state.py b/tests/contrib/test_contrib_session_state.py index c2339f6..deaa467 100644 --- a/tests/contrib/test_contrib_session_state.py +++ b/tests/contrib/test_contrib_session_state.py @@ -19,21 +19,21 @@ def notification(session_id: str, update): - return SessionNotification(sessionId=session_id, update=update) + return SessionNotification(session_id=session_id, update=update) def test_session_accumulator_merges_tool_calls(): acc = SessionAccumulator() start = ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-1", + session_update="tool_call", + tool_call_id="call-1", title="Read file", status="in_progress", ) acc.apply(notification("s", start)) progress = ToolCallProgress( - sessionUpdate="tool_call_update", - toolCallId="call-1", + session_update="tool_call_update", + tool_call_id="call-1", status="completed", content=[ ContentToolCallContent( @@ -55,7 +55,7 @@ def test_session_accumulator_records_plan_and_mode(): notification( "s", AgentPlanUpdate( - sessionUpdate="plan", + session_update="plan", entries=[ PlanEntry(content="Step 1", priority="medium", status="pending"), ], @@ -63,7 +63,7 @@ def test_session_accumulator_records_plan_and_mode(): ) ) snapshot = acc.apply( - notification("s", CurrentModeUpdate(sessionUpdate="current_mode_update", currentModeId="coding")) + notification("s", CurrentModeUpdate(session_update="current_mode_update", current_mode_id="coding")) ) assert snapshot.plan_entries[0].content == "Step 1" assert snapshot.current_mode_id == "coding" @@ -75,8 +75,8 @@ def test_session_accumulator_tracks_messages_and_commands(): notification( "s", AvailableCommandsUpdate( - sessionUpdate="available_commands_update", - availableCommands=[], + session_update="available_commands_update", + available_commands=[], ), ) ) @@ -84,7 +84,7 @@ def test_session_accumulator_tracks_messages_and_commands(): notification( "s", UserMessageChunk( - sessionUpdate="user_message_chunk", + session_update="user_message_chunk", content=TextContentBlock(type="text", text="Hello"), ), ) @@ -93,7 +93,7 @@ def test_session_accumulator_tracks_messages_and_commands(): notification( "s", AgentMessageChunk( - sessionUpdate="agent_message_chunk", + session_update="agent_message_chunk", content=TextContentBlock(type="text", text="Hi!"), ), ) @@ -113,8 +113,8 @@ def test_session_accumulator_auto_resets_on_new_session(): notification( "s1", ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-1", + session_update="tool_call", + tool_call_id="call-1", title="First", ), ) @@ -123,8 +123,8 @@ def test_session_accumulator_auto_resets_on_new_session(): notification( "s2", ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-2", + session_update="tool_call", + tool_call_id="call-2", title="Second", ), ) @@ -142,8 +142,8 @@ def test_session_accumulator_rejects_cross_session_when_auto_reset_disabled(): notification( "s1", ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-1", + session_update="tool_call", + tool_call_id="call-1", title="First", ), ) @@ -153,8 +153,8 @@ def test_session_accumulator_rejects_cross_session_when_auto_reset_disabled(): notification( "s2", ToolCallStart( - sessionUpdate="tool_call", - toolCallId="call-2", + session_update="tool_call", + tool_call_id="call-2", title="Second", ), ) diff --git a/tests/contrib/test_contrib_tool_calls.py b/tests/contrib/test_contrib_tool_calls.py index a3fb290..1cfc1f0 100644 --- a/tests/contrib/test_contrib_tool_calls.py +++ b/tests/contrib/test_contrib_tool_calls.py @@ -7,10 +7,10 @@ def test_tool_call_tracker_generates_ids_and_updates(): tracker = ToolCallTracker(id_factory=lambda: "generated-id") start = tracker.start("external", title="Run command") - assert start.toolCallId == "generated-id" + assert start.tool_call_id == "generated-id" progress = tracker.progress("external", status="completed") assert isinstance(progress, ToolCallProgress) - assert progress.toolCallId == "generated-id" + assert progress.tool_call_id == "generated-id" view = tracker.view("external") assert view.status == "completed" diff --git a/tests/real_user/test_cancel_prompt_flow.py b/tests/real_user/test_cancel_prompt_flow.py index 45d7798..1189e17 100644 --- a/tests/real_user/test_cancel_prompt_flow.py +++ b/tests/real_user/test_cancel_prompt_flow.py @@ -25,7 +25,7 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: except asyncio.TimeoutError as exc: msg = "Cancel notification did not arrive while prompt pending" raise AssertionError(msg) from exc - return PromptResponse(stopReason="cancelled") + return PromptResponse(stop_reason="cancelled") async def cancel(self, params: CancelNotification) -> None: await super().cancel(params) @@ -37,11 +37,11 @@ async def test_cancel_reaches_agent_during_prompt() -> None: async with _Server() as server: agent = LongRunningAgent() client = TestClient() - agent_conn = ClientSideConnection(lambda _conn: client, server.client_writer, server.client_reader) - _client_conn = AgentSideConnection(lambda _conn: agent, server.server_writer, server.server_reader) + agent_conn = ClientSideConnection(lambda _conn: client, server._client_writer, server._client_reader) + _client_conn = AgentSideConnection(lambda _conn: agent, server._server_writer, server._server_reader) prompt_request = PromptRequest( - sessionId="sess-xyz", + session_id="sess-xyz", prompt=[TextContentBlock(type="text", text="hello")], ) prompt_task = asyncio.create_task(agent_conn.prompt(prompt_request)) @@ -49,10 +49,10 @@ async def test_cancel_reaches_agent_during_prompt() -> None: await agent.prompt_started.wait() assert not prompt_task.done(), "Prompt finished before cancel was sent" - await agent_conn.cancel(CancelNotification(sessionId="sess-xyz")) + await agent_conn.cancel(CancelNotification(session_id="sess-xyz")) await asyncio.wait_for(agent.cancel_received.wait(), timeout=1.0) response = await asyncio.wait_for(prompt_task, timeout=1.0) - assert response.stopReason == "cancelled" + assert response.stop_reason == "cancelled" assert agent.cancellations == ["sess-xyz"] diff --git a/tests/real_user/test_permission_flow.py b/tests/real_user/test_permission_flow.py index b07817c..3b051af 100644 --- a/tests/real_user/test_permission_flow.py +++ b/tests/real_user/test_permission_flow.py @@ -20,12 +20,12 @@ def __init__(self, conn: AgentSideConnection) -> None: async def prompt(self, params: PromptRequest) -> PromptResponse: permission = await self._conn.requestPermission( RequestPermissionRequest( - sessionId=params.sessionId, + session_id=params.session_id, options=[ - PermissionOption(optionId="allow", name="Allow", kind="allow_once"), - PermissionOption(optionId="deny", name="Deny", kind="reject_once"), + PermissionOption(option_id="allow", name="Allow", kind="allow_once"), + PermissionOption(option_id="deny", name="Deny", kind="reject_once"), ], - toolCall=ToolCall(toolCallId="call-1", title="Write File"), + tool_call=ToolCall(tool_call_id="call-1", title="Write File"), ) ) self.permission_responses.append(permission) @@ -40,27 +40,27 @@ async def test_agent_request_permission_roundtrip() -> None: captured_agent = [] - agent_conn = ClientSideConnection(lambda _conn: client, server.client_writer, server.client_reader) + agent_conn = ClientSideConnection(lambda _conn: client, server._client_writer, server._client_reader) _agent_conn = AgentSideConnection( lambda conn: captured_agent.append(PermissionRequestAgent(conn)) or captured_agent[-1], - server.server_writer, - server.server_reader, + server._server_writer, + server._server_reader, ) response = await asyncio.wait_for( agent_conn.prompt( PromptRequest( - sessionId="sess-perm", + session_id="sess-perm", prompt=[TextContentBlock(type="text", text="needs approval")], ) ), timeout=1.0, ) - assert response.stopReason == "end_turn" + assert response.stop_reason == "end_turn" assert captured_agent, "Agent was not constructed" [agent] = captured_agent assert agent.permission_responses, "Agent did not receive permission response" permission_response = agent.permission_responses[0] assert permission_response.outcome.outcome == "selected" - assert permission_response.outcome.optionId == "allow" + assert permission_response.outcome.option_id == "allow" diff --git a/tests/real_user/test_stdio_limits.py b/tests/real_user/test_stdio_limits.py index 3de0ef9..eb9be15 100644 --- a/tests/real_user/test_stdio_limits.py +++ b/tests/real_user/test_stdio_limits.py @@ -22,7 +22,7 @@ def _large_line_script(size: int = LARGE_LINE_SIZE) -> str: @pytest.mark.asyncio async def test_spawn_stdio_transport_hits_default_limit() -> None: script = _large_line_script() - async with spawn_stdio_transport(sys.executable, "-c", script) as (reader, writer, _process): + async with spawn_stdio_transport(sys.executable, "-c", script) as (reader, _writer, _process): # readline() re-raises LimitOverrunError as ValueError on CPython 3.12+. with pytest.raises(ValueError): await reader.readline() @@ -36,6 +36,6 @@ async def test_spawn_stdio_transport_custom_limit_handles_large_line() -> None: "-c", script, limit=LARGE_LINE_SIZE * 2, - ) as (reader, writer, _process): + ) as (reader, _writer, _process): line = await reader.readline() assert len(line) == LARGE_LINE_SIZE + 1 diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 5373045..7fe3653 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -59,42 +59,62 @@ class _Server: def __init__(self) -> None: self._server: asyncio.AbstractServer | None = None - self.server_reader: asyncio.StreamReader | None = None - self.server_writer: asyncio.StreamWriter | None = None - self.client_reader: asyncio.StreamReader | None = None - self.client_writer: asyncio.StreamWriter | None = None + self._server_reader: asyncio.StreamReader | None = None + self._server_writer: asyncio.StreamWriter | None = None + self._client_reader: asyncio.StreamReader | None = None + self._client_writer: asyncio.StreamWriter | None = None async def __aenter__(self): async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - self.server_reader = reader - self.server_writer = writer + self._server_reader = reader + self._server_writer = writer self._server = await asyncio.start_server(handle, host="127.0.0.1", port=0) host, port = self._server.sockets[0].getsockname()[:2] - self.client_reader, self.client_writer = await asyncio.open_connection(host, port) + self._client_reader, self._client_writer = await asyncio.open_connection(host, port) # wait until server side is set for _ in range(100): - if self.server_reader and self.server_writer: + if self._server_reader and self._server_writer: break await asyncio.sleep(0.01) - assert self.server_reader and self.server_writer - assert self.client_reader and self.client_writer + assert self._server_reader and self._server_writer + assert self._client_reader and self._client_writer return self async def __aexit__(self, exc_type, exc, tb): - if self.client_writer: - self.client_writer.close() + if self._client_writer: + self._client_writer.close() with contextlib.suppress(Exception): - await self.client_writer.wait_closed() - if self.server_writer: - self.server_writer.close() + await self._client_writer.wait_closed() + if self._server_writer: + self._server_writer.close() with contextlib.suppress(Exception): - await self.server_writer.wait_closed() + await self._server_writer.wait_closed() if self._server: self._server.close() await self._server.wait_closed() + @property + def server_writer(self) -> asyncio.StreamWriter: + assert self._server_writer is not None + return self._server_writer + + @property + def server_reader(self) -> asyncio.StreamReader: + assert self._server_reader is not None + return self._server_reader + + @property + def client_writer(self) -> asyncio.StreamWriter: + assert self._client_writer is not None + return self._client_writer + + @property + def client_reader(self) -> asyncio.StreamReader: + assert self._client_reader is not None + return self._client_reader + # --------------------- Test Doubles ----------------------- @@ -114,7 +134,7 @@ def queue_permission_cancelled(self) -> None: def queue_permission_selected(self, option_id: str) -> None: self.permission_outcomes.append( - RequestPermissionResponse(outcome=AllowedOutcome(optionId=option_id, outcome="selected")) + RequestPermissionResponse(outcome=AllowedOutcome(option_id=option_id, outcome="selected")) ) async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: @@ -170,10 +190,10 @@ def __init__(self) -> None: async def initialize(self, params: InitializeRequest) -> InitializeResponse: # Avoid serializer warnings by omitting defaults - return InitializeResponse(protocolVersion=params.protocolVersion, agentCapabilities=None, authMethods=[]) + return InitializeResponse(protocol_version=params.protocol_version, agent_capabilities=None, auth_methods=[]) async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: - return NewSessionResponse(sessionId="test-session-123") + return NewSessionResponse(session_id="test-session-123") async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse: return LoadSessionResponse() @@ -183,10 +203,10 @@ async def authenticate(self, params: AuthenticateRequest) -> AuthenticateRespons async def prompt(self, params: PromptRequest) -> PromptResponse: self.prompts.append(params) - return PromptResponse(stopReason="end_turn") + return PromptResponse(stop_reason="end_turn") async def cancel(self, params: CancelNotification) -> None: - self.cancellations.append(params.sessionId) + self.cancellations.append(params.session_id) async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse: return SetSessionModeResponse() @@ -213,29 +233,31 @@ async def test_initialize_and_new_session(): agent = TestAgent() client = TestClient() # server side is agent; client side is client - agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - _client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) + agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) + _client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) - resp = await agent_conn.initialize(InitializeRequest(protocolVersion=1)) + resp = await agent_conn.initialize(InitializeRequest(protocol_version=1)) assert isinstance(resp, InitializeResponse) - assert resp.protocolVersion == 1 + assert resp.protocol_version == 1 - new_sess = await agent_conn.newSession(NewSessionRequest(mcpServers=[], cwd="/test")) - assert new_sess.sessionId == "test-session-123" + new_sess = await agent_conn.newSession(NewSessionRequest(mcp_servers=[], cwd="/test")) + assert new_sess.session_id == "test-session-123" load_resp = await agent_conn.loadSession( - LoadSessionRequest(sessionId=new_sess.sessionId, cwd="/test", mcpServers=[]) + LoadSessionRequest(session_id=new_sess.session_id, cwd="/test", mcp_servers=[]) ) assert isinstance(load_resp, LoadSessionResponse) - auth_resp = await agent_conn.authenticate(AuthenticateRequest(methodId="password")) + auth_resp = await agent_conn.authenticate(AuthenticateRequest(method_id="password")) assert isinstance(auth_resp, AuthenticateResponse) - mode_resp = await agent_conn.setSessionMode(SetSessionModeRequest(sessionId=new_sess.sessionId, modeId="ask")) + mode_resp = await agent_conn.setSessionMode( + SetSessionModeRequest(session_id=new_sess.session_id, mode_id="ask") + ) assert isinstance(mode_resp, SetSessionModeResponse) model_resp = await agent_conn.setSessionModel( - SetSessionModelRequest(sessionId=new_sess.sessionId, modelId="gpt-4o") + SetSessionModelRequest(session_id=new_sess.session_id, model_id="gpt-4o") ) assert isinstance(model_resp, SetSessionModelResponse) @@ -246,16 +268,16 @@ async def test_bidirectional_file_ops(): agent = TestAgent() client = TestClient() client.files["/test/file.txt"] = "Hello, World!" - _agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) + _agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) + client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # Agent asks client to read - res = await client_conn.readTextFile(ReadTextFileRequest(sessionId="sess", path="/test/file.txt")) + res = await client_conn.readTextFile(ReadTextFileRequest(session_id="sess", path="/test/file.txt")) assert res.content == "Hello, World!" # Agent asks client to write write_result = await client_conn.writeTextFile( - WriteTextFileRequest(sessionId="sess", path="/test/file.txt", content="Updated") + WriteTextFileRequest(session_id="sess", path="/test/file.txt", content="Updated") ) assert isinstance(write_result, WriteTextFileResponse) assert client.files["/test/file.txt"] == "Updated" @@ -267,11 +289,11 @@ async def test_cancel_notification_and_capture_wire(): # Build only agent-side (server) connection. Client side: raw reader to inspect wire agent = TestAgent() client = TestClient() - agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - _client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) + agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) + _client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # Send cancel notification from client-side connection to agent - await agent_conn.cancel(CancelNotification(sessionId="test-123")) + await agent_conn.cancel(CancelNotification(session_id="test-123")) # Read raw line from server peer (it will be consumed by agent receive loop quickly). # Instead, wait a brief moment and assert agent recorded it. @@ -287,24 +309,24 @@ async def test_session_notifications_flow(): async with _Server() as s: agent = TestAgent() client = TestClient() - _agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) + _agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) + client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # Agent -> Client notifications await client_conn.sessionUpdate( SessionNotification( - sessionId="sess", + session_id="sess", update=AgentMessageChunk( - sessionUpdate="agent_message_chunk", + session_update="agent_message_chunk", content=TextContentBlock(type="text", text="Hello"), ), ) ) await client_conn.sessionUpdate( SessionNotification( - sessionId="sess", + session_id="sess", update=UserMessageChunk( - sessionUpdate="user_message_chunk", + session_update="user_message_chunk", content=TextContentBlock(type="text", text="World"), ), ) @@ -316,7 +338,7 @@ async def test_session_notifications_flow(): break await asyncio.sleep(0.01) assert len(client.notifications) >= 2 - assert client.notifications[0].sessionId == "sess" + assert client.notifications[0].session_id == "sess" @pytest.mark.asyncio @@ -326,11 +348,11 @@ async def test_concurrent_reads(): client = TestClient() for i in range(5): client.files[f"/test/file{i}.txt"] = f"Content {i}" - _agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) + _agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) + client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) async def read_one(i: int): - return await client_conn.readTextFile(ReadTextFileRequest(sessionId="sess", path=f"/test/file{i}.txt")) + return await client_conn.readTextFile(ReadTextFileRequest(session_id="sess", path=f"/test/file{i}.txt")) results = await asyncio.gather(*(read_one(i) for i in range(5))) for i, res in enumerate(results): @@ -342,7 +364,7 @@ async def test_invalid_params_results_in_error_response(): async with _Server() as s: # Only start agent-side (server) so we can inject raw request from client socket agent = TestAgent() - _server_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) + _server_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # Send initialize with wrong param type (protocolVersion should be int) req = {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "oops"}} @@ -361,7 +383,7 @@ async def test_invalid_params_results_in_error_response(): async def test_method_not_found_results_in_error_response(): async with _Server() as s: agent = TestAgent() - _server_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) + _server_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) req = {"jsonrpc": "2.0", "id": 2, "method": "unknown/method", "params": {}} s.client_writer.write((json.dumps(req) + "\n").encode()) @@ -378,14 +400,14 @@ async def test_set_session_mode_and_extensions(): async with _Server() as s: agent = TestAgent() client = TestClient() - agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) + agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) + client_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # setSessionMode - resp = await agent_conn.setSessionMode(SetSessionModeRequest(sessionId="sess", modeId="yolo")) + resp = await agent_conn.setSessionMode(SetSessionModeRequest(session_id="sess", mode_id="yolo")) assert isinstance(resp, SetSessionModeResponse) - model_resp = await agent_conn.setSessionModel(SetSessionModelRequest(sessionId="sess", modelId="gpt-4o-mini")) + model_resp = await agent_conn.setSessionModel(SetSessionModelRequest(session_id="sess", model_id="gpt-4o-mini")) assert isinstance(model_resp, SetSessionModelResponse) # extMethod @@ -408,7 +430,7 @@ async def test_set_session_mode_and_extensions(): async def test_ignore_invalid_messages(): async with _Server() as s: agent = TestAgent() - _server_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader) + _server_conn = AgentSideConnection(lambda _conn: agent, s._server_writer, s._server_reader) # Message without id and method msg1 = {"jsonrpc": "2.0"} @@ -438,10 +460,10 @@ def bind(self, conn: AgentSideConnection) -> "_ExampleAgent": return self async def initialize(self, params: InitializeRequest) -> InitializeResponse: - return InitializeResponse(protocolVersion=params.protocolVersion) + return InitializeResponse(protocol_version=params.protocol_version) async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: - return NewSessionResponse(sessionId="sess_demo") + return NewSessionResponse(session_id="sess_demo") async def prompt(self, params: PromptRequest) -> PromptResponse: assert self._conn is not None @@ -449,14 +471,14 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: await self._conn.sessionUpdate( session_notification( - params.sessionId, + params.session_id, update_agent_message_text("I'll help you with that."), ) ) await self._conn.sessionUpdate( session_notification( - params.sessionId, + params.session_id, start_tool_call( "call_1", "Modifying configuration", @@ -469,27 +491,27 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: ) permission_request = RequestPermissionRequest( - sessionId=params.sessionId, - toolCall=ToolCall( - toolCallId="call_1", + session_id=params.session_id, + tool_call=ToolCall( + tool_call_id="call_1", title="Modifying configuration", kind="edit", status="pending", locations=[ToolCallLocation(path="/project/config.json")], - rawInput={"path": "/project/config.json"}, + raw_input={"path": "/project/config.json"}, ), options=[ - PermissionOption(kind="allow_once", name="Allow", optionId="allow"), - PermissionOption(kind="reject_once", name="Reject", optionId="reject"), + PermissionOption(kind="allow_once", name="Allow", option_id="allow"), + PermissionOption(kind="reject_once", name="Reject", option_id="reject"), ], ) response = await self._conn.requestPermission(permission_request) self.permission_response = response - if isinstance(response.outcome, AllowedOutcome) and response.outcome.optionId == "allow": + if isinstance(response.outcome, AllowedOutcome) and response.outcome.option_id == "allow": await self._conn.sessionUpdate( session_notification( - params.sessionId, + params.session_id, update_tool_call( "call_1", status="completed", @@ -499,12 +521,12 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: ) await self._conn.sessionUpdate( session_notification( - params.sessionId, + params.session_id, update_agent_message_text("Done."), ) ) - return PromptResponse(stopReason="end_turn") + return PromptResponse(stop_reason="end_turn") class _ExampleClient(TestClient): @@ -519,7 +541,7 @@ async def requestPermission(self, params: RequestPermissionRequest) -> RequestPe if not params.options: return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) option = params.options[0] - return RequestPermissionResponse(outcome=AllowedOutcome(optionId=option.optionId, outcome="selected")) + return RequestPermissionResponse(outcome=AllowedOutcome(option_id=option.option_id, outcome="selected")) @pytest.mark.asyncio @@ -528,29 +550,27 @@ async def test_example_agent_permission_flow(): agent = _ExampleAgent() client = _ExampleClient() - agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) - AgentSideConnection(lambda conn: agent.bind(conn), s.server_writer, s.server_reader) - - init = await agent_conn.initialize(InitializeRequest(protocolVersion=1)) - assert init.protocolVersion == 1 + agent_conn = ClientSideConnection(lambda _conn: client, s._client_writer, s._client_reader) + AgentSideConnection(lambda conn: agent.bind(conn), s._server_writer, s._server_reader) - session = await agent_conn.newSession(NewSessionRequest(mcpServers=[], cwd="/workspace")) - assert session.sessionId == "sess_demo" + init = await agent_conn.initialize(InitializeRequest(protocol_version=1)) + assert init.protocol_version == 1 + session = await agent_conn.newSession(NewSessionRequest(mcp_servers=[], cwd="/workspace")) + assert session.session_id == "sess_demo" prompt = PromptRequest( - sessionId=session.sessionId, + session_id=session.session_id, prompt=[TextContentBlock(type="text", text="Please edit config")], ) resp = await agent_conn.prompt(prompt) - assert resp.stopReason == "end_turn" - + assert resp.stop_reason == "end_turn" for _ in range(50): if len(client.notifications) >= 4: break await asyncio.sleep(0.02) assert len(client.notifications) >= 4 - session_updates = [getattr(note.update, "sessionUpdate", None) for note in client.notifications] + session_updates = [getattr(note.update, "session_update", None) for note in client.notifications] assert session_updates[:4] == ["agent_message_chunk", "tool_call", "tool_call_update", "agent_message_chunk"] first_message = client.notifications[0].update @@ -566,7 +586,7 @@ async def test_example_agent_permission_flow(): tool_update = client.notifications[2].update assert isinstance(tool_update, ToolCallProgress) assert tool_update.status == "completed" - assert tool_update.rawOutput == {"success": True} + assert tool_update.raw_output == {"success": True} final_message = client.notifications[3].update assert isinstance(final_message, AgentMessageChunk) @@ -575,11 +595,11 @@ async def test_example_agent_permission_flow(): assert len(client.permission_requests) == 1 options = client.permission_requests[0].options - assert [opt.optionId for opt in options] == ["allow", "reject"] + assert [opt.option_id for opt in options] == ["allow", "reject"] assert agent.permission_response is not None assert isinstance(agent.permission_response.outcome, AllowedOutcome) - assert agent.permission_response.outcome.optionId == "allow" + assert agent.permission_response.outcome.option_id == "allow" @pytest.mark.asyncio @@ -590,11 +610,11 @@ async def test_spawn_agent_process_roundtrip(tmp_path): test_client = TestClient() async with spawn_agent_process(lambda _agent: test_client, sys.executable, str(script)) as (client_conn, process): - init = await client_conn.initialize(InitializeRequest(protocolVersion=1)) + init = await client_conn.initialize(InitializeRequest(protocol_version=1)) assert isinstance(init, InitializeResponse) - session = await client_conn.newSession(NewSessionRequest(cwd=str(tmp_path), mcpServers=[])) + session = await client_conn.newSession(NewSessionRequest(cwd=str(tmp_path), mcp_servers=[])) prompt = PromptRequest( - sessionId=session.sessionId, + session_id=session.session_id, prompt=[TextContentBlock(type="text", text="hi spawn")], ) await client_conn.prompt(prompt) From a35648c400f3ede6cd8b5df08f9bdb034df13a46 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Wed, 19 Nov 2025 11:27:29 +0800 Subject: [PATCH 2/2] fix lint errors Signed-off-by: Frost Ming --- src/acp/agent/connection.py | 2 +- src/acp/contrib/permissions.py | 10 ++++---- src/acp/contrib/tool_calls.py | 18 ++++++------- src/acp/helpers.py | 46 +++++++++++++++++----------------- tests/test_utils.py | 6 ++--- 5 files changed, 41 insertions(+), 41 deletions(-) diff --git a/src/acp/agent/connection.py b/src/acp/agent/connection.py index eab6766..5f992b1 100644 --- a/src/acp/agent/connection.py +++ b/src/acp/agent/connection.py @@ -97,7 +97,7 @@ async def createTerminal(self, params: CreateTerminalRequest) -> TerminalHandle: params, CreateTerminalResponse, ) - return TerminalHandle(create_response.terminalId, params.sessionId, self._conn) + return TerminalHandle(create_response.terminal_id, params.session_id, self._conn) async def terminalOutput(self, params: TerminalOutputRequest) -> TerminalOutputResponse: return await request_model( diff --git a/src/acp/contrib/permissions.py b/src/acp/contrib/permissions.py index 5008092..462be9d 100644 --- a/src/acp/contrib/permissions.py +++ b/src/acp/contrib/permissions.py @@ -29,9 +29,9 @@ def __init__(self) -> None: def default_permission_options() -> tuple[PermissionOption, PermissionOption, PermissionOption]: """Return a standard approval/reject option set.""" return ( - PermissionOption(optionId="approve", name="Approve", kind="allow_once"), - PermissionOption(optionId="approve_for_session", name="Approve for session", kind="allow_always"), - PermissionOption(optionId="reject", name="Reject", kind="reject_once"), + PermissionOption(option_id="approve", name="Approve", kind="allow_once"), + PermissionOption(option_id="approve_for_session", name="Approve for session", kind="allow_always"), + PermissionOption(option_id="reject", name="Reject", kind="reject_once"), ) @@ -83,8 +83,8 @@ async def request_for( raise MissingPermissionOptionsError() request = RequestPermissionRequest( - sessionId=self._session_id, - toolCall=tool_call, + session_id=self._session_id, + tool_call=tool_call, options=list(option_set), ) return await self._requester(request) diff --git a/src/acp/contrib/tool_calls.py b/src/acp/contrib/tool_calls.py index 5907485..1449d6d 100644 --- a/src/acp/contrib/tool_calls.py +++ b/src/acp/contrib/tool_calls.py @@ -93,29 +93,29 @@ def to_view(self) -> TrackedToolCallView: def to_tool_call_model(self) -> ToolCall: return ToolCall( - toolCallId=self.tool_call_id, + tool_call_id=self.tool_call_id, title=self.title, kind=self.kind, status=self.status, content=_copy_model_list(self.content), locations=_copy_model_list(self.locations), - rawInput=self.raw_input, - rawOutput=self.raw_output, + raw_input=self.raw_input, + raw_output=self.raw_output, ) def to_start_model(self) -> ToolCallStart: if self.title is None: raise _MissingToolCallTitleError() return ToolCallStart( - sessionUpdate="tool_call", - toolCallId=self.tool_call_id, + session_update="tool_call", + tool_call_id=self.tool_call_id, title=self.title, kind=self.kind, status=self.status, content=_copy_model_list(self.content), locations=_copy_model_list(self.locations), - rawInput=self.raw_input, - rawOutput=self.raw_output, + raw_input=self.raw_input, + raw_output=self.raw_output, ) def update( @@ -155,8 +155,8 @@ def update( kwargs["rawInput"] = raw_input if raw_output is not UNSET: self.raw_output = raw_output - kwargs["rawOutput"] = raw_output - return ToolCallProgress(sessionUpdate="tool_call_update", toolCallId=self.tool_call_id, **kwargs) + kwargs["raw_output"] = raw_output + return ToolCallProgress(session_update="tool_call_update", tool_call_id=self.tool_call_id, **kwargs) def append_stream_text( self, diff --git a/src/acp/helpers.py b/src/acp/helpers.py index d5a473f..701cda7 100644 --- a/src/acp/helpers.py +++ b/src/acp/helpers.py @@ -83,11 +83,11 @@ def text_block(text: str) -> TextContentBlock: def image_block(data: str, mime_type: str, *, uri: str | None = None) -> ImageContentBlock: - return ImageContentBlock(type="image", data=data, mimeType=mime_type, uri=uri) + return ImageContentBlock(type="image", data=data, mime_type=mime_type, uri=uri) def audio_block(data: str, mime_type: str) -> AudioContentBlock: - return AudioContentBlock(type="audio", data=data, mimeType=mime_type) + return AudioContentBlock(type="audio", data=data, mime_type=mime_type) def resource_link_block( @@ -103,7 +103,7 @@ def resource_link_block( type="resource_link", name=name, uri=uri, - mimeType=mime_type, + mime_type=mime_type, size=size, description=description, title=title, @@ -111,11 +111,11 @@ def resource_link_block( def embedded_text_resource(uri: str, text: str, *, mime_type: str | None = None) -> TextResourceContents: - return TextResourceContents(uri=uri, text=text, mimeType=mime_type) + return TextResourceContents(uri=uri, text=text, mime_type=mime_type) def embedded_blob_resource(uri: str, blob: str, *, mime_type: str | None = None) -> BlobResourceContents: - return BlobResourceContents(uri=uri, blob=blob, mimeType=mime_type) + return BlobResourceContents(uri=uri, blob=blob, mime_type=mime_type) def resource_block( @@ -129,11 +129,11 @@ def tool_content(block: ContentBlock) -> ContentToolCallContent: def tool_diff_content(path: str, new_text: str, old_text: str | None = None) -> FileEditToolCallContent: - return FileEditToolCallContent(type="diff", path=path, newText=new_text, oldText=old_text) + return FileEditToolCallContent(type="diff", path=path, new_text=new_text, old_text=old_text) def tool_terminal_ref(terminal_id: str) -> TerminalToolCallContent: - return TerminalToolCallContent(type="terminal", terminalId=terminal_id) + return TerminalToolCallContent(type="terminal", terminal_id=terminal_id) def plan_entry( @@ -146,11 +146,11 @@ def plan_entry( def update_plan(entries: Iterable[PlanEntry]) -> AgentPlanUpdate: - return AgentPlanUpdate(sessionUpdate="plan", entries=list(entries)) + return AgentPlanUpdate(session_update="plan", entries=list(entries)) def update_user_message(content: ContentBlock) -> UserMessageChunk: - return UserMessageChunk(sessionUpdate="user_message_chunk", content=content) + return UserMessageChunk(session_update="user_message_chunk", content=content) def update_user_message_text(text: str) -> UserMessageChunk: @@ -158,7 +158,7 @@ def update_user_message_text(text: str) -> UserMessageChunk: def update_agent_message(content: ContentBlock) -> AgentMessageChunk: - return AgentMessageChunk(sessionUpdate="agent_message_chunk", content=content) + return AgentMessageChunk(session_update="agent_message_chunk", content=content) def update_agent_message_text(text: str) -> AgentMessageChunk: @@ -166,7 +166,7 @@ def update_agent_message_text(text: str) -> AgentMessageChunk: def update_agent_thought(content: ContentBlock) -> AgentThoughtChunk: - return AgentThoughtChunk(sessionUpdate="agent_thought_chunk", content=content) + return AgentThoughtChunk(session_update="agent_thought_chunk", content=content) def update_agent_thought_text(text: str) -> AgentThoughtChunk: @@ -175,17 +175,17 @@ def update_agent_thought_text(text: str) -> AgentThoughtChunk: def update_available_commands(commands: Iterable[AvailableCommand]) -> AvailableCommandsUpdate: return AvailableCommandsUpdate( - sessionUpdate="available_commands_update", - availableCommands=list(commands), + session_update="available_commands_update", + available_commands=list(commands), ) def update_current_mode(current_mode_id: str) -> CurrentModeUpdate: - return CurrentModeUpdate(sessionUpdate="current_mode_update", currentModeId=current_mode_id) + return CurrentModeUpdate(session_update="current_mode_update", current_mode_id=current_mode_id) def session_notification(session_id: str, update: SessionUpdate) -> SessionNotification: - return SessionNotification(sessionId=session_id, update=update) + return SessionNotification(session_id=session_id, update=update) def start_tool_call( @@ -200,15 +200,15 @@ def start_tool_call( raw_output: Any | None = None, ) -> ToolCallStart: return ToolCallStart( - sessionUpdate="tool_call", - toolCallId=tool_call_id, + session_update="tool_call", + tool_call_id=tool_call_id, title=title, kind=kind, status=status, content=list(content) if content is not None else None, locations=list(locations) if locations is not None else None, - rawInput=raw_input, - rawOutput=raw_output, + raw_input=raw_input, + raw_output=raw_output, ) @@ -266,13 +266,13 @@ def update_tool_call( raw_output: Any | None = None, ) -> ToolCallProgress: return ToolCallProgress( - sessionUpdate="tool_call_update", - toolCallId=tool_call_id, + session_update="tool_call_update", + tool_call_id=tool_call_id, title=title, kind=kind, status=status, content=list(content) if content is not None else None, locations=list(locations) if locations is not None else None, - rawInput=raw_input, - rawOutput=raw_output, + raw_input=raw_input, + raw_output=raw_output, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index fbbf08e..4ac2a75 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,7 @@ def test_serialize_params_uses_meta_aliases() -> None: chunk = AgentMessageChunk( - sessionUpdate="agent_message_chunk", + session_update="agent_message_chunk", content=TextContentBlock(type="text", text="demo", field_meta={"inner": "value"}), field_meta={"outer": "value"}, ) @@ -17,7 +17,7 @@ def test_serialize_params_uses_meta_aliases() -> None: def test_serialize_params_omits_meta_when_absent() -> None: chunk = AgentMessageChunk( - sessionUpdate="agent_message_chunk", + session_update="agent_message_chunk", content=TextContentBlock(type="text", text="demo"), ) @@ -29,7 +29,7 @@ def test_serialize_params_omits_meta_when_absent() -> None: def test_field_meta_can_be_set_by_name_on_models() -> None: chunk = AgentMessageChunk( - sessionUpdate="agent_message_chunk", + session_update="agent_message_chunk", content=TextContentBlock(type="text", text="demo", field_meta={"inner": "value"}), field_meta={"outer": "value"}, )