-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
NOTE: This works with pre-3.11 versions but timeout handling only works correctly with 3.11 or later.
- Loading branch information
Showing
9 changed files
with
223 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add `as_completed_safe()` which enhances `asyncio.as_completed()` using `PersistentTaskGroup` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""" | ||
A set of helper utilities to utilize taskgroups in better ways. | ||
""" | ||
|
||
import asyncio | ||
|
||
from . import PersistentTaskGroup | ||
|
||
__all__ = ("as_completed_safe", ) | ||
|
||
|
||
async def as_completed_safe(coros, timeout=None): | ||
""" | ||
This is a safer version of :func:`asyncio.as_completed()` which uses | ||
:class:`PersistentTaskGroup` as an underlying coroutine lifecycle keeper. | ||
Upon a timeout, it raises :class:`asyncio.TimeoutError` immediately | ||
and cancels all remaining tasks or coroutines. | ||
This requires Python 3.11 or higher to work properly with timeouts. | ||
.. versionadded:: 1.6 | ||
""" | ||
async with PersistentTaskGroup() as tg: | ||
tasks = [] | ||
for coro in coros: | ||
t = tg.create_task(coro) | ||
tasks.append(t) | ||
await asyncio.sleep(0) | ||
try: | ||
for result in asyncio.as_completed(tasks, timeout=timeout): | ||
yield result | ||
except GeneratorExit: | ||
# This happens when as_completed() is timeout. | ||
raise asyncio.TimeoutError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import asyncio | ||
import sys | ||
|
||
import async_timeout | ||
import pytest | ||
|
||
from aiotools import ( | ||
aclosing, | ||
as_completed_safe, | ||
VirtualClock, | ||
) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_as_completed_safe(): | ||
|
||
async def do_job(delay, idx): | ||
await asyncio.sleep(delay) | ||
return idx | ||
|
||
async def fail_job(delay): | ||
await asyncio.sleep(delay) | ||
1 / 0 | ||
|
||
with VirtualClock().patch_loop(): | ||
|
||
results = [] | ||
|
||
async with aclosing(as_completed_safe([ | ||
do_job(0.3, 1), | ||
do_job(0.2, 2), | ||
do_job(0.1, 3), | ||
])) as ag: | ||
async for result in ag: | ||
results.append(await result) | ||
|
||
assert results == [3, 2, 1] | ||
|
||
results = [] | ||
errors = [] | ||
|
||
async with aclosing(as_completed_safe([ | ||
do_job(0.1, 1), | ||
fail_job(0.2), | ||
do_job(0.3, 3), | ||
])) as ag: | ||
async for result in ag: | ||
try: | ||
results.append(await result) | ||
except Exception as e: | ||
errors.append(e) | ||
|
||
assert results == [1, 3] | ||
assert len(errors) == 1 | ||
assert isinstance(errors[0], ZeroDivisionError) | ||
|
||
results = [] | ||
errors = [] | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.skipif( | ||
sys.version_info < (3, 11), | ||
reason='timeout supoport requires Python 3.11 or higher', | ||
) | ||
async def test_as_completed_safe_timeout_intrinsic(): | ||
|
||
executed = 0 | ||
cancelled = 0 | ||
loop_count = 0 | ||
|
||
with VirtualClock().patch_loop(): | ||
|
||
async def do_job(delay, idx): | ||
nonlocal cancelled, executed | ||
try: | ||
await asyncio.sleep(delay) | ||
executed += 1 | ||
return idx | ||
except asyncio.CancelledError: | ||
cancelled += 1 | ||
|
||
results = [] | ||
timeouts = 0 | ||
|
||
try: | ||
async with aclosing(as_completed_safe([ | ||
do_job(0.1, 1), | ||
# timeout occurs here | ||
do_job(0.2, 2), | ||
do_job(10.0, 3), | ||
], timeout=0.15)) as ag: | ||
async for result in ag: | ||
results.append(await result) | ||
loop_count += 1 | ||
except asyncio.TimeoutError: | ||
timeouts += 1 | ||
|
||
assert loop_count == 1 | ||
assert executed == 1 | ||
assert cancelled == 2 | ||
assert results == [1] | ||
assert timeouts == 1 | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.skipif( | ||
sys.version_info < (3, 11), | ||
reason='timeout supoport requires Python 3.11 or higher', | ||
) | ||
async def test_as_completed_safe_timeout_extlib(): | ||
|
||
executed = 0 | ||
cancelled = 0 | ||
loop_count = 0 | ||
|
||
with VirtualClock().patch_loop(): | ||
|
||
async def do_job(delay, idx): | ||
nonlocal cancelled, executed | ||
try: | ||
await asyncio.sleep(delay) | ||
executed += 1 | ||
return idx | ||
except asyncio.CancelledError: | ||
cancelled += 1 | ||
|
||
results = [] | ||
timeouts = 0 | ||
|
||
try: | ||
async with async_timeout.timeout(0.15): | ||
async with aclosing(as_completed_safe([ | ||
do_job(0.1, 1), | ||
# timeout occurs here | ||
do_job(0.2, 2), | ||
do_job(10.0, 3), | ||
])) as ag: | ||
async for result in ag: | ||
results.append(await result) | ||
loop_count += 1 | ||
except asyncio.TimeoutError: | ||
timeouts += 1 | ||
|
||
assert loop_count == 1 | ||
assert executed == 1 | ||
assert cancelled == 2 | ||
assert results == [1] | ||
assert timeouts == 1 |