Skip to content

Commit

Permalink
Merge 8a469ea into 8c0bd7e
Browse files Browse the repository at this point in the history
  • Loading branch information
smurfix committed Apr 15, 2019
2 parents 8c0bd7e + 8a469ea commit fb8094a
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 20 deletions.
66 changes: 56 additions & 10 deletions anyio/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,18 @@ async def sleep(delay: float) -> None:

class CancelScope:
__slots__ = ('_deadline', '_shield', '_parent_scope', '_cancel_called', '_active',
'_timeout_task', '_tasks', '_timeout_expired')
'_timeout_task', '_tasks', '_timeout_expired', '_children')

def __init__(self, deadline: float = math.inf, shield: bool = False):
self._deadline = deadline
self._shield = shield
self._parent_scope = None
self._cancel_called = False
self._cancel_called = None
self._active = False
self._timeout_task = None
self._tasks = set() # type: Set[asyncio.Task]
self._timeout_expired = False
self._children = set() # type: Set[CancelScope]

async def __aenter__(self):
async def timeout():
Expand All @@ -191,13 +192,19 @@ async def timeout():

host_task = current_task()
self._parent_scope = get_cancel_scope(host_task)
if self._parent_scope is not None:
self._parent_scope._children.add(self)
self._tasks.add(host_task)
set_cancel_scope(host_task, self)

if self._deadline != math.inf:
self._timeout_task = get_running_loop().create_task(timeout())

self._active = True

if self._parent_scope is not None:
if self._parent_scope._cancel_called and not self._shield:
await self.cancel(self._parent_scope._cancel_called)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand All @@ -207,20 +214,28 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

host_task = current_task()
self._tasks.remove(host_task)
if self._parent_scope is not None:
self._parent_scope._children.remove(self)
set_cancel_scope(host_task, self._parent_scope)

if isinstance(exc_val, asyncio.CancelledError):
if self._timeout_expired:
return True
elif self._cancel_called:
elif self._cancel_called == id(self):
# This scope was directly cancelled
return True
elif isinstance(exc_val, CancelledError):
return exc_val.cause == id(self)

async def cancel(self):
async def cancel(self, real_scope=None):
if self._cancel_called:
return

self._cancel_called = True
if real_scope is None:
real_scope = id(self)
elif type(real_scope) is not int:
real_scope = id(real_scope)
self._cancel_called = real_scope

# Cancel any contained tasks
for task in self._tasks:
Expand All @@ -231,13 +246,33 @@ async def cancel(self):
if scope is self or not scope.shield:
task.cancel()

for child in self._children:
child._sub_cancel(real_scope)

def _sub_cancel(self, real_scope):
if self.shield or self._cancel_called:
return

self._cancel_called = real_scope

for task in self._tasks:
if task._coro.cr_await is not None and not task._coro.cr_running:
# Cancel the task directly, but only if it's blocked and isn't within a shielded
# scope
scope = get_cancel_scope(task)
if scope is self or not scope.shield:
task.cancel()

for child in self._children:
child._sub_cancel(real_scope)

@property
def deadline(self) -> float:
return self._deadline

@property
def cancel_called(self) -> bool:
return self._cancel_called
return self._cancel_called == id(self)

@property
def shield(self) -> bool:
Expand Down Expand Up @@ -269,8 +304,10 @@ def set_cancel_scope(task: asyncio.Task, scope: Optional[CancelScope]):
def check_cancelled():
task = current_task()
cancel_scope = get_cancel_scope(task)
if cancel_scope is not None and not cancel_scope._shield and cancel_scope._cancel_called:
raise CancelledError
if cancel_scope is None:
return
if not cancel_scope._shield and cancel_scope._cancel_called:
raise CancelledError(cancel_scope._cancel_called)


def open_cancel_scope(deadline: float = math.inf, shield: bool = False) -> CancelScope:
Expand All @@ -282,7 +319,10 @@ def open_cancel_scope(deadline: float = math.inf, shield: bool = False) -> Cance
async def fail_after(delay: float, shield: bool):
deadline = get_running_loop().time() + delay
async with CancelScope(deadline, shield) as scope:
await yield_(scope)
try:
await yield_(scope)
except asyncio.CancelledError:
raise CancelledError(scope._cancel_called)

if scope._timeout_expired:
raise TimeoutError
Expand All @@ -293,7 +333,10 @@ async def fail_after(delay: float, shield: bool):
async def move_on_after(delay: float, shield: bool):
deadline = get_running_loop().time() + delay
async with CancelScope(deadline=deadline, shield=shield) as scope:
await yield_(scope)
try:
await yield_(scope)
except asyncio.CancelledError:
raise CancelledError(scope._cancel_called)


async def current_effective_deadline():
Expand Down Expand Up @@ -358,6 +401,9 @@ async def _run_wrapped_task(self, func, *args):
task = current_task()
try:
await func(*args)
except CancelledError as exc:
await self.cancel_scope.cancel(exc.cause)
raise
except BaseException:
await self.cancel_scope.cancel()
raise
Expand Down
66 changes: 56 additions & 10 deletions anyio/_backends/curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,18 @@ async def sleep(delay: float):

class CancelScope:
__slots__ = ('_deadline', '_shield', '_parent_scope', '_cancel_called', '_active',
'_timeout_task', '_tasks', '_timeout_expired')
'_timeout_task', '_tasks', '_timeout_expired', '_children')

def __init__(self, deadline: float = math.inf, shield: bool = False):
self._deadline = deadline
self._shield = shield
self._parent_scope = None
self._cancel_called = False
self._cancel_called = None
self._active = False
self._timeout_task = None
self._tasks = set() # type: Set[curio.Task]
self._timeout_expired = False
self._children = set() # type: Set[CancelScope]

