Skip to content

Commit

Permalink
Merge f328d02 into 234e434
Browse files Browse the repository at this point in the history
  • Loading branch information
heckad committed Apr 6, 2024
2 parents 234e434 + f328d02 commit 34df288
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
variable when setting the ``debug`` flag in ``anyio.run()``
- Fixed ``SocketStream.receive()`` not detecting EOF on asyncio if there is also data in
the read buffer (`#701 <https://github.com/agronholm/anyio/issues/701>`_)
- Add check on closed ``MemoryObjectReceiveStream`` and ``MemoryObjectSendStream``
in ``__del__`` method

**4.3.0**

Expand Down
6 changes: 3 additions & 3 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
UNIXDatagramPacketType,
)
from ..lowlevel import RunVar
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from ..streams.memory import MemoryObjectReceiveStream

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand Down Expand Up @@ -1838,8 +1838,6 @@ def _create_task_info(task: asyncio.Task) -> TaskInfo:


class TestRunner(abc.TestRunner):
_send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]

def __init__(
self,
*,
Expand Down Expand Up @@ -1922,6 +1920,8 @@ async def _call_in_runner_task(
self._run_tests_and_fixtures(receive_stream)
)

self._runner_task.add_done_callback(lambda _: self._send_stream.close())

coro = func(*args, **kwargs)
future: asyncio.Future[T_Retval] = self.get_loop().create_future()
self._send_stream.send_nowait((coro, future))
Expand Down
45 changes: 45 additions & 0 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import sys
import traceback
import warnings
from collections import OrderedDict, deque
from dataclasses import dataclass, field
from types import TracebackType
Expand Down Expand Up @@ -59,10 +62,17 @@ def statistics(self) -> MemoryObjectStreamStatistics:
class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
_state: MemoryObjectStreamState[T_co]
_closed: bool = field(init=False, default=False)
_source_traceback: traceback.StackSummary | None = field(init=False, default=None)

def __post_init__(self) -> None:
self._state.open_receive_channels += 1

if self.is_source_traceback_capturing_enabled():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

def is_source_traceback_capturing_enabled(self) -> bool:
return sys.flags.dev_mode

def receive_nowait(self) -> T_co:
"""
Receive the next item if it can be done without waiting.
Expand Down Expand Up @@ -164,15 +174,36 @@ def __exit__(
) -> None:
self.close()

def __del__(self) -> None:
if not self._closed:
created_at_message = ""

if self._source_traceback is not None:
frame = self._source_traceback[-3]
created_at_message = f", created_at {frame[0]}:{frame[1]}"

warnings.warn(
f"Unclosed <{self.__class__.__name__}{created_at_message}>",
ResourceWarning,
source=self,
)


@dataclass(eq=False)
class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
_state: MemoryObjectStreamState[T_contra]
_closed: bool = field(init=False, default=False)
_source_traceback: traceback.StackSummary | None = field(init=False, default=None)

def __post_init__(self) -> None:
self._state.open_send_channels += 1

if self.is_source_traceback_capturing_enabled():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

def is_source_traceback_capturing_enabled(self) -> bool:
return sys.flags.dev_mode

def send_nowait(self, item: T_contra) -> None:
"""
Send an item immediately if it can be done without waiting.
Expand Down Expand Up @@ -281,3 +312,17 @@ def __exit__(
exc_tb: TracebackType | None,
) -> None:
self.close()

def __del__(self) -> None:
if not self._closed:
created_at_message = ""

if self._source_traceback is not None:
frame = self._source_traceback[-3]
created_at_message = f", created_at {frame.filename}:{frame.lineno}"

warnings.warn(
f"Unclosed <{self.__class__.__name__}{created_at_message}>",
ResourceWarning,
source=self,
)
12 changes: 12 additions & 0 deletions tests/streams/test_buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ async def test_receive_exactly() -> None:
assert result == b"abcdefgh"
assert isinstance(result, bytes)

send_stream.close()
receive_stream.close()


async def test_receive_exactly_incomplete() -> None:
send_stream, receive_stream = create_memory_object_stream[bytes](1)
Expand All @@ -26,6 +29,9 @@ async def test_receive_exactly_incomplete() -> None:
with pytest.raises(IncompleteRead):
await buffered_stream.receive_exactly(8)

send_stream.close()
receive_stream.close()


async def test_receive_until() -> None:
send_stream, receive_stream = create_memory_object_stream[bytes](2)
Expand All @@ -41,6 +47,9 @@ async def test_receive_until() -> None:
assert result == b"fg"
assert isinstance(result, bytes)

send_stream.close()
receive_stream.close()


async def test_receive_until_incomplete() -> None:
send_stream, receive_stream = create_memory_object_stream[bytes](1)
Expand All @@ -51,3 +60,6 @@ async def test_receive_until_incomplete() -> None:
assert await buffered_stream.receive_until(b"de", 10)

assert buffered_stream.buffer == b"abcd"

send_stream.close()
receive_stream.close()

0 comments on commit 34df288

Please sign in to comment.