Skip to content

Commit

Permalink
Fixed memory object stream sometimes dropping sent items (#735)
Browse files Browse the repository at this point in the history
Check if the receiving task has a pending cancellation before sending an item.

Fixes #728.
  • Loading branch information
agronholm committed May 26, 2024
1 parent 9f5f14b commit e7f750b
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 56 deletions.
8 changes: 8 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
portals
- Added ``__slots__`` to ``AsyncResource`` so that child classes can use ``__slots__``
(`#733 <https://github.com/agronholm/anyio/pull/733>`_; PR by Justin Su)
- Added the ``TaskInfo.has_pending_cancellation()`` method
- Fixed erroneous ``RuntimeError: called 'started' twice on the same task status``
when cancelling a task in a TaskGroup created with the ``start()`` method before
the first checkpoint is reached after calling ``task_status.started()``
(`#706 <https://github.com/agronholm/anyio/issues/706>`_; PR by Dominik Schwabe)
- Fixed two bugs with ``TaskGroup.start()`` on asyncio:

* Fixed erroneous ``RuntimeError: called 'started' twice on the same task status``
Expand All @@ -32,6 +37,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
variable when setting the ``debug`` flag in ``anyio.run()``
- Fixed ``SocketStream.receive()`` not detecting EOF on asyncio if there is also data in
the read buffer (`#701 <https://github.com/agronholm/anyio/issues/701>`_)
- Fixed ``MemoryObjectStream`` dropping an item if the item is delivered to a recipient
that is waiting to receive an item but has a cancellation pending
(`#728 <https://github.com/agronholm/anyio/issues/728>`_)
- Emit a ``ResourceWarning`` for ``MemoryObjectReceiveStream`` and
``MemoryObjectSendStream`` that were garbage collected without being closed (PR by
Andrey Kazantcev)
Expand Down
47 changes: 35 additions & 12 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import socket
import sys
import threading
import weakref
from asyncio import (
AbstractEventLoop,
CancelledError,
Expand Down Expand Up @@ -596,14 +597,14 @@ class TaskState:
itself because there are no guarantees about its implementation.
"""

__slots__ = "parent_id", "cancel_scope"
__slots__ = "parent_id", "cancel_scope", "__weakref__"

def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.parent_id = parent_id
self.cancel_scope = cancel_scope


_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState]
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()


#
Expand Down Expand Up @@ -1833,14 +1834,36 @@ async def __anext__(self) -> Signals:
#


def _create_task_info(task: asyncio.Task) -> TaskInfo:
task_state = _task_states.get(task)
if task_state is None:
parent_id = None
else:
parent_id = task_state.parent_id
class AsyncIOTaskInfo(TaskInfo):
def __init__(self, task: asyncio.Task):
task_state = _task_states.get(task)
if task_state is None:
parent_id = None
else:
parent_id = task_state.parent_id

super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
self._task = weakref.ref(task)

return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
def has_pending_cancellation(self) -> bool:
if not (task := self._task()):
# If the task isn't around anymore, it won't have a pending cancellation
return False

if sys.version_info >= (3, 11):
if task.cancelling():
return True
elif (
isinstance(task._fut_waiter, asyncio.Future)
and task._fut_waiter.cancelled()
):
return True

if task_state := _task_states.get(task):
if cancel_scope := task_state.cancel_scope:
return cancel_scope.cancel_called or cancel_scope._parent_cancelled()

return False


class TestRunner(abc.TestRunner):
Expand Down Expand Up @@ -2458,11 +2481,11 @@ def open_signal_receiver(

@classmethod
def get_current_task(cls) -> TaskInfo:
return _create_task_info(current_task()) # type: ignore[arg-type]
return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type]

@classmethod
def get_running_tasks(cls) -> list[TaskInfo]:
return [_create_task_info(task) for task in all_tasks() if not task.done()]
def get_running_tasks(cls) -> Sequence[TaskInfo]:
return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]

@classmethod
async def wait_all_tasks_blocked(cls) -> None:
Expand Down
36 changes: 23 additions & 13 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import socket
import sys
import types
import weakref
from collections.abc import AsyncIterator, Iterable
from concurrent.futures import Future
from dataclasses import dataclass
Expand Down Expand Up @@ -839,6 +840,24 @@ def run_test(
self._call_in_runner_task(test_func, **kwargs)


class TrioTaskInfo(TaskInfo):
def __init__(self, task: trio.lowlevel.Task):
parent_id = None
if task.parent_nursery and task.parent_nursery.parent_task:
parent_id = id(task.parent_nursery.parent_task)

super().__init__(id(task), parent_id, task.name, task.coro)
self._task = weakref.proxy(task)

def has_pending_cancellation(self) -> bool:
try:
return self._task._cancel_status.effectively_cancelled
except ReferenceError:
# If the task is no longer around, it surely doesn't have a cancellation
# pending
return False


class TrioBackend(AsyncBackend):
@classmethod
def run(
Expand Down Expand Up @@ -1125,28 +1144,19 @@ def open_signal_receiver(
@classmethod
def get_current_task(cls) -> TaskInfo:
task = current_task()

parent_id = None
if task.parent_nursery and task.parent_nursery.parent_task:
parent_id = id(task.parent_nursery.parent_task)

return TaskInfo(id(task), parent_id, task.name, task.coro)
return TrioTaskInfo(task)

@classmethod
def get_running_tasks(cls) -> list[TaskInfo]:
def get_running_tasks(cls) -> Sequence[TaskInfo]:
root_task = current_root_task()
assert root_task
task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)]
task_infos = [TrioTaskInfo(root_task)]
nurseries = root_task.child_nurseries
while nurseries:
new_nurseries: list[trio.Nursery] = []
for nursery in nurseries:
for task in nursery.child_tasks:
task_infos.append(
TaskInfo(
id(task), id(nursery.parent_task), task.name, task.coro
)
)
task_infos.append(TrioTaskInfo(task))
new_nurseries.extend(task.child_nurseries)

nurseries = new_nurseries
Expand Down
12 changes: 8 additions & 4 deletions src/anyio/_core/_testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Awaitable, Generator
from typing import Any
from typing import Any, cast

from ._eventloop import get_async_backend

Expand Down Expand Up @@ -45,8 +45,12 @@ def __hash__(self) -> int:
def __repr__(self) -> str:
return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})"

def _unwrap(self) -> TaskInfo:
return self
def has_pending_cancellation(self) -> bool:
"""
Return ``True`` if the task has a cancellation pending, ``False`` otherwise.
"""
return False


def get_current_task() -> TaskInfo:
Expand All @@ -66,7 +70,7 @@ def get_running_tasks() -> list[TaskInfo]:
:return: a list of task info objects
"""
return get_async_backend().get_running_tasks()
return cast("list[TaskInfo]", get_async_backend().get_running_tasks())


async def wait_all_tasks_blocked() -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def get_current_task(cls) -> TaskInfo:

@classmethod
@abstractmethod
def get_running_tasks(cls) -> list[TaskInfo]:
def get_running_tasks(cls) -> Sequence[TaskInfo]:
pass

@classmethod
Expand Down
32 changes: 21 additions & 11 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EndOfStream,
WouldBlock,
)
from .._core._testing import TaskInfo, get_current_task
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
from ..lowlevel import checkpoint

Expand All @@ -32,13 +33,19 @@ class MemoryObjectStreamStatistics(NamedTuple):
tasks_waiting_receive: int


@dataclass(eq=False)
class MemoryObjectItemReceiver(Generic[T_Item]):
task_info: TaskInfo = field(init=False, default_factory=get_current_task)
item: T_Item = field(init=False)


@dataclass(eq=False)
class MemoryObjectStreamState(Generic[T_Item]):
max_buffer_size: float = field()
buffer: deque[T_Item] = field(init=False, default_factory=deque)
open_send_channels: int = field(init=False, default=0)
open_receive_channels: int = field(init=False, default=0)
waiting_receivers: OrderedDict[Event, list[T_Item]] = field(
waiting_receivers: OrderedDict[Event, MemoryObjectItemReceiver[T_Item]] = field(
init=False, default_factory=OrderedDict
)
waiting_senders: OrderedDict[Event, T_Item] = field(
Expand Down Expand Up @@ -99,17 +106,17 @@ async def receive(self) -> T_co:
except WouldBlock:
# Add ourselves in the queue
receive_event = Event()
container: list[T_co] = []
self._state.waiting_receivers[receive_event] = container
receiver = MemoryObjectItemReceiver[T_co]()
self._state.waiting_receivers[receive_event] = receiver

try:
await receive_event.wait()
finally:
self._state.waiting_receivers.pop(receive_event, None)

if container:
return container[0]
else:
try:
return receiver.item
except AttributeError:
raise EndOfStream

def clone(self) -> MemoryObjectReceiveStream[T_co]:
Expand Down Expand Up @@ -199,11 +206,14 @@ def send_nowait(self, item: T_contra) -> None:
if not self._state.open_receive_channels:
raise BrokenResourceError

if self._state.waiting_receivers:
receive_event, container = self._state.waiting_receivers.popitem(last=False)
container.append(item)
receive_event.set()
elif len(self._state.buffer) < self._state.max_buffer_size:
while self._state.waiting_receivers:
receive_event, receiver = self._state.waiting_receivers.popitem(last=False)
if not receiver.task_info.has_pending_cancellation():
receiver.item = item
receive_event.set()
return

if len(self._state.buffer) < self._state.max_buffer_size:
self._state.buffer.append(item)
else:
raise WouldBlock
Expand Down
71 changes: 57 additions & 14 deletions tests/streams/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
fail_after,
wait_all_tasks_blocked,
)
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.abc import ObjectReceiveStream, ObjectSendStream, TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

if sys.version_info < (3, 11):
Expand Down Expand Up @@ -305,28 +305,49 @@ async def test_cancel_during_receive() -> None:
stream to be lost.
"""
receiver_scope = None

async def scoped_receiver() -> None:
nonlocal receiver_scope
with CancelScope() as receiver_scope:
async def scoped_receiver(task_status: TaskStatus[CancelScope]) -> None:
with CancelScope() as cancel_scope:
task_status.started(cancel_scope)
received.append(await receive.receive())

assert receiver_scope.cancel_called
assert cancel_scope.cancel_called

received: list[str] = []
send, receive = create_memory_object_stream[str]()
async with create_task_group() as tg:
tg.start_soon(scoped_receiver)
await wait_all_tasks_blocked()
send.send_nowait("hello")
assert receiver_scope is not None
receiver_scope.cancel()
with send, receive:
async with create_task_group() as tg:
receiver_scope = await tg.start(scoped_receiver)
await wait_all_tasks_blocked()
send.send_nowait("hello")
receiver_scope.cancel()

assert received == ["hello"]

send.close()
receive.close()

async def test_cancel_during_receive_buffered() -> None:
"""
Test that sending an item to a memory object stream when the receiver that is next
in line has been cancelled will not result in the item being lost.
"""

async def scoped_receiver(
receive: MemoryObjectReceiveStream[str], task_status: TaskStatus[CancelScope]
) -> None:
with CancelScope() as cancel_scope:
task_status.started(cancel_scope)
await receive.receive()

send, receive = create_memory_object_stream[str](1)
with send, receive:
async with create_task_group() as tg:
cancel_scope = await tg.start(scoped_receiver, receive)
await wait_all_tasks_blocked()
cancel_scope.cancel()
send.send_nowait("item")

# Since the item was not sent to the cancelled task, it should be available here
assert receive.receive_nowait() == "item"


async def test_close_receive_after_send() -> None:
Expand Down Expand Up @@ -455,3 +476,25 @@ async def test_not_closed_warning() -> None:
with pytest.warns(ResourceWarning, match="Unclosed <MemoryObjectReceiveStream>"):
del receive
gc.collect()


@pytest.mark.parametrize("anyio_backend", ["asyncio"], indirect=True)
async def test_send_to_natively_cancelled_receiver() -> None:
"""
Test that if a task waiting on receive.receive() is cancelled and then another
task sends an item, said item is not delivered to the task with a pending
cancellation, but rather to the next one in line.
"""
from asyncio import CancelledError, create_task

send, receive = create_memory_object_stream[str](1)
with send, receive:
receive_task = create_task(receive.receive())
await wait_all_tasks_blocked() # ensure that the task is waiting to receive
receive_task.cancel()
send.send_nowait("hello")
with pytest.raises(CancelledError):
await receive_task

assert receive.receive_nowait() == "hello"
2 changes: 1 addition & 1 deletion tests/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def inspect() -> None:
for task, expected_name in zip(task_infos, expected_names):
assert task.parent_id == host_task.id
assert task.name == expected_name
assert repr(task) == f"TaskInfo(id={task.id}, name={expected_name!r})"
assert repr(task).endswith(f"TaskInfo(id={task.id}, name={expected_name!r})")


@pytest.mark.skipif(
Expand Down

0 comments on commit e7f750b

Please sign in to comment.