diff --git a/AGENTS.md b/AGENTS.md index 2c04e75..fdc7f48 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,19 +1,35 @@ # Repository Guidelines ## Project Structure & Module Organization -The package code lives under `src/acp`, exposing the high-level Agent, transport helpers, and generated protocol schema. Generated artifacts such as `schema/` and `src/acp/schema.py` are refreshed via `scripts/gen_all.py` against the upstream ACP schema. Integration examples are in `examples/`, including `echo_agent.py` and the mini SWE bridge. Tests reside in `tests/` with async fixtures and doctests; documentation sources live in `docs/` and publish via MkDocs. Built distributions drop into `dist/` after builds. +- `src/acp/`: runtime package exposing agent/client abstractions, transports, and the generated `schema.py`. +- `schema/`: upstream JSON schema sources; regenerate Python bindings with `make gen-all`. +- `examples/`: runnable scripts (`echo_agent.py`, `client.py`, `gemini.py`, etc.) demonstrating stdio orchestration patterns. +- `tests/`: pytest suite, including opt-in Gemini smoke checks under `tests/test_gemini_example.py`. +- `docs/`: MkDocs content powering the hosted documentation. ## Build, Test, and Development Commands -Run `make install` to create a `uv` managed virtualenv and install pre-commit hooks. `make check` executes lock verification, Ruff linting, `ty` static checks, and deptry analysis. `make test` calls `uv run python -m pytest --doctest-modules`. For release prep use `make build` or `make build-and-publish`. `make gen-all` regenerates protocol models; export `ACP_SCHEMA_VERSION=` beforehand to fetch a specific upstream schema (defaults to the cached copy). `make docs` serves MkDocs locally; `make docs-test` ensures clean builds. +- `make install` — provision the `uv` virtualenv and install pre-commit hooks. +- `make check` — run Ruff linting/formatting, type analysis, dependency hygiene, and lock verification. +- `make test` — execute `pytest` (with doctests) inside the managed environment. +- `make gen-all` — refresh protocol artifacts when the ACP schema version advances (`ACP_SCHEMA_VERSION=` to pin an upstream tag). ## Coding Style & Naming Conventions -Target Python 3.10+ with type hints and 120-character lines enforced by Ruff (`pyproject.toml`). Prefer dataclasses/pydantic models from the schema modules rather than bare dicts. Tests may ignore security lint (see per-file ignores) but still follow snake_case names. Keep public API modules under `acp/*` lean; place utilities in internal `_`-prefixed modules when needed. +- Target Python 3.10+ with four-space indentation and type hints on public APIs. +- Ruff enforces formatting and lint rules (`uv run ruff check`, `uv run ruff format`); keep both clean before publishing. +- Prefer dataclasses or generated Pydantic models from `acp.schema` over ad-hoc dicts. Place shared utilities in `_`-prefixed internal modules. +- Prefer the builders in `acp.helpers` (for example `text_block`, `start_tool_call`) when constructing ACP payloads. The helpers instantiate the generated Pydantic models for you, keep literal discriminator fields out of call sites, and stay in lockstep with the schema thanks to the golden tests (`tests/test_golden.py`). ## Testing Guidelines -Pytest is the main framework with `pytest-asyncio` for coroutine tests and doctests activated on modules. Name test files `test_*.py` and co-locate fixtures under `tests/conftest.py`. Aim to cover new protocol surfaces with integration-style tests using the async agent stubs. Generate coverage reports via `tox -e py310` when assessing CI parity. +- Tests live in `tests/` and must be named `test_*.py`. Use `pytest.mark.asyncio` for coroutine coverage. +- Run `make test` (or `uv run python -m pytest`) prior to commits; include reproducing steps for any added fixtures. +- Gemini CLI coverage is disabled by default. Set `ACP_ENABLE_GEMINI_TESTS=1` (and `ACP_GEMINI_BIN=/path/to/gemini`) to exercise `tests/test_gemini_example.py`. ## Commit & Pull Request Guidelines -Commit history follows Conventional Commits (`feat:`, `fix:`, `docs:`). Scope commits narrowly and include context on affected protocol version or tooling. PRs should describe agent behaviors exercised, link related issues, and mention schema regeneration if applicable. Attach test output (`make check` or targeted pytest) and screenshots only when UI-adjacent docs change. Update docs/examples when altering the public agent API. +- Follow Conventional Commits (`feat:`, `fix:`, `docs:`, etc.) with succinct scopes, noting schema regenerations when applicable. +- PRs should describe exercised agent behaviours, link relevant issues, and include output from `make check` or focused pytest runs. +- Update documentation and examples whenever public APIs or transport behaviours change, and call out environment prerequisites for new integrations. ## Agent Integration Tips -Leverage `examples/mini_swe_agent/` as a template when bridging other command executors. Use `AgentSideConnection` with `stdio_streams()` for ACP-compliant clients; document any extra environment variables in README updates. +- Bootstrap agents from `examples/echo_agent.py` or `examples/agent.py`; pair with `examples/client.py` for round-trip validation. +- Use `spawn_agent_process` / `spawn_client_process` to embed ACP parties directly in Python applications. +- Validate new transports against `tests/test_rpc.py` and, when applicable, the Gemini example to ensure streaming updates and permission flows stay compliant. diff --git a/README.md b/README.md index 94408fc..0ab193b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ + + Agent Client Protocol + + # Agent Client Protocol (Python) Python SDK for the Agent Client Protocol (ACP). Build agents that speak ACP over stdio so tools like Zed can orchestrate them. @@ -6,9 +10,11 @@ Python SDK for the Agent Client Protocol (ACP). Build agents that speak ACP over **Highlights** -- Typed dataclasses generated from the upstream ACP schema (`acp.schema`) -- Async agent base class plus stdio transport helpers for quick bootstrapping -- Included examples that stream content updates and tool calls end-to-end +- Generated `pydantic` models that track the upstream ACP schema (`acp.schema`) +- Async base classes and JSON-RPC plumbing that keep stdio agents tiny +- Process helpers such as `spawn_agent_process` for embedding agents and clients directly in Python +- Batteries-included examples that exercise streaming updates, file I/O, and permission flows +- Optional Gemini CLI bridge (`examples/gemini.py`) for the `gemini --experimental-acp` integration ## Install @@ -29,29 +35,75 @@ uv add agent-client-protocol Prefer a step-by-step walkthrough? Read the [Quickstart guide](docs/quickstart.md) or the hosted docs: https://psiace.github.io/agent-client-protocol-python/. +### Launching from Python + +Embed the agent inside another Python process without spawning your own pipes: + +```python +import asyncio +import sys +from pathlib import Path + +from acp import spawn_agent_process, text_block +from acp.schema import InitializeRequest, NewSessionRequest, PromptRequest + + +async def main() -> None: + agent_script = Path("examples/echo_agent.py") + async with spawn_agent_process(lambda _agent: YourClient(), sys.executable, str(agent_script)) as (conn, _proc): + await conn.initialize(InitializeRequest(protocolVersion=1)) + session = await conn.newSession(NewSessionRequest(cwd=str(agent_script.parent), mcpServers=[])) + await conn.prompt( + PromptRequest( + sessionId=session.sessionId, + prompt=[text_block("Hello!")], + ) + ) + + +asyncio.run(main()) +``` + +`spawn_client_process` mirrors this pattern for the inverse direction. + ### Minimal agent sketch ```python import asyncio -from acp import Agent, AgentSideConnection, PromptRequest, PromptResponse, SessionNotification, stdio_streams -from acp.schema import AgentMessageChunk, TextContentBlock +from acp import ( + Agent, + AgentSideConnection, + InitializeRequest, + InitializeResponse, + NewSessionRequest, + NewSessionResponse, + PromptRequest, + PromptResponse, + session_notification, + stdio_streams, + text_block, + update_agent_message, +) class EchoAgent(Agent): def __init__(self, conn): self._conn = conn + async def initialize(self, params: InitializeRequest) -> InitializeResponse: + return InitializeResponse(protocolVersion=params.protocolVersion) + + async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: + return NewSessionResponse(sessionId="sess-1") + async def prompt(self, params: PromptRequest) -> PromptResponse: for block in params.prompt: - text = getattr(block, "text", "") + text = block.get("text", "") if isinstance(block, dict) else getattr(block, "text", "") await self._conn.sessionUpdate( - SessionNotification( - sessionId=params.sessionId, - update=AgentMessageChunk( - sessionUpdate="agent_message_chunk", - content=TextContentBlock(type="text", text=text), - ), + session_notification( + params.sessionId, + update_agent_message(text_block(text)), ) ) return PromptResponse(stopReason="end_turn") @@ -71,15 +123,54 @@ Full example with streaming and lifecycle hooks lives in [examples/echo_agent.py ## Examples -- `examples/mini_swe_agent`: bridges mini-swe-agent into ACP, including a duet launcher and Textual TUI client -- Additional transport helpers are documented in the [Mini SWE guide](docs/mini-swe-agent.md) +- `examples/echo_agent.py`: the canonical streaming agent with lifecycle hooks +- `examples/client.py`: interactive console client that can launch any ACP agent via stdio +- `examples/agent.py`: richer agent showcasing initialization, authentication, and chunked updates +- `examples/duet.py`: launches both example agent and client using `spawn_agent_process` +- `examples/gemini.py`: connects to the Gemini CLI in `--experimental-acp` mode, with optional auto-approval and sandbox flags + +## Helper APIs + +Use `acp.helpers` to build protocol payloads without manually shaping dictionaries: + +```python +from acp import start_tool_call, text_block, tool_content, update_tool_call + +start = start_tool_call("call-1", "Inspect config", kind="read", status="pending") +update = update_tool_call( + "call-1", + status="completed", + content=[tool_content(text_block("Inspection finished."))], +) +``` + +Helpers cover content blocks (`text_block`, `resource_link_block`), embedded resources, tool calls (`start_edit_tool_call`, `update_tool_call`), and session updates (`update_agent_message_text`, `session_notification`). ## Documentation - Project docs (MkDocs): https://psiace.github.io/agent-client-protocol-python/ - Local sources: `docs/` - [Quickstart](docs/quickstart.md) - - [Mini SWE Agent bridge](docs/mini-swe-agent.md) + - [Releasing](docs/releasing.md) + +## Gemini CLI bridge + +Want to exercise the `gemini` CLI over ACP? The repository includes a Python replica of the Go SDK's REPL: + +```bash +python examples/gemini.py --yolo # auto-approve permissions +python examples/gemini.py --sandbox --model gemini-2.5-pro +``` + +Defaults assume the CLI is discoverable via `PATH`; override with `--gemini` or `ACP_GEMINI_BIN=/path/to/gemini`. + +The smoke test (`tests/test_gemini_example.py`) is opt-in to avoid false negatives when the CLI is unavailable or lacks credentials. Enable it locally with: + +```bash +ACP_ENABLE_GEMINI_TESTS=1 ACP_GEMINI_BIN=/path/to/gemini uv run python -m pytest tests/test_gemini_example.py +``` + +The test gracefully skips when authentication prompts (e.g. missing `GOOGLE_CLOUD_PROJECT`) block the interaction. ## Development workflow diff --git a/docs/index.md b/docs/index.md index 77f022e..5a16958 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,12 +1,18 @@ + + Agent Client Protocol + + # Agent Client Protocol SDK (Python) Welcome to the Python SDK for the Agent Client Protocol (ACP). The package ships ready-to-use transports, typed protocol models, and examples that stream messages to ACP-aware clients such as Zed. ## What you get -- Fully typed dataclasses generated from the upstream ACP schema (`acp.schema`) -- Async agent base class and stdio helpers to spin up an agent in a few lines -- Examples that demonstrate streaming updates and tool execution over ACP +- Pydantic models generated from the upstream ACP schema (`acp.schema`) +- Async agent/client wrappers with JSON-RPC task supervision built in +- Process helpers (`spawn_agent_process`, `spawn_client_process`) for embedding ACP nodes inside Python applications +- Helper APIs in `acp.helpers` that mirror the Go/TS SDK builders for content blocks, tool calls, and session updates. They instantiate the generated Pydantic types for you, so call sites stay concise without sacrificing validation. +- Examples that showcase streaming updates, file operations, permission flows, and even a Gemini CLI bridge (`examples/gemini.py`) ## Getting started @@ -20,11 +26,27 @@ Welcome to the Python SDK for the Agent Client Protocol (ACP). The package ships ``` 3. Point your ACP-capable client at the running process (for Zed, configure an Agent Server entry). The SDK takes care of JSON-RPC framing and lifecycle transitions. -Prefer a guided tour? Head to the [Quickstart](quickstart.md) for step-by-step instructions, including how to run the agent from an editor or terminal. +Prefer a guided tour? Head to the [Quickstart](quickstart.md) for terminal, editor, and programmatic launch walkthroughs. + +## Gemini CLI bridge + +If you have access to the Gemini CLI (`gemini --experimental-acp`), run: + +```bash +python examples/gemini.py --yolo +``` + +Flags mirror the Go SDK example: + +- `--gemini /path/to/cli` or `ACP_GEMINI_BIN` to override discovery +- `--model`, `--sandbox`, `--debug` forwarded verbatim +- `--yolo` auto-approves permission prompts with sensible defaults + +An opt-in smoke test lives at `tests/test_gemini_example.py`. Enable it with `ACP_ENABLE_GEMINI_TESTS=1` (and optionally `ACP_GEMINI_TEST_ARGS`) when the CLI is authenticated; otherwise the test stays skipped. ## Documentation map -- [Quickstart](quickstart.md): install, run, and extend the echo agent -- [Mini SWE Agent guide](mini-swe-agent.md): bridge mini-swe-agent over ACP, including duet launcher and Textual client +- [Quickstart](quickstart.md): install, run, and embed the echo agent, plus next steps for extending it +- [Releasing](releasing.md): schema upgrade workflow, version bumps, and publishing checklist Source code lives under `src/acp/`, while tests and additional examples are available in `tests/` and `examples/`. If you plan to contribute, see the repository README for the development workflow. diff --git a/docs/mini-swe-agent.md b/docs/mini-swe-agent.md deleted file mode 100644 index 70df79d..0000000 --- a/docs/mini-swe-agent.md +++ /dev/null @@ -1,54 +0,0 @@ -# Mini SWE Agent bridge - -This example wraps mini-swe-agent behind ACP so editors such as Zed can interact with it over stdio. A duet launcher is included to run a local Textual client beside the bridge for quick experimentation. - -## Overview - -- Accepts ACP prompts, concatenates text blocks, and forwards them to mini-swe-agent -- Streams language-model output via `session/update` → `agent_message_chunk` -- Emits `tool_call` / `tool_call_update` pairs for shell execution, including stdout and return codes -- Sends a final `agent_message_chunk` when mini-swe-agent prints `COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT` - -## Requirements - -- Python environment with `mini-swe-agent` installed (`pip install mini-swe-agent`) -- ACP-capable client (e.g. Zed) or the bundled Textual client -- Optional: `.env` file at the repo root for shared configuration when using the duet launcher - -If `mini-swe-agent` is missing, the bridge falls back to the reference copy at `reference/mini-swe-agent/src`. - -## Configure models and credentials - -Set environment variables before launching the bridge: - -- `MINI_SWE_MODEL`: model identifier such as `openrouter/openai/gpt-4o-mini` -- `OPENROUTER_API_KEY` for OpenRouter models, or `OPENAI_API_KEY` / `ANTHROPIC_API_KEY` for native providers -- Optional `MINI_SWE_MODEL_KWARGS`: JSON blob of extra keyword arguments (OpenRouter defaults are injected automatically when omitted) - -The bridge selects the correct API key based on the chosen model and available variables. - -## Run inside Zed - -Add an Agent Server entry targeting `examples/mini_swe_agent/agent.py` and provide the environment variables there. Use Zed’s “Open ACP Logs” panel to observe streamed message chunks and tool call events in real time. - -## Run locally with the duet launcher - -To pair the bridge with the Textual TUI client, run: - -```bash -python examples/mini_swe_agent/duet.py -``` - -Both processes inherit settings from `.env` (thanks to `python-dotenv`) and communicate over dedicated pipes. - -**TUI shortcuts** -- `y`: YOLO -- `c`: Confirm -- `u`: Human (prompts for a shell command and streams it back as a tool call) -- `Enter`: Continue - -## Related files - -- Agent entrypoint: `examples/mini_swe_agent/agent.py` -- Duet launcher: `examples/mini_swe_agent/duet.py` -- Textual client: `examples/mini_swe_agent/client.py` diff --git a/docs/quickstart.md b/docs/quickstart.md index a840e37..03c2dee 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -15,17 +15,17 @@ pip install agent-client-protocol uv add agent-client-protocol ``` -## 2. Run the echo agent +## 2. Launch the Echo agent (terminal) -Launch the ready-made echo example, which streams text blocks back over ACP: +Start the ready-made echo example — it streams text blocks back to any ACP client: ```bash python examples/echo_agent.py ``` -Keep it running while you connect your client. +Leave this process running while you connect from an editor or another program. -## 3. Connect from your client +## 3. Connect from an editor ### Zed @@ -50,6 +50,43 @@ Open the Agents panel and start the session. Each message you send should be ech Any ACP client that communicates over stdio can spawn the same script; no additional transport configuration is required. +### Programmatic launch + +```python +import asyncio +import sys +from pathlib import Path + +from acp import spawn_agent_process, text_block +from acp.interfaces import Client +from acp.schema import InitializeRequest, NewSessionRequest, PromptRequest, SessionNotification + + +class SimpleClient(Client): + async def requestPermission(self, params): # pragma: no cover - minimal stub + return {"outcome": {"outcome": "cancelled"}} + + async def sessionUpdate(self, params: SessionNotification) -> None: + print("update:", params.sessionId, params.update) + + +async def main() -> None: + script = Path("examples/echo_agent.py") + async with spawn_agent_process(lambda _agent: SimpleClient(), sys.executable, str(script)) as (conn, _proc): + await conn.initialize(InitializeRequest(protocolVersion=1)) + session = await conn.newSession(NewSessionRequest(cwd=str(script.parent), mcpServers=[])) + await conn.prompt( + PromptRequest( + sessionId=session.sessionId, + prompt=[text_block("Hello from spawn!")], + ) + ) + +asyncio.run(main()) +``` + +`spawn_agent_process` manages the child process, wires its stdio into ACP framing, and closes everything when the block exits. The mirror helper `spawn_client_process` lets you drive an ACP client from Python as well. + ## 4. Extend the agent Create your own agent by subclassing `acp.Agent`. The pattern mirrors the echo example: @@ -64,15 +101,41 @@ class MyAgent(Agent): return PromptResponse(stopReason="end_turn") ``` -Hook it up with `AgentSideConnection` inside an async entrypoint and wire it to your client. Refer to [examples/echo_agent.py](https://github.com/psiace/agent-client-protocol-python/blob/main/examples/echo_agent.py) for the complete structure, including lifetime hooks (`initialize`, `newSession`) and streaming responses. +Hook it up with `AgentSideConnection` inside an async entrypoint and wire it to your client. Refer to: + +- [`examples/echo_agent.py`](https://github.com/psiace/agent-client-protocol-python/blob/main/examples/echo_agent.py) for the smallest streaming agent +- [`examples/agent.py`](https://github.com/psiace/agent-client-protocol-python/blob/main/examples/agent.py) for an implementation that negotiates capabilities and streams richer updates +- [`examples/duet.py`](https://github.com/psiace/agent-client-protocol-python/blob/main/examples/duet.py) to see `spawn_agent_process` in action alongside the interactive client +- [`examples/gemini.py`](https://github.com/psiace/agent-client-protocol-python/blob/main/examples/gemini.py) to drive the Gemini CLI (`--experimental-acp`) directly from Python + +Need builders for common payloads? `acp.helpers` mirrors the Go/TS helper APIs: + +```python +from acp import start_tool_call, update_tool_call, text_block, tool_content + +start_update = start_tool_call("call-42", "Open file", kind="read", status="pending") +finish_update = update_tool_call( + "call-42", + status="completed", + content=[tool_content(text_block("File opened."))], +) +``` + +Each helper wraps the generated Pydantic models in `acp.schema`, so the right discriminator fields (`type`, `sessionUpdate`, and friends) are always populated. That keeps examples readable while maintaining the same validation guarantees as constructing the models directly. Golden fixtures in `tests/test_golden.py` ensure the helpers stay in sync with future schema revisions. + +## 5. Optional: Talk to the Gemini CLI + +If you have the Gemini CLI installed and authenticated: + +```bash +python examples/gemini.py --yolo # auto-approve permission prompts +python examples/gemini.py --sandbox --model gemini-1.5-pro +``` -## Optional: Mini SWE Agent bridge +Environment helpers: -The repository also ships a bridge for [mini-swe-agent](https://github.com/groundx-ai/mini-swe-agent). To try it: +- `ACP_GEMINI_BIN` — override the CLI path (defaults to `PATH` lookup) +- `ACP_GEMINI_TEST_ARGS` — extra flags forwarded during the smoke test +- `ACP_ENABLE_GEMINI_TESTS=1` — opt-in toggle for `tests/test_gemini_example.py` -1. Install the dependency: - ```bash - pip install mini-swe-agent - ``` -2. Configure Zed to run `examples/mini_swe_agent/agent.py` and supply environment variables such as `MINI_SWE_MODEL` and `OPENROUTER_API_KEY`. -3. Review the [Mini SWE Agent guide](mini-swe-agent.md) for environment options, tool-call mapping, and a duet launcher that starts both the bridge and a Textual client (`python examples/mini_swe_agent/duet.py`). +Authentication hiccups (e.g. missing `GOOGLE_CLOUD_PROJECT`) are surfaced but treated as skips during testing so the suite stays green on machines without credentials. diff --git a/docs/releasing.md b/docs/releasing.md new file mode 100644 index 0000000..d6b448b --- /dev/null +++ b/docs/releasing.md @@ -0,0 +1,57 @@ +# Releasing + +This project tracks the ACP schema tags published by +[`agentclientprotocol/agent-client-protocol`](https://github.com/agentclientprotocol/agent-client-protocol). +Every release should line up with one of those tags so that the generated `acp.schema` module, examples, and package +version remain consistent. + +## Preparation + +1. Pick the target schema tag (for example `v0.4.5`) and regenerate the protocol bindings: + + ```bash + ACP_SCHEMA_VERSION=v0.4.5 make gen-all + ``` + + This downloads the upstream schema package and rewrites `schema/` plus the generated `src/acp/schema.py`. + +2. Bump the project version in `pyproject.toml`, updating `uv.lock` if dependencies changed. + +3. Run the standard checks: + + ```bash + make check + make test + ``` + + `make check` covers Ruff formatting/linting, static analysis, and dependency hygiene. + `make test` executes pytest (including doctests). + +4. Refresh documentation and examples (for instance the Gemini walkthrough) so they match the new schema behaviour. + +## Commit & Merge + +1. Make sure the diff only includes the expected artifacts: regenerated schema sources, `src/acp/schema.py`, version bumps, and doc updates. +2. Commit with a Conventional Commit message (for example `release: v0.4.5`) and note in the PR: + - The ACP schema tag you targeted + - Results from `make check` / `make test` + - Any behavioural or API changes worth highlighting +3. Merge once the review is approved. + +## Publish via GitHub Release + +Publishing is automated through `on-release-main.yml`. After the release PR merges to `main`: + +1. Draft a GitHub Release for the new tag (e.g. `v0.4.5`). If the tag is missing, the release UI will create it. +2. Once published, the workflow will: + - Write the tag back into `pyproject.toml` to keep the package version aligned + - Build and publish to PyPI via `uv publish` (using the `PYPI_TOKEN` secret) + - Deploy updated documentation with `mkdocs gh-deploy` + +No local `uv build`/`uv publish` runs are required—focus on providing a complete release summary (highlights, compatibility notes, etc.). + +## Additional Notes + +- Breaking schema updates often require refreshing golden fixtures (`tests/test_golden.py`), end-to-end cases such as `tests/test_rpc.py`, and any affected examples. +- Use `make clean` to remove generated artifacts if you need a fresh baseline before re-running `make gen-all`. +- Run optional checks like the Gemini smoke test (`ACP_ENABLE_GEMINI_TESTS=1`) whenever the environment is available to catch regressions before publishing. diff --git a/examples/agent.py b/examples/agent.py index c398cca..1356c37 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -18,17 +18,13 @@ PromptResponse, SetSessionModeRequest, SetSessionModeResponse, + session_notification, stdio_streams, + text_block, + update_agent_message, PROTOCOL_VERSION, ) -from acp.schema import ( - AgentCapabilities, - AgentMessageChunk, - McpCapabilities, - PromptCapabilities, - SessionNotification, - TextContentBlock, -) +from acp.schema import AgentCapabilities, McpCapabilities, PromptCapabilities class ExampleAgent(Agent): @@ -38,24 +34,24 @@ def __init__(self, conn: AgentSideConnection) -> None: async def _send_chunk(self, session_id: str, content: Any) -> None: await self._conn.sessionUpdate( - SessionNotification( - sessionId=session_id, - update=AgentMessageChunk( - sessionUpdate="agent_message_chunk", - content=content, - ), + session_notification( + session_id, + update_agent_message(content), ) ) async def initialize(self, params: InitializeRequest) -> InitializeResponse: # noqa: ARG002 logging.info("Received initialize request") + mcp_caps: McpCapabilities = McpCapabilities(http=False, sse=False) + prompt_caps: PromptCapabilities = PromptCapabilities(audio=False, embeddedContext=False, image=False) + agent_caps: AgentCapabilities = AgentCapabilities( + loadSession=False, + mcpCapabilities=mcp_caps, + promptCapabilities=prompt_caps, + ) return InitializeResponse( protocolVersion=PROTOCOL_VERSION, - agentCapabilities=AgentCapabilities( - loadSession=False, - mcpCapabilities=McpCapabilities(http=False, sse=False), - promptCapabilities=PromptCapabilities(audio=False, embeddedContext=False, image=False), - ), + agentCapabilities=agent_caps, ) async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: # noqa: ARG002 @@ -82,7 +78,7 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: # Notify the client what it just sent and then echo each content block back. await self._send_chunk( params.sessionId, - TextContentBlock(type="text", text="Client sent:"), + text_block("Client sent:"), ) for block in params.prompt: await self._send_chunk(params.sessionId, block) diff --git a/examples/client.py b/examples/client.py index 6cde6a8..bdb2ae9 100644 --- a/examples/client.py +++ b/examples/client.py @@ -1,4 +1,5 @@ import asyncio +import asyncio.subprocess as aio_subprocess import contextlib import logging import os @@ -13,9 +14,9 @@ PromptRequest, RequestError, SessionNotification, + text_block, PROTOCOL_VERSION, ) -from acp.schema import TextContentBlock class ExampleClient(Client): @@ -90,7 +91,7 @@ async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None: await conn.prompt( PromptRequest( sessionId=session_id, - prompt=[TextContentBlock(type="text", text=line)], + prompt=[text_block(line)], ) ) except Exception as exc: # noqa: BLE001 @@ -118,8 +119,8 @@ async def main(argv: list[str]) -> int: proc = await asyncio.create_subprocess_exec( spawn_program, *spawn_args, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, + stdin=aio_subprocess.PIPE, + stdout=aio_subprocess.PIPE, ) if proc.stdin is None or proc.stdout is None: diff --git a/examples/duet.py b/examples/duet.py index 049f164..de8d9ca 100644 --- a/examples/duet.py +++ b/examples/duet.py @@ -1,27 +1,44 @@ import asyncio +import importlib.util import os import sys from pathlib import Path +def _load_client_module(path: Path): + spec = importlib.util.spec_from_file_location("examples_client", path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load client module from {path}") + module = importlib.util.module_from_spec(spec) + sys.modules.setdefault("examples_client", module) + spec.loader.exec_module(module) + return module + + +from acp import PROTOCOL_VERSION, spawn_agent_process +from acp.schema import InitializeRequest, NewSessionRequest + + async def main() -> int: root = Path(__file__).resolve().parent - agent_path = str(root / "agent.py") - client_path = str(root / "client.py") + agent_path = root / "agent.py" - # Ensure PYTHONPATH includes project src for `from acp import ...` env = os.environ.copy() src_dir = str((root.parent / "src").resolve()) env["PYTHONPATH"] = src_dir + os.pathsep + env.get("PYTHONPATH", "") - # Run the client and let it spawn the agent, wiring stdio automatically. - proc = await asyncio.create_subprocess_exec( - sys.executable, - client_path, - agent_path, - env=env, - ) - return await proc.wait() + client_module = _load_client_module(root / "client.py") + client = client_module.ExampleClient() + + async with spawn_agent_process(lambda _agent: client, sys.executable, str(agent_path), env=env) as ( + conn, + process, + ): + await conn.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION, clientCapabilities=None)) + session = await conn.newSession(NewSessionRequest(mcpServers=[], cwd=str(root))) + await client_module.interactive_loop(conn, session.sessionId) + + return process.returncode or 0 if __name__ == "__main__": diff --git a/examples/echo_agent.py b/examples/echo_agent.py index 3a7f1c9..1bf04ff 100644 --- a/examples/echo_agent.py +++ b/examples/echo_agent.py @@ -1,4 +1,5 @@ import asyncio +from uuid import uuid4 from acp import ( Agent, @@ -9,10 +10,11 @@ NewSessionResponse, PromptRequest, PromptResponse, - SessionNotification, + session_notification, stdio_streams, + text_block, + update_agent_message, ) -from acp.schema import TextContentBlock, AgentMessageChunk class EchoAgent(Agent): @@ -23,18 +25,15 @@ async def initialize(self, params: InitializeRequest) -> InitializeResponse: return InitializeResponse(protocolVersion=params.protocolVersion) async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: - return NewSessionResponse(sessionId="sess-1") + return NewSessionResponse(sessionId=uuid4().hex) async def prompt(self, params: PromptRequest) -> PromptResponse: for block in params.prompt: - text = block.get("text", "") if isinstance(block, dict) else getattr(block, "text", "") + text = getattr(block, "text", "") await self._conn.sessionUpdate( - SessionNotification( - sessionId=params.sessionId, - update=AgentMessageChunk( - sessionUpdate="agent_message_chunk", - content=TextContentBlock(type="text", text=text), - ), + session_notification( + params.sessionId, + update_agent_message(text_block(text)), ) ) return PromptResponse(stopReason="end_turn") diff --git a/examples/gemini.py b/examples/gemini.py new file mode 100644 index 0000000..f1fe9a9 --- /dev/null +++ b/examples/gemini.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +import argparse +import asyncio +import asyncio.subprocess +import contextlib +import json +import os +import shutil +import sys +from pathlib import Path +from typing import Iterable + +from acp import ( + Client, + ClientSideConnection, + PROTOCOL_VERSION, + RequestError, + text_block, +) +from acp.schema import ( + AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, + AllowedOutcome, + CancelNotification, + ClientCapabilities, + FileEditToolCallContent, + FileSystemCapability, + CreateTerminalRequest, + CreateTerminalResponse, + DeniedOutcome, + EmbeddedResourceContentBlock, + KillTerminalCommandRequest, + KillTerminalCommandResponse, + InitializeRequest, + NewSessionRequest, + PermissionOption, + PromptRequest, + ReadTextFileRequest, + ReadTextFileResponse, + RequestPermissionRequest, + RequestPermissionResponse, + ResourceContentBlock, + ReleaseTerminalRequest, + ReleaseTerminalResponse, + SessionNotification, + TerminalToolCallContent, + TerminalOutputRequest, + TerminalOutputResponse, + TextContentBlock, + ToolCallProgress, + ToolCallStart, + UserMessageChunk, + WaitForTerminalExitRequest, + WaitForTerminalExitResponse, + WriteTextFileRequest, + WriteTextFileResponse, +) + + +class GeminiClient(Client): + """Minimal client implementation that can drive the Gemini CLI over ACP.""" + + def __init__(self, auto_approve: bool) -> None: + self._auto_approve = auto_approve + + async def requestPermission( + self, + params: RequestPermissionRequest, + ) -> RequestPermissionResponse: # type: ignore[override] + if self._auto_approve: + option = _pick_preferred_option(params.options) + if option is None: + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) + return RequestPermissionResponse(outcome=AllowedOutcome(optionId=option.optionId, outcome="selected")) + + title = params.toolCall.title or "" + if not params.options: + print(f"\n🔐 Permission requested: {title} (no options, cancelling)") + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) + print(f"\n🔐 Permission requested: {title}") + for idx, opt in enumerate(params.options, start=1): + print(f" {idx}. {opt.name} ({opt.kind})") + + loop = asyncio.get_running_loop() + while True: + choice = await loop.run_in_executor(None, lambda: input("Select option: ").strip()) + if not choice: + continue + if choice.isdigit(): + idx = int(choice) - 1 + if 0 <= idx < len(params.options): + opt = params.options[idx] + return RequestPermissionResponse(outcome=AllowedOutcome(optionId=opt.optionId, outcome="selected")) + print("Invalid selection, try again.") + + async def writeTextFile( + self, + params: WriteTextFileRequest, + ) -> WriteTextFileResponse: # type: ignore[override] + path = Path(params.path) + if not path.is_absolute(): + raise RequestError.invalid_params({"path": params.path, "reason": "path must be absolute"}) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(params.content) + print(f"[Client] Wrote {path} ({len(params.content)} bytes)") + return WriteTextFileResponse() + + async def readTextFile( + self, + params: ReadTextFileRequest, + ) -> ReadTextFileResponse: # type: ignore[override] + path = Path(params.path) + if not path.is_absolute(): + raise RequestError.invalid_params({"path": params.path, "reason": "path must be absolute"}) + text = path.read_text() + print(f"[Client] Read {path} ({len(text)} bytes)") + if params.line is not None or params.limit is not None: + text = _slice_text(text, params.line, params.limit) + return ReadTextFileResponse(content=text) + + async def sessionUpdate( + self, + params: SessionNotification, + ) -> None: # type: ignore[override] + update = params.update + if isinstance(update, AgentMessageChunk): + _print_text_content(update.content) + elif isinstance(update, AgentThoughtChunk): + print("\n[agent_thought]") + _print_text_content(update.content) + elif isinstance(update, UserMessageChunk): + print("\n[user_message]") + _print_text_content(update.content) + elif isinstance(update, AgentPlanUpdate): + print("\n[plan]") + for entry in update.entries: + print(f" - {entry.status.upper():<10} {entry.content}") + elif isinstance(update, ToolCallStart): + print(f"\n🔧 {update.title} ({update.status or 'pending'})") + elif isinstance(update, ToolCallProgress): + status = update.status or "in_progress" + print(f"\n🔧 Tool call `{update.toolCallId}` → {status}") + if update.content: + for item in update.content: + if isinstance(item, FileEditToolCallContent): + print(f" diff: {item.path}") + elif isinstance(item, TerminalToolCallContent): + print(f" terminal: {item.terminalId}") + elif isinstance(item, dict): + print(f" content: {json.dumps(item, indent=2)}") + else: + print(f"\n[session update] {update}") + + # Optional / terminal-related methods --------------------------------- + async def createTerminal( + self, + params: CreateTerminalRequest, + ) -> CreateTerminalResponse: # type: ignore[override] + print(f"[Client] createTerminal: {params}") + return CreateTerminalResponse(terminalId="term-1") + + async def terminalOutput( + self, + params: TerminalOutputRequest, + ) -> TerminalOutputResponse: # type: ignore[override] + print(f"[Client] terminalOutput: {params}") + return TerminalOutputResponse(output="", truncated=False) + + async def releaseTerminal( + self, + params: ReleaseTerminalRequest, + ) -> ReleaseTerminalResponse: # type: ignore[override] + print(f"[Client] releaseTerminal: {params}") + return ReleaseTerminalResponse() + + async def waitForTerminalExit( + self, + params: WaitForTerminalExitRequest, + ) -> WaitForTerminalExitResponse: # type: ignore[override] + print(f"[Client] waitForTerminalExit: {params}") + return WaitForTerminalExitResponse() + + async def killTerminal( + self, + params: KillTerminalCommandRequest, + ) -> KillTerminalCommandResponse: # type: ignore[override] + print(f"[Client] killTerminal: {params}") + return KillTerminalCommandResponse() + + +def _pick_preferred_option(options: Iterable[PermissionOption]) -> PermissionOption | None: + best: PermissionOption | None = None + for option in options: + if option.kind in {"allow_once", "allow_always"}: + return option + best = best or option + return best + + +def _slice_text(content: str, line: int | None, limit: int | None) -> str: + lines = content.splitlines() + start = 0 + if line: + start = max(line - 1, 0) + end = len(lines) + if limit: + end = min(start + limit, end) + return "\n".join(lines[start:end]) + + +def _print_text_content(content: object) -> None: + if isinstance(content, TextContentBlock): + print(content.text) + elif isinstance(content, ResourceContentBlock): + print(f"{content.name or content.uri}") + elif isinstance(content, EmbeddedResourceContentBlock): + resource = content.resource + text = getattr(resource, "text", None) + if text: + print(text) + else: + blob = getattr(resource, "blob", None) + print(blob if blob else "") + elif isinstance(content, dict): + text = content.get("text") # type: ignore[union-attr] + if text: + print(text) + + +async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None: + print("Type a message and press Enter to send.") + print("Commands: :cancel, :exit") + + loop = asyncio.get_running_loop() + while True: + try: + line = await loop.run_in_executor(None, lambda: input("\n> ").strip()) + except (EOFError, KeyboardInterrupt): + print("\nExiting.") + break + + if not line: + continue + if line in {":exit", ":quit"}: + break + if line == ":cancel": + await conn.cancel(CancelNotification(sessionId=session_id)) + continue + + try: + await conn.prompt( + PromptRequest( + sessionId=session_id, + prompt=[text_block(line)], + ) + ) + except RequestError as err: + _print_request_error("prompt", err) + except Exception as exc: # noqa: BLE001 + print(f"Prompt failed: {exc}", file=sys.stderr) + + +def _resolve_gemini_cli(binary: str | None) -> str: + if binary: + return binary + env_value = os.environ.get("ACP_GEMINI_BIN") + if env_value: + return env_value + resolved = shutil.which("gemini") + if resolved: + return resolved + raise FileNotFoundError("Unable to locate `gemini` CLI, provide --gemini path") + + +async def run(argv: list[str]) -> int: + parser = argparse.ArgumentParser(description="Interact with the Gemini CLI over ACP.") + parser.add_argument("--gemini", help="Path to the Gemini CLI binary") + parser.add_argument("--model", help="Model identifier to pass to Gemini") + parser.add_argument("--sandbox", action="store_true", help="Enable Gemini sandbox mode") + parser.add_argument("--debug", action="store_true", help="Pass --debug to Gemini") + parser.add_argument("--yolo", action="store_true", help="Auto-approve permission prompts") + args = parser.parse_args(argv[1:]) + + try: + gemini_path = _resolve_gemini_cli(args.gemini) + except FileNotFoundError as exc: + print(exc, file=sys.stderr) + return 1 + + cmd = [gemini_path, "--experimental-acp"] + if args.model: + cmd += ["--model", args.model] + if args.sandbox: + cmd.append("--sandbox") + if args.debug: + cmd.append("--debug") + + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=None, + ) + except FileNotFoundError as exc: + print(f"Failed to start Gemini CLI: {exc}", file=sys.stderr) + return 1 + + if proc.stdin is None or proc.stdout is None: + print("Gemini process did not expose stdio pipes.", file=sys.stderr) + proc.terminate() + with contextlib.suppress(ProcessLookupError): + await proc.wait() + return 1 + + client_impl = GeminiClient(auto_approve=args.yolo) + conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout) + + try: + init_resp = await conn.initialize( + InitializeRequest( + protocolVersion=PROTOCOL_VERSION, + clientCapabilities=ClientCapabilities( + fs=FileSystemCapability(readTextFile=True, writeTextFile=True), + terminal=True, + ), + ) + ) + except RequestError as err: + _print_request_error("initialize", err) + await _shutdown(proc, conn) + return 1 + except Exception as exc: # noqa: BLE001 + print(f"initialize error: {exc}", file=sys.stderr) + await _shutdown(proc, conn) + return 1 + + print(f"✅ Connected to Gemini (protocol v{init_resp.protocolVersion})") + + try: + session = await conn.newSession( + NewSessionRequest( + cwd=os.getcwd(), + mcpServers=[], + ) + ) + except RequestError as err: + _print_request_error("new_session", err) + await _shutdown(proc, conn) + return 1 + except Exception as exc: # noqa: BLE001 + print(f"new_session error: {exc}", file=sys.stderr) + await _shutdown(proc, conn) + return 1 + + print(f"📝 Created session: {session.sessionId}") + + try: + await interactive_loop(conn, session.sessionId) + finally: + await _shutdown(proc, conn) + + return 0 + + +def _print_request_error(stage: str, err: RequestError) -> None: + payload = err.to_error_obj() + message = payload.get("message", "") + code = payload.get("code") + print(f"{stage} error ({code}): {message}", file=sys.stderr) + data = payload.get("data") + if data is not None: + try: + formatted = json.dumps(data, indent=2) + except TypeError: + formatted = str(data) + print(formatted, file=sys.stderr) + + +async def _shutdown(proc: asyncio.subprocess.Process, conn: ClientSideConnection) -> None: + with contextlib.suppress(Exception): + await conn.close() + if proc.returncode is None: + proc.terminate() + try: + await asyncio.wait_for(proc.wait(), timeout=5) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + + +def main(argv: list[str] | None = None) -> int: + args = sys.argv if argv is None else argv + return asyncio.run(run(list(args))) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/mini_swe_agent/README.md b/examples/mini_swe_agent/README.md deleted file mode 100644 index 64b445e..0000000 --- a/examples/mini_swe_agent/README.md +++ /dev/null @@ -1,67 +0,0 @@ -# Mini SWE Agent (Python) — ACP Bridge - -> Just a show of the bridge in action. Not a best-effort or absolutely-correct implementation of the agent. - -A minimal Agent Client Protocol (ACP) bridge that wraps mini-swe-agent so it can be run by Zed as an external agent over stdio, and also provides a local Textual UI client. - -## Configure in Zed (recommended for editor integration) - -Add an `agent_servers` entry to Zed’s `settings.json`. Point `command` to the Python interpreter that has both `agent-client-protocol` and `mini-swe-agent` installed, and `args` to this example script: - -```json -{ - "agent_servers": { - "Mini SWE Agent (Python)": { - "command": "/abs/path/to/python", - "args": [ - "/abs/path/to/agent-client-protocol-python/examples/mini_swe_agent/agent.py" - ], - "env": { - "MINI_SWE_MODEL": "openrouter/openai/gpt-4o-mini", - "MINI_SWE_MODEL_KWARGS": "{\"api_base\":\"https://openrouter.ai/api/v1\"}", - "OPENROUTER_API_KEY": "sk-or-..." - } - } - } -} -``` - -Notes -- If you install `agent-client-protocol` from PyPI, you do not need to set `PYTHONPATH`. -- Using OpenRouter: - - Set `MINI_SWE_MODEL` to a model supported by OpenRouter (e.g. `openrouter/openai/gpt-4o-mini`, `openrouter/anthropic/claude-3.5-sonnet`). - - Set `MINI_SWE_MODEL_KWARGS` to a JSON containing `api_base`: `{ "api_base": "https://openrouter.ai/api/v1" }`. - - Set `OPENROUTER_API_KEY` to your API key. -- Alternatively, you can use native OpenAI/Anthropic APIs. Set `MINI_SWE_MODEL` accordingly and provide the vendor-specific API key; `MINI_SWE_MODEL_KWARGS` is optional. - -## Run locally with a TUI (Textual) - -Use the duet launcher to run both the ACP agent and the local Textual client connected over dedicated pipes. The client keeps your terminal stdio; ACP messages flow over separate FDs. - -```bash -# From repo root -python examples/mini_swe_agent/duet.py -``` - -Environment -- The launcher loads `.env` from the repo root using python-dotenv (override=True) so both child processes inherit the same environment. -- Minimum for OpenRouter: - - `MINI_SWE_MODEL="openrouter/openai/gpt-4o-mini"` - - `OPENROUTER_API_KEY="sk-or-..."` - - Optional: `MINI_SWE_MODEL_KWARGS='{"api_base":"https://openrouter.ai/api/v1"}'` (auto-injected if missing) - -Quit behavior -- Quit from the TUI cleanly ends the background loop; duet will terminate both processes gracefully and force-kill after a short timeout if needed. - -## Behavior overview - -- User prompt handling - - Text blocks are concatenated into a task and passed to mini-swe-agent. -- Streaming updates - - The agent sends `session/update` with `agent_message_chunk` for incremental messages. -- Command execution visualization - - Each bash execution is reported with a `tool_call` (start) and a `tool_call_update` (complete) including command and output (`returncode` in rawOutput). -- Final result - - A final `agent_message_chunk` is sent at the end of the turn with the submitted output. - -Use Zed’s “open acp logs” command to inspect ACP traffic if needed. diff --git a/examples/mini_swe_agent/agent.py b/examples/mini_swe_agent/agent.py deleted file mode 100644 index 4b24838..0000000 --- a/examples/mini_swe_agent/agent.py +++ /dev/null @@ -1,550 +0,0 @@ -import asyncio -import os -import re -import sys -import uuid -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, Literal - -from acp import ( - Agent, - AgentSideConnection, - AuthenticateRequest, - CancelNotification, - Client, - InitializeRequest, - InitializeResponse, - NewSessionRequest, - NewSessionResponse, - PromptRequest, - PromptResponse, - SessionNotification, - SetSessionModeRequest, - SetSessionModeResponse, - stdio_streams, - PROTOCOL_VERSION, -) -from acp.schema import ( - AgentMessageChunk, - AgentThoughtChunk, - AllowedOutcome, - ContentToolCallContent, - PermissionOption, - RequestPermissionRequest, - RequestPermissionResponse, - TextContentBlock, - ToolCallStart, - ToolCallProgress, - ToolCallUpdate, - UserMessageChunk, -) - - -# Lazily import mini-swe-agent to avoid hard dependency for users who don't need this example - - -@dataclass -class ACPAgentConfig: # Extra controls layered on top of mini-swe-agent defaults - mode: Literal["confirm", "yolo", "human"] = "confirm" - whitelist_actions: list[str] = field(default_factory=list) - confirm_exit: bool = True - - -def _create_streaming_mini_agent( - *, - client: Client, - session_id: str, - cwd: str, - model_name: str, - model_kwargs: dict[str, Any], - loop: asyncio.AbstractEventLoop, - ext_config: ACPAgentConfig, -): - """Create a DefaultAgent that emits ACP session/update events during execution. - - Returns (agent, error_message_if_any). - """ - try: - try: - from minisweagent.agents.default import ( - DefaultAgent, - NonTerminatingException, - Submitted, - LimitsExceeded, - AgentConfig as _BaseCfg, - ) # type: ignore - from minisweagent.environments.local import LocalEnvironment # type: ignore - from minisweagent.models.litellm_model import LitellmModel # type: ignore - except Exception: - # Fallback to vendored reference copy if available - REF_SRC = Path(__file__).resolve().parents[2] / "reference" / "mini-swe-agent" / "src" - if REF_SRC.is_dir(): - if str(REF_SRC) not in sys.path: - sys.path.insert(0, str(REF_SRC)) - from minisweagent.agents.default import ( - DefaultAgent, - NonTerminatingException, - Submitted, - LimitsExceeded, - AgentConfig as _BaseCfg, - ) # type: ignore - from minisweagent.environments.local import LocalEnvironment # type: ignore - from minisweagent.models.litellm_model import LitellmModel # type: ignore - else: - raise - - class _StreamingMiniAgent(DefaultAgent): # type: ignore[misc] - def __init__(self) -> None: - self._acp_client = client - self._session_id = session_id - self._tool_seq = 0 - self._loop = loop - # expose mini-swe-agent exception types for outer loop - self._Submitted = Submitted - self._NonTerminatingException = NonTerminatingException - self._LimitsExceeded = LimitsExceeded - model = LitellmModel(model_name=model_name, model_kwargs=model_kwargs) - env = LocalEnvironment(cwd=cwd) - super().__init__(model=model, env=env, config_class=_BaseCfg) - # extra config - self.acp_config = ext_config - # During initial seeding (system/user templates), suppress updates - self._emit_updates = False - - # --- ACP streaming helpers --- - - def _schedule(self, coro): - import asyncio as _asyncio - - return _asyncio.run_coroutine_threadsafe(coro, self._loop) - - async def _send(self, update_model) -> None: - await self._acp_client.sessionUpdate( - SessionNotification(sessionId=self._session_id, update=update_model) - ) - - def _send_cost_hint(self) -> None: - try: - cost = float(getattr(self.model, "cost", 0.0)) - except Exception: - cost = 0.0 - hint = AgentThoughtChunk( - sessionUpdate="agent_thought_chunk", - content=TextContentBlock(type="text", text=f"__COST__:{cost:.2f}"), - ) - try: - loop = asyncio.get_running_loop() - loop.create_task(self._send(hint)) - except RuntimeError: - self._schedule(self._send(hint)) - - async def on_tool_start(self, title: str, command: str, tool_call_id: str) -> None: - """Send a tool_call start notification for a bash command.""" - update = ToolCallStart( - sessionUpdate="tool_call", - toolCallId=tool_call_id, - title=title, - kind="execute", - status="pending", - content=[ - ContentToolCallContent( - type="content", content=TextContentBlock(type="text", text=f"```bash\n{command}\n```") - ) - ], - rawInput={"command": command}, - ) - await self._send(update) - - async def on_tool_complete( - self, - tool_call_id: str, - output: str, - returncode: int, - *, - status: str = "completed", - ) -> None: - """Send a tool_call_update with the final output and return code.""" - update = ToolCallProgress( - sessionUpdate="tool_call_update", - toolCallId=tool_call_id, - status=status, - content=[ - ContentToolCallContent( - type="content", content=TextContentBlock(type="text", text=f"```ansi\n{output}\n```") - ) - ], - rawOutput={"output": output, "returncode": returncode}, - ) - await self._send(update) - - def add_message(self, role: str, content: str, **kwargs): - super().add_message(role, content, **kwargs) - # Only stream LM output as agent_message_chunk; tool output is handled via tool_call_update. - if not getattr(self, "_emit_updates", True) or role != "assistant": - return - text = str(content) - block = TextContentBlock(type="text", text=text) - update = AgentMessageChunk(sessionUpdate="agent_message_chunk", content=block) - try: - loop = asyncio.get_running_loop() - loop.create_task(self._send(update)) - except RuntimeError: - self._schedule(self._send(update)) - # Fire-and-forget - - def _confirm_action_sync(self, tool_call_id: str, command: str) -> bool: - # Build request and block until client responds - req = RequestPermissionRequest( - sessionId=self._session_id, - options=[ - PermissionOption(optionId="allow-once", name="Allow once", kind="allow_once"), - PermissionOption(optionId="reject-once", name="Reject", kind="reject_once"), - ], - toolCall=ToolCallUpdate( - toolCallId=tool_call_id, - title="bash", - kind="execute", - status="pending", - content=[ - ContentToolCallContent( - type="content", - content=TextContentBlock(type="text", text=f"```bash\n{command}\n```"), - ) - ], - rawInput={"command": command}, - ), - ) - fut = self._schedule(self._acp_client.requestPermission(req)) - try: - resp: RequestPermissionResponse = fut.result() # type: ignore[assignment] - except Exception: - return False - out = resp.outcome - if isinstance(out, AllowedOutcome) and out.optionId in ("allow-once", "allow-always"): - return True - return False - - def execute_action(self, action: dict) -> dict: # type: ignore[override] - self._tool_seq += 1 - tool_id = f"mini-bash-{self._tool_seq}-{uuid.uuid4().hex[:8]}" - command = action.get("action", "") - - # Always create tool_call first (pending) - self._schedule(self.on_tool_start("bash", command, tool_id)) - - # Request permission unless whitelisted - if command.strip() and not any(re.match(r, command) for r in self.acp_config.whitelist_actions): - allowed = self._confirm_action_sync(tool_id, command) - if not allowed: - # Mark as cancelled/failed accordingly and abort this step - self._schedule( - self.on_tool_complete( - tool_id, - "Permission denied by user", - 0, - status="cancelled", - ) - ) - raise self._NonTerminatingException("Command not executed: denied by user") - - try: - # Mark in progress - self._schedule( - self._send( - ToolCallProgress( - sessionUpdate="tool_call_update", - toolCallId=tool_id, - status="in_progress", - ) - ) - ) - result = super().execute_action(action) - output = result.get("output", "") - returncode = int(result.get("returncode", 0) or 0) - self._schedule(self.on_tool_complete(tool_id, output, returncode, status="completed")) - return result - except self._Submitted as e: # type: ignore[misc] - final_text = str(e) - self._schedule(self.on_tool_complete(tool_id, final_text, 0, status="completed")) - raise - except self._NonTerminatingException as e: # type: ignore[misc] - msg = str(e) - status = ( - "cancelled" - if any( - key in msg - for key in ( - "Command not executed", - "Switching to human mode", - "switched to manual mode", - "Interrupted by user", - ) - ) - else "failed" - ) - self._schedule( - self.on_tool_complete(tool_id, msg, 124 if status != "cancelled" else 0, status=status) - ) - raise - except Exception as e: # include other failures - msg = str(e) or "execution failed" - self._schedule(self.on_tool_complete(tool_id, msg, 124, status="failed")) - raise - - return _StreamingMiniAgent(), None - except Exception as e: - return None, f"Failed to load mini-swe-agent: {e}" - - -class MiniSweACPAgent(Agent): - def __init__(self, client: Client) -> None: - self._client = client - self._sessions: Dict[str, Dict[str, Any]] = {} - - async def initialize(self, _params: InitializeRequest) -> InitializeResponse: - from acp.schema import AgentCapabilities, PromptCapabilities - - return InitializeResponse( - protocolVersion=PROTOCOL_VERSION, - agentCapabilities=AgentCapabilities( - loadSession=True, - promptCapabilities=PromptCapabilities(audio=False, image=False, embeddedContext=True), - ), - authMethods=[], - ) - - async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: - session_id = f"sess-{uuid.uuid4().hex[:12]}" - # load config from env for whitelist & confirm_exit - cfg = ACPAgentConfig() - try: - import json as _json - - wl = os.getenv("MINI_SWE_WHITELIST", "[]") - cfg.whitelist_actions = list(_json.loads(wl)) if wl else [] # type: ignore[assignment] - except Exception: - pass - ce = os.getenv("MINI_SWE_CONFIRM_EXIT") - if ce is not None: - cfg.confirm_exit = ce.lower() not in ("0", "false", "no") - self._sessions[session_id] = { - "cwd": params.cwd, - "agent": None, - "task": None, - "config": cfg, - } - return NewSessionResponse(sessionId=session_id) - - async def loadSession(self, params) -> None: # type: ignore[override] - try: - session_id = params.sessionId # type: ignore[attr-defined] - cwd = params.cwd # type: ignore[attr-defined] - except Exception: - session_id = getattr(params, "sessionId", "sess-unknown") - cwd = getattr(params, "cwd", os.getcwd()) - if session_id not in self._sessions: - cfg = ACPAgentConfig() - try: - import json as _json - - wl = os.getenv("MINI_SWE_WHITELIST", "[]") - cfg.whitelist_actions = list(_json.loads(wl)) if wl else [] # type: ignore[assignment] - except Exception: - pass - ce = os.getenv("MINI_SWE_CONFIRM_EXIT") - if ce is not None: - cfg.confirm_exit = ce.lower() not in ("0", "false", "no") - self._sessions[session_id] = {"cwd": cwd, "agent": None, "task": None, "config": cfg} - return None - - async def authenticate(self, _params: AuthenticateRequest) -> None: - return None - - async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # type: ignore[override] - sess = self._sessions.get(params.sessionId) - if not sess: - return SetSessionModeResponse() - mode = params.modeId.lower() - if mode in ("confirm", "yolo", "human"): - sess["config"].mode = mode # type: ignore[attr-defined] - return SetSessionModeResponse() - - def _extract_mode_from_blocks(self, blocks) -> Literal["confirm", "yolo", "human"] | None: - for b in blocks: - if getattr(b, "type", None) == "text": - t = getattr(b, "text", "") or "" - m = re.search(r"\[\[MODE:([a-zA-Z]+)\]\]", t) - if m: - mode = m.group(1).lower() - if mode in ("confirm", "yolo", "human"): - return mode # type: ignore[return-value] - return None - - def _extract_code_from_blocks(self, blocks) -> str | None: - for b in blocks: - if getattr(b, "type", None) == "text": - t = getattr(b, "text", "") or "" - actions = re.findall(r"```bash\n(.*?)\n```", t, re.DOTALL) - if actions: - return actions[0].strip() - return None - - async def prompt(self, params: PromptRequest) -> PromptResponse: - sess = self._sessions.get(params.sessionId) - if not sess: - self._sessions[params.sessionId] = { - "cwd": os.getcwd(), - "agent": None, - "task": None, - "config": ACPAgentConfig(), - } - sess = self._sessions[params.sessionId] - - # Init or reuse agent - agent = sess.get("agent") - if agent is None: - model_name = os.getenv("MINI_SWE_MODEL", os.getenv("OPENAI_MODEL", "gpt-4o-mini")) - try: - import json - - model_kwargs = json.loads(os.getenv("MINI_SWE_MODEL_KWARGS", "{}")) - if not isinstance(model_kwargs, dict): - model_kwargs = {} - except Exception: - model_kwargs = {} - loop = asyncio.get_running_loop() - agent, err = _create_streaming_mini_agent( - client=self._client, - session_id=params.sessionId, - cwd=sess.get("cwd") or os.getcwd(), - model_name=model_name, - model_kwargs=model_kwargs, - loop=loop, - ext_config=sess["config"], - ) - if err: - await self._client.sessionUpdate( - SessionNotification( - sessionId=params.sessionId, - update=AgentMessageChunk( - sessionUpdate="agent_message_chunk", - content=TextContentBlock( - type="text", - text=( - "mini-swe-agent load error: " - + err - + "\nPlease install mini-swe-agent or its dependencies in the configured venv." - ), - ), - ), - ) - ) - return PromptResponse(stopReason="end_turn") - sess["agent"] = agent - - # Mode is controlled entirely client-side via requestPermission behavior; no control blocks are parsed. - - # Initialize conversation on first task - if not sess.get("task"): - # Build task - task_parts: list[str] = [] - for block in params.prompt: - btype = getattr(block, "type", None) - if btype == "text": - text = getattr(block, "text", "") - if text and not text.strip().startswith("[[MODE:"): - task_parts.append(str(text)) - task = "\n".join(task_parts).strip() or "Help me with the current repository." - sess["task"] = task - agent.extra_template_vars |= {"task": task} - agent.messages = [] - # Seed templates without emitting updates - agent._emit_updates = False # type: ignore[attr-defined] - agent.add_message("system", agent.render_template(agent.config.system_template)) - agent.add_message("user", agent.render_template(agent.config.instance_template)) - agent._emit_updates = True # type: ignore[attr-defined] - - # Decide the source of the next action - try: - if sess["config"].mode == "human": - # Expect a bash command from the client - cmd = self._extract_code_from_blocks(params.prompt) - if not cmd: - # Ask user to provide a command and return - await self._client.sessionUpdate( - SessionNotification( - sessionId=params.sessionId, - update=AgentMessageChunk( - sessionUpdate="agent_message_chunk", - content=TextContentBlock(type="text", text="Human mode: please submit a bash command."), - ), - ) - ) - return PromptResponse(stopReason="end_turn") - # Fabricate assistant message with the command - msg_content = f"\n```bash\n{cmd}\n```" - agent.add_message("assistant", msg_content) - response = {"content": msg_content} - else: - # Query the model in a worker thread to keep the event loop free - response = await asyncio.to_thread(agent.query) - # Send cost hint after each model call - try: - agent._send_cost_hint() # type: ignore[attr-defined] - except Exception: - pass - - # Execute and record observation in worker thread - await asyncio.to_thread(agent.get_observation, response) - except getattr(agent, "_NonTerminatingException") as e: # type: ignore[misc] - agent.add_message("user", str(e)) - except getattr(agent, "_Submitted") as e: # type: ignore[misc] - final_message = str(e) - agent.add_message("user", final_message) - # Ask for confirmation / new task if configured - if sess["config"].confirm_exit: - await self._client.sessionUpdate( - SessionNotification( - sessionId=params.sessionId, - update=AgentMessageChunk( - sessionUpdate="agent_message_chunk", - content=TextContentBlock( - type="text", - text=( - "Agent finished. Type a new task in the next message to continue, or do nothing to end." - ), - ), - ), - ) - ) - # Reset task so that next prompt can set a new one - sess["task"] = None - except getattr(agent, "_LimitsExceeded") as e: # type: ignore[misc] - agent.add_message("user", f"Limits exceeded: {e}") - except Exception as e: - # Surface unexpected errors to the client to avoid silent waits - await self._client.sessionUpdate( - SessionNotification( - sessionId=params.sessionId, - update=AgentMessageChunk( - sessionUpdate="agent_message_chunk", - content=TextContentBlock(type="text", text=f"Error while processing: {e}"), - ), - ) - ) - - return PromptResponse(stopReason="end_turn") - - async def cancel(self, _params: CancelNotification) -> None: - return None - - -async def main() -> None: - reader, writer = await stdio_streams() - AgentSideConnection(lambda client: MiniSweACPAgent(client), writer, reader) - await asyncio.Event().wait() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/mini_swe_agent/client.py b/examples/mini_swe_agent/client.py deleted file mode 100644 index 33d6380..0000000 --- a/examples/mini_swe_agent/client.py +++ /dev/null @@ -1,650 +0,0 @@ -import asyncio -import os -import queue -import re -import threading -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Iterable, Literal, Optional - - -from rich.spinner import Spinner -from rich.text import Text -from textual.app import App, ComposeResult, SystemCommand -from textual.binding import Binding -from textual.containers import Container, Vertical, VerticalScroll -from textual.css.query import NoMatches -from textual.events import Key -from textual.screen import Screen -from textual.widgets import Footer, Header, Input, Static, TextArea - -from acp import ( - Client, - PROTOCOL_VERSION, - ClientSideConnection, - InitializeRequest, - NewSessionRequest, - PromptRequest, - RequestPermissionRequest, - RequestPermissionResponse, - SessionNotification, - SetSessionModeRequest, -) -from acp.schema import ( - AgentMessageChunk, - AgentThoughtChunk, - AllowedOutcome, - ContentToolCallContent, - PermissionOption, - TextContentBlock, - ToolCallStart, - ToolCallProgress, - UserMessageChunk, -) -from acp.stdio import _WritePipeProtocol - - -MODE = Literal["confirm", "yolo", "human"] - - -@dataclass -class UIMessage: - role: str # "assistant" or "user" - content: str - - -def _messages_to_steps(messages: list[UIMessage]) -> list[list[UIMessage]]: - steps: list[list[UIMessage]] = [] - current: list[UIMessage] = [] - for m in messages: - current.append(m) - if m.role == "user": - steps.append(current) - current = [] - if current: - steps.append(current) - return steps - - -class SmartInputContainer(Container): - def __init__(self, app: "TextualMiniSweClient"): - super().__init__(classes="smart-input-container") - self._app = app - self._multiline_mode = False - self.can_focus = True - self.display = False - - self.pending_prompt: Optional[str] = None - self._input_event = threading.Event() - self._input_result: Optional[str] = None - - self._header_display = Static(id="input-header-display", classes="message-header input-request-header") - self._hint_text = Static(classes="hint-text") - self._single_input = Input(placeholder="Type your input...") - self._multi_input = TextArea(show_line_numbers=False, classes="multi-input") - self._input_elements_container = Vertical( - self._header_display, - self._hint_text, - self._single_input, - self._multi_input, - classes="message-container", - ) - - def compose(self) -> ComposeResult: - yield self._input_elements_container - - def on_mount(self) -> None: - self._multi_input.display = False - self._update_mode_display() - - def on_focus(self) -> None: - if self._multiline_mode: - self._multi_input.focus() - else: - self._single_input.focus() - - def request_input(self, prompt: str) -> str: - self._input_event.clear() - self._input_result = None - self.pending_prompt = prompt - self._header_display.update(prompt) - self._update_mode_display() - # If we're already on the Textual thread, call directly; otherwise marshal. - if getattr(self._app, "_thread_id", None) == threading.get_ident(): - self._app.update_content() - else: - self._app.call_from_thread(self._app.update_content) - self._input_event.wait() - return self._input_result or "" - - def _complete_input(self, input_text: str): - self._input_result = input_text - self.pending_prompt = None - self.display = False - self._single_input.value = "" - self._multi_input.text = "" - self._multiline_mode = False - self._update_mode_display() - self._app.update_content() - # Reset scroll position to bottom - self._app._vscroll.scroll_y = 0 - self._input_event.set() - - def action_toggle_mode(self) -> None: - if self.pending_prompt is None or self._multiline_mode: - return - self._multiline_mode = True - self._update_mode_display() - self.on_focus() - - def _update_mode_display(self) -> None: - if self._multiline_mode: - self._multi_input.text = self._single_input.value - self._single_input.display = False - self._multi_input.display = True - self._hint_text.update( - "[reverse][bold][$accent] Ctrl+D [/][/][/] to submit, [reverse][bold][$accent] Tab [/][/][/] to switch focus with other controls" - ) - else: - self._hint_text.update( - "[reverse][bold][$accent] Enter [/][/][/] to submit, [reverse][bold][$accent] Ctrl+T [/][/][/] to switch to multi-line input, [reverse][bold][$accent] Tab [/][/][/] to switch focus with other controls", - ) - self._multi_input.display = False - self._single_input.display = True - - def on_input_submitted(self, event: Input.Submitted) -> None: - if not self._multiline_mode: - text = event.input.value.strip() - self._complete_input(text) - - def on_key(self, event: Key) -> None: - if event.key == "ctrl+t" and not self._multiline_mode: - event.prevent_default() - self.action_toggle_mode() - return - if self._multiline_mode and event.key == "ctrl+d": - event.prevent_default() - self._complete_input(self._multi_input.text.strip()) - return - if event.key == "escape": - event.prevent_default() - self.can_focus = False - self._app.set_focus(None) - return - - -class MiniSweClientImpl(Client): - def __init__(self, app: "TextualMiniSweClient") -> None: - self._app = app - - async def sessionUpdate(self, params: SessionNotification) -> None: - upd = params.update - - def _post(msg: UIMessage) -> None: - if getattr(self._app, "_thread_id", None) == threading.get_ident(): - self._app.enqueue_message(msg) - self._app.on_message_added() - else: - self._app.call_from_thread(lambda: (self._app.enqueue_message(msg), self._app.on_message_added())) - - if isinstance(upd, AgentMessageChunk): - # agent message - txt = _content_to_text(upd.content) - _post(UIMessage("assistant", txt)) - elif isinstance(upd, UserMessageChunk): - txt = _content_to_text(upd.content) - _post(UIMessage("user", txt)) - elif isinstance(upd, AgentThoughtChunk): - # agent thought chunk (informational) - txt = _content_to_text(upd.content) - _post(UIMessage("assistant", f"[thought]\n{txt}")) - elif isinstance(upd, ToolCallStart): - # tool call start → record structured state - self._app._update_tool_call( - upd.toolCallId, title=upd.title or "", status=upd.status or "pending", content=upd.content - ) - self._app.call_from_thread(self._app.update_content) - elif isinstance(upd, ToolCallProgress): - # tool call update → update structured state - self._app._update_tool_call(upd.toolCallId, status=upd.status, content=upd.content) - self._app.call_from_thread(self._app.update_content) - - async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: - # Respect client-side mode shortcuts - mode = self._app.mode - if mode == "yolo": - return RequestPermissionResponse(outcome=AllowedOutcome(outcome="selected", optionId="allow-once")) - # Prompt user for decision - prompt = "Approve tool call? Press Enter to allow once, type 'n' to reject" - ans = self._app.input_container.request_input(prompt).strip().lower() - if ans in ("", "y", "yes"): - return RequestPermissionResponse(outcome=AllowedOutcome(outcome="selected", optionId="allow-once")) - return RequestPermissionResponse(outcome=AllowedOutcome(outcome="selected", optionId="reject-once")) - - # Optional features not used in this example - async def writeTextFile(self, params): - return None - - async def readTextFile(self, params): - return None - - -def _content_to_text(content) -> str: - if hasattr(content, "text"): - return str(content.text) - return str(content) - - -class TextualMiniSweClient(App): - BINDINGS = [ - Binding("right,l", "next_step", "Step++", tooltip="Show next step of the agent"), - Binding("left,h", "previous_step", "Step--", tooltip="Show previous step of the agent"), - Binding("0", "first_step", "Step=0", tooltip="Show first step of the agent", show=False), - Binding("$", "last_step", "Step=-1", tooltip="Show last step of the agent", show=False), - Binding("j,down", "scroll_down", "Scroll down", show=False), - Binding("k,up", "scroll_up", "Scroll up", show=False), - Binding("q,ctrl+q", "quit", "Quit", tooltip="Quit the agent"), - Binding("y,ctrl+y", "yolo", "YOLO mode", tooltip="Switch to YOLO Mode (LM actions will execute immediately)"), - Binding( - "c", - "confirm", - "CONFIRM mode", - tooltip="Switch to Confirm Mode (LM proposes commands and you confirm/reject them)", - ), - Binding("u,ctrl+u", "human", "HUMAN mode", tooltip="Switch to Human Mode (you can now type commands directly)"), - Binding("enter", "continue_step", "Next"), - Binding("f1,question_mark", "toggle_help_panel", "Help", tooltip="Show help"), - ] - - def __init__(self) -> None: - # Load CSS - css_path = os.environ.get( - "MSWEA_MINI_STYLE_PATH", - str( - Path(__file__).resolve().parents[2] - / "reference" - / "mini-swe-agent" - / "src" - / "minisweagent" - / "config" - / "mini.tcss" - ), - ) - try: - self.__class__.CSS = Path(css_path).read_text() - except Exception: - self.__class__.CSS = "" - super().__init__() - self.mode: MODE = "confirm" - self._vscroll = VerticalScroll() - self.input_container = SmartInputContainer(self) - self.messages: list[UIMessage] = [] - self._spinner = Spinner("dots") - self.agent_state: Literal["UNINITIALIZED", "RUNNING", "AWAITING_INPUT", "STOPPED"] = "UNINITIALIZED" - self._bg_loop: Optional[asyncio.AbstractEventLoop] = None - self._bg_thread: Optional[threading.Thread] = None - self._conn: Optional[ClientSideConnection] = None - self._session_id: Optional[str] = None - self._pending_human_command: Optional[str] = None - self._outbox: "queue.Queue[list[TextContentBlock]]" = queue.Queue() - # Pagination and metrics - self._i_step: int = 0 - self.n_steps: int = 1 - # Structured state for tool calls and plan - self._tool_calls: dict[str, dict] = {} - self._plan: list[dict] = [] - self._ask_new_task_pending = False - - # --- Textual lifecycle --- - - def compose(self) -> ComposeResult: - yield Header() - with Container(id="main"): - with self._vscroll: - with Vertical(id="content"): - pass - yield self.input_container - yield Footer() - - def on_mount(self) -> None: - self.agent_state = "RUNNING" - self.update_content() - self.set_interval(1 / 8, self._update_headers) - # Ask for initial task without blocking UI - threading.Thread(target=self._ask_initial_task, daemon=True).start() - - def _ask_initial_task(self) -> None: - task = self.input_container.request_input("Enter your task for mini-swe-agent:") - blocks = [TextContentBlock(type="text", text=task)] - self._outbox.put(blocks) - self._start_connection_thread() - - def on_unmount(self) -> None: - if self._bg_loop: - try: - self._bg_loop.call_soon_threadsafe(self._bg_loop.stop) - except Exception: - pass - - # --- Backend comms --- - - def _start_connection_thread(self) -> None: - """Start a background thread running the ACP connection event loop.""" - - def _runner() -> None: - loop = asyncio.new_event_loop() - self._bg_loop = loop - asyncio.set_event_loop(loop) - loop.run_until_complete(self._run_connection()) - - t = threading.Thread(target=_runner, daemon=True) - t.start() - self._bg_thread = t - - async def _open_acp_streams_from_env(self) -> tuple[Optional[asyncio.StreamReader], Optional[asyncio.StreamWriter]]: - """If launched via duet, open ACP streams from inherited FDs; else return (None, None).""" - read_fd_s = os.environ.get("MSWEA_READ_FD") - write_fd_s = os.environ.get("MSWEA_WRITE_FD") - if not read_fd_s or not write_fd_s: - return None, None - read_fd = int(read_fd_s) - write_fd = int(write_fd_s) - loop = asyncio.get_running_loop() - # Reader - reader = asyncio.StreamReader() - reader_proto = asyncio.StreamReaderProtocol(reader) - r_file = os.fdopen(read_fd, "rb", buffering=0) - await loop.connect_read_pipe(lambda: reader_proto, r_file) - # Writer - write_proto = _WritePipeProtocol() - w_file = os.fdopen(write_fd, "wb", buffering=0) - transport, _ = await loop.connect_write_pipe(lambda: write_proto, w_file) - writer = asyncio.StreamWriter(transport, write_proto, None, loop) - return reader, writer - - async def _run_connection(self) -> None: - """Run the ACP client connection using FDs provided by duet; do not fallback.""" - reader, writer = await self._open_acp_streams_from_env() - if reader is None or writer is None: # type: ignore[truthy-bool] - # Do not fallback; inform user and stop - self.call_from_thread( - lambda: ( - self.enqueue_message( - UIMessage( - "assistant", - "Communication endpoints not provided. Please launch via examples/mini_swe_agent/duet.py", - ) - ), - self.on_message_added(), - ) - ) - self.agent_state = "STOPPED" - return - - self._conn = ClientSideConnection(lambda _agent: MiniSweClientImpl(self), writer, reader) - try: - resp = await self._conn.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION)) - self.call_from_thread( - lambda: ( - self.enqueue_message(UIMessage("assistant", f"Initialized v{resp.protocolVersion}")), - self.on_message_added(), - ) - ) - new_sess = await self._conn.newSession(NewSessionRequest(mcpServers=[], cwd=os.getcwd())) - self._session_id = new_sess.sessionId - self.call_from_thread( - lambda: ( - self.enqueue_message(UIMessage("assistant", f"Session {self._session_id} created")), - self.on_message_added(), - ) - ) - except Exception as e: - self.call_from_thread( - lambda: ( - self.enqueue_message(UIMessage("assistant", f"ACP connect error: {e}")), - self.on_message_added(), - ) - ) - self.agent_state = "STOPPED" - return - - # Autostep loop: take queued prompts and send; if none and mode != human, keep stepping - while self.agent_state != "STOPPED": - blocks: list[TextContentBlock] - try: - blocks = self._outbox.get_nowait() - except queue.Empty: - # Auto-advance a step when not in human mode and we're not awaiting input - if self.mode != "human" and self.input_container.pending_prompt is None: - blocks = [] - else: - await asyncio.sleep(0.05) - continue - # Send prompt turn - try: - result = await self._conn.prompt(PromptRequest(sessionId=self._session_id, prompt=blocks)) - # Minimal finish/new task UX: after each stopReason, if not human and idle, offer new task - if ( - self.mode != "human" - and not self._ask_new_task_pending - and self.input_container.pending_prompt is None - ): - self._ask_new_task_pending = True - - def _ask_new(): - task = self.input_container.request_input( - "Turn complete. Type a new task or press Enter to continue:" - ) - if task.strip(): - self._outbox.put([TextContentBlock(type="text", text=task)]) - else: - self._outbox.put([]) - self._ask_new_task_pending = False - - threading.Thread(target=_ask_new, daemon=True).start() - except Exception as e: - # Break on connection shutdowns to stop background thread cleanly - msg = str(e) - if isinstance(e, (BrokenPipeError, ConnectionResetError)) or "Broken pipe" in msg or "closed" in msg: - self.agent_state = "STOPPED" - break - self.call_from_thread(lambda: self.enqueue_message(UIMessage("assistant", f"prompt error: {e}"))) - # Tiny delay to avoid busy-looping - await asyncio.sleep(0.05) - - def send_human_command(self, cmd: str) -> None: - if not cmd.strip(): - return - code = f"```bash\n{cmd.strip()}\n```" - self._outbox.put([TextContentBlock(type="text", text=code)]) - - # --- UI updates --- - - def enqueue_message(self, msg: UIMessage) -> None: - self.messages.append(msg) - - def on_message_added(self) -> None: - auto_follow = self._vscroll.scroll_y <= 1 and self._i_step == self.n_steps - 1 - # recompute step pages - items = _messages_to_steps(self.messages) - self.n_steps = max(1, len(items)) - self.update_content() - if auto_follow: - self.action_last_step() - - # --- Structured state helpers --- - - def _update_tool_call( - self, tool_id: str, *, title: Optional[str] = None, status: Optional[str] = None, content=None - ) -> None: - tc = self._tool_calls.get(tool_id, {"toolCallId": tool_id, "title": "", "status": "pending", "content": []}) - if title is not None: - tc["title"] = title - if status is not None: - tc["status"] = status - if content: - # Append any text content blocks - texts = [] - for c in content: - if isinstance(c, ContentToolCallContent) and getattr(c.content, "type", None) == "text": - texts.append(getattr(c.content, "text", "")) - if texts: - tc.setdefault("content", []).append("\n".join(texts)) - self._tool_calls[tool_id] = tc - - def update_content(self) -> None: - container = self.query_one("#content", Vertical) - container.remove_children() - if not self.messages: - container.mount(Static("Waiting for agent…")) - return - items = _messages_to_steps(self.messages) - page = items[self._i_step] if items else [] - for m in page[-400:]: - message_container = Vertical(classes="message-container") - container.mount(message_container) - role = m.role.replace("assistant", "mini-swe-agent").upper() - message_container.mount(Static(role, classes="message-header")) - message_container.mount(Static(Text(m.content, no_wrap=False), classes="message-content")) - # Render structured tool calls at the end of the page - if self._tool_calls: - tc_container = Vertical(classes="message-container") - container.mount(tc_container) - tc_container.mount(Static("TOOL CALLS", classes="message-header")) - for tcid, tc in self._tool_calls.items(): - block = Vertical(classes="message-content") - tc_container.mount(block) - status = tc.get("status", "") - title = tc.get("title", "") - block.mount(Static(Text(f"[TOOL] {title} — {status}", no_wrap=False))) - for chunk in tc.get("content", []) or []: - block.mount(Static(Text(chunk, no_wrap=False))) - if self.input_container.pending_prompt is not None: - self.agent_state = "AWAITING_INPUT" - self.input_container.display = ( - self.input_container.pending_prompt is not None and self._i_step == len(items) - 1 - ) - if self.input_container.display: - self.input_container.on_focus() - self._update_headers() - self.refresh() - - def _update_headers(self) -> None: - status_text = self.agent_state - if self.agent_state == "RUNNING": - spinner_frame = str(self._spinner.render(time.time())).strip() - status_text = f"{self.agent_state} {spinner_frame}" - self.title = f"Step {self._i_step + 1}/{self.n_steps} - {status_text}" - try: - self.query_one("Header").set_class(self.agent_state == "RUNNING", "running") - except NoMatches: - pass - - # --- Actions --- - - # --- Pagination helpers --- - - @property - def i_step(self) -> int: - return self._i_step - - @i_step.setter - def i_step(self, value: int) -> None: - if value != self._i_step: - self._i_step = max(0, min(value, self.n_steps - 1)) - self._vscroll.scroll_to(y=0, animate=False) - self.update_content() - - # --- Actions --- - - def action_next_step(self) -> None: - self.i_step += 1 - - def action_previous_step(self) -> None: - self.i_step -= 1 - - def action_first_step(self) -> None: - self.i_step = 0 - - def action_last_step(self) -> None: - self.i_step = self.n_steps - 1 - - def action_scroll_down(self) -> None: - self._vscroll.scroll_to(y=self._vscroll.scroll_target_y + 15) - - def action_scroll_up(self) -> None: - self._vscroll.scroll_to(y=self._vscroll.scroll_target_y - 15) - - def _set_agent_mode_async(self, mode_id: str) -> None: - if not self._conn or not self._session_id or not self._bg_loop: - return - - def _schedule() -> None: - try: - self._bg_loop.create_task( - self._conn.setSessionMode(SetSessionModeRequest(sessionId=self._session_id, modeId=mode_id)) - ) - except Exception: - pass - - try: - self._bg_loop.call_soon_threadsafe(_schedule) - except Exception: - pass - - def action_yolo(self): - self.mode = "yolo" - self._set_agent_mode_async("yolo") - if self.input_container.pending_prompt is not None: - self.input_container._complete_input("") - self.notify("YOLO mode enabled") - - def action_confirm(self): - self.mode = "confirm" - self._set_agent_mode_async("confirm") - if self.input_container.pending_prompt is not None: - self.input_container._complete_input("") - self.notify("Confirm mode enabled") - - def action_human(self): - self.mode = "human" - self._set_agent_mode_async("human") - - # Ask for a command asynchronously to avoid blocking UI - def _ask(): - cmd = self.input_container.request_input("Type a bash command to run:") - if cmd.strip(): - self.send_human_command(cmd) - - threading.Thread(target=_ask, daemon=True).start() - self.notify("Human mode: commands will be executed as you submit them") - - def action_continue_step(self): - # For non-human modes, enqueue an empty turn to advance one step. - if self.mode != "human": - self._outbox.put([]) - return - - # For human, prompt for next command. - def _ask(): - cmd = self.input_container.request_input("Type a bash command to run:") - if cmd.strip(): - self.send_human_command(cmd) - - threading.Thread(target=_ask, daemon=True).start() - - def action_toggle_help_panel(self) -> None: - if self.query("HelpPanel"): - self.action_hide_help_panel() - else: - self.action_show_help_panel() - - -def main() -> None: - app = TextualMiniSweClient() - app.run() - - -if __name__ == "__main__": - main() diff --git a/examples/mini_swe_agent/duet.py b/examples/mini_swe_agent/duet.py deleted file mode 100644 index a0e0487..0000000 --- a/examples/mini_swe_agent/duet.py +++ /dev/null @@ -1,91 +0,0 @@ -import asyncio -import contextlib -import os -import sys -from pathlib import Path - - -async def main() -> None: - # Launch agent and client, wiring a dedicated pipe pair for ACP protocol. - # Client keeps its own stdin/stdout for the Textual UI. - root = Path(__file__).resolve().parent - agent_path = str(root / "agent.py") - client_path = str(root / "client.py") - - # Load .env into process env so children inherit it (prefer python-dotenv if available) - try: - from dotenv import load_dotenv # type: ignore - - # Load .env from repo root: examples/mini_swe_agent -> examples -> REPO - load_dotenv(dotenv_path=str(root.parents[1] / ".env"), override=True) - except Exception: - pass - - base_env = os.environ.copy() - src_dir = str((root.parents[1] / "src").resolve()) - base_env["PYTHONPATH"] = src_dir + os.pathsep + base_env.get("PYTHONPATH", "") - - # Create two pipes: agent->client and client->agent - a2c_r, a2c_w = os.pipe() - c2a_r, c2a_w = os.pipe() - # Ensure the FDs we pass to children are inheritable - for fd in (a2c_r, a2c_w, c2a_r, c2a_w): - os.set_inheritable(fd, True) - - # Start agent: stdin <- client (c2a_r), stdout -> client (a2c_w) - agent = await asyncio.create_subprocess_exec( - sys.executable, - agent_path, - stdin=c2a_r, - stdout=a2c_w, - stderr=sys.stderr, - env=base_env, - close_fds=True, - ) - - # Start client with ACP FDs exported via environment; keep terminal IO for UI - client_env = base_env.copy() - client_env["MSWEA_READ_FD"] = str(a2c_r) # where client reads ACP messages - client_env["MSWEA_WRITE_FD"] = str(c2a_w) # where client writes ACP messages - - client = await asyncio.create_subprocess_exec( - sys.executable, - client_path, - env=client_env, - pass_fds=(a2c_r, c2a_w), # ensure client inherits these FDs - close_fds=True, - ) - - # Close parent's copies of the pipe ends to avoid leaks - for fd in (a2c_r, a2c_w, c2a_r, c2a_w): - with contextlib.suppress(OSError): - os.close(fd) - - # If either process exits, terminate the other gracefully - agent_task = asyncio.create_task(agent.wait()) - client_task = asyncio.create_task(client.wait()) - done, pending = await asyncio.wait({agent_task, client_task}, return_when=asyncio.FIRST_COMPLETED) - - # Terminate the peer process - if agent_task in done and client.returncode is None: - with contextlib.suppress(ProcessLookupError): - client.terminate() - if client_task in done and agent.returncode is None: - with contextlib.suppress(ProcessLookupError): - agent.terminate() - - # Wait a bit, then kill if still running - try: - await asyncio.wait_for(asyncio.gather(agent.wait(), client.wait()), timeout=3) - except asyncio.TimeoutError: - with contextlib.suppress(ProcessLookupError): - if agent.returncode is None: - agent.kill() - with contextlib.suppress(ProcessLookupError): - if client.returncode is None: - client.kill() - await asyncio.gather(agent.wait(), client.wait()) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index 8287852..9ff07e3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,7 +10,6 @@ copyright: Maintained by psiace. nav: - Home: index.md - Quick Start: quickstart.md - - Mini SWE Agent: mini-swe-agent.md plugins: - search - mkdocstrings: diff --git a/pyproject.toml b/pyproject.toml index 097b3c4..19090c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,6 @@ dev = [ "mkdocs>=1.4.2", "mkdocs-material>=8.5.10", "mkdocstrings[python]>=0.26.1", - "mini-swe-agent>=1.10.0", "python-dotenv>=1.1.1", ] diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index 10d4366..897afce 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -6,6 +6,7 @@ import shutil import subprocess import sys +from collections.abc import Callable from pathlib import Path ROOT = Path(__file__).resolve().parents[1] @@ -43,6 +44,56 @@ "ToolCallContent3": "TerminalToolCallContent", } +ENUM_LITERAL_MAP: dict[str, tuple[str, ...]] = { + "PermissionOptionKind": ( + "allow_once", + "allow_always", + "reject_once", + "reject_always", + ), + "PlanEntryPriority": ("high", "medium", "low"), + "PlanEntryStatus": ("pending", "in_progress", "completed"), + "StopReason": ("end_turn", "max_tokens", "max_turn_requests", "refusal", "cancelled"), + "ToolCallStatus": ("pending", "in_progress", "completed", "failed"), + "ToolKind": ("read", "edit", "delete", "move", "search", "execute", "think", "fetch", "switch_mode", "other"), +} + +FIELD_TYPE_OVERRIDES: tuple[tuple[str, str, str, bool], ...] = ( + ("PermissionOption", "kind", "PermissionOptionKind", False), + ("PlanEntry", "priority", "PlanEntryPriority", False), + ("PlanEntry", "status", "PlanEntryStatus", False), + ("PromptResponse", "stopReason", "StopReason", False), + ("ToolCallUpdate", "kind", "ToolKind", True), + ("ToolCallUpdate", "status", "ToolCallStatus", True), + ("ToolCallProgress", "kind", "ToolKind", True), + ("ToolCallProgress", "status", "ToolCallStatus", True), + ("ToolCallStart", "kind", "ToolKind", True), + ("ToolCallStart", "status", "ToolCallStatus", True), + ("ToolCall", "kind", "ToolKind", True), + ("ToolCall", "status", "ToolCallStatus", True), +) + +DEFAULT_VALUE_OVERRIDES: tuple[tuple[str, str, str], ...] = ( + ("AgentCapabilities", "mcpCapabilities", "McpCapabilities(http=False, sse=False)"), + ( + "AgentCapabilities", + "promptCapabilities", + "PromptCapabilities(audio=False, embeddedContext=False, image=False)", + ), + ("ClientCapabilities", "fs", "FileSystemCapability(readTextFile=False, writeTextFile=False)"), + ("ClientCapabilities", "terminal", "False"), + ( + "InitializeRequest", + "clientCapabilities", + "ClientCapabilities(fs=FileSystemCapability(readTextFile=False, writeTextFile=False), terminal=False)", + ), + ( + "InitializeResponse", + "agentCapabilities", + "AgentCapabilities(loadSession=False, mcpCapabilities=McpCapabilities(http=False, sse=False), promptCapabilities=PromptCapabilities(audio=False, embeddedContext=False, image=False))", + ), +) + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Generate src/acp/schema.py from the ACP JSON schema.") @@ -131,9 +182,13 @@ def rename_types(output_path: Path) -> list[str]: leftover_classes = sorted(set(leftover_class_pattern.findall(content))) header_block = "\n".join(header_lines) + "\n\n" + content = _apply_field_overrides(content) + content = _apply_default_overrides(content) + alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())] alias_block = BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n" + content = _inject_enum_aliases(content) content = header_block + content.rstrip() + "\n\n" + alias_block if not content.endswith("\n"): content += "\n" @@ -150,6 +205,91 @@ def rename_types(output_path: Path) -> list[str]: return warnings +def _apply_field_overrides(content: str) -> str: + for class_name, field_name, new_type, optional in FIELD_TYPE_OVERRIDES: + if optional: + pattern = re.compile( + rf"(class {class_name}\(BaseModel\):.*?\n\s+{field_name}:\s+Annotated\[\s*)Optional\[str],", + re.DOTALL, + ) + content, count = pattern.subn(rf"\1Optional[{new_type}],", content) + else: + pattern = re.compile( + rf"(class {class_name}\(BaseModel\):.*?\n\s+{field_name}:\s+Annotated\[\s*)str,", + re.DOTALL, + ) + content, count = pattern.subn(rf"\1{new_type},", content) + if count == 0: + print( + f"Warning: failed to apply type override for {class_name}.{field_name} -> {new_type}", + file=sys.stderr, + ) + return content + + +def _apply_default_overrides(content: str) -> str: + for class_name, field_name, replacement in DEFAULT_VALUE_OVERRIDES: + class_pattern = re.compile( + rf"(class {class_name}\(BaseModel\):)(.*?)(?=\nclass |\Z)", + re.DOTALL, + ) + + def replace_block( + match: re.Match[str], + _field_name: str = field_name, + _replacement: str = replacement, + _class_name: str = class_name, + ) -> str: + header, block = match.group(1), match.group(2) + field_patterns: tuple[tuple[re.Pattern[str], Callable[[re.Match[str]], str]], ...] = ( + ( + re.compile( + rf"(\n\s+{_field_name}:.*?\]\s*=\s*)([\s\S]*?)(?=\n\s{{4}}[A-Za-z_]|$)", + re.DOTALL, + ), + lambda m, _rep=_replacement: m.group(1) + _rep, + ), + ( + re.compile( + rf"(\n\s+{_field_name}:[^\n]*=)\s*([^\n]+)", + re.MULTILINE, + ), + lambda m, _rep=_replacement: m.group(1) + " " + _rep, + ), + ) + for pattern, replacer in field_patterns: + new_block, count = pattern.subn(replacer, block, count=1) + if count: + return header + new_block + print( + f"Warning: failed to override default for {_class_name}.{_field_name}", + file=sys.stderr, + ) + return match.group(0) + + content, count = class_pattern.subn(replace_block, content, count=1) + if count == 0: + print( + f"Warning: class {class_name} not found for default override on {field_name}", + file=sys.stderr, + ) + return content + + +def _inject_enum_aliases(content: str) -> str: + enum_lines = [ + f"{name} = Literal[{', '.join(repr(value) for value in values)}]" for name, values in ENUM_LITERAL_MAP.items() + ] + if not enum_lines: + return content + block = "\n".join(enum_lines) + "\n\n" + class_index = content.find("\nclass ") + if class_index == -1: + return content + insertion_point = class_index + 1 # include leading newline + return content[:insertion_point] + block + content[insertion_point:] + + def format_with_ruff(file_path: Path) -> None: uv_executable = shutil.which("uv") if uv_executable is None: diff --git a/src/acp/__init__.py b/src/acp/__init__.py index 9e916de..3f5e72f 100644 --- a/src/acp/__init__.py +++ b/src/acp/__init__.py @@ -6,6 +6,31 @@ RequestError, TerminalHandle, ) +from .helpers import ( + audio_block, + embedded_blob_resource, + embedded_text_resource, + image_block, + plan_entry, + resource_block, + resource_link_block, + session_notification, + start_edit_tool_call, + start_read_tool_call, + start_tool_call, + text_block, + tool_content, + tool_diff_content, + tool_terminal_ref, + update_agent_message, + update_agent_message_text, + update_agent_thought, + update_agent_thought_text, + update_plan, + update_tool_call, + update_user_message, + update_user_message_text, +) from .meta import ( AGENT_METHODS, CLIENT_METHODS, @@ -45,7 +70,8 @@ WriteTextFileRequest, WriteTextFileResponse, ) -from .stdio import stdio_streams +from .stdio import spawn_agent_process, spawn_client_process, spawn_stdio_connection, stdio_streams +from .transports import default_environment, spawn_stdio_transport __all__ = [ # noqa: RUF022 # constants @@ -95,4 +121,33 @@ "TerminalHandle", # stdio helper "stdio_streams", + "spawn_stdio_connection", + "spawn_agent_process", + "spawn_client_process", + "default_environment", + "spawn_stdio_transport", + # helpers + "text_block", + "image_block", + "audio_block", + "resource_link_block", + "embedded_text_resource", + "embedded_blob_resource", + "resource_block", + "tool_content", + "tool_diff_content", + "tool_terminal_ref", + "plan_entry", + "update_plan", + "update_user_message", + "update_user_message_text", + "update_agent_message", + "update_agent_message_text", + "update_agent_thought", + "update_agent_thought_text", + "session_notification", + "start_tool_call", + "start_read_tool_call", + "start_edit_tool_call", + "update_tool_call", ] diff --git a/src/acp/agent/connection.py b/src/acp/agent/connection.py index c26fe07..eab6766 100644 --- a/src/acp/agent/connection.py +++ b/src/acp/agent/connection.py @@ -43,13 +43,14 @@ def __init__( to_agent: Callable[[AgentSideConnection], Agent], input_stream: Any, output_stream: Any, + **connection_kwargs: Any, ) -> None: agent = to_agent(self) handler = self._create_handler(agent) if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_AGENT_CONNECTION_ERROR) - self._conn = Connection(handler, input_stream, output_stream) + self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) def _create_handler(self, agent: Agent) -> MethodHandler: router = build_agent_router(agent) @@ -135,3 +136,12 @@ async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any] async def extNotification(self, method: str, params: dict[str, Any]) -> None: await self._conn.send_notification(f"_{method}", params) + + async def close(self) -> None: + await self._conn.close() + + async def __aenter__(self) -> AgentSideConnection: + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index 1177dcd..f97ff25 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -44,13 +44,14 @@ def __init__( to_client: Callable[[Agent], Client], input_stream: Any, output_stream: Any, + **connection_kwargs: Any, ) -> None: if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_CLIENT_CONNECTION_ERROR) client = to_client(self) # type: ignore[arg-type] handler = self._create_handler(client) - self._conn = Connection(handler, input_stream, output_stream) + self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) def _create_handler(self, client: Client) -> MethodHandler: router = build_client_router(client) @@ -127,3 +128,12 @@ async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any] async def extNotification(self, method: str, params: dict[str, Any]) -> None: await self._conn.send_notification(f"_{method}", params) + + async def close(self) -> None: + await self._conn.close() + + async def __aenter__(self) -> ClientSideConnection: + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() diff --git a/src/acp/connection.py b/src/acp/connection.py index 960ee37..0b3230e 100644 --- a/src/acp/connection.py +++ b/src/acp/connection.py @@ -2,9 +2,13 @@ import asyncio import contextlib +import copy +import inspect import json import logging from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from enum import Enum from typing import Any from pydantic import BaseModel, ValidationError @@ -31,7 +35,7 @@ MethodHandler = Callable[[str, JsonValue | None, bool], Awaitable[JsonValue | None]] -__all__ = ["Connection", "JsonValue", "MethodHandler"] +__all__ = ["Connection", "JsonValue", "MethodHandler", "StreamDirection", "StreamEvent"] DispatcherFactory = Callable[ @@ -40,6 +44,20 @@ ] +class StreamDirection(str, Enum): + INCOMING = "incoming" + OUTGOING = "outgoing" + + +@dataclass(frozen=True, slots=True) +class StreamEvent: + direction: StreamDirection + message: dict[str, Any] + + +StreamObserver = Callable[[StreamEvent], Awaitable[None] | None] + + class Connection: """Minimal JSON-RPC 2.0 connection over newline-delimited JSON frames.""" @@ -53,6 +71,7 @@ def __init__( state_store: MessageStateStore | None = None, dispatcher_factory: DispatcherFactory | None = None, sender_factory: SenderFactory | None = None, + observers: list[StreamObserver] | None = None, ) -> None: self._handler = handler self._writer = writer @@ -78,6 +97,7 @@ def __init__( self._run_notification, ) self._dispatcher.start() + self._observers: list[StreamObserver] = list(observers or []) async def close(self) -> None: """Stop the receive loop and cancel any in-flight handler tasks.""" @@ -95,17 +115,23 @@ async def __aenter__(self) -> Connection: async def __aexit__(self, exc_type, exc, tb) -> None: await self.close() + def add_observer(self, observer: StreamObserver) -> None: + """Register a callback that receives every raw JSON-RPC message.""" + self._observers.append(observer) + async def send_request(self, method: str, params: JsonValue | None = None) -> Any: request_id = self._next_request_id self._next_request_id += 1 future = self._state.register_outgoing(request_id, method) payload = {"jsonrpc": "2.0", "id": request_id, "method": method, "params": params} await self._sender.send(payload) + self._notify_observers(StreamDirection.OUTGOING, payload) return await future async def send_notification(self, method: str, params: JsonValue | None = None) -> None: payload = {"jsonrpc": "2.0", "method": method, "params": params} await self._sender.send(payload) + self._notify_observers(StreamDirection.OUTGOING, payload) async def _receive_loop(self) -> None: try: @@ -118,6 +144,7 @@ async def _receive_loop(self) -> None: except Exception: logging.exception("Error parsing JSON-RPC message") continue + self._notify_observers(StreamDirection.INCOMING, message) await self._process_message(message) except asyncio.CancelledError: return @@ -134,6 +161,27 @@ async def _process_message(self, message: dict[str, Any]) -> None: if has_id: await self._handle_response(message) + def _notify_observers(self, direction: StreamDirection, message: dict[str, Any]) -> None: + if not self._observers: + return + snapshot = copy.deepcopy(message) + event = StreamEvent(direction, snapshot) + for observer in list(self._observers): + try: + result = observer(event) + except Exception: + logging.exception("Stream observer failed", exc_info=True) + continue + if inspect.isawaitable(result): + self._tasks.create( + result, + name=f"acp.Connection.observer.{direction.value}", + on_error=self._on_observer_error, + ) + + def _on_observer_error(self, task: asyncio.Task[Any], exc: BaseException) -> None: + logging.exception("Stream observer coroutine failed", exc_info=exc) + async def _run_request(self, message: dict[str, Any]) -> Any: payload: dict[str, Any] = {"jsonrpc": "2.0", "id": message["id"]} method = message["method"] @@ -147,15 +195,18 @@ async def _run_request(self, message: dict[str, Any]) -> Any: result = result.model_dump() payload["result"] = result if result is not None else None await self._sender.send(payload) + self._notify_observers(StreamDirection.OUTGOING, payload) return payload.get("result") except RequestError as exc: payload["error"] = exc.to_error_obj() await self._sender.send(payload) + self._notify_observers(StreamDirection.OUTGOING, payload) raise except ValidationError as exc: err = RequestError.invalid_params({"errors": exc.errors()}) payload["error"] = err.to_error_obj() await self._sender.send(payload) + self._notify_observers(StreamDirection.OUTGOING, payload) raise err from None except Exception as exc: try: @@ -165,6 +216,7 @@ async def _run_request(self, message: dict[str, Any]) -> Any: err = RequestError.internal_error(data) payload["error"] = err.to_error_obj() await self._sender.send(payload) + self._notify_observers(StreamDirection.OUTGOING, payload) raise err from None async def _run_notification(self, message: dict[str, Any]) -> None: diff --git a/src/acp/helpers.py b/src/acp/helpers.py new file mode 100644 index 0000000..8514da2 --- /dev/null +++ b/src/acp/helpers.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from typing import Any + +from .schema import ( + AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, + AudioContentBlock, + BlobResourceContents, + ContentToolCallContent, + EmbeddedResource, + EmbeddedResourceContentBlock, + FileEditToolCallContent, + ImageContentBlock, + PlanEntry, + PlanEntryPriority, + PlanEntryStatus, + ResourceContentBlock, + SessionNotification, + TerminalToolCallContent, + TextContentBlock, + TextResourceContents, + ToolCallLocation, + ToolCallProgress, + ToolCallStart, + ToolCallStatus, + ToolKind, + UserMessageChunk, +) + +ContentBlock = ( + TextContentBlock | ImageContentBlock | AudioContentBlock | ResourceContentBlock | EmbeddedResourceContentBlock +) + +SessionUpdate = ( + AgentMessageChunk | AgentPlanUpdate | AgentThoughtChunk | UserMessageChunk | ToolCallStart | ToolCallProgress +) + +ToolCallContentVariant = ContentToolCallContent | FileEditToolCallContent | TerminalToolCallContent + +__all__ = [ + "audio_block", + "embedded_blob_resource", + "embedded_text_resource", + "image_block", + "plan_entry", + "resource_block", + "resource_link_block", + "session_notification", + "start_edit_tool_call", + "start_read_tool_call", + "start_tool_call", + "text_block", + "tool_content", + "tool_diff_content", + "tool_terminal_ref", + "update_agent_message", + "update_agent_message_text", + "update_agent_thought", + "update_agent_thought_text", + "update_plan", + "update_tool_call", + "update_user_message", + "update_user_message_text", +] + + +def text_block(text: str) -> TextContentBlock: + return TextContentBlock(type="text", text=text) + + +def image_block(data: str, mime_type: str, *, uri: str | None = None) -> ImageContentBlock: + return ImageContentBlock(type="image", data=data, mimeType=mime_type, uri=uri) + + +def audio_block(data: str, mime_type: str) -> AudioContentBlock: + return AudioContentBlock(type="audio", data=data, mimeType=mime_type) + + +def resource_link_block( + name: str, + uri: str, + *, + mime_type: str | None = None, + size: int | None = None, + description: str | None = None, + title: str | None = None, +) -> ResourceContentBlock: + return ResourceContentBlock( + type="resource_link", + name=name, + uri=uri, + mimeType=mime_type, + size=size, + description=description, + title=title, + ) + + +def embedded_text_resource(uri: str, text: str, *, mime_type: str | None = None) -> EmbeddedResource: + return EmbeddedResource(resource=TextResourceContents(uri=uri, text=text, mimeType=mime_type)) + + +def embedded_blob_resource(uri: str, blob: str, *, mime_type: str | None = None) -> EmbeddedResource: + return EmbeddedResource(resource=BlobResourceContents(uri=uri, blob=blob, mimeType=mime_type)) + + +def resource_block( + resource: EmbeddedResource | TextResourceContents | BlobResourceContents, +) -> EmbeddedResourceContentBlock: + resource_obj = resource.resource if isinstance(resource, EmbeddedResource) else resource + return EmbeddedResourceContentBlock(type="resource", resource=resource_obj) + + +def tool_content(block: ContentBlock) -> ContentToolCallContent: + return ContentToolCallContent(type="content", content=block) + + +def tool_diff_content(path: str, new_text: str, old_text: str | None = None) -> FileEditToolCallContent: + return FileEditToolCallContent(type="diff", path=path, newText=new_text, oldText=old_text) + + +def tool_terminal_ref(terminal_id: str) -> TerminalToolCallContent: + return TerminalToolCallContent(type="terminal", terminalId=terminal_id) + + +def plan_entry( + content: str, + *, + priority: PlanEntryPriority = "medium", + status: PlanEntryStatus = "pending", +) -> PlanEntry: + return PlanEntry(content=content, priority=priority, status=status) + + +def update_plan(entries: Iterable[PlanEntry]) -> AgentPlanUpdate: + return AgentPlanUpdate(sessionUpdate="plan", entries=list(entries)) + + +def update_user_message(content: ContentBlock) -> UserMessageChunk: + return UserMessageChunk(sessionUpdate="user_message_chunk", content=content) + + +def update_user_message_text(text: str) -> UserMessageChunk: + return update_user_message(text_block(text)) + + +def update_agent_message(content: ContentBlock) -> AgentMessageChunk: + return AgentMessageChunk(sessionUpdate="agent_message_chunk", content=content) + + +def update_agent_message_text(text: str) -> AgentMessageChunk: + return update_agent_message(text_block(text)) + + +def update_agent_thought(content: ContentBlock) -> AgentThoughtChunk: + return AgentThoughtChunk(sessionUpdate="agent_thought_chunk", content=content) + + +def update_agent_thought_text(text: str) -> AgentThoughtChunk: + return update_agent_thought(text_block(text)) + + +def session_notification(session_id: str, update: SessionUpdate) -> SessionNotification: + return SessionNotification(sessionId=session_id, update=update) + + +def start_tool_call( + tool_call_id: str, + title: str, + *, + kind: ToolKind | None = None, + status: ToolCallStatus | None = None, + content: Sequence[ToolCallContentVariant] | None = None, + locations: Sequence[ToolCallLocation] | None = None, + raw_input: Any | None = None, + raw_output: Any | None = None, +) -> ToolCallStart: + return ToolCallStart( + sessionUpdate="tool_call", + toolCallId=tool_call_id, + title=title, + kind=kind, + status=status, + content=list(content) if content is not None else None, + locations=list(locations) if locations is not None else None, + rawInput=raw_input, + rawOutput=raw_output, + ) + + +def start_read_tool_call( + tool_call_id: str, + title: str, + path: str, + *, + extra_options: Sequence[ToolCallContentVariant] | None = None, +) -> ToolCallStart: + content = list(extra_options) if extra_options is not None else None + locations = [ToolCallLocation(path=path)] + raw_input = {"path": path} + return start_tool_call( + tool_call_id, + title, + kind="read", + status="pending", + content=content, + locations=locations, + raw_input=raw_input, + ) + + +def start_edit_tool_call( + tool_call_id: str, + title: str, + path: str, + content: Any, + *, + extra_options: Sequence[ToolCallContentVariant] | None = None, +) -> ToolCallStart: + locations = [ToolCallLocation(path=path)] + raw_input = {"path": path, "content": content} + return start_tool_call( + tool_call_id, + title, + kind="edit", + status="pending", + content=list(extra_options) if extra_options is not None else None, + locations=locations, + raw_input=raw_input, + ) + + +def update_tool_call( + tool_call_id: str, + *, + title: str | None = None, + kind: ToolKind | None = None, + status: ToolCallStatus | None = None, + content: Sequence[ToolCallContentVariant] | None = None, + locations: Sequence[ToolCallLocation] | None = None, + raw_input: Any | None = None, + raw_output: Any | None = None, +) -> ToolCallProgress: + return ToolCallProgress( + sessionUpdate="tool_call_update", + toolCallId=tool_call_id, + title=title, + kind=kind, + status=status, + content=list(content) if content is not None else None, + locations=list(locations) if locations is not None else None, + rawInput=raw_input, + rawOutput=raw_output, + ) diff --git a/src/acp/schema.py b/src/acp/schema.py index 13f1c66..79d83a9 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -9,6 +9,14 @@ from pydantic import BaseModel, Field, RootModel +PermissionOptionKind = Literal["allow_once", "allow_always", "reject_once", "reject_always"] +PlanEntryPriority = Literal["high", "medium", "low"] +PlanEntryStatus = Literal["pending", "in_progress", "completed"] +StopReason = Literal["end_turn", "max_tokens", "max_turn_requests", "refusal", "cancelled"] +ToolCallStatus = Literal["pending", "in_progress", "completed", "failed"] +ToolKind = Literal["read", "edit", "delete", "move", "search", "execute", "think", "fetch", "switch_mode", "other"] + + class AuthenticateRequest(BaseModel): field_meta: Annotated[ Optional[Any], @@ -394,11 +402,11 @@ class AgentCapabilities(BaseModel): mcpCapabilities: Annotated[ Optional[McpCapabilities], Field(description="MCP capabilities supported by the agent."), - ] = {"http": False, "sse": False} + ] = McpCapabilities(http=False, sse=False) promptCapabilities: Annotated[ Optional[PromptCapabilities], Field(description="Prompt capabilities supported by the agent."), - ] = {"audio": False, "embeddedContext": False, "image": False} + ] = PromptCapabilities(audio=False, embeddedContext=False, image=False) class Annotations(BaseModel): @@ -468,7 +476,7 @@ class ClientCapabilities(BaseModel): Field( description="File system capabilities supported by the client.\nDetermines which file operations the agent can request." ), - ] = {"readTextFile": False, "writeTextFile": False} + ] = FileSystemCapability(readTextFile=False, writeTextFile=False) terminal: Annotated[ Optional[bool], Field(description="Whether the Client support all `terminal/*` methods."), @@ -567,7 +575,7 @@ class InitializeRequest(BaseModel): clientCapabilities: Annotated[ Optional[ClientCapabilities], Field(description="Capabilities supported by the client."), - ] = {"fs": {"readTextFile": False, "writeTextFile": False}, "terminal": False} + ] = ClientCapabilities(fs=FileSystemCapability(readTextFile=False, writeTextFile=False), terminal=False) protocolVersion: Annotated[ int, Field( @@ -586,15 +594,11 @@ class InitializeResponse(BaseModel): agentCapabilities: Annotated[ Optional[AgentCapabilities], Field(description="Capabilities supported by the agent."), - ] = { - "loadSession": False, - "mcpCapabilities": {"http": False, "sse": False}, - "promptCapabilities": { - "audio": False, - "embeddedContext": False, - "image": False, - }, - } + ] = AgentCapabilities( + loadSession=False, + mcpCapabilities=McpCapabilities(http=False, sse=False), + promptCapabilities=PromptCapabilities(audio=False, embeddedContext=False, image=False), + ) authMethods: Annotated[ Optional[List[AuthMethod]], Field(description="Authentication methods supported by the agent."), @@ -636,7 +640,7 @@ class PermissionOption(BaseModel): Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - kind: Annotated[str, Field(description="Hint about the nature of this permission option.")] + kind: Annotated[PermissionOptionKind, Field(description="Hint about the nature of this permission option.")] name: Annotated[str, Field(description="Human-readable label to display to the user.")] optionId: Annotated[str, Field(description="Unique identifier for this permission option.")] @@ -651,12 +655,12 @@ class PlanEntry(BaseModel): Field(description="Human-readable description of what this task aims to accomplish."), ] priority: Annotated[ - str, + PlanEntryPriority, Field( description="The relative importance of this task.\nUsed to indicate which tasks are most critical to the overall goal." ), ] - status: Annotated[str, Field(description="Current execution status of this task.")] + status: Annotated[PlanEntryStatus, Field(description="Current execution status of this task.")] class PromptResponse(BaseModel): @@ -664,7 +668,7 @@ class PromptResponse(BaseModel): Optional[Any], Field(alias="_meta", description="Extension point for implementations"), ] = None - stopReason: Annotated[str, Field(description="Indicates why the agent stopped processing the turn.")] + stopReason: Annotated[StopReason, Field(description="Indicates why the agent stopped processing the turn.")] class ReadTextFileRequest(BaseModel): @@ -913,14 +917,14 @@ class ToolCallUpdate(BaseModel): Optional[List[Union[ContentToolCallContent, FileEditToolCallContent, TerminalToolCallContent]]], Field(description="Replace the content collection."), ] = None - kind: Annotated[Optional[str], Field(description="Update the tool kind.")] = None + kind: Annotated[Optional[ToolKind], Field(description="Update the tool kind.")] = None locations: Annotated[ Optional[List[ToolCallLocation]], Field(description="Replace the locations collection."), ] = None rawInput: Annotated[Optional[Any], Field(description="Update the raw input.")] = None rawOutput: Annotated[Optional[Any], Field(description="Update the raw output.")] = None - status: Annotated[Optional[str], Field(description="Update the execution status.")] = None + status: Annotated[Optional[ToolCallStatus], Field(description="Update the execution status.")] = None title: Annotated[Optional[str], Field(description="Update the human-readable title.")] = None toolCallId: Annotated[str, Field(description="The ID of the tool call being updated.")] @@ -951,7 +955,7 @@ class ToolCallStart(BaseModel): Field(description="Content produced by the tool call."), ] = None kind: Annotated[ - Optional[str], + Optional[ToolKind], Field( description="The category of tool being invoked.\nHelps clients choose appropriate icons and UI treatment." ), @@ -963,7 +967,7 @@ class ToolCallStart(BaseModel): rawInput: Annotated[Optional[Any], Field(description="Raw input parameters sent to the tool.")] = None rawOutput: Annotated[Optional[Any], Field(description="Raw output returned by the tool.")] = None sessionUpdate: Literal["tool_call"] - status: Annotated[Optional[str], Field(description="Current execution status of the tool call.")] = None + status: Annotated[Optional[ToolCallStatus], Field(description="Current execution status of the tool call.")] = None title: Annotated[ str, Field(description="Human-readable title describing what the tool is doing."), @@ -983,7 +987,7 @@ class ToolCallProgress(BaseModel): Optional[List[Union[ContentToolCallContent, FileEditToolCallContent, TerminalToolCallContent]]], Field(description="Replace the content collection."), ] = None - kind: Annotated[Optional[str], Field(description="Update the tool kind.")] = None + kind: Annotated[Optional[ToolKind], Field(description="Update the tool kind.")] = None locations: Annotated[ Optional[List[ToolCallLocation]], Field(description="Replace the locations collection."), @@ -991,7 +995,7 @@ class ToolCallProgress(BaseModel): rawInput: Annotated[Optional[Any], Field(description="Update the raw input.")] = None rawOutput: Annotated[Optional[Any], Field(description="Update the raw output.")] = None sessionUpdate: Literal["tool_call_update"] - status: Annotated[Optional[str], Field(description="Update the execution status.")] = None + status: Annotated[Optional[ToolCallStatus], Field(description="Update the execution status.")] = None title: Annotated[Optional[str], Field(description="Update the human-readable title.")] = None toolCallId: Annotated[str, Field(description="The ID of the tool call being updated.")] @@ -1006,7 +1010,7 @@ class ToolCall(BaseModel): Field(description="Content produced by the tool call."), ] = None kind: Annotated[ - Optional[str], + Optional[ToolKind], Field( description="The category of tool being invoked.\nHelps clients choose appropriate icons and UI treatment." ), @@ -1017,7 +1021,7 @@ class ToolCall(BaseModel): ] = None rawInput: Annotated[Optional[Any], Field(description="Raw input parameters sent to the tool.")] = None rawOutput: Annotated[Optional[Any], Field(description="Raw output returned by the tool.")] = None - status: Annotated[Optional[str], Field(description="Current execution status of the tool call.")] = None + status: Annotated[Optional[ToolCallStatus], Field(description="Current execution status of the tool call.")] = None title: Annotated[ str, Field(description="Human-readable title describing what the tool is doing."), diff --git a/src/acp/stdio.py b/src/acp/stdio.py index a0c1011..40aa5a8 100644 --- a/src/acp/stdio.py +++ b/src/acp/stdio.py @@ -1,12 +1,29 @@ from __future__ import annotations import asyncio +import asyncio.subprocess as aio_subprocess import contextlib import logging import platform import sys from asyncio import transports as aio_transports -from typing import cast +from collections.abc import AsyncIterator, Callable, Mapping +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Any, cast + +from .agent.connection import AgentSideConnection +from .client.connection import ClientSideConnection +from .connection import Connection, MethodHandler, StreamObserver +from .interfaces import Agent, Client +from .transports import spawn_stdio_transport + +__all__ = [ + "spawn_agent_process", + "spawn_client_process", + "spawn_stdio_connection", + "stdio_streams", +] class _WritePipeProtocol(asyncio.BaseProtocol): @@ -110,3 +127,72 @@ async def stdio_streams() -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: if platform.system() == "Windows": return await _windows_stdio_streams(loop) return await _posix_stdio_streams(loop) + + +@asynccontextmanager +async def spawn_stdio_connection( + handler: MethodHandler, + command: str, + *args: str, + env: Mapping[str, str] | None = None, + cwd: str | Path | None = None, + observers: list[StreamObserver] | None = None, + **transport_kwargs: Any, +) -> AsyncIterator[tuple[Connection, aio_subprocess.Process]]: + """Spawn a subprocess and bind its stdio to a low-level Connection.""" + async with spawn_stdio_transport(command, *args, env=env, cwd=cwd, **transport_kwargs) as (reader, writer, process): + conn = Connection(handler, writer, reader, observers=observers) + try: + yield conn, process + finally: + await conn.close() + + +@asynccontextmanager +async def spawn_agent_process( + to_client: Callable[[Agent], Client], + command: str, + *args: str, + env: Mapping[str, str] | None = None, + cwd: str | Path | None = None, + transport_kwargs: Mapping[str, Any] | None = None, + **connection_kwargs: Any, +) -> AsyncIterator[tuple[ClientSideConnection, aio_subprocess.Process]]: + """Spawn an ACP agent subprocess and return a ClientSideConnection to it.""" + async with spawn_stdio_transport( + command, + *args, + env=env, + cwd=cwd, + **(dict(transport_kwargs) if transport_kwargs else {}), + ) as (reader, writer, process): + conn = ClientSideConnection(to_client, writer, reader, **connection_kwargs) + try: + yield conn, process + finally: + await conn.close() + + +@asynccontextmanager +async def spawn_client_process( + to_agent: Callable[[AgentSideConnection], Agent], + command: str, + *args: str, + env: Mapping[str, str] | None = None, + cwd: str | Path | None = None, + transport_kwargs: Mapping[str, Any] | None = None, + **connection_kwargs: Any, +) -> AsyncIterator[tuple[AgentSideConnection, aio_subprocess.Process]]: + """Spawn an ACP client subprocess and return an AgentSideConnection to it.""" + async with spawn_stdio_transport( + command, + *args, + env=env, + cwd=cwd, + **(dict(transport_kwargs) if transport_kwargs else {}), + ) as (reader, writer, process): + conn = AgentSideConnection(to_agent, writer, reader, **connection_kwargs) + try: + yield conn, process + finally: + await conn.close() diff --git a/src/acp/telemetry.py b/src/acp/telemetry.py index 1d7e340..011ed46 100644 --- a/src/acp/telemetry.py +++ b/src/acp/telemetry.py @@ -3,7 +3,7 @@ import os from collections.abc import Mapping from contextlib import AbstractContextManager, ExitStack, nullcontext -from typing import Any +from typing import Any, cast try: from logfire import span as logfire_span @@ -38,4 +38,4 @@ def span_context(name: str, *, attributes: Mapping[str, Any] | None = None) -> A if logfire_span is not None: stack.enter_context(logfire_span(name, attributes=attrs)) stack.enter_context(_start_tracer_span(name, attributes=attributes)) - return stack + return cast(AbstractContextManager[None], stack) diff --git a/src/acp/terminal.py b/src/acp/terminal.py index 698611e..619039c 100644 --- a/src/acp/terminal.py +++ b/src/acp/terminal.py @@ -1,5 +1,7 @@ from __future__ import annotations +from contextlib import suppress + from .connection import Connection from .meta import CLIENT_METHODS from .schema import ( @@ -47,3 +49,17 @@ async def release(self) -> ReleaseTerminalResponse: ) payload = response if isinstance(response, dict) else {} return ReleaseTerminalResponse.model_validate(payload) + + async def aclose(self) -> None: + """Release the terminal, ignoring errors that occur during shutdown.""" + with suppress(Exception): + await self.release() + + async def close(self) -> None: + await self.aclose() + + async def __aenter__(self) -> TerminalHandle: + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.aclose() diff --git a/src/acp/transports.py b/src/acp/transports.py new file mode 100644 index 0000000..be2a002 --- /dev/null +++ b/src/acp/transports.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import asyncio +import asyncio.subprocess as aio_subprocess +import contextlib +import os +from collections.abc import AsyncIterator, Mapping +from contextlib import asynccontextmanager +from pathlib import Path + +__all__ = ["DEFAULT_INHERITED_ENV_VARS", "default_environment", "spawn_stdio_transport"] + +DEFAULT_INHERITED_ENV_VARS = ( + [ + "APPDATA", + "HOMEDRIVE", + "HOMEPATH", + "LOCALAPPDATA", + "PATH", + "PATHEXT", + "PROCESSOR_ARCHITECTURE", + "SYSTEMDRIVE", + "SYSTEMROOT", + "TEMP", + "USERNAME", + "USERPROFILE", + ] + if os.name == "nt" + else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"] +) + + +def default_environment() -> dict[str, str]: + """Return a trimmed environment based on MCP best practices.""" + env: dict[str, str] = {} + for key in DEFAULT_INHERITED_ENV_VARS: + value = os.environ.get(key) + if value is None: + continue + # Skip function-style env vars on some shells (see MCP reference) + if value.startswith("()"): + continue + env[key] = value + return env + + +@asynccontextmanager +async def spawn_stdio_transport( + command: str, + *args: str, + env: Mapping[str, str] | None = None, + cwd: str | Path | None = None, + stderr: int | None = aio_subprocess.PIPE, + shutdown_timeout: float = 2.0, +) -> AsyncIterator[tuple[asyncio.StreamReader, asyncio.StreamWriter, aio_subprocess.Process]]: + """Launch a subprocess and expose its stdio streams as asyncio transports. + + This mirrors the defensive shutdown behaviour used by the MCP Python SDK: + close stdin first, wait for graceful exit, then escalate to terminate/kill. + """ + merged_env = dict(default_environment()) + if env: + merged_env.update(env) + + process = await asyncio.create_subprocess_exec( + command, + *args, + stdin=aio_subprocess.PIPE, + stdout=aio_subprocess.PIPE, + stderr=stderr, + env=merged_env, + cwd=str(cwd) if cwd is not None else None, + ) + + if process.stdout is None or process.stdin is None: + process.kill() + await process.wait() + msg = "spawn_stdio_transport requires stdout/stderr pipes" + raise RuntimeError(msg) + + try: + yield process.stdout, process.stdin, process + finally: + # Attempt graceful stdin shutdown first + if process.stdin is not None: + try: + process.stdin.write_eof() + except (AttributeError, OSError, RuntimeError): + process.stdin.close() + with contextlib.suppress(Exception): + await process.stdin.drain() + with contextlib.suppress(Exception): + process.stdin.close() + with contextlib.suppress(Exception): + await process.stdin.wait_closed() + + try: + await asyncio.wait_for(process.wait(), timeout=shutdown_timeout) + except asyncio.TimeoutError: + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=shutdown_timeout) + except asyncio.TimeoutError: + process.kill() + await process.wait() diff --git a/tests/golden/cancel_notification.json b/tests/golden/cancel_notification.json new file mode 100644 index 0000000..a5461d2 --- /dev/null +++ b/tests/golden/cancel_notification.json @@ -0,0 +1,3 @@ +{ + "sessionId": "sess_abc123def456" +} diff --git a/tests/golden/content_audio.json b/tests/golden/content_audio.json new file mode 100644 index 0000000..6cd650e --- /dev/null +++ b/tests/golden/content_audio.json @@ -0,0 +1,5 @@ +{ + "type": "audio", + "mimeType": "audio/wav", + "data": "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB..." +} diff --git a/tests/golden/content_image.json b/tests/golden/content_image.json new file mode 100644 index 0000000..fca8b88 --- /dev/null +++ b/tests/golden/content_image.json @@ -0,0 +1,5 @@ +{ + "type": "image", + "mimeType": "image/png", + "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB..." +} diff --git a/tests/golden/content_resource_blob.json b/tests/golden/content_resource_blob.json new file mode 100644 index 0000000..4832503 --- /dev/null +++ b/tests/golden/content_resource_blob.json @@ -0,0 +1,8 @@ +{ + "type": "resource", + "resource": { + "uri": "file:///home/user/document.pdf", + "mimeType": "application/pdf", + "blob": "" + } +} diff --git a/tests/golden/content_resource_link.json b/tests/golden/content_resource_link.json new file mode 100644 index 0000000..4e33c1e --- /dev/null +++ b/tests/golden/content_resource_link.json @@ -0,0 +1,7 @@ +{ + "type": "resource_link", + "uri": "file:///home/user/document.pdf", + "name": "document.pdf", + "mimeType": "application/pdf", + "size": 1024000 +} diff --git a/tests/golden/content_resource_text.json b/tests/golden/content_resource_text.json new file mode 100644 index 0000000..f73945a --- /dev/null +++ b/tests/golden/content_resource_text.json @@ -0,0 +1,8 @@ +{ + "type": "resource", + "resource": { + "uri": "file:///home/user/script.py", + "mimeType": "text/x-python", + "text": "def hello():\n print('Hello, world!')" + } +} diff --git a/tests/golden/content_text.json b/tests/golden/content_text.json new file mode 100644 index 0000000..63b2e85 --- /dev/null +++ b/tests/golden/content_text.json @@ -0,0 +1,4 @@ +{ + "type": "text", + "text": "What's the weather like today?" +} diff --git a/tests/golden/fs_read_text_file_request.json b/tests/golden/fs_read_text_file_request.json new file mode 100644 index 0000000..3d3ccca --- /dev/null +++ b/tests/golden/fs_read_text_file_request.json @@ -0,0 +1,6 @@ +{ + "sessionId": "sess_abc123def456", + "path": "/home/user/project/src/main.py", + "line": 10, + "limit": 50 +} diff --git a/tests/golden/fs_read_text_file_response.json b/tests/golden/fs_read_text_file_response.json new file mode 100644 index 0000000..b5dac57 --- /dev/null +++ b/tests/golden/fs_read_text_file_response.json @@ -0,0 +1,3 @@ +{ + "content": "def hello_world():\n print('Hello, world!')\n" +} diff --git a/tests/golden/fs_write_text_file_request.json b/tests/golden/fs_write_text_file_request.json new file mode 100644 index 0000000..efbad09 --- /dev/null +++ b/tests/golden/fs_write_text_file_request.json @@ -0,0 +1,5 @@ +{ + "sessionId": "sess_abc123def456", + "path": "/home/user/project/config.json", + "content": "{\n \"debug\": true,\n \"version\": \"1.0.0\"\n}" +} diff --git a/tests/golden/initialize_request.json b/tests/golden/initialize_request.json new file mode 100644 index 0000000..b239909 --- /dev/null +++ b/tests/golden/initialize_request.json @@ -0,0 +1,9 @@ +{ + "protocolVersion": 1, + "clientCapabilities": { + "fs": { + "readTextFile": true, + "writeTextFile": true + } + } +} diff --git a/tests/golden/initialize_response.json b/tests/golden/initialize_response.json new file mode 100644 index 0000000..66abb81 --- /dev/null +++ b/tests/golden/initialize_response.json @@ -0,0 +1,13 @@ +{ + "protocolVersion": 1, + "agentCapabilities": { + "loadSession": true, + "mcpCapabilities": {}, + "promptCapabilities": { + "image": true, + "audio": true, + "embeddedContext": true + } + }, + "authMethods": [] +} diff --git a/tests/golden/new_session_request.json b/tests/golden/new_session_request.json new file mode 100644 index 0000000..132c920 --- /dev/null +++ b/tests/golden/new_session_request.json @@ -0,0 +1,13 @@ +{ + "cwd": "/home/user/project", + "mcpServers": [ + { + "name": "filesystem", + "command": "/path/to/mcp-server", + "args": [ + "--stdio" + ], + "env": [] + } + ] +} diff --git a/tests/golden/new_session_response.json b/tests/golden/new_session_response.json new file mode 100644 index 0000000..a5461d2 --- /dev/null +++ b/tests/golden/new_session_response.json @@ -0,0 +1,3 @@ +{ + "sessionId": "sess_abc123def456" +} diff --git a/tests/golden/permission_outcome_cancelled.json b/tests/golden/permission_outcome_cancelled.json new file mode 100644 index 0000000..38f0331 --- /dev/null +++ b/tests/golden/permission_outcome_cancelled.json @@ -0,0 +1,3 @@ +{ + "outcome": "cancelled" +} diff --git a/tests/golden/permission_outcome_selected.json b/tests/golden/permission_outcome_selected.json new file mode 100644 index 0000000..3a194c2 --- /dev/null +++ b/tests/golden/permission_outcome_selected.json @@ -0,0 +1,4 @@ +{ + "outcome": "selected", + "optionId": "allow-once" +} diff --git a/tests/golden/prompt_request.json b/tests/golden/prompt_request.json new file mode 100644 index 0000000..816fae1 --- /dev/null +++ b/tests/golden/prompt_request.json @@ -0,0 +1,17 @@ +{ + "sessionId": "sess_abc123def456", + "prompt": [ + { + "type": "text", + "text": "Can you analyze this code for potential issues?" + }, + { + "type": "resource", + "resource": { + "uri": "file:///home/user/project/main.py", + "mimeType": "text/x-python", + "text": "def process_data(items):\n for item in items:\n print(item)" + } + } + ] +} diff --git a/tests/golden/request_permission_request.json b/tests/golden/request_permission_request.json new file mode 100644 index 0000000..1fb297f --- /dev/null +++ b/tests/golden/request_permission_request.json @@ -0,0 +1,18 @@ +{ + "sessionId": "sess_abc123def456", + "toolCall": { + "toolCallId": "call_001" + }, + "options": [ + { + "optionId": "allow-once", + "name": "Allow once", + "kind": "allow_once" + }, + { + "optionId": "reject-once", + "name": "Reject", + "kind": "reject_once" + } + ] +} diff --git a/tests/golden/request_permission_response_selected.json b/tests/golden/request_permission_response_selected.json new file mode 100644 index 0000000..e29b89b --- /dev/null +++ b/tests/golden/request_permission_response_selected.json @@ -0,0 +1,6 @@ +{ + "outcome": { + "outcome": "selected", + "optionId": "allow-once" + } +} diff --git a/tests/golden/session_update_agent_message_chunk.json b/tests/golden/session_update_agent_message_chunk.json new file mode 100644 index 0000000..7ace7ed --- /dev/null +++ b/tests/golden/session_update_agent_message_chunk.json @@ -0,0 +1,7 @@ +{ + "sessionUpdate": "agent_message_chunk", + "content": { + "type": "text", + "text": "The capital of France is Paris." + } +} diff --git a/tests/golden/session_update_agent_thought_chunk.json b/tests/golden/session_update_agent_thought_chunk.json new file mode 100644 index 0000000..893c13b --- /dev/null +++ b/tests/golden/session_update_agent_thought_chunk.json @@ -0,0 +1,7 @@ +{ + "sessionUpdate": "agent_thought_chunk", + "content": { + "type": "text", + "text": "Thinking about best approach..." + } +} diff --git a/tests/golden/session_update_plan.json b/tests/golden/session_update_plan.json new file mode 100644 index 0000000..bad3e8a --- /dev/null +++ b/tests/golden/session_update_plan.json @@ -0,0 +1,15 @@ +{ + "sessionUpdate": "plan", + "entries": [ + { + "content": "Check for syntax errors", + "priority": "high", + "status": "pending" + }, + { + "content": "Identify potential type issues", + "priority": "medium", + "status": "pending" + } + ] +} diff --git a/tests/golden/session_update_tool_call.json b/tests/golden/session_update_tool_call.json new file mode 100644 index 0000000..448649d --- /dev/null +++ b/tests/golden/session_update_tool_call.json @@ -0,0 +1,7 @@ +{ + "sessionUpdate": "tool_call", + "toolCallId": "call_001", + "title": "Reading configuration file", + "kind": "read", + "status": "pending" +} diff --git a/tests/golden/session_update_tool_call_edit.json b/tests/golden/session_update_tool_call_edit.json new file mode 100644 index 0000000..1cf0bda --- /dev/null +++ b/tests/golden/session_update_tool_call_edit.json @@ -0,0 +1,16 @@ +{ + "sessionUpdate": "tool_call", + "toolCallId": "call_003", + "title": "Apply edit", + "kind": "edit", + "status": "pending", + "locations": [ + { + "path": "/home/user/project/src/config.json" + } + ], + "rawInput": { + "path": "/home/user/project/src/config.json", + "content": "print('hello')" + } +} diff --git a/tests/golden/session_update_tool_call_locations_rawinput.json b/tests/golden/session_update_tool_call_locations_rawinput.json new file mode 100644 index 0000000..a1ac3e4 --- /dev/null +++ b/tests/golden/session_update_tool_call_locations_rawinput.json @@ -0,0 +1,13 @@ +{ + "sessionUpdate": "tool_call", + "toolCallId": "call_lr", + "title": "Tracking file", + "locations": [ + { + "path": "/home/user/project/src/config.json" + } + ], + "rawInput": { + "path": "/home/user/project/src/config.json" + } +} diff --git a/tests/golden/session_update_tool_call_read.json b/tests/golden/session_update_tool_call_read.json new file mode 100644 index 0000000..d533afb --- /dev/null +++ b/tests/golden/session_update_tool_call_read.json @@ -0,0 +1,15 @@ +{ + "sessionUpdate": "tool_call", + "toolCallId": "call_001", + "title": "Reading configuration file", + "kind": "read", + "status": "pending", + "locations": [ + { + "path": "/home/user/project/src/config.json" + } + ], + "rawInput": { + "path": "/home/user/project/src/config.json" + } +} diff --git a/tests/golden/session_update_tool_call_update_content.json b/tests/golden/session_update_tool_call_update_content.json new file mode 100644 index 0000000..e28b461 --- /dev/null +++ b/tests/golden/session_update_tool_call_update_content.json @@ -0,0 +1,14 @@ +{ + "sessionUpdate": "tool_call_update", + "toolCallId": "call_001", + "status": "in_progress", + "content": [ + { + "type": "content", + "content": { + "type": "text", + "text": "Found 3 configuration files..." + } + } + ] +} diff --git a/tests/golden/session_update_tool_call_update_more_fields.json b/tests/golden/session_update_tool_call_update_more_fields.json new file mode 100644 index 0000000..d5af335 --- /dev/null +++ b/tests/golden/session_update_tool_call_update_more_fields.json @@ -0,0 +1,27 @@ +{ + "sessionUpdate": "tool_call_update", + "toolCallId": "call_010", + "title": "Processing changes", + "kind": "edit", + "status": "completed", + "locations": [ + { + "path": "/home/user/project/src/config.json" + } + ], + "rawInput": { + "path": "/home/user/project/src/config.json" + }, + "rawOutput": { + "result": "ok" + }, + "content": [ + { + "type": "content", + "content": { + "type": "text", + "text": "Edit completed." + } + } + ] +} diff --git a/tests/golden/session_update_user_message_chunk.json b/tests/golden/session_update_user_message_chunk.json new file mode 100644 index 0000000..8ca73e7 --- /dev/null +++ b/tests/golden/session_update_user_message_chunk.json @@ -0,0 +1,7 @@ +{ + "sessionUpdate": "user_message_chunk", + "content": { + "type": "text", + "text": "What's the capital of France?" + } +} diff --git a/tests/golden/tool_content_content_text.json b/tests/golden/tool_content_content_text.json new file mode 100644 index 0000000..bf3b6f7 --- /dev/null +++ b/tests/golden/tool_content_content_text.json @@ -0,0 +1,7 @@ +{ + "type": "content", + "content": { + "type": "text", + "text": "Analysis complete. Found 3 issues." + } +} diff --git a/tests/golden/tool_content_diff.json b/tests/golden/tool_content_diff.json new file mode 100644 index 0000000..98482cb --- /dev/null +++ b/tests/golden/tool_content_diff.json @@ -0,0 +1,6 @@ +{ + "type": "diff", + "path": "/home/user/project/src/config.json", + "oldText": "{\n \"debug\": false\n}", + "newText": "{\n \"debug\": true\n}" +} diff --git a/tests/golden/tool_content_diff_no_old.json b/tests/golden/tool_content_diff_no_old.json new file mode 100644 index 0000000..c044187 --- /dev/null +++ b/tests/golden/tool_content_diff_no_old.json @@ -0,0 +1,5 @@ +{ + "type": "diff", + "path": "/home/user/project/src/config.json", + "newText": "{\n \"debug\": true\n}" +} diff --git a/tests/golden/tool_content_terminal.json b/tests/golden/tool_content_terminal.json new file mode 100644 index 0000000..fd0c676 --- /dev/null +++ b/tests/golden/tool_content_terminal.json @@ -0,0 +1,4 @@ +{ + "type": "terminal", + "terminalId": "term_001" +} diff --git a/tests/real_user/test_permission_flow.py b/tests/real_user/test_permission_flow.py index 7987b7a..f337cce 100644 --- a/tests/real_user/test_permission_flow.py +++ b/tests/real_user/test_permission_flow.py @@ -22,8 +22,8 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: RequestPermissionRequest( sessionId=params.sessionId, options=[ - PermissionOption(optionId="allow", name="Allow", kind="allow"), - PermissionOption(optionId="deny", name="Deny", kind="deny"), + PermissionOption(optionId="allow", name="Allow", kind="allow_once"), + PermissionOption(optionId="deny", name="Deny", kind="reject_once"), ], toolCall=ToolCallUpdate(toolCallId="call-1", title="Write File"), ) diff --git a/tests/test_gemini_example.py b/tests/test_gemini_example.py new file mode 100644 index 0000000..1702855 --- /dev/null +++ b/tests/test_gemini_example.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import os +import shlex +import shutil +import subprocess +import sys +from pathlib import Path + +import pytest + + +def _flag_enabled() -> bool: + value = os.environ.get("ACP_ENABLE_GEMINI_TESTS", "").strip().lower() + return value in {"1", "true", "yes", "on"} + + +def _resolve_gemini_binary() -> str | None: + override = os.environ.get("ACP_GEMINI_BIN") + if override: + return override + return shutil.which("gemini") + + +GEMINI_BIN = _resolve_gemini_binary() +pytestmark = pytest.mark.skipif( + not (_flag_enabled() and GEMINI_BIN), + reason="Gemini tests disabled. Set ACP_ENABLE_GEMINI_TESTS=1 and provide the gemini CLI.", +) + + +def test_gemini_example_smoke() -> None: + env = os.environ.copy() + src_path = str(Path(__file__).resolve().parent.parent / "src") + python_path = env.get("PYTHONPATH") + env["PYTHONPATH"] = src_path if not python_path else os.pathsep.join([src_path, python_path]) + + extra_args = shlex.split(env.get("ACP_GEMINI_TEST_ARGS", "")) + cmd = [ + sys.executable, + str(Path("examples/gemini.py").resolve()), + "--gemini", + GEMINI_BIN or "gemini", + "--yolo", + *extra_args, + ] + + proc = subprocess.Popen( # noqa: S603 - command is built from trusted inputs + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=Path(__file__).resolve().parent.parent, + ) + + assert proc.stdin is not None + assert proc.stdout is not None + + try: + stdout, stderr = proc.communicate(":exit\n", timeout=120) + except subprocess.TimeoutExpired: + proc.kill() + stdout, stderr = proc.communicate() + pytest.fail(_format_failure("Gemini example timed out", stdout, stderr), pytrace=False) + + combined = f"{stdout}\n{stderr}" + if proc.returncode != 0: + auth_errors = ( + "Authentication failed", + "Authentication required", + "GOOGLE_CLOUD_PROJECT", + ) + if any(token in combined for token in auth_errors): + pytest.skip(f"Gemini CLI authentication required:\n{combined}") + pytest.fail( + _format_failure(f"Gemini example exited with {proc.returncode}", stdout, stderr), + pytrace=False, + ) + + assert "Connected to Gemini" in combined or "✅ Connected to Gemini" in combined + + +def _format_failure(prefix: str, stdout: str, stderr: str) -> str: + return f"{prefix}.\nstdout:\n{stdout}\nstderr:\n{stderr}" diff --git a/tests/test_golden.py b/tests/test_golden.py new file mode 100644 index 0000000..430bd04 --- /dev/null +++ b/tests/test_golden.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import json +from collections.abc import Callable +from pathlib import Path + +import pytest +from pydantic import BaseModel + +from acp import ( + audio_block, + embedded_blob_resource, + embedded_text_resource, + image_block, + plan_entry, + resource_block, + resource_link_block, + start_edit_tool_call, + start_read_tool_call, + start_tool_call, + text_block, + tool_content, + tool_diff_content, + tool_terminal_ref, + update_agent_message_text, + update_agent_thought_text, + update_plan, + update_tool_call, + update_user_message_text, +) +from acp.schema import ( + AgentMessageChunk, + AgentPlanUpdate, + AgentThoughtChunk, + AllowedOutcome, + AudioContentBlock, + CancelNotification, + ContentToolCallContent, + DeniedOutcome, + EmbeddedResourceContentBlock, + FileEditToolCallContent, + ImageContentBlock, + InitializeRequest, + InitializeResponse, + NewSessionRequest, + NewSessionResponse, + PromptRequest, + ReadTextFileRequest, + ReadTextFileResponse, + RequestPermissionRequest, + RequestPermissionResponse, + ResourceContentBlock, + TerminalToolCallContent, + TextContentBlock, + ToolCallLocation, + ToolCallProgress, + ToolCallStart, + UserMessageChunk, + WriteTextFileRequest, +) + +GOLDEN_DIR = Path(__file__).parent / "golden" + +# Map each golden fixture to the concrete schema model it should conform to. +GOLDEN_CASES: dict[str, type[BaseModel]] = { + "cancel_notification": CancelNotification, + "content_audio": AudioContentBlock, + "content_image": ImageContentBlock, + "content_resource_blob": EmbeddedResourceContentBlock, + "content_resource_link": ResourceContentBlock, + "content_resource_text": EmbeddedResourceContentBlock, + "content_text": TextContentBlock, + "fs_read_text_file_request": ReadTextFileRequest, + "fs_read_text_file_response": ReadTextFileResponse, + "fs_write_text_file_request": WriteTextFileRequest, + "initialize_request": InitializeRequest, + "initialize_response": InitializeResponse, + "new_session_request": NewSessionRequest, + "new_session_response": NewSessionResponse, + "permission_outcome_cancelled": DeniedOutcome, + "permission_outcome_selected": AllowedOutcome, + "prompt_request": PromptRequest, + "request_permission_request": RequestPermissionRequest, + "request_permission_response_selected": RequestPermissionResponse, + "session_update_agent_message_chunk": AgentMessageChunk, + "session_update_agent_thought_chunk": AgentThoughtChunk, + "session_update_plan": AgentPlanUpdate, + "session_update_tool_call": ToolCallStart, + "session_update_tool_call_edit": ToolCallStart, + "session_update_tool_call_locations_rawinput": ToolCallStart, + "session_update_tool_call_read": ToolCallStart, + "session_update_tool_call_update_content": ToolCallProgress, + "session_update_tool_call_update_more_fields": ToolCallProgress, + "session_update_user_message_chunk": UserMessageChunk, + "tool_content_content_text": ContentToolCallContent, + "tool_content_diff": FileEditToolCallContent, + "tool_content_diff_no_old": FileEditToolCallContent, + "tool_content_terminal": TerminalToolCallContent, +} + +_PARAMS = tuple(sorted(GOLDEN_CASES.items())) +_PARAM_IDS = [name for name, _ in _PARAMS] + +GOLDEN_BUILDERS: dict[str, Callable[[], BaseModel]] = { + "content_text": lambda: text_block("What's the weather like today?"), + "content_image": lambda: image_block("iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB...", "image/png"), + "content_audio": lambda: audio_block("UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB...", "audio/wav"), + "content_resource_text": lambda: resource_block( + embedded_text_resource( + "file:///home/user/script.py", + "def hello():\n print('Hello, world!')", + mime_type="text/x-python", + ) + ), + "content_resource_blob": lambda: resource_block( + embedded_blob_resource( + "file:///home/user/document.pdf", + "", + mime_type="application/pdf", + ) + ), + "content_resource_link": lambda: resource_link_block( + "document.pdf", + "file:///home/user/document.pdf", + mime_type="application/pdf", + size=1_024_000, + ), + "tool_content_content_text": lambda: tool_content(text_block("Analysis complete. Found 3 issues.")), + "tool_content_diff": lambda: tool_diff_content( + "/home/user/project/src/config.json", + '{\n "debug": true\n}', + '{\n "debug": false\n}', + ), + "tool_content_diff_no_old": lambda: tool_diff_content( + "/home/user/project/src/config.json", + '{\n "debug": true\n}', + ), + "tool_content_terminal": lambda: tool_terminal_ref("term_001"), + "session_update_user_message_chunk": lambda: update_user_message_text("What's the capital of France?"), + "session_update_agent_message_chunk": lambda: update_agent_message_text("The capital of France is Paris."), + "session_update_agent_thought_chunk": lambda: update_agent_thought_text("Thinking about best approach..."), + "session_update_plan": lambda: update_plan([ + plan_entry( + "Check for syntax errors", + priority="high", + status="pending", + ), + plan_entry( + "Identify potential type issues", + priority="medium", + status="pending", + ), + ]), + "session_update_tool_call": lambda: start_tool_call( + "call_001", + "Reading configuration file", + kind="read", + status="pending", + ), + "session_update_tool_call_read": lambda: start_read_tool_call( + "call_001", + "Reading configuration file", + "/home/user/project/src/config.json", + ), + "session_update_tool_call_edit": lambda: start_edit_tool_call( + "call_003", + "Apply edit", + "/home/user/project/src/config.json", + "print('hello')", + ), + "session_update_tool_call_locations_rawinput": lambda: start_tool_call( + "call_lr", + "Tracking file", + locations=[ToolCallLocation(path="/home/user/project/src/config.json")], + raw_input={"path": "/home/user/project/src/config.json"}, + ), + "session_update_tool_call_update_content": lambda: update_tool_call( + "call_001", + status="in_progress", + content=[tool_content(text_block("Found 3 configuration files..."))], + ), + "session_update_tool_call_update_more_fields": lambda: update_tool_call( + "call_010", + title="Processing changes", + kind="edit", + status="completed", + locations=[ToolCallLocation(path="/home/user/project/src/config.json")], + raw_input={"path": "/home/user/project/src/config.json"}, + raw_output={"result": "ok"}, + content=[tool_content(text_block("Edit completed."))], + ), +} + +_HELPER_PARAMS = tuple(sorted(GOLDEN_BUILDERS.items())) +_HELPER_IDS = [name for name, _ in _HELPER_PARAMS] + + +def _load_golden(name: str) -> dict: + path = GOLDEN_DIR / f"{name}.json" + return json.loads(path.read_text()) + + +def _dump_model(model: BaseModel) -> dict: + return model.model_dump(mode="json", by_alias=True, exclude_none=True, exclude_unset=True) + + +def test_golden_cases_covered() -> None: + available = {path.stem for path in GOLDEN_DIR.glob("*.json")} + assert available == set(GOLDEN_CASES), "Add the new golden file to GOLDEN_CASES." + + +@pytest.mark.parametrize( + ("name", "model_cls"), + _PARAMS, + ids=_PARAM_IDS, +) +def test_json_golden_roundtrip(name: str, model_cls: type[BaseModel]) -> None: + raw = _load_golden(name) + model = model_cls.model_validate(raw) + assert _dump_model(model) == raw + + +@pytest.mark.parametrize( + ("name", "builder"), + _HELPER_PARAMS, + ids=_HELPER_IDS, +) +def test_helpers_match_golden(name: str, builder: Callable[[], BaseModel]) -> None: + raw = _load_golden(name) + model = builder() + assert isinstance(model, BaseModel) + assert _dump_model(model) == raw diff --git a/tests/test_rpc.py b/tests/test_rpc.py index ea6fb6e..eba7321 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -1,6 +1,8 @@ import asyncio import contextlib import json +import sys +from pathlib import Path import pytest @@ -32,12 +34,22 @@ SetSessionModeResponse, WriteTextFileRequest, WriteTextFileResponse, + session_notification, + spawn_agent_process, + start_tool_call, + update_agent_message_text, + update_tool_call, ) from acp.schema import ( AgentMessageChunk, AllowedOutcome, DeniedOutcome, + PermissionOption, TextContentBlock, + ToolCallLocation, + ToolCallProgress, + ToolCallStart, + ToolCallUpdate, UserMessageChunk, ) @@ -411,3 +423,188 @@ async def test_ignore_invalid_messages(): # Should not receive any response lines with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(s.client_reader.readline(), timeout=0.1) + + +class _ExampleAgent(Agent): + __test__ = False + + def __init__(self) -> None: + self._conn: AgentSideConnection | None = None + self.permission_response: RequestPermissionResponse | None = None + self.prompt_requests: list[PromptRequest] = [] + + def bind(self, conn: AgentSideConnection) -> "_ExampleAgent": + self._conn = conn + return self + + async def initialize(self, params: InitializeRequest) -> InitializeResponse: + return InitializeResponse(protocolVersion=params.protocolVersion) + + async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: + return NewSessionResponse(sessionId="sess_demo") + + async def prompt(self, params: PromptRequest) -> PromptResponse: + assert self._conn is not None + self.prompt_requests.append(params) + + await self._conn.sessionUpdate( + session_notification( + params.sessionId, + update_agent_message_text("I'll help you with that."), + ) + ) + + await self._conn.sessionUpdate( + session_notification( + params.sessionId, + start_tool_call( + "call_1", + "Modifying configuration", + kind="edit", + status="pending", + locations=[ToolCallLocation(path="/project/config.json")], + raw_input={"path": "/project/config.json"}, + ), + ) + ) + + permission_request = RequestPermissionRequest( + sessionId=params.sessionId, + toolCall=ToolCallUpdate( + toolCallId="call_1", + title="Modifying configuration", + kind="edit", + status="pending", + locations=[ToolCallLocation(path="/project/config.json")], + rawInput={"path": "/project/config.json"}, + ), + options=[ + PermissionOption(kind="allow_once", name="Allow", optionId="allow"), + PermissionOption(kind="reject_once", name="Reject", optionId="reject"), + ], + ) + response = await self._conn.requestPermission(permission_request) + self.permission_response = response + + if isinstance(response.outcome, AllowedOutcome) and response.outcome.optionId == "allow": + await self._conn.sessionUpdate( + session_notification( + params.sessionId, + update_tool_call( + "call_1", + status="completed", + raw_output={"success": True}, + ), + ) + ) + await self._conn.sessionUpdate( + session_notification( + params.sessionId, + update_agent_message_text("Done."), + ) + ) + + return PromptResponse(stopReason="end_turn") + + +class _ExampleClient(TestClient): + __test__ = False + + def __init__(self) -> None: + super().__init__() + self.permission_requests: list[RequestPermissionRequest] = [] + + async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: + self.permission_requests.append(params) + if not params.options: + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) + option = params.options[0] + return RequestPermissionResponse(outcome=AllowedOutcome(optionId=option.optionId, outcome="selected")) + + +@pytest.mark.asyncio +async def test_example_agent_permission_flow(): + async with _Server() as s: + agent = _ExampleAgent() + client = _ExampleClient() + + agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) + AgentSideConnection(lambda conn: agent.bind(conn), s.server_writer, s.server_reader) + + init = await agent_conn.initialize(InitializeRequest(protocolVersion=1)) + assert init.protocolVersion == 1 + + session = await agent_conn.newSession(NewSessionRequest(mcpServers=[], cwd="/workspace")) + assert session.sessionId == "sess_demo" + + prompt = PromptRequest( + sessionId=session.sessionId, + prompt=[TextContentBlock(type="text", text="Please edit config")], + ) + resp = await agent_conn.prompt(prompt) + assert resp.stopReason == "end_turn" + + for _ in range(50): + if len(client.notifications) >= 4: + break + await asyncio.sleep(0.02) + + assert len(client.notifications) >= 4 + session_updates = [getattr(note.update, "sessionUpdate", None) for note in client.notifications] + assert session_updates[:4] == ["agent_message_chunk", "tool_call", "tool_call_update", "agent_message_chunk"] + + first_message = client.notifications[0].update + assert isinstance(first_message, AgentMessageChunk) + assert isinstance(first_message.content, TextContentBlock) + assert first_message.content.text == "I'll help you with that." + + tool_call = client.notifications[1].update + assert isinstance(tool_call, ToolCallStart) + assert tool_call.title == "Modifying configuration" + assert tool_call.status == "pending" + + tool_update = client.notifications[2].update + assert isinstance(tool_update, ToolCallProgress) + assert tool_update.status == "completed" + assert tool_update.rawOutput == {"success": True} + + final_message = client.notifications[3].update + assert isinstance(final_message, AgentMessageChunk) + assert isinstance(final_message.content, TextContentBlock) + assert final_message.content.text == "Done." + + assert len(client.permission_requests) == 1 + options = client.permission_requests[0].options + assert [opt.optionId for opt in options] == ["allow", "reject"] + + assert agent.permission_response is not None + assert isinstance(agent.permission_response.outcome, AllowedOutcome) + assert agent.permission_response.outcome.optionId == "allow" + + +@pytest.mark.asyncio +async def test_spawn_agent_process_roundtrip(tmp_path): + script = Path(__file__).parents[1] / "examples" / "echo_agent.py" + assert script.exists() + + test_client = TestClient() + + async with spawn_agent_process(lambda _agent: test_client, sys.executable, str(script)) as (client_conn, process): + init = await client_conn.initialize(InitializeRequest(protocolVersion=1)) + assert isinstance(init, InitializeResponse) + session = await client_conn.newSession(NewSessionRequest(cwd=str(tmp_path), mcpServers=[])) + prompt = PromptRequest( + sessionId=session.sessionId, + prompt=[TextContentBlock(type="text", text="hi spawn")], + ) + await client_conn.prompt(prompt) + + # Wait for echo agent notification to arrive + for _ in range(50): + if test_client.notifications: + break + await asyncio.sleep(0.02) + + assert test_client.notifications + + assert process.returncode is not None diff --git a/uv.lock b/uv.lock index 679436a..c0c7387 100644 --- a/uv.lock +++ b/uv.lock @@ -20,7 +20,6 @@ logfire = [ dev = [ { name = "datamodel-code-generator" }, { name = "deptry" }, - { name = "mini-swe-agent" }, { name = "mkdocs" }, { name = "mkdocs-material" }, { name = "mkdocstrings", extra = ["python"] }, @@ -45,7 +44,6 @@ provides-extras = ["logfire"] dev = [ { name = "datamodel-code-generator", specifier = ">=0.25" }, { name = "deptry", specifier = ">=0.23.0" }, - { name = "mini-swe-agent", specifier = ">=1.10.0" }, { name = "mkdocs", specifier = ">=1.4.2" }, { name = "mkdocs-material", specifier = ">=8.5.10" }, { name = "mkdocstrings", extras = ["python"], specifier = ">=0.26.1" }, @@ -1089,29 +1087,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/19/04f9b178c2d8a15b076c8b5140708fa6ffc5601fb6f1e975537072df5b2a/mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307", size = 6354, upload-time = "2021-02-05T18:55:29.583Z" }, ] -[[package]] -name = "mini-swe-agent" -version = "1.10.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jinja2" }, - { name = "litellm" }, - { name = "openai" }, - { name = "platformdirs" }, - { name = "prompt-toolkit" }, - { name = "python-dotenv" }, - { name = "pyyaml" }, - { name = "requests" }, - { name = "rich" }, - { name = "tenacity" }, - { name = "textual" }, - { name = "typer" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ee/72/3ee88176f2a7dd99da3bf57f70c0a0aa5d3728e3b6711ed659a0cdd67eb9/mini_swe_agent-1.10.0.tar.gz", hash = "sha256:c0fe700fe58bb24aa706f5aec7a812ead63cbf522a67dcb96dba26d3fc21136f", size = 45158, upload-time = "2025-08-29T01:40:26.475Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/06/fd298c6781261084a8233218231ffebdc315593b728f4535ef2c80517b9f/mini_swe_agent-1.10.0-py3-none-any.whl", hash = "sha256:47be76446f7b1975d844270a5d8349cc09811adc8ea3c70cac47f2b29cc7a63b", size = 68367, upload-time = "2025-08-29T01:40:25.528Z" }, -] - [[package]] name = "mkdocs" version = "1.6.1"