Skip to content

Commit

Permalink
pythongh-115957: Close coroutine if the TaskGroup is inactive
Browse files Browse the repository at this point in the history
  • Loading branch information
Jason-Y-Z committed Mar 5, 2024
1 parent 7af063d commit 5bcc3d5
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
4 changes: 4 additions & 0 deletions Doc/library/asyncio-task.rst
Expand Up @@ -335,6 +335,10 @@ and reliable way to wait for all tasks in the group to finish.
Create a task in this task group.
The signature matches that of :func:`asyncio.create_task`.

.. versionchanged:: 3.13

Close the given coroutine if the task group is not active.

Example::

async def main():
Expand Down
5 changes: 5 additions & 0 deletions Doc/whatsnew/3.13.rst
Expand Up @@ -185,6 +185,11 @@ Other Language Changes

(Contributed by Sebastian Pipping in :gh:`115623`.)

* When :func:`asyncio.TaskGroup.create_task` is called on an inactive
:class:`asyncio.TaskGroup`, the given coroutine will be closed (which
prevents a :exc:`RuntimeWarning`).

(Contributed by Arthur Tacca and Jason Zhang in :gh:`115957`.)

New Modules
===========
Expand Down
3 changes: 3 additions & 0 deletions Lib/asyncio/taskgroups.py
Expand Up @@ -154,10 +154,13 @@ def create_task(self, coro, *, name=None, context=None):
Similar to `asyncio.create_task`.
"""
if not self._entered:
coro.close()
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and not self._tasks:
coro.close()
raise RuntimeError(f"TaskGroup {self!r} is finished")
if self._aborting:
coro.close()
raise RuntimeError(f"TaskGroup {self!r} is shutting down")
if context is None:
task = self._loop.create_task(coro, name=name)
Expand Down
21 changes: 11 additions & 10 deletions Lib/test/test_asyncio/test_taskgroups.py
Expand Up @@ -738,10 +738,7 @@ async def coro2(g):
await asyncio.sleep(1)
except asyncio.CancelledError:
with self.assertRaises(RuntimeError):
g.create_task(c1 := coro1())
# We still have to await c1 to avoid a warning
with self.assertRaises(ZeroDivisionError):
await c1
g.create_task(coro1())

with self.assertRaises(ExceptionGroup) as cm:
async with taskgroups.TaskGroup() as g:
Expand Down Expand Up @@ -803,16 +800,12 @@ async def test_taskgroup_finished(self):
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro

async def test_taskgroup_not_entered(self):
tg = taskgroups.TaskGroup()
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro

async def test_taskgroup_without_parent_task(self):
tg = taskgroups.TaskGroup()
Expand All @@ -821,8 +814,16 @@ async def test_taskgroup_without_parent_task(self):
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro

def test_coro_closed_when_tg_closed(self):
async def run_coro_after_tg_closes():
async with taskgroups.TaskGroup() as tg:
pass
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)
loop = asyncio.get_event_loop()
loop.run_until_complete(run_coro_after_tg_closes())


if __name__ == "__main__":
Expand Down

0 comments on commit 5bcc3d5

Please sign in to comment.