Skip to content

Commit

Permalink
Merge 52e3d3e into e0529a3
Browse files Browse the repository at this point in the history
  • Loading branch information
heckad committed Mar 2, 2024
2 parents e0529a3 + 52e3d3e commit aeda0b8
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"exceptiongroup >= 1.0.2; python_version < '3.11'",
"idna >= 2.8",
"sniffio >= 1.1",
"typing_extensions >= 4.1; python_version < '3.11'",
"typing_extensions >= 4.1; python_version < '3.12'",
]
dynamic = ["version"]

Expand Down
4 changes: 4 additions & 0 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1851,6 +1851,7 @@ def __init__(
self._runner = Runner(debug=debug, loop_factory=loop_factory)
self._exceptions: list[BaseException] = []
self._runner_task: asyncio.Task | None = None
self._send_stream: MemoryObjectSendStream[Any] | None = None

def __enter__(self) -> TestRunner:
self._runner.__enter__()
Expand All @@ -1865,6 +1866,9 @@ def __exit__(
) -> None:
self._runner.__exit__(exc_type, exc_val, exc_tb)

if (send_stream := self._send_stream) is not None:
send_stream.close()

def get_loop(self) -> AbstractEventLoop:
return self._runner.get_loop()

Expand Down
118 changes: 84 additions & 34 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

import sys
import traceback
import warnings
from abc import ABCMeta
from collections import OrderedDict, deque
from dataclasses import dataclass, field
from types import TracebackType
Expand All @@ -14,6 +18,11 @@
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
from ..lowlevel import checkpoint

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

T_Item = TypeVar("T_Item")
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)
Expand Down Expand Up @@ -55,12 +64,72 @@ def statistics(self) -> MemoryObjectStreamStatistics:
)


@dataclass(eq=False)
class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
class AsyncCloseableResource(metaclass=ABCMeta):
"""
Abstract base class for all closeable asynchronous resources.
Works as an asynchronous context manager which returns the instance itself on enter,
and calls :meth:`aclose` on exit.
"""

def __init__(self):
self._closed = False

self._source_traceback = traceback.extract_stack(sys._getframe(1))

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()

async def aclose(self) -> None:
"""Close the resource."""

self.close()

def __enter__(self) -> Self:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()

def close(self) -> None:
"""Close the resource."""

self._closed = True

def __del__(self) -> None:
if not self._closed:
frame = self._source_traceback[-3]
created_at_message = f", created_at {frame[0]}:{frame[1]}"

warnings.warn(
f"Unclosed <{self.__class__.__name__}{created_at_message}>",
ResourceWarning,
source=self,
)


class MemoryObjectReceiveStream(
Generic[T_co], AsyncCloseableResource, ObjectReceiveStream[T_co]
):
_state: MemoryObjectStreamState[T_co]
_closed: bool = field(init=False, default=False)

def __post_init__(self) -> None:
def __init__(self, _state: MemoryObjectStreamState[T_co]) -> None:
super().__init__()

self._state = _state
self._state.open_receive_channels += 1

def receive_nowait(self) -> T_co:
Expand Down Expand Up @@ -142,9 +211,6 @@ def close(self) -> None:
for event in send_events:
event.set()

async def aclose(self) -> None:
self.close()

def statistics(self) -> MemoryObjectStreamStatistics:
"""
Return statistics about the current state of this stream.
Expand All @@ -153,24 +219,19 @@ def statistics(self) -> MemoryObjectStreamStatistics:
"""
return self._state.statistics()

def __enter__(self) -> MemoryObjectReceiveStream[T_co]:
return self
def __repr__(self) -> str:
return f"{self.__class__.__name__}"

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()


@dataclass(eq=False)
class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
class MemoryObjectSendStream(
Generic[T_contra], AsyncCloseableResource, ObjectSendStream[T_contra]
):
_state: MemoryObjectStreamState[T_contra]
_closed: bool = field(init=False, default=False)

def __post_init__(self) -> None:
def __init__(self, _state: MemoryObjectStreamState[T_co]) -> None:
super().__init__()

self._state = _state
self._state.open_send_channels += 1

def send_nowait(self, item: T_contra) -> None:
Expand Down Expand Up @@ -260,9 +321,6 @@ def close(self) -> None:
for event in receive_events:
event.set()

async def aclose(self) -> None:
self.close()

def statistics(self) -> MemoryObjectStreamStatistics:
"""
Return statistics about the current state of this stream.
Expand All @@ -271,13 +329,5 @@ def statistics(self) -> MemoryObjectStreamStatistics:
"""
return self._state.statistics()

def __enter__(self) -> MemoryObjectSendStream[T_contra]:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
8 changes: 8 additions & 0 deletions tests/streams/test_buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ async def test_receive_exactly() -> None:
assert result == b"abcdefgh"
assert isinstance(result, bytes)

send_stream.close(), receive_stream.close()


async def test_receive_exactly_incomplete() -> None:
send_stream, receive_stream = create_memory_object_stream[bytes](1)
Expand All @@ -26,6 +28,8 @@ async def test_receive_exactly_incomplete() -> None:
with pytest.raises(IncompleteRead):
await buffered_stream.receive_exactly(8)

send_stream.close(), receive_stream.close()


async def test_receive_until() -> None:
send_stream, receive_stream = create_memory_object_stream[bytes](2)
Expand All @@ -41,6 +45,8 @@ async def test_receive_until() -> None:
assert result == b"fg"
assert isinstance(result, bytes)

send_stream.close(), receive_stream.close()


async def test_receive_until_incomplete() -> None:
send_stream, receive_stream = create_memory_object_stream[bytes](1)
Expand All @@ -51,3 +57,5 @@ async def test_receive_until_incomplete() -> None:
assert await buffered_stream.receive_until(b"de", 10)

assert buffered_stream.buffer == b"abcd"

send_stream.close(), receive_stream.close()

0 comments on commit aeda0b8

Please sign in to comment.