From 89795b9cf2a35a6b8972ffc7f01f00f87a899d66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 13:55:17 +0200 Subject: [PATCH] Used TypeVarTuple and ParamSpec in several places (#652) Co-authored-by: Ganden Schaffner --- docs/versionhistory.rst | 16 +++++++ pyproject.toml | 1 + src/anyio/_backends/_asyncio.py | 48 +++++++++++++------ src/anyio/_backends/_trio.py | 51 ++++++++++++++------ src/anyio/_core/_eventloop.py | 11 ++++- src/anyio/_core/_fileio.py | 9 ++-- src/anyio/abc/_eventloop.py | 24 +++++++--- src/anyio/abc/_tasks.py | 11 ++++- src/anyio/from_thread.py | 85 +++++++++++++++++++++------------ src/anyio/streams/tls.py | 9 +++- src/anyio/to_process.py | 11 ++++- src/anyio/to_thread.py | 11 ++++- tests/test_from_thread.py | 27 +++++------ 13 files changed, 217 insertions(+), 97 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 7931bf46..f8c0b293 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -10,6 +10,21 @@ This library adheres to `Semantic Versioning 2.0 `_. Lura Skye) - Enabled the ``Event`` and ``CapacityLimiter`` classes to be instantiated outside an event loop thread +- Broadly improved/fixed the type annotations. Among other things, many functions and + methods that take variadic positional arguments now make use of PEP 646 + ``TypeVarTuple`` to allow the positional arguments to be validated by static type + checkers. These changes affected numerous methods and functions, including: + + * ``anyio.run()`` + * ``TaskGroup.start_soon()`` + * ``anyio.from_thread.run()`` + * ``anyio.to_thread.run_sync()`` + * ``anyio.to_process.run_sync()`` + * ``BlockingPortal.call()`` + * ``BlockingPortal.start_task_soon()`` + * ``BlockingPortal.start_task()`` + + (`#560 `_) - Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help from Egor Blagov) @@ -18,6 +33,7 @@ This library adheres to `Semantic Versioning 2.0 `_. - Fixed cancellation propagating on asyncio from a task group to child tasks if the task hosting the task group is in a shielded cancel scope (`#642 `_) +- Fixed the type annotation of ``anyio.Path.samefile()`` to match Typeshed **4.1.0** diff --git a/pyproject.toml b/pyproject.toml index 5c4af55c..d2d9690f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "exceptiongroup >= 1.0.2; python_version < '3.11'", "idna >= 2.8", "sniffio >= 1.1", + "typing_extensions >= 4.1; python_version < '3.11'", ] dynamic = ["version"] diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 95b8e556..e884f564 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -82,8 +82,14 @@ from ..lowlevel import RunVar from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + if sys.version_info >= (3, 11): from asyncio import Runner + from typing import TypeVarTuple, Unpack else: import contextvars import enum @@ -91,6 +97,7 @@ from asyncio import coroutines, events, exceptions, tasks from exceptiongroup import BaseExceptionGroup + from typing_extensions import TypeVarTuple, Unpack class _State(enum.Enum): CREATED = "created" @@ -271,6 +278,8 @@ def _do_shutdown(future: asyncio.futures.Future) -> None: T_Retval = TypeVar("T_Retval") T_contra = TypeVar("T_contra", contravariant=True) +PosArgsT = TypeVarTuple("PosArgsT") +P = ParamSpec("P") _root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") @@ -682,8 +691,8 @@ async def __aexit__( def _spawn( self, - func: Callable[..., Awaitable[Any]], - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + args: tuple[Unpack[PosArgsT]], name: object, task_status_future: asyncio.Future | None = None, ) -> asyncio.Task: @@ -752,13 +761,16 @@ def task_done(_task: asyncio.Task) -> None: return task def start_soon( - self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + name: object = None, ) -> None: self._spawn(func, args, name) async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None - ) -> None: + ) -> Any: future: asyncio.Future = asyncio.Future() task = self._spawn(func, args, name, future) @@ -875,11 +887,11 @@ def __init__(self) -> None: def _spawn_task_from_thread( self, - func: Callable, - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, - future: Future, + future: Future[T_Retval], ) -> None: AsyncIOBackend.run_sync_from_thread( partial(self._task_group.start_soon, name=name), @@ -1883,7 +1895,10 @@ async def _run_tests_and_fixtures( future.set_result(retval) async def _call_in_runner_task( - self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object + self, + func: Callable[P, Awaitable[T_Retval]], + *args: P.args, + **kwargs: P.kwargs, ) -> T_Retval: if not self._runner_task: self._send_stream, receive_stream = create_memory_object_stream[ @@ -1949,8 +1964,8 @@ class AsyncIOBackend(AsyncBackend): @classmethod def run( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], options: dict[str, Any], ) -> T_Retval: @@ -2062,8 +2077,8 @@ def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter: @classmethod async def run_sync_in_worker_thread( cls, - func: Callable[..., T_Retval], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], abandon_on_cancel: bool = False, limiter: abc.CapacityLimiter | None = None, ) -> T_Retval: @@ -2133,8 +2148,8 @@ def check_cancelled(cls) -> None: @classmethod def run_async_from_thread( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], token: object, ) -> T_Retval: async def task_wrapper(scope: CancelScope) -> T_Retval: @@ -2160,7 +2175,10 @@ async def task_wrapper(scope: CancelScope) -> T_Retval: @classmethod def run_sync_from_thread( - cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + token: object, ) -> T_Retval: @wraps(func) def wrapper() -> None: diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 2caa2a43..a0d14c74 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -62,12 +62,22 @@ from ..abc._eventloop import AsyncBackend from ..streams.memory import MemoryObjectSendStream -if sys.version_info < (3, 11): +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: from exceptiongroup import BaseExceptionGroup + from typing_extensions import TypeVarTuple, Unpack T = TypeVar("T") T_Retval = TypeVar("T_Retval") T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType) +PosArgsT = TypeVarTuple("PosArgsT") +P = ParamSpec("P") # @@ -167,7 +177,12 @@ async def __aexit__( finally: self._active = False - def start_soon(self, func: Callable, *args: object, name: object = None) -> None: + def start_soon( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + name: object = None, + ) -> None: if not self._active: raise RuntimeError( "This task group is not active; no new tasks can be started." @@ -177,7 +192,7 @@ def start_soon(self, func: Callable, *args: object, name: object = None) -> None async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None - ) -> object: + ) -> Any: if not self._active: raise RuntimeError( "This task group is not active; no new tasks can be started." @@ -201,11 +216,11 @@ def __init__(self) -> None: def _spawn_task_from_thread( self, - func: Callable, - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, - future: Future, + future: Future[T_Retval], ) -> None: trio.from_thread.run_sync( partial(self._task_group.start_soon, name=name), @@ -724,7 +739,7 @@ class TestRunner(abc.TestRunner): def __init__(self, **options: Any) -> None: from queue import Queue - self._call_queue: Queue[Callable[..., object]] = Queue() + self._call_queue: Queue[Callable[[], object]] = Queue() self._send_stream: MemoryObjectSendStream | None = None self._options = options @@ -754,7 +769,10 @@ def _main_task_finished(self, outcome: object) -> None: self._send_stream = None def _call_in_runner_task( - self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object + self, + func: Callable[P, Awaitable[T_Retval]], + *args: P.args, + **kwargs: P.kwargs, ) -> T_Retval: if self._send_stream is None: trio.lowlevel.start_guest_run( @@ -808,8 +826,8 @@ class TrioBackend(AsyncBackend): @classmethod def run( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], options: dict[str, Any], ) -> T_Retval: @@ -868,8 +886,8 @@ def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter: @classmethod async def run_sync_in_worker_thread( cls, - func: Callable[..., T_Retval], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], abandon_on_cancel: bool = False, limiter: abc.CapacityLimiter | None = None, ) -> T_Retval: @@ -891,15 +909,18 @@ def check_cancelled(cls) -> None: @classmethod def run_async_from_thread( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], token: object, ) -> T_Retval: return trio.from_thread.run(func, *args) @classmethod def run_sync_from_thread( - cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + token: object, ) -> T_Retval: return trio.from_thread.run_sync(func, *args) diff --git a/src/anyio/_core/_eventloop.py b/src/anyio/_core/_eventloop.py index b74d02b0..a9c6e825 100644 --- a/src/anyio/_core/_eventloop.py +++ b/src/anyio/_core/_eventloop.py @@ -10,6 +10,11 @@ import sniffio +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + if TYPE_CHECKING: from ..abc import AsyncBackend @@ -17,12 +22,14 @@ BACKENDS = "asyncio", "trio" T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + threadlocals = threading.local() def run( - func: Callable[..., Awaitable[T_Retval]], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], backend: str = "asyncio", backend_options: dict[str, Any] | None = None, ) -> T_Retval: diff --git a/src/anyio/_core/_fileio.py b/src/anyio/_core/_fileio.py index f51bf450..53f32339 100644 --- a/src/anyio/_core/_fileio.py +++ b/src/anyio/_core/_fileio.py @@ -15,7 +15,6 @@ AsyncIterator, Final, Generic, - cast, overload, ) @@ -211,7 +210,7 @@ async def __anext__(self) -> Path: if nextval is None: raise StopAsyncIteration from None - return Path(cast("PathLike[str]", nextval)) + return Path(nextval) class Path: @@ -518,7 +517,7 @@ def relative_to(self, *other: str | PathLike[str]) -> Path: async def readlink(self) -> Path: target = await to_thread.run_sync(os.readlink, self._path) - return Path(cast(str, target)) + return Path(target) async def rename(self, target: str | pathlib.PurePath | Path) -> Path: if isinstance(target, Path): @@ -545,9 +544,7 @@ def rglob(self, pattern: str) -> AsyncIterator[Path]: async def rmdir(self) -> None: await to_thread.run_sync(self._path.rmdir) - async def samefile( - self, other_path: str | bytes | int | pathlib.Path | Path - ) -> bool: + async def samefile(self, other_path: str | PathLike[str]) -> bool: if isinstance(other_path, Path): other_path = other_path._path diff --git a/src/anyio/abc/_eventloop.py b/src/anyio/abc/_eventloop.py index 9f1660c9..4470d83d 100644 --- a/src/anyio/abc/_eventloop.py +++ b/src/anyio/abc/_eventloop.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +import sys from abc import ABCMeta, abstractmethod from collections.abc import AsyncIterator, Awaitable, Mapping from os import PathLike @@ -17,6 +18,11 @@ overload, ) +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + if TYPE_CHECKING: from typing import Literal @@ -39,6 +45,7 @@ from ._testing import TestRunner T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") class AsyncBackend(metaclass=ABCMeta): @@ -46,8 +53,8 @@ class AsyncBackend(metaclass=ABCMeta): @abstractmethod def run( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], options: dict[str, Any], ) -> T_Retval: @@ -169,8 +176,8 @@ def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter: @abstractmethod async def run_sync_in_worker_thread( cls, - func: Callable[..., T_Retval], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], abandon_on_cancel: bool = False, limiter: CapacityLimiter | None = None, ) -> T_Retval: @@ -185,8 +192,8 @@ def check_cancelled(cls) -> None: @abstractmethod def run_async_from_thread( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple[Any], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], token: object, ) -> T_Retval: pass @@ -194,7 +201,10 @@ def run_async_from_thread( @classmethod @abstractmethod def run_sync_from_thread( - cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + token: object, ) -> T_Retval: pass diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index 9ea3608e..7ad4938c 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -1,15 +1,22 @@ from __future__ import annotations +import sys from abc import ABCMeta, abstractmethod from collections.abc import Awaitable, Callable from types import TracebackType from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + if TYPE_CHECKING: from .._core._tasks import CancelScope T_Retval = TypeVar("T_Retval") T_contra = TypeVar("T_contra", contravariant=True) +PosArgsT = TypeVarTuple("PosArgsT") class TaskStatus(Protocol[T_contra]): @@ -42,8 +49,8 @@ class TaskGroup(metaclass=ABCMeta): @abstractmethod def start_soon( self, - func: Callable[..., Awaitable[Any]], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], name: object = None, ) -> None: """ diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 63716496..4a987031 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import threading from collections.abc import Awaitable, Callable, Generator from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait @@ -24,11 +25,19 @@ from .abc import AsyncBackend from .abc._tasks import TaskStatus +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + T_Retval = TypeVar("T_Retval") -T_co = TypeVar("T_co") +T_co = TypeVar("T_co", covariant=True) +PosArgsT = TypeVarTuple("PosArgsT") -def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: +def run( + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT] +) -> T_Retval: """ Call a coroutine function from a worker thread. @@ -48,7 +57,9 @@ def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: return async_backend.run_async_from_thread(func, args, token=token) -def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: +def run_sync( + func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] +) -> T_Retval: """ Call a function in the event loop thread from a worker thread. @@ -69,8 +80,8 @@ def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): - _enter_future: Future - _exit_future: Future + _enter_future: Future[T_co] + _exit_future: Future[bool | None] _exit_event: Event _exit_exc_info: tuple[ type[BaseException] | None, BaseException | None, TracebackType | None @@ -106,8 +117,7 @@ async def run_async_cm(self) -> bool | None: def __enter__(self) -> T_co: self._enter_future = Future() self._exit_future = self._portal.start_task_soon(self.run_async_cm) - cm = self._enter_future.result() - return cast(T_co, cm) + return self._enter_future.result() def __exit__( self, @@ -182,9 +192,13 @@ async def stop(self, cancel_remaining: bool = False) -> None: self._task_group.cancel_scope.cancel() async def _call_func( - self, func: Callable, args: tuple, kwargs: dict[str, Any], future: Future + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], + kwargs: dict[str, Any], + future: Future[T_Retval], ) -> None: - def callback(f: Future) -> None: + def callback(f: Future[T_Retval]) -> None: if f.cancelled() and self._event_loop_thread_id not in ( None, threading.get_ident(), @@ -192,15 +206,17 @@ def callback(f: Future) -> None: self.call(scope.cancel) try: - retval = func(*args, **kwargs) - if isawaitable(retval): + retval_or_awaitable = func(*args, **kwargs) + if isawaitable(retval_or_awaitable): with CancelScope() as scope: if future.cancelled(): scope.cancel() else: future.add_done_callback(callback) - retval = await retval + retval = await retval_or_awaitable + else: + retval = retval_or_awaitable except self._cancelled_exc_class: future.cancel() future.set_running_or_notify_cancel() @@ -219,11 +235,11 @@ def callback(f: Future) -> None: def _spawn_task_from_thread( self, - func: Callable, - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, - future: Future, + future: Future[T_Retval], ) -> None: """ Spawn a new task using the given callable. @@ -241,17 +257,23 @@ def _spawn_task_from_thread( raise NotImplementedError @overload - def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: + def call( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], + ) -> T_Retval: ... @overload - def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: + def call( + self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] + ) -> T_Retval: ... def call( self, - func: Callable[..., Awaitable[T_Retval] | T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + *args: Unpack[PosArgsT], ) -> T_Retval: """ Call the given function in the event loop thread. @@ -268,22 +290,25 @@ def call( @overload def start_task_soon( self, - func: Callable[..., Awaitable[T_Retval]], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: ... @overload def start_task_soon( - self, func: Callable[..., T_Retval], *args: object, name: object = None + self, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], + name: object = None, ) -> Future[T_Retval]: ... def start_task_soon( self, - func: Callable[..., Awaitable[T_Retval] | T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: """ @@ -305,16 +330,16 @@ def start_task_soon( """ self._check_running() - f: Future = Future() + f: Future[T_Retval] = Future() self._spawn_task_from_thread(func, args, {}, name, f) return f def start_task( self, - func: Callable[..., Awaitable[Any]], + func: Callable[..., Awaitable[T_Retval]], *args: object, name: object = None, - ) -> tuple[Future[Any], Any]: + ) -> tuple[Future[T_Retval], Any]: """ Start a task in the portal's task group and wait until it signals for readiness. @@ -326,13 +351,13 @@ def start_task( :return: a tuple of (future, task_status_value) where the ``task_status_value`` is the value passed to ``task_status.started()`` from within the target function - :rtype: tuple[concurrent.futures.Future[Any], Any] + :rtype: tuple[concurrent.futures.Future[T_Retval], Any] .. versionadded:: 3.0 """ - def task_done(future: Future) -> None: + def task_done(future: Future[T_Retval]) -> None: if not task_status_future.done(): if future.cancelled(): task_status_future.cancel() @@ -397,7 +422,7 @@ async def run_portal() -> None: future: Future[BlockingPortal] = Future() with ThreadPoolExecutor(1) as executor: run_future = executor.submit( - _eventloop.run, + _eventloop.run, # type: ignore[arg-type] run_portal, backend=backend, backend_options=backend_options, diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py index 8468f33d..e913eedb 100644 --- a/src/anyio/streams/tls.py +++ b/src/anyio/streams/tls.py @@ -3,6 +3,7 @@ import logging import re import ssl +import sys from collections.abc import Callable, Mapping from dataclasses import dataclass from functools import wraps @@ -17,7 +18,13 @@ from .._core._typedattr import TypedAttributeSet, typed_attribute from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") _PCTRTT = Tuple[Tuple[str, str], ...] _PCTRTTT = Tuple[_PCTRTT, ...] @@ -126,7 +133,7 @@ async def wrap( return wrapper async def _call_sslobject_method( - self, func: Callable[..., T_Retval], *args: object + self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] ) -> T_Retval: while True: try: diff --git a/src/anyio/to_process.py b/src/anyio/to_process.py index 2867d42d..1ff06f0b 100644 --- a/src/anyio/to_process.py +++ b/src/anyio/to_process.py @@ -18,9 +18,16 @@ from .lowlevel import RunVar, checkpoint_if_cancelled from .streams.buffered import BufferedByteReceiveStream +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + WORKER_MAX_IDLE_TIME = 300 # 5 minutes T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + _process_pool_workers: RunVar[set[Process]] = RunVar("_process_pool_workers") _process_pool_idle_workers: RunVar[deque[tuple[Process, float]]] = RunVar( "_process_pool_idle_workers" @@ -29,8 +36,8 @@ async def run_sync( - func: Callable[..., T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], cancellable: bool = False, limiter: CapacityLimiter | None = None, ) -> T_Retval: diff --git a/src/anyio/to_thread.py b/src/anyio/to_thread.py index d9a632e8..5070516e 100644 --- a/src/anyio/to_thread.py +++ b/src/anyio/to_thread.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from collections.abc import Callable from typing import TypeVar from warnings import warn @@ -7,12 +8,18 @@ from ._core._eventloop import get_async_backend from .abc import CapacityLimiter +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") async def run_sync( - func: Callable[..., T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], abandon_on_cancel: bool = False, cancellable: bool | None = None, limiter: CapacityLimiter | None = None, diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index 0e580462..f387e755 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -206,8 +206,8 @@ async def test_run_sync_from_thread_exception(self) -> None: exc.match("unsupported operand type") async def test_run_anyio_async_func_from_thread(self) -> None: - def worker(*args: int) -> Literal[True]: - from_thread.run(sleep, *args) + def worker(delay: float) -> Literal[True]: + from_thread.run(sleep, delay) return True assert await to_thread.run_sync(worker, 0) @@ -507,29 +507,29 @@ async def run_in_context() -> AsyncGenerator[None, None]: def test_start_no_value( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started() with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - future, value = portal.start_task(taskfunc) # type: ignore[arg-type] + future, value = portal.start_task(taskfunc) assert value is None assert future.result() is None def test_start_with_value( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started("foo") with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - future, value = portal.start_task(taskfunc) # type: ignore[arg-type] + future, value = portal.start_task(taskfunc) assert value == "foo" assert future.result() is None def test_start_crash_before_started_call( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: object) -> NoReturn: + async def taskfunc(*, task_status: object) -> NoReturn: raise Exception("foo") with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: @@ -539,7 +539,7 @@ def taskfunc(*, task_status: object) -> NoReturn: def test_start_crash_after_started_call( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> NoReturn: + async def taskfunc(*, task_status: TaskStatus) -> NoReturn: task_status.started(2) raise Exception("foo") @@ -552,24 +552,21 @@ def taskfunc(*, task_status: TaskStatus) -> NoReturn: def test_start_no_started_call( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: pass with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: with pytest.raises(RuntimeError, match="Task exited"): - portal.start_task(taskfunc) # type: ignore[arg-type] + portal.start_task(taskfunc) def test_start_with_name( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started(get_current_task().name) with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - future, start_value = portal.start_task( - taskfunc, # type: ignore[arg-type] - name="testname", - ) + future, start_value = portal.start_task(taskfunc, name="testname") assert start_value == "testname" def test_contextvar_propagation_sync(