Skip to content

Commit b90b19a

Browse files
committed
feat: Expose new APIs for running agent and connect with client
Signed-off-by: Frost Ming <me@frostming.com>
1 parent 856ea0d commit b90b19a

21 files changed

+962
-551
lines changed

docs/quickstart.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class SimpleClient(Client):
9090

9191
async def main() -> None:
9292
script = Path("examples/echo_agent.py")
93-
async with spawn_agent_process(lambda _agent: SimpleClient(), sys.executable, str(script)) as (conn, _proc):
93+
async with spawn_agent_process(SimpleClient(), sys.executable, str(script)) as (conn, _proc):
9494
await conn.initialize(protocol_version=1)
9595
session = await conn.new_session(cwd=str(script.parent), mcp_servers=[])
9696
await conn.prompt(
@@ -119,7 +119,7 @@ class MyAgent(Agent):
119119
return PromptResponse(stop_reason="end_turn")
120120
```
121121

122-
Hook it up with `AgentSideConnection` inside an async entrypoint and wire it to your client. Refer to:
122+
Run it with `run_agent()` inside an async entrypoint and wire it to your client. Refer to:
123123

124124
- [`examples/echo_agent.py`](https://github.com/agentclientprotocol/python-sdk/blob/main/examples/echo_agent.py) for the smallest streaming agent
125125
- [`examples/agent.py`](https://github.com/agentclientprotocol/python-sdk/blob/main/examples/agent.py) for an implementation that negotiates capabilities and streams richer updates

examples/agent.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
NewSessionResponse,
1212
PromptResponse,
1313
SetSessionModeResponse,
14-
stdio_streams,
14+
run_agent,
1515
text_block,
1616
update_agent_message,
1717
PROTOCOL_VERSION,
1818
)
19+
from acp.interfaces import Client
1920
from acp.schema import (
2021
AgentCapabilities,
2122
AgentMessageChunk,
@@ -33,11 +34,15 @@
3334

3435

3536
class ExampleAgent(Agent):
36-
def __init__(self, conn: AgentSideConnection) -> None:
37-
self._conn = conn
37+
_conn: Client
38+
39+
def __init__(self) -> None:
3840
self._next_session_id = 0
3941
self._sessions: set[str] = set()
4042

43+
def on_connect(self, conn: Client) -> None:
44+
self._conn = conn
45+
4146
async def _send_agent_message(self, session_id: str, content: Any) -> None:
4247
update = content if isinstance(content, AgentMessageChunk) else update_agent_message(content)
4348
await self._conn.session_update(session_id, update)
@@ -114,9 +119,7 @@ async def ext_notification(self, method: str, params: dict[str, Any]) -> None:
114119

115120
async def main() -> None:
116121
logging.basicConfig(level=logging.INFO)
117-
reader, writer = await stdio_streams()
118-
AgentSideConnection(ExampleAgent, writer, reader)
119-
await asyncio.Event().wait()
122+
await run_agent(ExampleAgent())
120123

121124

122125
if __name__ == "__main__":

examples/client.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from acp import (
1111
Client,
1212
ClientSideConnection,
13-
InitializeRequest,
14-
NewSessionRequest,
15-
PromptRequest,
13+
connect_to_agent,
1614
RequestError,
1715
text_block,
1816
PROTOCOL_VERSION,
@@ -190,7 +188,7 @@ async def main(argv: list[str]) -> int:
190188
return 1
191189

192190
client_impl = ExampleClient()
193-
conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout)
191+
conn = connect_to_agent(client_impl, proc.stdin, proc.stdout)
194192

195193
await conn.initialize(
196194
protocol_version=PROTOCOL_VERSION,

examples/duet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def main() -> int:
3030
client_module = _load_client_module(root / "client.py")
3131
client = client_module.ExampleClient()
3232

33-
async with spawn_agent_process(lambda _agent: client, sys.executable, str(agent_path), env=env) as (
33+
async with spawn_agent_process(client, sys.executable, str(agent_path), env=env) as (
3434
conn,
3535
process,
3636
):

examples/echo_agent.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
InitializeResponse,
99
NewSessionResponse,
1010
PromptResponse,
11-
stdio_streams,
11+
run_agent,
1212
text_block,
1313
update_agent_message,
1414
)
15+
from acp.interfaces import Client
1516
from acp.schema import (
1617
AudioContentBlock,
1718
ClientCapabilities,
@@ -27,7 +28,9 @@
2728

2829

2930
class EchoAgent(Agent):
30-
def __init__(self, conn: AgentSideConnection) -> None:
31+
_conn: Client
32+
33+
def on_connect(self, conn: Client) -> None:
3134
self._conn = conn
3235

3336
async def initialize(
@@ -67,9 +70,7 @@ async def prompt(
6770

6871

6972
async def main() -> None:
70-
reader, writer = await stdio_streams()
71-
AgentSideConnection(EchoAgent, writer, reader)
72-
await asyncio.Event().wait()
73+
await run_agent(EchoAgent())
7374

7475

7576
if __name__ == "__main__":

examples/gemini.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313

1414
from acp import (
1515
Client,
16-
ClientSideConnection,
16+
connect_to_agent,
1717
PROTOCOL_VERSION,
1818
RequestError,
1919
text_block,
2020
)
21+
from acp.core import ClientSideConnection
2122
from acp.schema import (
2223
AgentMessageChunk,
2324
AgentPlanUpdate,
@@ -316,7 +317,7 @@ async def run(argv: list[str]) -> int:
316317
return 1
317318

318319
client_impl = GeminiClient(auto_approve=args.yolo)
319-
conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout)
320+
conn = connect_to_agent(client_impl, proc.stdin, proc.stdout)
320321

321322
try:
322323
init_resp = await conn.initialize(

src/acp/__init__.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from typing import Any
2+
13
from .core import (
24
Agent,
3-
AgentSideConnection,
45
Client,
5-
ClientSideConnection,
66
RequestError,
77
TerminalHandle,
8+
connect_to_agent,
9+
run_agent,
810
)
911
from .helpers import (
1012
audio_block,
@@ -73,6 +75,19 @@
7375
from .stdio import spawn_agent_process, spawn_client_process, spawn_stdio_connection, stdio_streams
7476
from .transports import default_environment, spawn_stdio_transport
7577

78+
_DEPRECATED_NAMES = [
79+
(
80+
"AgentSideConnection",
81+
"acp.core:AgentSideConnection",
82+
"Using `AgentSideConnection` directly is deprecated, please use `acp.run_agent` instead.",
83+
),
84+
(
85+
"ClientSideConnection",
86+
"acp.core:ClientSideConnection",
87+
"Using `ClientSideConnection` directly is deprecated, please use `acp.connect_to_agent` instead.",
88+
),
89+
]
90+
7691
__all__ = [ # noqa: RUF022
7792
# constants
7893
"PROTOCOL_VERSION",
@@ -113,8 +128,8 @@
113128
"ReleaseTerminalRequest",
114129
"ReleaseTerminalResponse",
115130
# core
116-
"AgentSideConnection",
117-
"ClientSideConnection",
131+
"run_agent",
132+
"connect_to_agent",
118133
"RequestError",
119134
"Agent",
120135
"Client",
@@ -151,3 +166,16 @@
151166
"start_edit_tool_call",
152167
"update_tool_call",
153168
]
169+
170+
171+
def __getattr__(name: str) -> Any:
172+
import warnings
173+
from importlib import import_module
174+
175+
for deprecated_name, new_path, warning in _DEPRECATED_NAMES:
176+
if name == deprecated_name:
177+
warnings.warn(warning, DeprecationWarning, stacklevel=2)
178+
module_name, attr_name = new_path.split(":")
179+
module = import_module(module_name)
180+
return getattr(module, attr_name)
181+
raise AttributeError(f"module {__name__} has no attribute {name}") # noqa: TRY003

src/acp/agent/connection.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import asyncio
44
from collections.abc import Callable
5-
from typing import Any
5+
from typing import Any, cast, final
66

77
from ..connection import Connection
8-
from ..interfaces import Agent
8+
from ..interfaces import Agent, Client
99
from ..meta import CLIENT_METHODS
1010
from ..schema import (
1111
AgentMessageChunk,
@@ -38,28 +38,37 @@
3838
WriteTextFileResponse,
3939
)
4040
from ..terminal import TerminalHandle
41-
from ..utils import notify_model, param_model, request_model, request_optional_model
41+
from ..utils import compatible_class, notify_model, param_model, request_model, request_optional_model
4242
from .router import build_agent_router
4343

4444
__all__ = ["AgentSideConnection"]
4545
_AGENT_CONNECTION_ERROR = "AgentSideConnection requires asyncio StreamWriter/StreamReader"
4646

4747

48+
@final
49+
@compatible_class
4850
class AgentSideConnection:
4951
"""Agent-side connection wrapper that dispatches JSON-RPC messages to a Client implementation."""
5052

5153
def __init__(
5254
self,
53-
to_agent: Callable[[AgentSideConnection], Agent],
55+
to_agent: Callable[[Client], Agent] | Agent,
5456
input_stream: Any,
5557
output_stream: Any,
58+
listening: bool = False,
5659
**connection_kwargs: Any,
5760
) -> None:
58-
agent = to_agent(self)
59-
handler = build_agent_router(agent)
61+
agent = to_agent(cast(Client, self)) if callable(to_agent) else to_agent
6062
if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader):
6163
raise TypeError(_AGENT_CONNECTION_ERROR)
62-
self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs)
64+
handler = build_agent_router(agent) # type: ignore[arg-type]
65+
self._conn = Connection(handler, input_stream, output_stream, listening=listening, **connection_kwargs)
66+
if on_connect := getattr(agent, "on_connect", None):
67+
on_connect(self)
68+
69+
async def listen(self) -> None:
70+
"""Start listening for incoming messages."""
71+
await self._conn.main_loop()
6372

6473
@param_model(SessionNotification)
6574
async def session_update(

src/acp/client/connection.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
from collections.abc import Callable
5-
from typing import Any
5+
from typing import Any, cast, final
66

77
from ..connection import Connection
88
from ..interfaces import Agent, Client
@@ -34,24 +34,32 @@
3434
StdioMcpServer,
3535
TextContentBlock,
3636
)
37-
from ..utils import notify_model, param_model, request_model, request_model_from_dict
37+
from ..utils import compatible_class, notify_model, param_model, request_model, request_model_from_dict
3838
from .router import build_client_router
3939

4040
__all__ = ["ClientSideConnection"]
4141
_CLIENT_CONNECTION_ERROR = "ClientSideConnection requires asyncio StreamWriter/StreamReader"
4242

4343

44+
@final
45+
@compatible_class
4446
class ClientSideConnection:
4547
"""Client-side connection wrapper that dispatches JSON-RPC messages to an Agent implementation."""
4648

4749
def __init__(
48-
self, to_client: Callable[[Agent], Client], input_stream: Any, output_stream: Any, **connection_kwargs: Any
50+
self,
51+
to_client: Callable[[Agent], Client] | Client,
52+
input_stream: Any,
53+
output_stream: Any,
54+
**connection_kwargs: Any,
4955
) -> None:
5056
if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader):
5157
raise TypeError(_CLIENT_CONNECTION_ERROR)
52-
client = to_client(self)
53-
handler = build_client_router(client)
58+
client = to_client(cast(Agent, self)) if callable(to_client) else to_client
59+
handler = build_client_router(client) # type: ignore[arg-type]
5460
self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs)
61+
if on_connect := getattr(client, "on_connect", None):
62+
on_connect(self)
5563

5664
@param_model(InitializeRequest)
5765
async def initialize(

src/acp/connection.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
dispatcher_factory: DispatcherFactory | None = None,
7373
sender_factory: SenderFactory | None = None,
7474
observers: list[StreamObserver] | None = None,
75+
listening: bool = True,
7576
) -> None:
7677
self._handler = handler
7778
self._writer = writer
@@ -83,11 +84,14 @@ def __init__(
8384
self._queue = queue or InMemoryMessageQueue()
8485
self._closed = False
8586
self._sender = (sender_factory or self._default_sender_factory)(self._writer, self._tasks)
86-
self._recv_task = self._tasks.create(
87-
self._receive_loop(),
88-
name="acp.Connection.receive",
89-
on_error=self._on_receive_error,
90-
)
87+
if listening:
88+
self._recv_task = self._tasks.create(
89+
self._receive_loop(),
90+
name="acp.Connection.receive",
91+
on_error=self._on_receive_error,
92+
)
93+
else:
94+
self._recv_task = None
9195
dispatcher_factory = dispatcher_factory or self._default_dispatcher_factory
9296
self._dispatcher = dispatcher_factory(
9397
self._queue,
@@ -109,6 +113,14 @@ async def close(self) -> None:
109113
await self._tasks.shutdown()
110114
self._state.reject_all_outgoing(ConnectionError("Connection closed"))
111115

116+
async def main_loop(self) -> None:
117+
try:
118+
await self._receive_loop()
119+
except Exception as exc:
120+
logging.exception("Connection main loop failed", exc_info=exc)
121+
self._on_receive_error(None, exc) # type: ignore[arg-type]
122+
raise
123+
112124
async def __aenter__(self) -> Connection:
113125
return self
114126

0 commit comments

Comments
 (0)