Skip to content

Commit

Permalink
Merge pull request #206 from DABND19/feature/typed-aggregate
Browse files Browse the repository at this point in the history
feat: Added typehints for aggregate and aggregate_async.
  • Loading branch information
mosquito committed May 7, 2024
2 parents 48a4ae9 + a29c543 commit 6d2f617
Showing 1 changed file with 111 additions and 54 deletions.
165 changes: 111 additions & 54 deletions aiomisc/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import asyncio
import functools
import inspect
import logging
from asyncio import CancelledError, Event, Future, Lock, wait_for
from dataclasses import dataclass
from inspect import Parameter
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Union
from typing import (
Any,
Awaitable,
Callable,
Generic,
Iterable,
List,
Optional,
Protocol,
TypeVar,
)

from .compat import EventLoopMixin
from .counters import Statistic
Expand All @@ -13,19 +24,25 @@
log = logging.getLogger(__name__)


V = TypeVar("V")
R = TypeVar("R")


@dataclass(frozen=True)
class Arg:
value: Any
future: Future
class Arg(Generic[V, R]):
value: V
future: "Future[R]"


class ResultNotSetError(Exception):
pass


AggFuncHighLevel = Callable[[Any], Awaitable[Iterable]]
AggFuncAsync = Callable[[Arg], Awaitable]
AggFunc = Union[AggFuncHighLevel, AggFuncAsync]
class AggregateAsyncFunc(Protocol, Generic[V, R]):
__name__: str

async def __call__(self, *args: Arg[V, R]) -> None:
...


class AggregateStatistic(Statistic):
Expand All @@ -36,27 +53,30 @@ class AggregateStatistic(Statistic):
done: int


class Aggregator(EventLoopMixin):
def _has_variadic_positional(func: Callable[..., Any]) -> bool:
return any(
parameter.kind == Parameter.VAR_POSITIONAL
for parameter in inspect.signature(func).parameters.values()
)

_func: AggFunc

class AggregatorAsync(EventLoopMixin, Generic[V, R]):

_func: AggregateAsyncFunc[V, R]
_max_count: Optional[int]
_leeway: float
_first_call_at: Optional[float]
_args: list
_futures: List[Future]
_futures: "List[Future[R]]"
_event: Event
_lock: Lock

def __init__(
self, func: AggFunc, *, leeway_ms: float,
self, func: AggregateAsyncFunc[V, R], *, leeway_ms: float,
max_count: Optional[int] = None,
statistic_name: Optional[str] = None,
):
has_variadic_positional = any(
parameter.kind == Parameter.VAR_POSITIONAL
for parameter in inspect.signature(func).parameters.values()
)
if not has_variadic_positional:
if not _has_variadic_positional(func):
raise ValueError(
"Function must accept variadic positional arguments",
)
Expand Down Expand Up @@ -94,9 +114,18 @@ def leeway_ms(self) -> float:
def count(self) -> int:
return len(self._args)

async def _execute(self, *, args: list, futures: List[Future]) -> None:
async def _execute(
self,
*,
args: List[V],
futures: "List[Future[R]]",
) -> None:
args_ = [
Arg(value=arg, future=future)
for arg, future in zip(args, futures)
]
try:
results = await self._func(*args)
await self._func(*args_)
self._statistic.success += 1
except CancelledError:
# Other waiting tasks can try to finish the job instead.
Expand All @@ -108,31 +137,29 @@ async def _execute(self, *, args: list, futures: List[Future]) -> None:
finally:
self._statistic.done += 1

self._set_results(results, futures)

def _set_results(self, results: Iterable, futures: List[Future]) -> None:
for future, result in zip(futures, results):
# Validate that all results/exceptions are set by the func
for future in futures:
if not future.done():
future.set_result(result)
future.set_exception(ResultNotSetError)

def _set_exception(
self, exc: Exception, futures: List[Future],
self, exc: Exception, futures: List["Future[R]"],
) -> None:
for future in futures:
if not future.done():
future.set_exception(exc)

async def aggregate(self, arg: Any) -> Any:
async def aggregate(self, arg: V) -> R:
if self._first_call_at is None:
self._first_call_at = self.loop.time()
first_call_at = self._first_call_at

