Skip to content

Commit

Permalink
fix: Keep Python 3.6 support
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Nov 17, 2020
1 parent ebaaf4e commit 855d78f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 14 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ test =
codecov
dev =
lint =
flake8>=3.7.9
flake8>=3.8.4
typecheck =
mypy>=0.770
mypy>=0.790
docs =
sphinx
sphinx-autodoc-typehints
Expand Down
6 changes: 6 additions & 0 deletions src/aiotools/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@
all_tasks = asyncio.all_tasks
else:
all_tasks = asyncio.Task.all_tasks


if hasattr(asyncio, 'current_task'):
current_task = asyncio.current_task
else:
current_task = asyncio.Task.current_task
27 changes: 15 additions & 12 deletions src/aiotools/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@
from __future__ import annotations

import asyncio
from contextvars import ContextVar
try:
from contextvars import ContextVar
has_contextvars = True
except ImportError:
has_contextvars = False
import functools
import itertools
import textwrap
import traceback

from .compat import current_task, get_running_loop


__all__ = (
'MultiError',
Expand All @@ -40,7 +46,8 @@
)


current_taskgroup: ContextVar[TaskGroup] = ContextVar('current_taskgroup')
if has_contextvars:
current_taskgroup: ContextVar[TaskGroup] = ContextVar('current_taskgroup')


class TaskGroup:
Expand All @@ -61,7 +68,7 @@ def __init__(self, *, name=None):
self._entered = False
self._exiting = False
self._aborting = False
self._loop = None
self._loop = get_running_loop()
self._parent_task = None
self._parent_cancel_requested = False
self._tasks = set()
Expand Down Expand Up @@ -94,19 +101,14 @@ async def __aenter__(self):
f"TaskGroup {self!r} has been already entered")
self._entered = True

if self._loop is None:
self._loop = asyncio.get_event_loop()

if hasattr(asyncio, 'current_task'):
self._parent_task = asyncio.current_task(self._loop)
else:
self._parent_task = asyncio.Task.current_task(self._loop)
self._parent_task = current_task()

if self._parent_task is None:
raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task')
self._patch_task(self._parent_task)
self._current_taskgroup_token = current_taskgroup.set(self)
if has_contextvars:
self._current_taskgroup_token = current_taskgroup.set(self)
return self

async def __aexit__(self, et, exc, tb):
Expand Down Expand Up @@ -171,7 +173,8 @@ async def __aexit__(self, et, exc, tb):

assert self._unfinished_tasks == 0
self._on_completed_fut = None # no longer needed
current_taskgroup.reset(self._current_taskgroup_token)
if has_contextvars:
current_taskgroup.reset(self._current_taskgroup_token)

if self._base_error is not None:
raise self._base_error
Expand Down
9 changes: 9 additions & 0 deletions tests/test_taskgroup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import sys

import pytest

Expand Down Expand Up @@ -26,6 +27,10 @@ async def test_delayed_subtasks():


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason='contextvars is available only in Python 3.7 or later',
)
async def test_contextual_taskgroup():

refs = []
Expand All @@ -50,6 +55,10 @@ async def check_tg(delay):
current_taskgroup.get()


@pytest.mark.skipif(
sys.version_info < (3, 7),
reason='contextvars is available only in Python 3.7 or later',
)
@pytest.mark.filterwarnings('ignore::RuntimeWarning')
@pytest.mark.asyncio
async def test_contextual_taskgroup_spawning():
Expand Down

0 comments on commit 855d78f

Please sign in to comment.