Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/acp/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ def get_extra_info(self, name: str, default=None): # type: ignore[override]
return default


async def _windows_stdio_streams(loop: asyncio.AbstractEventLoop) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
reader = asyncio.StreamReader()
async def _windows_stdio_streams(
loop: asyncio.AbstractEventLoop,
limit: int | None = None,
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
reader = asyncio.StreamReader(limit=limit) if limit is not None else asyncio.StreamReader()
_ = asyncio.StreamReaderProtocol(reader)

_start_stdin_feeder(loop, reader)
Expand All @@ -108,9 +111,12 @@ async def _windows_stdio_streams(loop: asyncio.AbstractEventLoop) -> tuple[async
return reader, writer


async def _posix_stdio_streams(loop: asyncio.AbstractEventLoop) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
async def _posix_stdio_streams(
loop: asyncio.AbstractEventLoop,
limit: int | None = None,
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
# Reader from stdin
reader = asyncio.StreamReader()
reader = asyncio.StreamReader(limit=limit) if limit is not None else asyncio.StreamReader()
reader_protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin)

Expand All @@ -121,12 +127,16 @@ async def _posix_stdio_streams(loop: asyncio.AbstractEventLoop) -> tuple[asyncio
return reader, writer


async def stdio_streams() -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""Create stdio asyncio streams; on Windows use a thread feeder + custom stdout transport."""
async def stdio_streams(limit: int | None = None) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""Create stdio asyncio streams; on Windows use a thread feeder + custom stdout transport.

Args:
limit: Optional buffer limit for the stdin reader.
"""
loop = asyncio.get_running_loop()
if platform.system() == "Windows":
return await _windows_stdio_streams(loop)
return await _posix_stdio_streams(loop)
return await _windows_stdio_streams(loop, limit=limit)
return await _posix_stdio_streams(loop, limit=limit)


@asynccontextmanager
Expand Down
31 changes: 22 additions & 9 deletions src/acp/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async def spawn_stdio_transport(
env: Mapping[str, str] | None = None,
cwd: str | Path | None = None,
stderr: int | None = aio_subprocess.PIPE,
limit: int | None = None,
shutdown_timeout: float = 2.0,
) -> AsyncIterator[tuple[asyncio.StreamReader, asyncio.StreamWriter, aio_subprocess.Process]]:
"""Launch a subprocess and expose its stdio streams as asyncio transports.
Expand All @@ -62,15 +63,27 @@ async def spawn_stdio_transport(
if env:
merged_env.update(env)

process = await asyncio.create_subprocess_exec(
command,
*args,
stdin=aio_subprocess.PIPE,
stdout=aio_subprocess.PIPE,
stderr=stderr,
env=merged_env,
cwd=str(cwd) if cwd is not None else None,
)
if limit is None:
process = await asyncio.create_subprocess_exec(
command,
*args,
stdin=aio_subprocess.PIPE,
stdout=aio_subprocess.PIPE,
stderr=stderr,
env=merged_env,
cwd=str(cwd) if cwd is not None else None,
)
else:
process = await asyncio.create_subprocess_exec(
command,
*args,
stdin=aio_subprocess.PIPE,
stdout=aio_subprocess.PIPE,
stderr=stderr,
env=merged_env,
cwd=str(cwd) if cwd is not None else None,
limit=limit,
)

if process.stdout is None or process.stdin is None:
process.kill()
Expand Down
41 changes: 41 additions & 0 deletions tests/real_user/test_stdio_limits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import sys
import textwrap

import pytest

from acp.transports import spawn_stdio_transport

LARGE_LINE_SIZE = 70 * 1024


def _large_line_script(size: int = LARGE_LINE_SIZE) -> str:
return textwrap.dedent(
f"""
import sys
sys.stdout.write("X" * {size})
sys.stdout.write("\\n")
sys.stdout.flush()
"""
).strip()


@pytest.mark.asyncio
async def test_spawn_stdio_transport_hits_default_limit() -> None:
script = _large_line_script()
async with spawn_stdio_transport(sys.executable, "-c", script) as (reader, writer, _process):
# readline() re-raises LimitOverrunError as ValueError on CPython 3.12+.
with pytest.raises(ValueError):
await reader.readline()


@pytest.mark.asyncio
async def test_spawn_stdio_transport_custom_limit_handles_large_line() -> None:
script = _large_line_script()
async with spawn_stdio_transport(
sys.executable,
"-c",
script,
limit=LARGE_LINE_SIZE * 2,
) as (reader, writer, _process):
line = await reader.readline()
assert len(line) == LARGE_LINE_SIZE + 1