diff --git a/examples/echo_agent.py b/examples/echo_agent.py index 1bf04ff..657eb28 100644 --- a/examples/echo_agent.py +++ b/examples/echo_agent.py @@ -29,13 +29,15 @@ async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: async def prompt(self, params: PromptRequest) -> PromptResponse: for block in params.prompt: - text = getattr(block, "text", "") - await self._conn.sessionUpdate( - session_notification( - params.sessionId, - update_agent_message(text_block(text)), - ) - ) + text = block.get("text", "") if isinstance(block, dict) else getattr(block, "text", "") + chunk = update_agent_message(text_block(text)) + chunk.field_meta = {"echo": True} + chunk.content.field_meta = {"echo": True} + + notification = session_notification(params.sessionId, chunk) + notification.field_meta = {"source": "echo_agent"} + + await self._conn.sessionUpdate(notification) return PromptResponse(stopReason="end_turn") diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index c61e3c3..b95362e 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -199,6 +199,7 @@ def rename_types(output_path: Path) -> list[str]: content = _apply_field_overrides(content) content = _apply_default_overrides(content) content = _add_description_comments(content) + content = _ensure_custom_base_model(content) alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())] alias_block = BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n" @@ -220,6 +221,37 @@ def rename_types(output_path: Path) -> list[str]: return warnings +def _ensure_custom_base_model(content: str) -> str: + if "class BaseModel(_BaseModel):" in content: + return content + lines = content.splitlines() + for idx, line in enumerate(lines): + if not line.startswith("from pydantic import "): + continue + imports = [part.strip() for part in line[len("from pydantic import ") :].split(",")] + has_alias = any(part == "BaseModel as _BaseModel" for part in imports) + has_config = any(part == "ConfigDict" for part in imports) + new_imports = [] + for part in imports: + if part == "BaseModel": + new_imports.append("BaseModel as _BaseModel") + has_alias = True + else: + new_imports.append(part) + if not has_alias: + new_imports.append("BaseModel as _BaseModel") + if not has_config: + new_imports.append("ConfigDict") + lines[idx] = "from pydantic import " + ", ".join(new_imports) + insert_idx = idx + 1 + lines.insert(insert_idx, "") + lines.insert(insert_idx + 1, "class BaseModel(_BaseModel):") + lines.insert(insert_idx + 2, " model_config = ConfigDict(populate_by_name=True)") + lines.insert(insert_idx + 3, "") + break + return "\n".join(lines) + "\n" + + def _apply_field_overrides(content: str) -> str: for class_name, field_name, new_type, optional in FIELD_TYPE_OVERRIDES: if optional: diff --git a/src/acp/connection.py b/src/acp/connection.py index 0b3230e..34142d7 100644 --- a/src/acp/connection.py +++ b/src/acp/connection.py @@ -192,7 +192,12 @@ async def _run_request(self, message: dict[str, Any]) -> Any: try: result = await self._handler(method, message.get("params"), False) if isinstance(result, BaseModel): - result = result.model_dump() + result = result.model_dump( + mode="json", + by_alias=True, + exclude_none=True, + exclude_unset=True, + ) payload["result"] = result if result is not None else None await self._sender.send(payload) self._notify_observers(StreamDirection.OUTGOING, payload) diff --git a/src/acp/schema.py b/src/acp/schema.py index b65a63e..d9e2da1 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -6,8 +6,7 @@ from enum import Enum from typing import Annotated, Any, List, Literal, Optional, Union -from pydantic import BaseModel, Field, RootModel - +from pydantic import BaseModel as _BaseModel, Field, RootModel, ConfigDict PermissionOptionKind = Literal["allow_once", "allow_always", "reject_once", "reject_always"] PlanEntryPriority = Literal["high", "medium", "low"] @@ -17,6 +16,10 @@ ToolKind = Literal["read", "edit", "delete", "move", "search", "execute", "think", "fetch", "switch_mode", "other"] +class BaseModel(_BaseModel): + model_config = ConfigDict(populate_by_name=True) + + class Jsonrpc(Enum): field_2_0 = "2.0" diff --git a/src/acp/utils.py b/src/acp/utils.py index b84a5a7..e81d7ba 100644 --- a/src/acp/utils.py +++ b/src/acp/utils.py @@ -24,7 +24,7 @@ def serialize_params(params: BaseModel) -> dict[str, Any]: """Return a JSON-serializable representation used for RPC calls.""" - return params.model_dump(exclude_none=True, exclude_defaults=True) + return params.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True) def normalize_result(payload: Any) -> dict[str, Any]: diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..fbbf08e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,38 @@ +from acp.schema import AgentMessageChunk, TextContentBlock +from acp.utils import serialize_params + + +def test_serialize_params_uses_meta_aliases() -> None: + chunk = AgentMessageChunk( + sessionUpdate="agent_message_chunk", + content=TextContentBlock(type="text", text="demo", field_meta={"inner": "value"}), + field_meta={"outer": "value"}, + ) + + payload = serialize_params(chunk) + + assert payload["_meta"] == {"outer": "value"} + assert payload["content"]["_meta"] == {"inner": "value"} + + +def test_serialize_params_omits_meta_when_absent() -> None: + chunk = AgentMessageChunk( + sessionUpdate="agent_message_chunk", + content=TextContentBlock(type="text", text="demo"), + ) + + payload = serialize_params(chunk) + + assert "_meta" not in payload + assert "_meta" not in payload["content"] + + +def test_field_meta_can_be_set_by_name_on_models() -> None: + chunk = AgentMessageChunk( + sessionUpdate="agent_message_chunk", + content=TextContentBlock(type="text", text="demo", field_meta={"inner": "value"}), + field_meta={"outer": "value"}, + ) + + assert chunk.field_meta == {"outer": "value"} + assert chunk.content.field_meta == {"inner": "value"}