|
2 | 2 |
|
3 | 3 | import asyncio |
4 | 4 | from collections.abc import Callable |
5 | | -from typing import Any |
| 5 | +from typing import Any, cast, final |
6 | 6 |
|
7 | 7 | from ..connection import Connection |
8 | | -from ..interfaces import Agent |
| 8 | +from ..interfaces import Agent, Client |
9 | 9 | from ..meta import CLIENT_METHODS |
10 | 10 | from ..schema import ( |
11 | 11 | AgentMessageChunk, |
|
38 | 38 | WriteTextFileResponse, |
39 | 39 | ) |
40 | 40 | from ..terminal import TerminalHandle |
41 | | -from ..utils import notify_model, param_model, request_model, request_optional_model |
| 41 | +from ..utils import compatible_class, notify_model, param_model, request_model, request_optional_model |
42 | 42 | from .router import build_agent_router |
43 | 43 |
|
44 | 44 | __all__ = ["AgentSideConnection"] |
45 | 45 | _AGENT_CONNECTION_ERROR = "AgentSideConnection requires asyncio StreamWriter/StreamReader" |
46 | 46 |
|
47 | 47 |
|
| 48 | +@final |
| 49 | +@compatible_class |
48 | 50 | class AgentSideConnection: |
49 | 51 | """Agent-side connection wrapper that dispatches JSON-RPC messages to a Client implementation.""" |
50 | 52 |
|
51 | 53 | def __init__( |
52 | 54 | self, |
53 | | - to_agent: Callable[[AgentSideConnection], Agent], |
| 55 | + to_agent: Callable[[Client], Agent] | Agent, |
54 | 56 | input_stream: Any, |
55 | 57 | output_stream: Any, |
| 58 | + listening: bool = False, |
56 | 59 | **connection_kwargs: Any, |
57 | 60 | ) -> None: |
58 | | - agent = to_agent(self) |
59 | | - handler = build_agent_router(agent) |
| 61 | + agent = to_agent(cast(Client, self)) if callable(to_agent) else to_agent |
60 | 62 | if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): |
61 | 63 | raise TypeError(_AGENT_CONNECTION_ERROR) |
62 | | - self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) |
| 64 | + handler = build_agent_router(agent) # type: ignore[arg-type] |
| 65 | + self._conn = Connection(handler, input_stream, output_stream, listening=listening, **connection_kwargs) |
| 66 | + if on_connect := getattr(agent, "on_connect", None): |
| 67 | + on_connect(self) |
| 68 | + |
| 69 | + async def listen(self) -> None: |
| 70 | + """Start listening for incoming messages.""" |
| 71 | + await self._conn.main_loop() |
63 | 72 |
|
64 | 73 | @param_model(SessionNotification) |
65 | 74 | async def session_update( |
|
0 commit comments