Skip to content

Commit

Permalink
feat: Add as_completed_safe() (#52)
Browse files Browse the repository at this point in the history
NOTE: This works with pre-3.11 versions but timeout handling only works correctly with 3.11 or later.
  • Loading branch information
achimnol committed Mar 14, 2023
1 parent bd3e26a commit ce5e13a
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 14 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11.0-alpha.6 - 3.11"]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python
Expand Down Expand Up @@ -35,7 +35,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11.0-alpha.6 - 3.11"]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python
Expand Down Expand Up @@ -68,7 +68,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11.0-alpha.6 - 3.11"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python
Expand Down Expand Up @@ -102,7 +102,7 @@ jobs:
--initial-cluster-state new \
--auto-compaction-retention 1
- name: Test with pytest
timeout-minutes: 1
timeout-minutes: 3
run: |
python -m pytest -v --cov=src tests
- name: Send code coverage report
Expand Down
1 change: 1 addition & 0 deletions changes/52.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `as_completed_safe()` which enhances `asyncio.as_completed()` using `PersistentTaskGroup`
7 changes: 7 additions & 0 deletions docs/aiotools.taskgroup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,10 @@ Task Group
``except*`` syntax and :exc:`ExceptionGroup` methods if they use Python
3.11 or later. Note that if none of the passed exceptions passed is a
:exc:`BaseException`, it automatically becomes :exc:`ExceptionGroup`.


Task Group Utilities
====================

.. automodule:: aiotools.taskgroup.utils
:members:
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ build =
twine>=3.8.0
towncrier~=21.3
test =
pytest~=6.2.5
pytest-asyncio~=0.16.0
async_timeout~=4.0.2
pytest~=7.2.2
pytest-asyncio~=0.20.3
pytest-cov
pytest-mock
codecov
Expand Down
12 changes: 8 additions & 4 deletions src/aiotools/taskgroup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,24 @@
from . import persistent
from .base import * # noqa
from .persistent import * # noqa
__all__ = (
__all__ = [
'MultiError',
'TaskGroupError',
*base.__all__,
*persistent.__all__,
)
]
else:
from . import base_compat
from . import persistent_compat
from .base_compat import * # type: ignore # noqa
from .persistent_compat import * # type: ignore # noqa
__all__ = ( # type: ignore
__all__ = [ # type: ignore
'MultiError',
'TaskGroupError',
*base_compat.__all__,
*persistent_compat.__all__,
)
]

from .utils import as_completed_safe # noqa

__all__.append("as_completed_safe")
10 changes: 8 additions & 2 deletions src/aiotools/taskgroup/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,19 @@ async def __aexit__(
self._base_error is None):
self._base_error = exc_val

if exc_type is asyncio.CancelledError:
if (
exc_type is asyncio.CancelledError or
exc_type is asyncio.TimeoutError
):
if self._parent_cancel_requested:
self._parent_task.uncancel()
else:
propagate_cancellation_error = exc_type
if exc_type is not None and not self._aborting:
if exc_type is asyncio.CancelledError:
if (
exc_type is asyncio.CancelledError or
exc_type is asyncio.TimeoutError
):
propagate_cancellation_error = exc_type
self._trigger_shutdown()

Expand Down
10 changes: 8 additions & 2 deletions src/aiotools/taskgroup/persistent_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@ async def __aexit__(
self._base_error is None):
self._base_error = exc_val

if exc_type is asyncio.CancelledError:
if (
exc_type is asyncio.CancelledError or
exc_type is asyncio.TimeoutError
):
if self._parent_cancel_requested:
# Only if we did request task to cancel ourselves
# we mark it as no longer cancelled.
Expand All @@ -252,7 +255,10 @@ async def __aexit__(
propagate_cancelation = True

if exc_type is not None and not self._aborting:
if exc_type is asyncio.CancelledError:
if (
exc_type is asyncio.CancelledError or
exc_type is asyncio.TimeoutError
):
propagate_cancelation = True
self._trigger_shutdown()

Expand Down
35 changes: 35 additions & 0 deletions src/aiotools/taskgroup/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
A set of helper utilities to utilize taskgroups in better ways.
"""

import asyncio

from . import PersistentTaskGroup

__all__ = ("as_completed_safe", )


async def as_completed_safe(coros, timeout=None):
"""
This is a safer version of :func:`asyncio.as_completed()` which uses
:class:`PersistentTaskGroup` as an underlying coroutine lifecycle keeper.
Upon a timeout, it raises :class:`asyncio.TimeoutError` immediately
and cancels all remaining tasks or coroutines.
This requires Python 3.11 or higher to work properly with timeouts.
.. versionadded:: 1.6
"""
async with PersistentTaskGroup() as tg:
tasks = []
for coro in coros:
t = tg.create_task(coro)
tasks.append(t)
await asyncio.sleep(0)
try:
for result in asyncio.as_completed(tasks, timeout=timeout):
yield result
except GeneratorExit:
# This happens when as_completed() is timeout.
raise asyncio.TimeoutError()
149 changes: 149 additions & 0 deletions tests/test_taskgroup_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import asyncio
import sys

import async_timeout
import pytest

from aiotools import (
aclosing,
as_completed_safe,
VirtualClock,
)


@pytest.mark.asyncio
async def test_as_completed_safe():

async def do_job(delay, idx):
await asyncio.sleep(delay)
return idx

async def fail_job(delay):
await asyncio.sleep(delay)
1 / 0

with VirtualClock().patch_loop():

results = []

async with aclosing(as_completed_safe([
do_job(0.3, 1),
do_job(0.2, 2),
do_job(0.1, 3),
])) as ag:
async for result in ag:
results.append(await result)

assert results == [3, 2, 1]

results = []
errors = []

async with aclosing(as_completed_safe([
do_job(0.1, 1),
fail_job(0.2),
do_job(0.3, 3),
])) as ag:
async for result in ag:
try:
results.append(await result)
except Exception as e:
errors.append(e)

assert results == [1, 3]
assert len(errors) == 1
assert isinstance(errors[0], ZeroDivisionError)

results = []
errors = []


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason='timeout supoport requires Python 3.11 or higher',
)
async def test_as_completed_safe_timeout_intrinsic():

executed = 0
cancelled = 0
loop_count = 0

with VirtualClock().patch_loop():

async def do_job(delay, idx):
nonlocal cancelled, executed
try:
await asyncio.sleep(delay)
executed += 1
return idx
except asyncio.CancelledError:
cancelled += 1

results = []
timeouts = 0

try:
async with aclosing(as_completed_safe([
do_job(0.1, 1),
# timeout occurs here
do_job(0.2, 2),
do_job(10.0, 3),
], timeout=0.15)) as ag:
async for result in ag:
results.append(await result)
loop_count += 1
except asyncio.TimeoutError:
timeouts += 1

assert loop_count == 1
assert executed == 1
assert cancelled == 2
assert results == [1]
assert timeouts == 1


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason='timeout supoport requires Python 3.11 or higher',
)
async def test_as_completed_safe_timeout_extlib():

executed = 0
cancelled = 0
loop_count = 0

with VirtualClock().patch_loop():

async def do_job(delay, idx):
nonlocal cancelled, executed
try:
await asyncio.sleep(delay)
executed += 1
return idx
except asyncio.CancelledError:
cancelled += 1

results = []
timeouts = 0

try:
async with async_timeout.timeout(0.15):
async with aclosing(as_completed_safe([
do_job(0.1, 1),
# timeout occurs here
do_job(0.2, 2),
do_job(10.0, 3),
])) as ag:
async for result in ag:
results.append(await result)
loop_count += 1
except asyncio.TimeoutError:
timeouts += 1

assert loop_count == 1
assert executed == 1
assert cancelled == 2
assert results == [1]
assert timeouts == 1

0 comments on commit ce5e13a

Please sign in to comment.