args: list = self._args
futures: List[Future] = self._futures
futures: "List[Future[R]]" = self._futures
event: Event = self._event
lock: Lock = self._lock
args.append(arg)
future: Future = Future()
future: "Future[R]" = Future()
futures.append(future)

if self.count == self.max_count:
Expand Down Expand Up @@ -165,33 +192,61 @@ async def aggregate(self, arg: Any) -> Any:
return future.result()


class AggregatorAsync(Aggregator):
S = TypeVar("S", contravariant=True)
T = TypeVar("T", covariant=True)

async def _execute(self, *, args: list, futures: List[Future]) -> None:
args = [
Arg(value=arg, future=future)
for arg, future in zip(args, futures)
]
try:
await self._func(*args)
self._statistic.success += 1
except CancelledError:
# Other waiting tasks can try to finish the job instead.
raise
except Exception as e:
self._set_exception(e, futures)
self._statistic.error += 1
return
finally:
self._statistic.done += 1

# Validate that all results/exceptions are set by the func
for future in futures:
if not future.done():
future.set_exception(ResultNotSetError)
class AggregateFunc(Protocol, Generic[S, T]):
__name__: str

async def __call__(self, *args: S) -> Iterable[T]:
...


def _to_async_aggregate(func: AggregateFunc[V, R]) -> AggregateAsyncFunc[V, R]:
@functools.wraps(
func,
assigned=tuple(
item
for item in functools.WRAPPER_ASSIGNMENTS
if item != "__annotations__"
),
)
async def wrapper(*args: Arg[V, R]) -> None:
args_ = [item.value for item in args]
results = await func(*args_)
for res, arg in zip(results, args):
if not arg.future.done():
arg.future.set_result(res)

return wrapper


class Aggregator(AggregatorAsync[V, R], Generic[V, R]):
def __init__(
self,
func: AggregateFunc[V, R],
*,
leeway_ms: float,
max_count: Optional[int] = None,
statistic_name: Optional[str] = None,
) -> None:
if not _has_variadic_positional(func):
raise ValueError(
"Function must accept variadic positional arguments",
)

super().__init__(
_to_async_aggregate(func),
leeway_ms=leeway_ms,
max_count=max_count,
statistic_name=statistic_name,
)


def aggregate(leeway_ms: float, max_count: Optional[int] = None) -> Callable:
def aggregate(
leeway_ms: float, max_count: Optional[int] = None
) -> Callable[[AggregateFunc[V, R]], Callable[[V], Awaitable[R]]]:
"""
Parametric decorator that aggregates multiple
(but no more than ``max_count`` defaulting to ``None``) single-argument
Expand Down Expand Up @@ -220,7 +275,7 @@ def aggregate(leeway_ms: float, max_count: Optional[int] = None) -> Callable:
:return:
"""
def decorator(func: AggFuncHighLevel) -> Callable[[Any], Awaitable]:
def decorator(func: AggregateFunc[V, R]) -> Callable[[V], Awaitable[R]]:
aggregator = Aggregator(
func, max_count=max_count, leeway_ms=leeway_ms,
)
Expand All @@ -229,8 +284,8 @@ def decorator(func: AggFuncHighLevel) -> Callable[[Any], Awaitable]:


def aggregate_async(
leeway_ms: float, max_count: Optional[int] = None,
) -> Callable:
leeway_ms: float, max_count: Optional[int] = None,
) -> Callable[[AggregateAsyncFunc[V, R]], Callable[[V], Awaitable[R]]]:
"""
Same as ``aggregate``, but with ``func`` arguments of type ``Arg``
containing ``value`` and ``future`` attributes instead. In this setting
Expand All @@ -241,7 +296,9 @@ def aggregate_async(
:return:
"""
def decorator(func: AggFuncAsync) -> Callable[[Any], Awaitable]:
def decorator(
func: AggregateAsyncFunc[V, R]
) -> Callable[[V], Awaitable[R]]:
aggregator = AggregatorAsync(
func, max_count=max_count, leeway_ms=leeway_ms,
)
Expand Down

0 comments on commit 6d2f617

Please sign in to comment.