async def __aenter__(self):
async def timeout():
Expand All @@ -78,12 +79,18 @@ async def timeout():

host_task = await curio.current_task()
self._parent_scope = get_cancel_scope(host_task)
if self._parent_scope is not None:
self._parent_scope._children.add(self)
self._tasks.add(host_task)
set_cancel_scope(host_task, self)

if self._deadline != math.inf:
self._timeout_task = await curio.spawn(timeout)

if self._parent_scope is not None:
if self._parent_scope._cancel_called and not self._shield:
await self.cancel(self._parent_scope._cancel_called)

self._active = True
return self

Expand All @@ -94,20 +101,28 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

host_task = await curio.current_task()
self._tasks.remove(host_task)
if self._parent_scope is not None:
self._parent_scope._children.remove(self)
set_cancel_scope(host_task, self._parent_scope)

if isinstance(exc_val, curio.TaskCancelled):
if self._timeout_expired:
return True
elif self._cancel_called:
elif self._cancel_called == id(self):
# This scope was directly cancelled
return True
elif isinstance(exc_val, CancelledError):
return exc_val.cause == id(self)

async def cancel(self):
async def cancel(self, real_scope=None):
if self._cancel_called:
return

self._cancel_called = True
if real_scope is None:
real_scope = id(self)
elif type(real_scope) is not int:
real_scope = id(real_scope)
self._cancel_called = real_scope

# Cancel any contained tasks
for task in self._tasks:
Expand All @@ -118,13 +133,33 @@ async def cancel(self):
if scope is self or not scope.shield:
await task.cancel(blocking=False)

for child in list(self._children):
await child._sub_cancel(real_scope)

async def _sub_cancel(self, real_scope):
if self.shield or self._cancel_called:
return

self._cancel_called = real_scope

for task in list(self._tasks):
if task.coro.cr_await is not None and not task.coro.cr_running:
# Cancel the task directly, but only if it's blocked and isn't within a shielded
# scope
scope = get_cancel_scope(task)
if scope is self or not scope.shield:
await task.cancel()

for child in self._children:
await child._sub_cancel(real_scope)

@property
def deadline(self) -> float:
return self._deadline

@property
def cancel_called(self) -> bool:
return self._cancel_called
return self._cancel_called == id(self)

@property
def shield(self) -> bool:
Expand Down Expand Up @@ -156,16 +191,21 @@ def set_cancel_scope(task: curio.Task, scope: Optional[CancelScope]) -> None:
async def check_cancelled():
task = await curio.current_task()
cancel_scope = get_cancel_scope(task)
if cancel_scope is not None and not cancel_scope._shield and cancel_scope._cancel_called:
raise CancelledError
if cancel_scope is None:
return
if not cancel_scope._shield and cancel_scope._cancel_called:
raise CancelledError(cancel_scope._cancel_called)


@asynccontextmanager
@async_generator
async def fail_after(delay: float, shield: bool):
deadline = await curio.clock() + delay
async with CancelScope(deadline, shield) as scope:
await yield_(scope)
try:
await yield_(scope)
except curio.TaskCancelled:
raise CancelledError(scope._cancel_called)

if scope._timeout_expired:
raise TimeoutError
Expand All @@ -176,7 +216,10 @@ async def fail_after(delay: float, shield: bool):
async def move_on_after(delay: float, shield: bool):
deadline = await curio.clock() + delay
async with CancelScope(deadline=deadline, shield=shield) as scope:
await yield_(scope)
try:
await yield_(scope)
except curio.TaskCancelled:
raise CancelledError(scope._cancel_called)


async def current_effective_deadline():
Expand Down Expand Up @@ -241,6 +284,9 @@ async def _run_wrapped_task(self, func, *args):
task = await curio.current_task()
try:
await func(*args)
except CancelledError as exc:
await self.cancel_scope.cancel(exc.cause)
raise
except BaseException:
await self.cancel_scope.cancel()
raise
Expand Down
4 changes: 4 additions & 0 deletions anyio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def __repr__(self) -> str:

class CancelledError(Exception):
"""Raised when the enclosing cancel scope has been cancelled."""
def __init__(self, cause=None):
if type(cause) is not int:
cause = id(cause)
self.cause = cause


class IncompleteRead(Exception):
Expand Down
53 changes: 53 additions & 0 deletions tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,27 @@ async def test_nested_move_on_after():
assert not inner_scope.cancel_called


@pytest.mark.anyio
async def test_fail_in_taskgroup():
bad = False

async def check():
async with fail_after(1):
await sleep(1)
nonlocal bad
bad = True

try:
async with create_task_group() as tg:
await tg.start(check)
await sleep(0.1)
await tg.cancel_scope.cancel()
raise RuntimeError("Owch")
except Exception:
pass
assert not bad


@pytest.mark.anyio
async def test_shielding():
async def cancel_when_ready():
Expand Down Expand Up @@ -362,3 +383,35 @@ async def child(fail):

exc.match('foo')
assert not sleep_completed


@pytest.mark.anyio
async def test_cancelled_parent():
async def child():
async with open_cancel_scope():
await sleep(1)
raise RuntimeError("This should not be printed")

async def parent(tg):
try:
await sleep(1)
finally:
await tg.spawn(child)
pass

async with create_task_group() as tg:
await tg.spawn(parent, tg)
await tg.cancel_scope.cancel()


@pytest.mark.anyio
async def test_cancelled_parent_scope():
async with open_cancel_scope() as scope:
async with open_cancel_scope() as sc2:
try:
await scope.cancel()
await sleep(2)
except BaseException:
await sc2.cancel()
raise
raise RuntimeError("This should not be printed")

0 comments on commit fb8094a

Please sign in to comment.