diff --git a/examples/agent.py b/examples/agent.py index 1356c37..a75e1a5 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -24,69 +24,60 @@ update_agent_message, PROTOCOL_VERSION, ) -from acp.schema import AgentCapabilities, McpCapabilities, PromptCapabilities +from acp.schema import AgentCapabilities, AgentMessageChunk, Implementation class ExampleAgent(Agent): def __init__(self, conn: AgentSideConnection) -> None: self._conn = conn self._next_session_id = 0 + self._sessions: set[str] = set() - async def _send_chunk(self, session_id: str, content: Any) -> None: - await self._conn.sessionUpdate( - session_notification( - session_id, - update_agent_message(content), - ) - ) + async def _send_agent_message(self, session_id: str, content: Any) -> None: + update = content if isinstance(content, AgentMessageChunk) else update_agent_message(content) + await self._conn.sessionUpdate(session_notification(session_id, update)) async def initialize(self, params: InitializeRequest) -> InitializeResponse: # noqa: ARG002 logging.info("Received initialize request") - mcp_caps: McpCapabilities = McpCapabilities(http=False, sse=False) - prompt_caps: PromptCapabilities = PromptCapabilities(audio=False, embeddedContext=False, image=False) - agent_caps: AgentCapabilities = AgentCapabilities( - loadSession=False, - mcpCapabilities=mcp_caps, - promptCapabilities=prompt_caps, - ) return InitializeResponse( protocolVersion=PROTOCOL_VERSION, - agentCapabilities=agent_caps, + agentCapabilities=AgentCapabilities(), + agentInfo=Implementation(name="example-agent", title="Example Agent", version="0.1.0"), ) async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: # noqa: ARG002 - logging.info("Received authenticate request") + logging.info("Received authenticate request %s", params.methodId) return AuthenticateResponse() async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # noqa: ARG002 logging.info("Received new session request") session_id = str(self._next_session_id) self._next_session_id += 1 - return NewSessionResponse(sessionId=session_id) + self._sessions.add(session_id) + return NewSessionResponse(sessionId=session_id, modes=None) async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: # noqa: ARG002 - logging.info("Received load session request") + logging.info("Received load session request %s", params.sessionId) + self._sessions.add(params.sessionId) return LoadSessionResponse() async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # noqa: ARG002 - logging.info("Received set session mode request") + logging.info("Received set session mode request %s -> %s", params.sessionId, params.modeId) return SetSessionModeResponse() async def prompt(self, params: PromptRequest) -> PromptResponse: - logging.info("Received prompt request") + logging.info("Received prompt request for session %s", params.sessionId) + if params.sessionId not in self._sessions: + self._sessions.add(params.sessionId) - # Notify the client what it just sent and then echo each content block back. - await self._send_chunk( - params.sessionId, - text_block("Client sent:"), - ) + await self._send_agent_message(params.sessionId, text_block("Client sent:")) for block in params.prompt: - await self._send_chunk(params.sessionId, block) + await self._send_agent_message(params.sessionId, block) return PromptResponse(stopReason="end_turn") async def cancel(self, params: CancelNotification) -> None: # noqa: ARG002 - logging.info("Received cancel notification") + logging.info("Received cancel notification for session %s", params.sessionId) async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002 logging.info("Received extension method call: %s", method) diff --git a/examples/client.py b/examples/client.py index bdb2ae9..8c62462 100644 --- a/examples/client.py +++ b/examples/client.py @@ -17,6 +17,16 @@ text_block, PROTOCOL_VERSION, ) +from acp.schema import ( + AgentMessageChunk, + AudioContentBlock, + ClientCapabilities, + EmbeddedResourceContentBlock, + ImageContentBlock, + Implementation, + ResourceContentBlock, + TextContentBlock, +) class ExampleClient(Client): @@ -46,20 +56,24 @@ async def killTerminal(self, params): # type: ignore[override] async def sessionUpdate(self, params: SessionNotification) -> None: update = params.update - if isinstance(update, dict): - kind = update.get("sessionUpdate") - content = update.get("content") - else: - kind = getattr(update, "sessionUpdate", None) - content = getattr(update, "content", None) - - if kind != "agent_message_chunk" or content is None: + if not isinstance(update, AgentMessageChunk): return - if isinstance(content, dict): - text = content.get("text", "") + content = update.content + text: str + if isinstance(content, TextContentBlock): + text = content.text + elif isinstance(content, ImageContentBlock): + text = "" + elif isinstance(content, AudioContentBlock): + text = "