Skip to content

Commit

Permalink
Added the receive_nowait() method to all streams
Browse files Browse the repository at this point in the history
Closes #482.
  • Loading branch information
agronholm committed Oct 25, 2022
1 parent 1dabfd8 commit 87c4c78
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Bumped minimum version of trio to v0.22
- Added ``create_unix_datagram_socket`` and ``create_connected_unix_datagram_socket`` to
create UNIX datagram sockets (PR by Jean Hominal)
- Added the ``receive_nowait()`` method to the entire stream class hierarchy
- Improved type annotations:

- Several functions and methods that previously only accepted coroutines as the return
Expand Down
10 changes: 10 additions & 0 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,16 @@ def _spawn_task_from_thread(
class StreamReaderWrapper(abc.ByteReceiveStream):
_stream: asyncio.StreamReader

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
if self._stream.exception():
raise self._stream.exception()
elif not self._stream._buffer: # type: ignore[attr-defined]
raise WouldBlock

data = self._stream._buffer[:max_bytes] # type: ignore[attr-defined]
del self._stream._buffer[:max_bytes] # type: ignore[attr-defined]
return data

async def receive(self, max_bytes: int = 65536) -> bytes:
data = await self._stream.read(max_bytes)
if data:
Expand Down
12 changes: 12 additions & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,18 @@ def __init__(self, trio_socket: TrioSocketType) -> None:
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
try:
data = self._raw_socket.recv(max_bytes)
except BaseException as exc:
self._convert_socket_error(exc)

if data:
return data
else:
raise EndOfStream

async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
try:
Expand Down
31 changes: 30 additions & 1 deletion src/anyio/abc/_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Callable
from typing import Any, Generic, TypeVar, Union

from .._core._exceptions import EndOfStream
from .._core._exceptions import EndOfStream, WouldBlock
from .._core._typedattr import TypedAttributeProvider
from ._resources import AsyncResource
from ._tasks import TaskGroup
Expand Down Expand Up @@ -36,6 +36,20 @@ async def __anext__(self) -> T_co:
except EndOfStream:
raise StopAsyncIteration

def receive_nowait(self) -> T_co:
"""
Receive the next item if it can be done without waiting.
:raises ~anyio.ClosedResourceError: if the receive stream has been explicitly
closed
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
:raises ~anyio.BrokenResourceError: if this stream has been rendered unusable
due to external causes
:raises ~anyio.WouldBlock: if there is no item immeditately available
"""
raise WouldBlock

@abstractmethod
async def receive(self) -> T_co:
"""
Expand Down Expand Up @@ -132,6 +146,21 @@ async def __anext__(self) -> bytes:
except EndOfStream:
raise StopAsyncIteration

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
"""
Receive at most ``max_bytes`` bytes from the peer, if it can be done without
blocking.
.. note:: Implementors of this interface should not return an empty
:class:`bytes` object, and users should ignore them.
:param max_bytes: maximum number of bytes to receive
:return: the received bytes
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
:raises ~anyio.WouldBlock: if there is no data waiting to be received
"""
raise WouldBlock

@abstractmethod
async def receive(self, max_bytes: int = 65536) -> bytes:
"""
Expand Down
21 changes: 21 additions & 0 deletions src/anyio/streams/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ def buffer(self) -> bytes:
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return self.receive_stream.extra_attributes

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
if self._closed:
raise ClosedResourceError

if self._buffer:
chunk = bytes(self._buffer[:max_bytes])
del self._buffer[:max_bytes]
return chunk
elif isinstance(self.receive_stream, ByteReceiveStream):
return self.receive_stream.receive_nowait(max_bytes)
else:
# With a bytes-oriented object stream, we need to handle any surplus bytes
# we get from the receive_nowait() call
chunk = self.receive_stream.receive_nowait()
if len(chunk) > max_bytes:
# Save the surplus bytes in the buffer
self._buffer.extend(chunk[max_bytes:])
return chunk[:max_bytes]
else:
return chunk

async def receive(self, max_bytes: int = 65536) -> bytes:
if self._closed:
raise ClosedResourceError
Expand Down
11 changes: 0 additions & 11 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@ def __post_init__(self) -> None:
self._state.open_receive_channels += 1

def receive_nowait(self) -> T_co:
"""
Receive the next item if it can be done without waiting.
:return: the received item
:raises ~anyio.ClosedResourceError: if this send stream has been closed
:raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
closed from the sending end
:raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
waiting to send
"""
if self._closed:
raise ClosedResourceError

Expand Down
3 changes: 3 additions & 0 deletions src/anyio/streams/stapled.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class StapledByteStream(ByteStream):
send_stream: ByteSendStream
receive_stream: ByteReceiveStream

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
return self.receive_stream.receive_nowait(max_bytes)

async def receive(self, max_bytes: int = 65536) -> bytes:
return await self.receive_stream.receive(max_bytes)

Expand Down
7 changes: 7 additions & 0 deletions src/anyio/streams/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def __post_init__(self, encoding: str, errors: str) -> None:
decoder_class = codecs.getincrementaldecoder(encoding)
self._decoder = decoder_class(errors=errors)

def receive_nowait(self) -> str:
while True:
chunk = self.transport_stream.receive_nowait()
decoded = self._decoder.decode(chunk)
if decoded:
return decoded

async def receive(self) -> str:
while True:
chunk = await self.transport_stream.receive()
Expand Down
12 changes: 12 additions & 0 deletions src/anyio/streams/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .. import (
BrokenResourceError,
EndOfStream,
WouldBlock,
aclose_forcefully,
get_cancelled_exc_class,
)
Expand Down Expand Up @@ -195,6 +196,17 @@ async def aclose(self) -> None:

await self.transport_stream.aclose()

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
try:
data = self._ssl_object.read(max_bytes)
except ssl.SSLError:
raise WouldBlock

if not data:
raise EndOfStream

return data

async def receive(self, max_bytes: int = 65536) -> bytes:
data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
if not data:
Expand Down

0 comments on commit 87c4c78

Please sign in to comment.