Skip to content

Commit

Permalink
feat(taskgroup): Add "current_taskgroup" context variable with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Nov 17, 2020
1 parent bc9bada commit ebaaf4e
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/aiotools/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
# The original source code is taken from:
# https://github.com/edgedb/edgedb-python/blob/bcbe005/edgedb/_taskgroup.py

from __future__ import annotations

import asyncio
from contextvars import ContextVar
import functools
import itertools
import textwrap
Expand All @@ -34,9 +36,13 @@
'MultiError',
'TaskGroup',
'TaskGroupError',
'current_taskgroup',
)


current_taskgroup: ContextVar[TaskGroup] = ContextVar('current_taskgroup')


class TaskGroup:
"""
Provides a guard against a group of tasks spawend via its :meth:`create_task`
Expand Down Expand Up @@ -100,7 +106,7 @@ async def __aenter__(self):
raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task')
self._patch_task(self._parent_task)

self._current_taskgroup_token = current_taskgroup.set(self)
return self

async def __aexit__(self, et, exc, tb):
Expand Down Expand Up @@ -165,6 +171,7 @@ async def __aexit__(self, et, exc, tb):

assert self._unfinished_tasks == 0
self._on_completed_fut = None # no longer needed
current_taskgroup.reset(self._current_taskgroup_token)

if self._base_error is not None:
raise self._base_error
Expand Down
141 changes: 141 additions & 0 deletions tests/test_taskgroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import asyncio

import pytest

from aiotools import (
current_taskgroup,
TaskGroup,
TaskGroupError,
VirtualClock,
)


@pytest.mark.asyncio
async def test_delayed_subtasks():
with VirtualClock().patch_loop():
async with TaskGroup() as tg:
t1 = tg.create_task(asyncio.sleep(3, 'a'))
t2 = tg.create_task(asyncio.sleep(2, 'b'))
t3 = tg.create_task(asyncio.sleep(1, 'c'))
assert t1.done()
assert t2.done()
assert t3.done()
assert t1.result() == 'a'
assert t2.result() == 'b'
assert t3.result() == 'c'


@pytest.mark.asyncio
async def test_contextual_taskgroup():

refs = []

async def check_tg(delay):
await asyncio.sleep(delay)
refs.append(current_taskgroup.get())

with VirtualClock().patch_loop():
async with TaskGroup() as outer_tg:
ot1 = outer_tg.create_task(check_tg(0.1))
async with TaskGroup() as inner_tg:
it1 = inner_tg.create_task(check_tg(0.2))
ot2 = outer_tg.create_task(check_tg(0.3))
assert ot1.done()
assert ot2.done()
assert it1.done()
assert refs == [outer_tg, inner_tg, outer_tg]

with pytest.raises(LookupError):
# outside of any taskgroup, this is an error.
current_taskgroup.get()


@pytest.mark.filterwarnings('ignore::RuntimeWarning')
@pytest.mark.asyncio
async def test_contextual_taskgroup_spawning():

total_jobs = 0

async def job():
nonlocal total_jobs
await asyncio.sleep(0)
total_jobs += 1

async def spawn_job():
await asyncio.sleep(0)
tg = current_taskgroup.get()
tg.create_task(job())

async def inner_tg_job():
await asyncio.sleep(0)
async with TaskGroup() as tg:
tg.create_task(job())

with VirtualClock().patch_loop():

total_jobs = 0
with pytest.raises(TaskGroupError), pytest.warns(RuntimeWarning):
# When the taskgroup terminates immediately after spawning subtasks,
# the spawned subtasks may not be allowed to proceed because the parent
# taskgroup is already in the terminating procedure.
async with TaskGroup() as tg:
t = tg.create_task(spawn_job())
assert not t.done()
assert total_jobs == 0

total_jobs = 0
async with TaskGroup() as tg:
tg.create_task(inner_tg_job())
tg.create_task(spawn_job())
tg.create_task(inner_tg_job())
tg.create_task(spawn_job())
# Give the subtasks chances to run.
await asyncio.sleep(1)
assert total_jobs == 4


@pytest.mark.asyncio
async def test_taskgroup_cancellation():
with VirtualClock().patch_loop():

async def do_job(delay, result):
# NOTE: replacing do_job directly with asyncio.sleep
# results future-pending-after-loop-closed error,
# because asyncio.sleep() is not a task but a future.
await asyncio.sleep(delay)
return result

with pytest.raises(asyncio.CancelledError):
async with TaskGroup() as tg:
t1 = tg.create_task(do_job(0.3, 'a'))
t2 = tg.create_task(do_job(0.6, 'b'))
await asyncio.sleep(0.5)
raise asyncio.CancelledError

assert t1.done()
assert t2.cancelled()
assert t1.result() == 'a'


@pytest.mark.asyncio
async def test_subtask_cancellation():

results = []

async def do_job():
await asyncio.sleep(1)
results.append('a')

async def do_cancel():
await asyncio.sleep(0.5)
raise asyncio.CancelledError

with VirtualClock().patch_loop():
async with TaskGroup() as tg:
t1 = tg.create_task(do_job())
t2 = tg.create_task(do_cancel())
t3 = tg.create_task(do_job())
assert t1.done()
assert t2.cancelled()
assert t3.done()
assert results == ['a', 'a']

0 comments on commit ebaaf4e

Please sign in to comment.