Skip to content

Commit

Permalink
Merge 46f1af0 into 234e434
Browse files Browse the repository at this point in the history
  • Loading branch information
heckad committed Apr 6, 2024
2 parents 234e434 + 46f1af0 commit 3c2b642
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
UNIXDatagramPacketType,
)
from ..lowlevel import RunVar
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from ..streams.memory import MemoryObjectReceiveStream

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand Down Expand Up @@ -1838,8 +1838,6 @@ def _create_task_info(task: asyncio.Task) -> TaskInfo:


class TestRunner(abc.TestRunner):
_send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]

def __init__(
self,
*,
Expand Down Expand Up @@ -1922,6 +1920,8 @@ async def _call_in_runner_task(
self._run_tests_and_fixtures(receive_stream)
)

self._runner_task.add_done_callback(lambda _: self._send_stream.close())

coro = func(*args, **kwargs)
future: asyncio.Future[T_Retval] = self.get_loop().create_future()
self._send_stream.send_nowait((coro, future))
Expand Down
46 changes: 46 additions & 0 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import sys
import traceback
import warnings
from collections import OrderedDict, deque
from dataclasses import dataclass, field
from types import TracebackType
Expand Down Expand Up @@ -59,10 +62,17 @@ def statistics(self) -> MemoryObjectStreamStatistics:
class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
_state: MemoryObjectStreamState[T_co]
_closed: bool = field(init=False, default=False)
_source_traceback: traceback.StackSummary | None = field(init=False, default=None)

def __post_init__(self) -> None:
self._state.open_receive_channels += 1

if self.is_source_traceback_capturing_enabled():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

def is_source_traceback_capturing_enabled(self) -> bool:
return sys.flags.dev_mode

def receive_nowait(self) -> T_co:
"""
Receive the next item if it can be done without waiting.
Expand Down Expand Up @@ -164,15 +174,36 @@ def __exit__(
) -> None:
self.close()

def __del__(self) -> None:
if not self._closed:
created_at_message = ""

if self._source_traceback is not None:
frame = self._source_traceback[-3]
created_at_message = f", created_at {frame[0]}:{frame[1]}"

warnings.warn(
f"Unclosed <{self.__class__.__name__}{created_at_message}>",
ResourceWarning,
source=self,
)


@dataclass(eq=False)
class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
_state: MemoryObjectStreamState[T_contra]
_closed: bool = field(init=False, default=False)
_source_traceback: traceback.StackSummary | None = field(init=False, default=None)

def __post_init__(self) -> None:
self._state.open_send_channels += 1

if self.is_source_traceback_capturing_enabled():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

def is_source_traceback_capturing_enabled(self) -> bool:
return sys.flags.dev_mode

def send_nowait(self, item: T_contra) -> None:
"""
Send an item immediately if it can be done without waiting.
Expand Down Expand Up @@ -281,3 +312,18 @@ def __exit__(
exc_tb: TracebackType | None,
) -> None:
self.close()

def __del__(self) -> None:
if not self._closed:
created_at_message = ""
created_at_message = ""

if self._source_traceback is not None:
frame = self._source_traceback[-3]
created_at_message = f", created_at {frame.filename}:{frame.lineno}"

warnings.warn(
f"Unclosed <{self.__class__.__name__}{created_at_message}>",
ResourceWarning,
source=self,
)
12 changes: 12 additions & 0 deletions tests/streams/test_buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ async def test_receive_exactly() -> None:
assert result == b"abcdefgh"
assert isinstance(result, bytes)

send_stream.close()
receive_stream.close()


async def test_receive_exactly_incomplete() -> None:
send_stream, receive_stream = create_memory_object_stream[bytes](1)
Expand All @@ -26,6 +29,9 @@ async def test_receive_exactly_incomplete() -> None:
with pytest.raises(IncompleteRead):
await buffered_stream.receive_exactly(8)

send_stream.close()
receive_stream.close()


async def test_receive_until() -> None:
send_stream, receive_stream = create_memory_object_stream[bytes](2)
Expand All @@ -41,6 +47,9 @@ async def test_receive_until() -> None:
assert result == b"fg"
assert isinstance(result, bytes)

send_stream.close()
receive_stream.close()


async def test_receive_until_incomplete() -> None:
send_stream, receive_stream = create_memory_object_stream[bytes](1)
Expand All @@ -51,3 +60,6 @@ async def test_receive_until_incomplete() -> None:
assert await buffered_stream.receive_until(b"de", 10)

assert buffered_stream.buffer == b"abcd"

send_stream.close()
receive_stream.close()
102 changes: 101 additions & 1 deletion tests/streams/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import gc
import sys
from typing import NoReturn

import pytest
from pytest_mock import MockerFixture

from anyio import (
BrokenResourceError,
Expand Down Expand Up @@ -52,6 +54,9 @@ async def receiver() -> None:

assert received_objects == ["hello", "anyio"]

send.close()
receive.close()


async def test_receive_then_send_nowait() -> None:
async def receiver() -> None:
Expand All @@ -68,6 +73,9 @@ async def receiver() -> None:

assert sorted(received_objects, reverse=True) == ["hello", "anyio"]

send.close()
receive.close()


async def test_send_then_receive_nowait() -> None:
send, receive = create_memory_object_stream[str](0)
Expand All @@ -76,6 +84,9 @@ async def test_send_then_receive_nowait() -> None:
await wait_all_tasks_blocked()
assert receive.receive_nowait() == "hello"

send.close()
receive.close()


async def test_send_is_unblocked_after_receive_nowait() -> None:
send, receive = create_memory_object_stream[str](1)
Expand All @@ -89,6 +100,9 @@ async def test_send_is_unblocked_after_receive_nowait() -> None:

assert receive.receive_nowait() == "anyio"

send.close()
receive.close()


async def test_send_nowait_then_receive_nowait() -> None:
send, receive = create_memory_object_stream[str](2)
Expand All @@ -97,6 +111,9 @@ async def test_send_nowait_then_receive_nowait() -> None:
assert receive.receive_nowait() == "hello"
assert receive.receive_nowait() == "anyio"

send.close()
receive.close()


async def test_iterate() -> None:
async def receiver() -> None:
Expand All @@ -113,6 +130,9 @@ async def receiver() -> None:

assert received_objects == ["hello", "anyio"]

send.close()
receive.close()


async def test_receive_send_closed_send_stream() -> None:
send, receive = create_memory_object_stream[None]()
Expand All @@ -123,6 +143,8 @@ async def test_receive_send_closed_send_stream() -> None:
with pytest.raises(ClosedResourceError):
await send.send(None)

receive.close()


async def test_receive_send_closed_receive_stream() -> None:
send, receive = create_memory_object_stream[None]()
Expand All @@ -133,6 +155,8 @@ async def test_receive_send_closed_receive_stream() -> None:
with pytest.raises(BrokenResourceError):
await send.send(None)

send.close()


async def test_cancel_receive() -> None:
send, receive = create_memory_object_stream[str]()
Expand All @@ -144,6 +168,9 @@ async def test_cancel_receive() -> None:
with pytest.raises(WouldBlock):
send.send_nowait("hello")

send.close()
receive.close()


async def test_cancel_send() -> None:
send, receive = create_memory_object_stream[str]()
Expand All @@ -155,6 +182,9 @@ async def test_cancel_send() -> None:
with pytest.raises(WouldBlock):
receive.receive_nowait()

send.close()
receive.close()


async def test_clone() -> None:
send1, receive1 = create_memory_object_stream[str](1)
Expand All @@ -165,6 +195,11 @@ async def test_clone() -> None:
send2.send_nowait("hello")
assert receive2.receive_nowait() == "hello"

send1.close()
receive1.close()
send2.close()
receive2.close()


async def test_clone_closed() -> None:
send, receive = create_memory_object_stream[NoReturn](1)
Expand All @@ -185,6 +220,9 @@ async def test_close_send_while_receiving() -> None:
assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], EndOfStream)

