diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 78ef5105..77bac3b1 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -8,6 +8,7 @@ This library adheres to `Semantic Versioning 2.0 `_. - Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``, ``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by Lura Skye) +- Added the ``receive_nowait()`` method to the entire stream class hierarchy - Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help from Egor Blagov) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index bdc0aa8d..0d46307b 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -880,6 +880,16 @@ def _spawn_task_from_thread( class StreamReaderWrapper(abc.ByteReceiveStream): _stream: asyncio.StreamReader + def receive_nowait(self, max_bytes: int = 65536) -> bytes: + if self._stream.exception(): + raise self._stream.exception() + elif not self._stream._buffer: # type: ignore[attr-defined] + raise WouldBlock + + data = self._stream._buffer[:max_bytes] # type: ignore[attr-defined] + del self._stream._buffer[:max_bytes] # type: ignore[attr-defined] + return data + async def receive(self, max_bytes: int = 65536) -> bytes: data = await self._stream.read(max_bytes) if data: diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index eb891d22..7df56875 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -375,6 +375,18 @@ def __init__(self, trio_socket: TrioSocketType) -> None: self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") + def receive_nowait(self, max_bytes: int = 65536) -> bytes: + with self._receive_guard: + try: + data = self._raw_socket.recv(max_bytes) + except BaseException as exc: + self._convert_socket_error(exc) + + if data: + return data + else: + raise EndOfStream + async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: try: diff --git a/src/anyio/abc/_streams.py b/src/anyio/abc/_streams.py index 8c638683..e7de1562 100644 --- a/src/anyio/abc/_streams.py +++ b/src/anyio/abc/_streams.py @@ -4,7 +4,7 @@ from collections.abc import Callable from typing import Any, Generic, TypeVar, Union -from .._core._exceptions import EndOfStream +from .._core._exceptions import EndOfStream, WouldBlock from .._core._typedattr import TypedAttributeProvider from ._resources import AsyncResource from ._tasks import TaskGroup @@ -36,6 +36,20 @@ async def __anext__(self) -> T_co: except EndOfStream: raise StopAsyncIteration + def receive_nowait(self) -> T_co: + """ + Receive the next item if it can be done without waiting. + + :raises ~anyio.ClosedResourceError: if the receive stream has been explicitly + closed + :raises ~anyio.EndOfStream: if this stream has been closed from the other end + :raises ~anyio.BrokenResourceError: if this stream has been rendered unusable + due to external causes + :raises ~anyio.WouldBlock: if there is no item immeditately available + + """ + raise WouldBlock + @abstractmethod async def receive(self) -> T_co: """ @@ -132,6 +146,21 @@ async def __anext__(self) -> bytes: except EndOfStream: raise StopAsyncIteration + def receive_nowait(self, max_bytes: int = 65536) -> bytes: + """ + Receive at most ``max_bytes`` bytes from the peer, if it can be done without + blocking. + + .. note:: Implementors of this interface should not return an empty + :class:`bytes` object, and users should ignore them. + + :param max_bytes: maximum number of bytes to receive + :return: the received bytes + :raises ~anyio.EndOfStream: if this stream has been closed from the other end + :raises ~anyio.WouldBlock: if there is no data waiting to be received + """ + raise WouldBlock + @abstractmethod async def receive(self, max_bytes: int = 65536) -> bytes: """ diff --git a/src/anyio/streams/buffered.py b/src/anyio/streams/buffered.py index f5d5e836..1ff61875 100644 --- a/src/anyio/streams/buffered.py +++ b/src/anyio/streams/buffered.py @@ -32,6 +32,27 @@ def buffer(self) -> bytes: def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.receive_stream.extra_attributes + def receive_nowait(self, max_bytes: int = 65536) -> bytes: + if self._closed: + raise ClosedResourceError + + if self._buffer: + chunk = bytes(self._buffer[:max_bytes]) + del self._buffer[:max_bytes] + return chunk + elif isinstance(self.receive_stream, ByteReceiveStream): + return self.receive_stream.receive_nowait(max_bytes) + else: + # With a bytes-oriented object stream, we need to handle any surplus bytes + # we get from the receive_nowait() call + chunk = self.receive_stream.receive_nowait() + if len(chunk) > max_bytes: + # Save the surplus bytes in the buffer + self._buffer.extend(chunk[max_bytes:]) + return chunk[:max_bytes] + else: + return chunk + async def receive(self, max_bytes: int = 65536) -> bytes: if self._closed: raise ClosedResourceError diff --git a/src/anyio/streams/memory.py b/src/anyio/streams/memory.py index bc2425b7..cfb3fa60 100644 --- a/src/anyio/streams/memory.py +++ b/src/anyio/streams/memory.py @@ -64,17 +64,6 @@ def __post_init__(self) -> None: self._state.open_receive_channels += 1 def receive_nowait(self) -> T_co: - """ - Receive the next item if it can be done without waiting. - - :return: the received item - :raises ~anyio.ClosedResourceError: if this send stream has been closed - :raises ~anyio.EndOfStream: if the buffer is empty and this stream has been - closed from the sending end - :raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks - waiting to send - - """ if self._closed: raise ClosedResourceError diff --git a/src/anyio/streams/stapled.py b/src/anyio/streams/stapled.py index 80f64a2e..e49f24a2 100644 --- a/src/anyio/streams/stapled.py +++ b/src/anyio/streams/stapled.py @@ -34,6 +34,9 @@ class StapledByteStream(ByteStream): send_stream: ByteSendStream receive_stream: ByteReceiveStream + def receive_nowait(self, max_bytes: int = 65536) -> bytes: + return self.receive_stream.receive_nowait(max_bytes) + async def receive(self, max_bytes: int = 65536) -> bytes: return await self.receive_stream.receive(max_bytes) diff --git a/src/anyio/streams/text.py b/src/anyio/streams/text.py index f1a11278..8e4a38e0 100644 --- a/src/anyio/streams/text.py +++ b/src/anyio/streams/text.py @@ -42,6 +42,13 @@ def __post_init__(self, encoding: str, errors: str) -> None: decoder_class = codecs.getincrementaldecoder(encoding) self._decoder = decoder_class(errors=errors) + def receive_nowait(self) -> str: + while True: + chunk = self.transport_stream.receive_nowait() + decoded = self._decoder.decode(chunk) + if decoded: + return decoded + async def receive(self) -> str: while True: chunk = await self.transport_stream.receive() diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py index 8468f33d..7a20564f 100644 --- a/src/anyio/streams/tls.py +++ b/src/anyio/streams/tls.py @@ -11,6 +11,7 @@ from .. import ( BrokenResourceError, EndOfStream, + WouldBlock, aclose_forcefully, get_cancelled_exc_class, ) @@ -194,6 +195,45 @@ async def aclose(self) -> None: await self.transport_stream.aclose() + def receive_nowait(self, max_bytes: int = 65536) -> bytes: + while True: + try: + data = self._ssl_object.read(max_bytes) + break + except ssl.SSLWantReadError: + try: + data = self.transport_stream.receive_nowait() + except WouldBlock: + raise WouldBlock from None + except EndOfStream: + self._read_bio.write_eof() + except OSError as exc: + self._read_bio.write_eof() + self._write_bio.write_eof() + raise BrokenResourceError from exc + else: + self._read_bio.write(data) + except ssl.SSLWantWriteError: + raise WouldBlock from None + except ssl.SSLError as exc: + self._read_bio.write_eof() + self._write_bio.write_eof() + if ( + isinstance(exc, ssl.SSLEOFError) + or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror + ): + if self.standard_compatible: + raise BrokenResourceError from exc + else: + raise EndOfStream from None + + raise + + if not data: + raise EndOfStream + + return data + async def receive(self, max_bytes: int = 65536) -> bytes: data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) if not data: