Skip to content

Commit

Permalink
Fixed cancellation propagation when task group host is in a shielded …
Browse files Browse the repository at this point in the history
…scope (#648)

Co-authored-by: Ganden Schaffner <gschaffner@pm.me>
  • Loading branch information
agronholm and gschaffner committed Dec 14, 2023
1 parent 3ea17f9 commit 44ca5ea
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 47 deletions.
3 changes: 3 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
from Egor Blagov)
- Fixed ``loop_factory`` and ``use_uvloop`` options not being used on the asyncio
backend (`#643 <https://github.com/agronholm/anyio/issues/643>`_)
- Fixed cancellation propagating on asyncio from a task group to child tasks if the task
hosting the task group is in a shielded cancel scope
(`#642 <https://github.com/agronholm/anyio/issues/642>`_)

**4.1.0**

Expand Down
109 changes: 63 additions & 46 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._deadline = deadline
self._shield = shield
self._parent_scope: CancelScope | None = None
self._child_scopes: set[CancelScope] = set()
self._cancel_called = False
self._cancelled_caught = False
self._active = False
Expand All @@ -369,6 +370,9 @@ def __enter__(self) -> CancelScope:
else:
self._parent_scope = task_state.cancel_scope
task_state.cancel_scope = self
if self._parent_scope is not None:
self._parent_scope._child_scopes.add(self)
self._parent_scope._tasks.remove(host_task)

self._timeout()
self._active = True
Expand All @@ -377,7 +381,7 @@ def __enter__(self) -> CancelScope:

# Start cancelling the host task if the scope was cancelled before entering
if self._cancel_called:
self._deliver_cancellation()
self._deliver_cancellation(self)

return self

Expand Down Expand Up @@ -409,13 +413,15 @@ def __exit__(
self._timeout_handle = None

self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)

host_task_state.cancel_scope = self._parent_scope

# Restart the cancellation effort in the farthest directly cancelled parent
# Restart the cancellation effort in the closest directly cancelled parent
# scope if this one was shielded
if self._shield:
self._deliver_cancellation_to_parent()
self._restart_cancellation_in_parent()

if self._cancel_called and exc_val is not None:
for exc in iterate_exceptions(exc_val):
Expand Down Expand Up @@ -451,65 +457,70 @@ def _timeout(self) -> None:
else:
self._timeout_handle = loop.call_at(self._deadline, self._timeout)

def _deliver_cancellation(self) -> None:
def _deliver_cancellation(self, origin: CancelScope) -> bool:
"""
Deliver cancellation to directly contained tasks and nested cancel scopes.
Schedule another run at the end if we still have tasks eligible for
cancellation.
:param origin: the cancel scope that originated the cancellation
:return: ``True`` if the delivery needs to be retried on the next cycle
"""
should_retry = False
current = current_task()
for task in self._tasks:
if task._must_cancel: # type: ignore[attr-defined]
continue

# The task is eligible for cancellation if it has started and is not in a
# cancel scope shielded from this one
cancel_scope = _task_states[task].cancel_scope
while cancel_scope is not self:
if cancel_scope is None or cancel_scope._shield:
break
else:
cancel_scope = cancel_scope._parent_scope
else:
should_retry = True
if task is not current and (
task is self._host_task or _task_started(task)
):
waiter = task._fut_waiter # type: ignore[attr-defined]
if not isinstance(waiter, asyncio.Future) or not waiter.done():
self._cancel_calls += 1
if sys.version_info >= (3, 9):
task.cancel(f"Cancelled by cancel scope {id(self):x}")
else:
task.cancel()
# The task is eligible for cancellation if it has started
should_retry = True
if task is not current and (task is self._host_task or _task_started(task)):
waiter = task._fut_waiter # type: ignore[attr-defined]
if not isinstance(waiter, asyncio.Future) or not waiter.done():
self._cancel_calls += 1
if sys.version_info >= (3, 9):
task.cancel(f"Cancelled by cancel scope {id(origin):x}")
else:
task.cancel()

# Deliver cancellation to child scopes that aren't shielded or running their own
# cancellation callbacks
for scope in self._child_scopes:
if not scope._shield and not scope.cancel_called:
should_retry = scope._deliver_cancellation(origin) or should_retry

# Schedule another callback if there are still tasks left
if should_retry:
self._cancel_handle = get_running_loop().call_soon(
self._deliver_cancellation
)
else:
self._cancel_handle = None
if origin is self:
if should_retry:
self._cancel_handle = get_running_loop().call_soon(
self._deliver_cancellation, origin
)
else:
self._cancel_handle = None

return should_retry

def _restart_cancellation_in_parent(self) -> None:
"""
Restart the cancellation effort in the closest directly cancelled parent scope.
def _deliver_cancellation_to_parent(self) -> None:
"""Start cancellation effort in the farthest directly cancelled parent scope"""
"""
scope = self._parent_scope
scope_to_cancel: CancelScope | None = None
while scope is not None:
if scope._cancel_called and scope._cancel_handle is None:
scope_to_cancel = scope
if scope._cancel_called:
if scope._cancel_handle is None:
scope._deliver_cancellation(scope)

break

# No point in looking beyond any shielded scope
if scope._shield:
break

scope = scope._parent_scope

if scope_to_cancel is not None:
scope_to_cancel._deliver_cancellation()

def _parent_cancelled(self) -> bool:
# Check whether any parent has been cancelled
cancel_scope = self._parent_scope
Expand All @@ -529,7 +540,7 @@ def cancel(self) -> None:

self._cancel_called = True
if self._host_task is not None:
self._deliver_cancellation()
self._deliver_cancellation(self)

@property
def deadline(self) -> float:
Expand Down Expand Up @@ -562,7 +573,7 @@ def shield(self, value: bool) -> None:
if self._shield != value:
self._shield = value
if not value:
self._deliver_cancellation_to_parent()
self._restart_cancellation_in_parent()


#
Expand Down Expand Up @@ -623,6 +634,7 @@ def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
self._active = False
self._exceptions: list[BaseException] = []
self._tasks: set[asyncio.Task] = set()

async def __aenter__(self) -> TaskGroup:
self.cancel_scope.__enter__()
Expand All @@ -642,9 +654,9 @@ async def __aexit__(
self._exceptions.append(exc_val)

cancelled_exc_while_waiting_tasks: CancelledError | None = None
while self.cancel_scope._tasks:
while self._tasks:
try:
await asyncio.wait(self.cancel_scope._tasks)
await asyncio.wait(self._tasks)
except CancelledError as exc:
# This task was cancelled natively; reraise the CancelledError later
# unless this task was already interrupted by another exception
Expand Down Expand Up @@ -676,8 +688,11 @@ def _spawn(
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
assert _task in self.cancel_scope._tasks
self.cancel_scope._tasks.remove(_task)
task_state = _task_states[_task]
assert task_state.cancel_scope is not None
assert _task in task_state.cancel_scope._tasks
task_state.cancel_scope._tasks.remove(_task)
self._tasks.remove(task)
del _task_states[_task]

try:
Expand All @@ -693,7 +708,8 @@ def task_done(_task: asyncio.Task) -> None:
if not isinstance(exc, CancelledError):
self._exceptions.append(exc)

self.cancel_scope.cancel()
if not self.cancel_scope._parent_cancelled():
self.cancel_scope.cancel()
else:
task_status_future.set_exception(exc)
elif task_status_future is not None and not task_status_future.done():
Expand Down Expand Up @@ -732,6 +748,7 @@ def task_done(_task: asyncio.Task) -> None:
parent_id=parent_id, cancel_scope=self.cancel_scope
)
self.cancel_scope._tasks.add(task)
self._tasks.add(task)
return task

def start_soon(
Expand Down
45 changes: 44 additions & 1 deletion tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ async def killer(scope: CancelScope) -> None:
assert isinstance(exc.value.exceptions[0], TimeoutError)


async def test_triple_nested_shield() -> None:
async def test_triple_nested_shield_checkpoint_in_outer() -> None:
"""Regression test for #370."""

got_past_checkpoint = False
Expand All @@ -867,6 +867,26 @@ async def taskfunc() -> None:
assert not got_past_checkpoint


async def test_triple_nested_shield_checkpoint_in_middle() -> None:
got_past_checkpoint = False

async def taskfunc() -> None:
nonlocal got_past_checkpoint

with CancelScope() as scope1:
with CancelScope():
with CancelScope(shield=True):
scope1.cancel()

await checkpoint()
got_past_checkpoint = True

async with create_task_group() as tg:
tg.start_soon(taskfunc)

assert not got_past_checkpoint


def test_task_group_in_generator(
anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
Expand Down Expand Up @@ -1293,6 +1313,29 @@ def handler(excgrp: BaseExceptionGroup) -> None:
await anyio.sleep_forever()


async def test_cancel_child_task_when_host_is_shielded() -> None:
# Regression test for #642
# Tests that cancellation propagates to a child task even if the host task is within
# a shielded cancel scope.
cancelled = anyio.Event()

async def wait_cancel() -> None:
try:
await anyio.sleep_forever()
except anyio.get_cancelled_exc_class():
cancelled.set()
raise

with CancelScope() as parent_scope:
async with anyio.create_task_group() as task_group:
task_group.start_soon(wait_cancel)
await wait_all_tasks_blocked()

with CancelScope(shield=True), fail_after(1):
parent_scope.cancel()
await cancelled.wait()


class TestTaskStatusTyping:
"""
These tests do not do anything at run time, but since the test suite is also checked
Expand Down

0 comments on commit 44ca5ea

Please sign in to comment.