send.close()
receive.close()


async def test_close_receive_while_sending() -> None:
send, receive = create_memory_object_stream[str](0)
Expand All @@ -197,13 +235,19 @@ async def test_close_receive_while_sending() -> None:
assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], BrokenResourceError)

send.close()
receive.close()


async def test_receive_after_send_closed() -> None:
send, receive = create_memory_object_stream[str](1)
await send.send("hello")
await send.aclose()
assert await receive.receive() == "hello"

send.close()
receive.close()


async def test_receive_when_cancelled() -> None:
"""
Expand All @@ -225,6 +269,9 @@ async def test_receive_when_cancelled() -> None:
assert await receive.receive() == "hello"
assert await receive.receive() == "world"

send.close()
receive.close()


async def test_send_when_cancelled() -> None:
"""
Expand All @@ -248,6 +295,9 @@ async def receiver() -> None:

assert received == ["world"]

send.close()
receive.close()


async def test_cancel_during_receive() -> None:
"""
Expand Down Expand Up @@ -275,6 +325,9 @@ async def scoped_receiver() -> None:

assert received == ["hello"]

send.close()
receive.close()


async def test_close_receive_after_send() -> None:
async def send() -> None:
Expand All @@ -290,6 +343,9 @@ async def receive() -> None:
tg.start_soon(send)
tg.start_soon(receive)

send_stream.close()
receive_stream.close()


async def test_statistics() -> None:
send_stream, receive_stream = create_memory_object_stream[None](1)
Expand Down Expand Up @@ -347,6 +403,9 @@ async def test_statistics() -> None:
assert stream.statistics().tasks_waiting_send == 0
assert stream.statistics().tasks_waiting_receive == 0

send_stream.close()
receive_stream.close()


async def test_sync_close() -> None:
send_stream, receive_stream = create_memory_object_stream[None](1)
Expand Down Expand Up @@ -374,7 +433,48 @@ async def test_type_variance() -> None:
send1: MemoryObjectSendStream[int] = send # noqa: F841
send2: ObjectSendStream[int] = send # noqa: F841

send.close()
receive.close()


async def test_deprecated_item_type_parameter() -> None:
with pytest.warns(DeprecationWarning, match="item_type argument has been "):
create_memory_object_stream(item_type=int)
send, receive = create_memory_object_stream(item_type=int) # type: ignore[var-annotated]

send.close()
receive.close()


@pytest.mark.parametrize("is_source_traceback_capturing_enabled", [True, False])
async def test_not_closed_warning(
mocker: MockerFixture, is_source_traceback_capturing_enabled: bool
) -> None:
mocker.patch.object(
MemoryObjectReceiveStream,
"is_source_traceback_capturing_enabled",
return_value=is_source_traceback_capturing_enabled,
)
mocker.patch.object(
MemoryObjectSendStream,
"is_source_traceback_capturing_enabled",
return_value=is_source_traceback_capturing_enabled,
)

send, receive = create_memory_object_stream[int]()

if is_source_traceback_capturing_enabled:
match_suffix = ", .*>$"
else:
match_suffix = ">$"

with pytest.warns(
ResourceWarning, match=f"Unclosed <MemoryObjectSendStream{match_suffix}"
):
del send
gc.collect()

with pytest.warns(
ResourceWarning, match=f"Unclosed <MemoryObjectReceiveStream{match_suffix}"
):
del receive
gc.collect()

0 comments on commit 3c2b642

Please sign in to comment.