Skip to content

Commit

Permalink
Fix a potential memory leak in PersistentTaskGroup (#54)
Browse files Browse the repository at this point in the history
- It no longer stores the unhandled exceptions from subtasks,
  as they will be kept for the whole lifespan of the
  PersistentTaskGroup.
- Update CI to enfroce typecheck
  • Loading branch information
achimnol committed May 2, 2023
1 parent 540daf2 commit 4a6e790
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 82 deletions.
13 changes: 5 additions & 8 deletions .github/workflows/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
Expand Down Expand Up @@ -37,7 +37,7 @@ jobs:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
Expand All @@ -51,16 +51,13 @@ jobs:
python -m pip install -U pip setuptools wheel
python -m pip install -U -r requirements/typecheck.txt
- name: Type check with mypy
continue-on-error: true
run: |
if [ "$GITHUB_EVENT_NAME" == "pull_request" -a -n "$GITHUB_HEAD_REF" ]; then
echo "(skipping matchers for pull request from local branches)"
else
echo "::add-matcher::.github/workflows/mypy-matcher.json"
fi
python -m mypy --no-color-output src/aiotools tests
- name: Allow failures
run: true
test:
runs-on: ubuntu-latest
Expand All @@ -70,7 +67,7 @@ jobs:
os: [ubuntu-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
Expand Down Expand Up @@ -106,7 +103,7 @@ jobs:
run: |
python -m pytest -v --cov=src tests
- name: Send code coverage report
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v3
with:
env_vars: GHA_OS,GHA_PYTHON

Expand All @@ -115,7 +112,7 @@ jobs:
if: github.event_name == 'push' && contains(github.ref, 'refs/tags/')
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
Expand Down
1 change: 1 addition & 0 deletions changes/54.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PersistentTaskGroup no longer stores the history of unhandled exceptions and raises them as an exception group to prevent memory leaks
3 changes: 2 additions & 1 deletion src/aiotools/func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import functools
from typing import Optional

from .compat import get_running_loop

Expand All @@ -26,7 +27,7 @@ async def wrapped(*cargs, **ckwargs):

def lru_cache(maxsize: int = 128,
typed: bool = False,
expire_after: float = None):
expire_after: Optional[float] = None):
"""
A simple LRU cache just like :func:`functools.lru_cache`, but it works for
coroutines. This is not as heavily optimized as :func:`functools.lru_cache`
Expand Down
2 changes: 1 addition & 1 deletion src/aiotools/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def start_server(
),
num_workers: int = 1,
args: Iterable[Any] = tuple(),
wait_timeout: float = None,
wait_timeout: Optional[float] = None,
) -> None:
"""
Starts a multi-process server where each process has their own individual
Expand Down
21 changes: 12 additions & 9 deletions src/aiotools/taskgroup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,30 @@
from .types import MultiError, TaskGroupError

if hasattr(asyncio, 'TaskGroup'):
from . import base
from . import persistent
from .base import * # noqa
from .persistent import * # noqa
__all__ = [
'MultiError',
'TaskGroupError',
*base.__all__,
"MultiError",
"TaskGroup",
"TaskGroupError",
"current_taskgroup",
*persistent.__all__,
]
else:
from . import base_compat
from . import persistent_compat
from .base_compat import * # type: ignore # noqa
from .persistent_compat import * # type: ignore # noqa
__all__ = [ # type: ignore
'MultiError',
'TaskGroupError',
*base_compat.__all__,
from .base_compat import has_contextvars
__all__ = [ # type: ignore # noqa
"MultiError",
"TaskGroup",
"TaskGroupError",
*persistent_compat.__all__,
]
if has_contextvars:
__all__.append("current_taskgroup")


from .utils import as_completed_safe # noqa

Expand Down
34 changes: 5 additions & 29 deletions src/aiotools/taskgroup/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Awaitable,
Callable,
Coroutine,
List,
Optional,
Sequence,
Type,
Expand Down Expand Up @@ -42,7 +41,6 @@ class PersistentTaskGroup:

_base_error: Optional[BaseException]
_exc_handler: AsyncExceptionHandler
_errors: Optional[List[BaseException]]
_tasks: "weakref.WeakSet[asyncio.Task]"
_on_completed_fut: Optional[asyncio.Future]
_current_taskgroup_token: Optional[Token["PersistentTaskGroup"]]
Expand All @@ -54,13 +52,12 @@ def all_ptaskgroups(cls) -> Sequence['PersistentTaskGroup']:
def __init__(
self,
*,
name: str = None,
exception_handler: AsyncExceptionHandler = None,
name: Optional[str] = None,
exception_handler: Optional[AsyncExceptionHandler] = None,
) -> None:
self._entered = False
self._exiting = False
self._aborting = False
self._errors = []
self._base_error = None
self._name = name or f"{next(_ptaskgroup_idx)}"
self._parent_cancel_requested = False
Expand All @@ -85,7 +82,7 @@ def create_task(
self,
coro: Coroutine[Any, Any, Any],
*,
name: str = None,
name: Optional[str] = None,
) -> Awaitable[Any]:
if not self._entered:
# When used as object attribute, auto-enter.
Expand All @@ -98,7 +95,7 @@ def _create_task_with_name(
self,
coro: Coroutine[Any, Any, Any],
*,
name: str = None,
name: Optional[str] = None,
cb: Callable[[asyncio.Task], Any],
) -> Awaitable[Any]:
loop = compat.get_running_loop()
Expand Down Expand Up @@ -192,7 +189,6 @@ def _on_task_done(self, task: asyncio.Task) -> None:
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
assert self._parent_task is not None
assert self._errors is not None

if self._on_completed_fut is not None and not self._unfinished_tasks:
if not self._on_completed_fut.done():
Expand All @@ -207,7 +203,6 @@ def _on_task_done(self, task: asyncio.Task) -> None:
return

# Now the exception is BaseException.
self._errors.append(exc)
if self._base_error is None:
self._base_error = exc

Expand All @@ -227,8 +222,8 @@ async def __aexit__(
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
assert self._parent_task is not None
self._exiting = True
assert self._errors is not None
propagate_cancellation_error: Optional[
Union[Type[BaseException], BaseException]
] = None
Expand Down Expand Up @@ -263,23 +258,6 @@ async def __aexit__(
if propagate_cancellation_error is not None:
raise propagate_cancellation_error

if exc_val is not None and exc_type is not asyncio.CancelledError:
# If there are any unhandled errors, let's add them to
# the bubbled up exception group.
# Normally, they should have been swallowed and logged
# by the fallback exception handler.
self._errors.append(exc_val)

if self._errors:
# Bubble up errors
errors = self._errors
self._errors = None
me = BaseExceptionGroup(
'unhandled errors in a PersistentTaskGroup',
errors,
)
raise me from None

return None

def __repr__(self) -> str:
Expand All @@ -290,8 +268,6 @@ def __repr__(self) -> str:
info.append(f'tasks={len(self._tasks)}')
if self._unfinished_tasks:
info.append(f'unfinished={self._unfinished_tasks}')
if self._errors:
info.append(f'errors={len(self._errors)}')
if self._aborting:
info.append('cancelling')
elif self._entered:
Expand Down
35 changes: 5 additions & 30 deletions src/aiotools/taskgroup/persistent_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Awaitable,
Callable,
Coroutine,
List,
Optional,
Sequence,
Type,
Expand All @@ -23,7 +22,7 @@

from .. import compat
from .common import create_task_with_name, patch_task
from .types import AsyncExceptionHandler, TaskGroupError
from .types import AsyncExceptionHandler

__all__ = [
'PersistentTaskGroup',
Expand All @@ -47,7 +46,6 @@ class PersistentTaskGroup:

_base_error: Optional[BaseException]
_exc_handler: AsyncExceptionHandler
_errors: Optional[List[BaseException]]
_tasks: "weakref.WeakSet[asyncio.Task]"
_on_completed_fut: Optional[asyncio.Future]
_current_taskgroup_token: Optional["Token[PersistentTaskGroup]"]
Expand All @@ -59,13 +57,12 @@ def all_ptaskgroups(cls) -> Sequence['PersistentTaskGroup']:
def __init__(
self,
*,
name: str = None,
exception_handler: AsyncExceptionHandler = None,
name: Optional[str] = None,
exception_handler: Optional[AsyncExceptionHandler] = None,
) -> None:
self._entered = False
self._exiting = False
self._aborting = False
self._errors = []
self._base_error = None
self._name = name or f"{next(_ptaskgroup_idx)}"
self._parent_cancel_requested = False
Expand All @@ -90,7 +87,7 @@ def create_task(
self,
coro: Coroutine[Any, Any, Any],
*,
name: str = None,
name: Optional[str] = None,
) -> Awaitable[Any]:
if not self._entered:
# When used as object attribute, auto-enter.
Expand All @@ -103,7 +100,7 @@ def _create_task_with_name(
self,
coro: Coroutine[Any, Any, Any],
*,
name: str = None,
name: Optional[str] = None,
cb: Callable[[asyncio.Task], Any],
) -> Awaitable[Any]:
loop = compat.get_running_loop()
Expand Down Expand Up @@ -197,7 +194,6 @@ def _on_task_done(self, task: asyncio.Task) -> None:
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
assert self._parent_task is not None
assert self._errors is not None

if self._on_completed_fut is not None and not self._unfinished_tasks:
if not self._on_completed_fut.done():
Expand All @@ -212,7 +208,6 @@ def _on_task_done(self, task: asyncio.Task) -> None:
return

# Now the exception is BaseException.
self._errors.append(exc)
if self._base_error is None:
self._base_error = exc

Expand All @@ -235,7 +230,6 @@ async def __aexit__(
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
self._exiting = True
assert self._errors is not None
propagate_cancelation = False

if (exc_val is not None and
Expand Down Expand Up @@ -276,23 +270,6 @@ async def __aexit__(
# request now.
raise asyncio.CancelledError()

if exc_val is not None and exc_type is not asyncio.CancelledError:
# If there are any unhandled errors, let's add them to
# the bubbled up exception group.
# Normally, they should have been swallowed and logged
# by the fallback exception handler.
self._errors.append(exc_val)

if self._errors:
# Bubble up errors
errors = self._errors
self._errors = None
me = TaskGroupError(
'unhandled errors in a PersistentTaskGroup',
errors,
)
raise me from None

return None

def __repr__(self) -> str:
Expand All @@ -303,8 +280,6 @@ def __repr__(self) -> str:
info.append(f'tasks={len(self._tasks)}')
if self._unfinished_tasks:
info.append(f'unfinished={self._unfinished_tasks}')
if self._errors:
info.append(f'errors={len(self._errors)}')
if self._aborting:
info.append('cancelling')
elif self._entered:
Expand Down
8 changes: 4 additions & 4 deletions src/aiotools/taskgroup/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def __call__(

if not hasattr(builtins, 'ExceptionGroup'):

class MultiError(Exception):
class MultiError(Exception): # type: ignore[no-redef]

def __init__(self, msg, errors=()):
if errors:
Expand All @@ -47,15 +47,15 @@ def get_error_types(self):
def __reduce__(self):
return (type(self), (self.args,), {'__errors__': self.__errors__})

class TaskGroupError(MultiError):
class TaskGroupError(MultiError): # type: ignore[no-redef]
"""
An alias to :exc:`MultiError`.
"""
pass

else:

class MultiError(ExceptionGroup):
class MultiError(ExceptionGroup): # type: ignore[no-redef]

def __init__(self, msg, errors=()):
super().__init__(msg, errors)
Expand All @@ -64,5 +64,5 @@ def __init__(self, msg, errors=()):
def get_error_types(self):
return {type(e) for e in self.exceptions}

class TaskGroupError(MultiError):
class TaskGroupError(MultiError): # type: ignore[no-redef]
pass

0 comments on commit 4a6e790

Please sign in to comment.