Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the receive_nowait() method to all streams #487

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``,
``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by
Lura Skye)
- Added the ``receive_nowait()`` method to the entire stream class hierarchy
- Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing
to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help
from Egor Blagov)
Expand Down
10 changes: 10 additions & 0 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,16 @@ def _spawn_task_from_thread(
class StreamReaderWrapper(abc.ByteReceiveStream):
_stream: asyncio.StreamReader

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
if self._stream.exception():
raise self._stream.exception()
elif not self._stream._buffer: # type: ignore[attr-defined]
raise WouldBlock

data = self._stream._buffer[:max_bytes] # type: ignore[attr-defined]
del self._stream._buffer[:max_bytes] # type: ignore[attr-defined]
return data

async def receive(self, max_bytes: int = 65536) -> bytes:
data = await self._stream.read(max_bytes)
if data:
Expand Down
12 changes: 12 additions & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,18 @@ def __init__(self, trio_socket: TrioSocketType) -> None:
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
try:
data = self._raw_socket.recv(max_bytes)
except BaseException as exc:
self._convert_socket_error(exc)

if data:
return data
else:
raise EndOfStream

async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
try:
Expand Down
31 changes: 30 additions & 1 deletion src/anyio/abc/_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Callable
from typing import Any, Generic, TypeVar, Union

from .._core._exceptions import EndOfStream
from .._core._exceptions import EndOfStream, WouldBlock
from .._core._typedattr import TypedAttributeProvider
from ._resources import AsyncResource
from ._tasks import TaskGroup
Expand Down Expand Up @@ -36,6 +36,20 @@ async def __anext__(self) -> T_co:
except EndOfStream:
raise StopAsyncIteration

def receive_nowait(self) -> T_co:
"""
Receive the next item if it can be done without waiting.

:raises ~anyio.ClosedResourceError: if the receive stream has been explicitly
closed
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
:raises ~anyio.BrokenResourceError: if this stream has been rendered unusable
due to external causes
:raises ~anyio.WouldBlock: if there is no item immeditately available

"""
raise WouldBlock

@abstractmethod
async def receive(self) -> T_co:
"""
Expand Down Expand Up @@ -132,6 +146,21 @@ async def __anext__(self) -> bytes:
except EndOfStream:
raise StopAsyncIteration

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
"""
Receive at most ``max_bytes`` bytes from the peer, if it can be done without
blocking.

.. note:: Implementors of this interface should not return an empty
:class:`bytes` object, and users should ignore them.

:param max_bytes: maximum number of bytes to receive
:return: the received bytes
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
:raises ~anyio.WouldBlock: if there is no data waiting to be received
"""
raise WouldBlock

@abstractmethod
async def receive(self, max_bytes: int = 65536) -> bytes:
"""
Expand Down
21 changes: 21 additions & 0 deletions src/anyio/streams/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ def buffer(self) -> bytes:
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return self.receive_stream.extra_attributes

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
if self._closed:
raise ClosedResourceError

if self._buffer:
chunk = bytes(self._buffer[:max_bytes])
del self._buffer[:max_bytes]
return chunk
elif isinstance(self.receive_stream, ByteReceiveStream):
return self.receive_stream.receive_nowait(max_bytes)
else:
# With a bytes-oriented object stream, we need to handle any surplus bytes
# we get from the receive_nowait() call
chunk = self.receive_stream.receive_nowait()
if len(chunk) > max_bytes:
# Save the surplus bytes in the buffer
self._buffer.extend(chunk[max_bytes:])
return chunk[:max_bytes]
else:
return chunk

async def receive(self, max_bytes: int = 65536) -> bytes:
if self._closed:
raise ClosedResourceError
Expand Down
11 changes: 0 additions & 11 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,6 @@ def __post_init__(self) -> None:
self._state.open_receive_channels += 1

def receive_nowait(self) -> T_co:
"""
Receive the next item if it can be done without waiting.

:return: the received item
:raises ~anyio.ClosedResourceError: if this send stream has been closed
:raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
closed from the sending end
:raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
waiting to send

"""
if self._closed:
raise ClosedResourceError

Expand Down
3 changes: 3 additions & 0 deletions src/anyio/streams/stapled.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class StapledByteStream(ByteStream):
send_stream: ByteSendStream
receive_stream: ByteReceiveStream

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
return self.receive_stream.receive_nowait(max_bytes)

async def receive(self, max_bytes: int = 65536) -> bytes:
return await self.receive_stream.receive(max_bytes)

Expand Down
7 changes: 7 additions & 0 deletions src/anyio/streams/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def __post_init__(self, encoding: str, errors: str) -> None:
decoder_class = codecs.getincrementaldecoder(encoding)
self._decoder = decoder_class(errors=errors)

def receive_nowait(self) -> str:
while True:
chunk = self.transport_stream.receive_nowait()
decoded = self._decoder.decode(chunk)
if decoded:
return decoded

async def receive(self) -> str:
while True:
chunk = await self.transport_stream.receive()
Expand Down
40 changes: 40 additions & 0 deletions src/anyio/streams/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .. import (
BrokenResourceError,
EndOfStream,
WouldBlock,
aclose_forcefully,
get_cancelled_exc_class,
)
Expand Down Expand Up @@ -194,6 +195,45 @@ async def aclose(self) -> None:

await self.transport_stream.aclose()

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
while True:
try:
data = self._ssl_object.read(max_bytes)
break
except ssl.SSLWantReadError:
try:
data = self.transport_stream.receive_nowait()
except WouldBlock:
raise WouldBlock from None
except EndOfStream:
self._read_bio.write_eof()
except OSError as exc:
self._read_bio.write_eof()
self._write_bio.write_eof()
raise BrokenResourceError from exc
else:
self._read_bio.write(data)
except ssl.SSLWantWriteError:
raise WouldBlock from None
except ssl.SSLError as exc:
self._read_bio.write_eof()
self._write_bio.write_eof()
if (
isinstance(exc, ssl.SSLEOFError)
or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
):
if self.standard_compatible:
raise BrokenResourceError from exc
else:
raise EndOfStream from None

raise

if not data:
raise EndOfStream

return data

async def receive(self, max_bytes: int = 65536) -> bytes:
data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
if not data:
Expand Down