Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,48 @@

class AiohttpSyncStream:
"""Wraps aiohttp's async StreamReader as a synchronous file-like object.
read(n) blocks until exactly n bytes are available from the HTTP response."""
read(n) blocks until exactly n bytes are available from the HTTP response.

Maintains an internal byte buffer that is refilled one HTTP chunk at a time
so the deserializer's many small read(n) calls don't each pay the cost of a
full asyncio event-loop turn."""

# Max bytes pulled from the response per underlying read. Matches
# aiohttp.StreamReader's default 64 KB limit, which is the per-connection
# high-water mark, asking for more in one read() never returns more.
_FILL_SIZE = 64 * 1024

def __init__(self, response, loop, read_timeout):
self._response = response
self._loop = loop
self._read_timeout = read_timeout
self._buf = bytearray()
self._pos = 0

def read(self, n):
if n <= 0:
return b''
while len(self._buf) - self._pos < n:
data = self._read_chunk()
if not data:
partial = bytes(self._buf[self._pos:])
self._buf.clear()
self._pos = 0
raise asyncio.IncompleteReadError(partial=partial, expected=n)
self._buf.extend(data)
end = self._pos + n
out = bytes(self._buf[self._pos:end])
self._pos = end
# Reclaim memory once the buffer is fully drained
if self._pos == len(self._buf):
self._buf.clear()
self._pos = 0
return out

def _read_chunk(self):
async def _read():
async with async_timeout.timeout(self._read_timeout):
return await self._response.content.readexactly(n)
return await self._response.content.read(self._FILL_SIZE)
return self._loop.run_until_complete(_read())


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ class TestAiohttpSyncStream:

The class should:
- Have a read(n) method that blocks until exactly n bytes are available
- Bridge async readexactly(n) to sync via loop.run_until_complete()
- Refill its internal buffer one HTTP chunk at a time so the
deserializer's many small read(n) calls don't each cost a full
asyncio event-loop turn
- Raise on timeout
- Raise asyncio.IncompleteReadError on premature disconnect
"""
Expand All @@ -169,13 +171,12 @@ def test_read_returns_exact_bytes(self):
loop = asyncio.new_event_loop()
mock_response = MagicMock()
mock_response.content = MagicMock()
mock_response.content.readexactly = AsyncMock(return_value=b'\x01\x02\x03\x04')
mock_response.content.read = AsyncMock(side_effect=[b'\x01\x02\x03\x04', b''])

stream = AiohttpSyncStream(mock_response, loop, read_timeout=30)
result = stream.read(4)

assert result == b'\x01\x02\x03\x04'
mock_response.content.readexactly.assert_awaited_once_with(4)
loop.close()

def test_read_single_byte(self):
Expand All @@ -185,7 +186,7 @@ def test_read_single_byte(self):
loop = asyncio.new_event_loop()
mock_response = MagicMock()
mock_response.content = MagicMock()
mock_response.content.readexactly = AsyncMock(return_value=b'\x84')
mock_response.content.read = AsyncMock(side_effect=[b'\x84', b''])

stream = AiohttpSyncStream(mock_response, loop, read_timeout=30)
result = stream.read(1)
Expand All @@ -194,35 +195,53 @@ def test_read_single_byte(self):
loop.close()

def test_read_multiple_sequential_calls(self):
"""Multiple read() calls should each invoke readexactly independently."""
"""Multiple read() calls should be served from a single buffered chunk."""
from gremlin_python.driver.aiohttp.transport import AiohttpSyncStream

loop = asyncio.new_event_loop()
mock_response = MagicMock()
mock_response.content = MagicMock()
mock_response.content.readexactly = AsyncMock(side_effect=[b'\x84', b'\x00', b'\x01\x02\x03\x04'])
# The whole payload arrives in one chunk; subsequent calls return EOF.
mock_response.content.read = AsyncMock(side_effect=[b'\x84\x00\x01\x02\x03\x04', b''])

stream = AiohttpSyncStream(mock_response, loop, read_timeout=30)
assert stream.read(1) == b'\x84'
assert stream.read(1) == b'\x00'
assert stream.read(4) == b'\x01\x02\x03\x04'
assert mock_response.content.readexactly.await_count == 3
# Only one underlying read was needed for three user-level read() calls
assert mock_response.content.read.await_count == 1
loop.close()

def test_read_refills_buffer_across_chunks(self):
"""read(n) should refill from the underlying stream when the buffer is short."""
from gremlin_python.driver.aiohttp.transport import AiohttpSyncStream

loop = asyncio.new_event_loop()
mock_response = MagicMock()
mock_response.content = MagicMock()
# Data arrives in two chunks; read(6) must span both.
mock_response.content.read = AsyncMock(side_effect=[b'\x01\x02\x03', b'\x04\x05\x06', b''])

stream = AiohttpSyncStream(mock_response, loop, read_timeout=30)
assert stream.read(6) == b'\x01\x02\x03\x04\x05\x06'
assert mock_response.content.read.await_count == 2
loop.close()

def test_read_raises_on_incomplete_read(self):
"""read() should propagate IncompleteReadError when server disconnects mid-stream."""
"""read() should raise IncompleteReadError when the server disconnects mid-stream."""
from gremlin_python.driver.aiohttp.transport import AiohttpSyncStream

loop = asyncio.new_event_loop()
mock_response = MagicMock()
mock_response.content = MagicMock()
mock_response.content.readexactly = AsyncMock(
side_effect=asyncio.IncompleteReadError(partial=b'\x01', expected=4)
)
# First chunk delivers one byte, then EOF — caller asked for four.
mock_response.content.read = AsyncMock(side_effect=[b'\x01', b''])

stream = AiohttpSyncStream(mock_response, loop, read_timeout=30)
with pytest.raises(asyncio.IncompleteReadError):
with pytest.raises(asyncio.IncompleteReadError) as exc_info:
stream.read(4)
assert exc_info.value.partial == b'\x01'
assert exc_info.value.expected == 4
loop.close()

def test_read_raises_on_timeout(self):
Expand All @@ -232,7 +251,7 @@ def test_read_raises_on_timeout(self):
loop = asyncio.new_event_loop()
mock_response = MagicMock()
mock_response.content = MagicMock()
mock_response.content.readexactly = AsyncMock(side_effect=asyncio.TimeoutError())
mock_response.content.read = AsyncMock(side_effect=asyncio.TimeoutError())

stream = AiohttpSyncStream(mock_response, loop, read_timeout=1)
with pytest.raises(asyncio.TimeoutError):
Expand Down Expand Up @@ -268,7 +287,7 @@ def test_get_stream_uses_current_response(self):
transport._read_timeout = 30
mock_resp = MagicMock()
mock_resp.content = MagicMock()
mock_resp.content.readexactly = AsyncMock(return_value=b'\x84')
mock_resp.content.read = AsyncMock(side_effect=[b'\x84', b''])
transport._http_req_resp = mock_resp

stream = transport.get_stream()
Expand Down
Loading