Skip to content

Commit

Permalink
Used TypeVarTuple and ParamSpec in several places (#652)
Browse files Browse the repository at this point in the history
Co-authored-by: Ganden Schaffner <gschaffner@pm.me>
  • Loading branch information
agronholm and gschaffner committed Dec 16, 2023
1 parent 3a4ec47 commit 89795b9
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 97 deletions.
16 changes: 16 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
Lura Skye)
- Enabled the ``Event`` and ``CapacityLimiter`` classes to be instantiated outside an
event loop thread
- Broadly improved/fixed the type annotations. Among other things, many functions and
methods that take variadic positional arguments now make use of PEP 646
``TypeVarTuple`` to allow the positional arguments to be validated by static type
checkers. These changes affected numerous methods and functions, including:

* ``anyio.run()``
* ``TaskGroup.start_soon()``
* ``anyio.from_thread.run()``
* ``anyio.to_thread.run_sync()``
* ``anyio.to_process.run_sync()``
* ``BlockingPortal.call()``
* ``BlockingPortal.start_task_soon()``
* ``BlockingPortal.start_task()``

(`#560 <https://github.com/agronholm/anyio/issues/560>`_)
- Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing
to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help
from Egor Blagov)
Expand All @@ -18,6 +33,7 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Fixed cancellation propagating on asyncio from a task group to child tasks if the task
hosting the task group is in a shielded cancel scope
(`#642 <https://github.com/agronholm/anyio/issues/642>`_)
- Fixed the type annotation of ``anyio.Path.samefile()`` to match Typeshed

**4.1.0**

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"exceptiongroup >= 1.0.2; python_version < '3.11'",
"idna >= 2.8",
"sniffio >= 1.1",
"typing_extensions >= 4.1; python_version < '3.11'",
]
dynamic = ["version"]

Expand Down
48 changes: 33 additions & 15 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,22 @@
from ..lowlevel import RunVar
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

if sys.version_info >= (3, 11):
from asyncio import Runner
from typing import TypeVarTuple, Unpack
else:
import contextvars
import enum
import signal
from asyncio import coroutines, events, exceptions, tasks

from exceptiongroup import BaseExceptionGroup
from typing_extensions import TypeVarTuple, Unpack

class _State(enum.Enum):
CREATED = "created"
Expand Down Expand Up @@ -271,6 +278,8 @@ def _do_shutdown(future: asyncio.futures.Future) -> None:

T_Retval = TypeVar("T_Retval")
T_contra = TypeVar("T_contra", contravariant=True)
PosArgsT = TypeVarTuple("PosArgsT")
P = ParamSpec("P")

_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")

Expand Down Expand Up @@ -682,8 +691,8 @@ async def __aexit__(

def _spawn(
self,
func: Callable[..., Awaitable[Any]],
args: tuple,
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
args: tuple[Unpack[PosArgsT]],
name: object,
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
Expand Down Expand Up @@ -752,13 +761,16 @@ def task_done(_task: asyncio.Task) -> None:
return task

def start_soon(
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
*args: Unpack[PosArgsT],
name: object = None,
) -> None:
self._spawn(func, args, name)

async def start(
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
) -> None:
) -> Any:
future: asyncio.Future = asyncio.Future()
task = self._spawn(func, args, name, future)

Expand Down Expand Up @@ -875,11 +887,11 @@ def __init__(self) -> None:

def _spawn_task_from_thread(
self,
func: Callable,
args: tuple[Any, ...],
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
name: object,
future: Future,
future: Future[T_Retval],
) -> None:
AsyncIOBackend.run_sync_from_thread(
partial(self._task_group.start_soon, name=name),
Expand Down Expand Up @@ -1883,7 +1895,10 @@ async def _run_tests_and_fixtures(
future.set_result(retval)

async def _call_in_runner_task(
self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
self,
func: Callable[P, Awaitable[T_Retval]],
*args: P.args,
**kwargs: P.kwargs,
) -> T_Retval:
if not self._runner_task:
self._send_stream, receive_stream = create_memory_object_stream[
Expand Down Expand Up @@ -1949,8 +1964,8 @@ class AsyncIOBackend(AsyncBackend):
@classmethod
def run(
cls,
func: Callable[..., Awaitable[T_Retval]],
args: tuple,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
options: dict[str, Any],
) -> T_Retval:
Expand Down Expand Up @@ -2062,8 +2077,8 @@ def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
@classmethod
async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
func: Callable[[Unpack[PosArgsT]], T_Retval],
args: tuple[Unpack[PosArgsT]],
abandon_on_cancel: bool = False,
limiter: abc.CapacityLimiter | None = None,
) -> T_Retval:
Expand Down Expand Up @@ -2133,8 +2148,8 @@ def check_cancelled(cls) -> None:
@classmethod
def run_async_from_thread(
cls,
func: Callable[..., Awaitable[T_Retval]],
args: tuple[Any, ...],
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
args: tuple[Unpack[PosArgsT]],
token: object,
) -> T_Retval:
async def task_wrapper(scope: CancelScope) -> T_Retval:
Expand All @@ -2160,7 +2175,10 @@ async def task_wrapper(scope: CancelScope) -> T_Retval:

@classmethod
def run_sync_from_thread(
cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
cls,
func: Callable[[Unpack[PosArgsT]], T_Retval],
args: tuple[Unpack[PosArgsT]],
token: object,
) -> T_Retval:
@wraps(func)
def wrapper() -> None:
Expand Down
51 changes: 36 additions & 15 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,22 @@
from ..abc._eventloop import AsyncBackend
from ..streams.memory import MemoryObjectSendStream

if sys.version_info < (3, 11):
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from exceptiongroup import BaseExceptionGroup
from typing_extensions import TypeVarTuple, Unpack

T = TypeVar("T")
T_Retval = TypeVar("T_Retval")
T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType)
PosArgsT = TypeVarTuple("PosArgsT")
P = ParamSpec("P")


#
Expand Down Expand Up @@ -167,7 +177,12 @@ async def __aexit__(
finally:
self._active = False

def start_soon(self, func: Callable, *args: object, name: object = None) -> None:
def start_soon(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
*args: Unpack[PosArgsT],
name: object = None,
) -> None:
if not self._active:
raise RuntimeError(
"This task group is not active; no new tasks can be started."
Expand All @@ -177,7 +192,7 @@ def start_soon(self, func: Callable, *args: object, name: object = None) -> None

async def start(
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
) -> object:
) -> Any:
if not self._active:
raise RuntimeError(
"This task group is not active; no new tasks can be started."
Expand All @@ -201,11 +216,11 @@ def __init__(self) -> None:

def _spawn_task_from_thread(
self,
func: Callable,
args: tuple,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
name: object,
future: Future,
future: Future[T_Retval],
) -> None:
trio.from_thread.run_sync(
partial(self._task_group.start_soon, name=name),
Expand Down Expand Up @@ -724,7 +739,7 @@ class TestRunner(abc.TestRunner):
def __init__(self, **options: Any) -> None:
from queue import Queue

self._call_queue: Queue[Callable[..., object]] = Queue()
self._call_queue: Queue[Callable[[], object]] = Queue()
self._send_stream: MemoryObjectSendStream | None = None
self._options = options

Expand Down Expand Up @@ -754,7 +769,10 @@ def _main_task_finished(self, outcome: object) -> None:
self._send_stream = None

def _call_in_runner_task(
self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
self,
func: Callable[P, Awaitable[T_Retval]],
*args: P.args,
**kwargs: P.kwargs,
) -> T_Retval:
if self._send_stream is None:
trio.lowlevel.start_guest_run(
Expand Down Expand Up @@ -808,8 +826,8 @@ class TrioBackend(AsyncBackend):
@classmethod
def run(
cls,
func: Callable[..., Awaitable[T_Retval]],
args: tuple,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
options: dict[str, Any],
) -> T_Retval:
Expand Down Expand Up @@ -868,8 +886,8 @@ def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
@classmethod
async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
func: Callable[[Unpack[PosArgsT]], T_Retval],
args: tuple[Unpack[PosArgsT]],
abandon_on_cancel: bool = False,
limiter: abc.CapacityLimiter | None = None,
) -> T_Retval:
Expand All @@ -891,15 +909,18 @@ def check_cancelled(cls) -> None:
@classmethod
def run_async_from_thread(
cls,
func: Callable[..., Awaitable[T_Retval]],
args: tuple[Any, ...],
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
args: tuple[Unpack[PosArgsT]],
token: object,
) -> T_Retval:
return trio.from_thread.run(func, *args)

@classmethod
def run_sync_from_thread(
cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
cls,
func: Callable[[Unpack[PosArgsT]], T_Retval],
args: tuple[Unpack[PosArgsT]],
token: object,
) -> T_Retval:
return trio.from_thread.run_sync(func, *args)

Expand Down
11 changes: 9 additions & 2 deletions src/anyio/_core/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,26 @@

import sniffio

if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack

if TYPE_CHECKING:
from ..abc import AsyncBackend

# This must be updated when new backends are introduced
BACKENDS = "asyncio", "trio"

T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")

threadlocals = threading.local()


def run(
func: Callable[..., Awaitable[T_Retval]],
*args: object,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
*args: Unpack[PosArgsT],
backend: str = "asyncio",
backend_options: dict[str, Any] | None = None,
) -> T_Retval:
Expand Down
9 changes: 3 additions & 6 deletions src/anyio/_core/_fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
AsyncIterator,
Final,
Generic,
cast,
overload,
)

Expand Down Expand Up @@ -211,7 +210,7 @@ async def __anext__(self) -> Path:
if nextval is None:
raise StopAsyncIteration from None

return Path(cast("PathLike[str]", nextval))
return Path(nextval)


class Path:
Expand Down Expand Up @@ -518,7 +517,7 @@ def relative_to(self, *other: str | PathLike[str]) -> Path:

async def readlink(self) -> Path:
target = await to_thread.run_sync(os.readlink, self._path)
return Path(cast(str, target))
return Path(target)

async def rename(self, target: str | pathlib.PurePath | Path) -> Path:
if isinstance(target, Path):
Expand All @@ -545,9 +544,7 @@ def rglob(self, pattern: str) -> AsyncIterator[Path]:
async def rmdir(self) -> None:
await to_thread.run_sync(self._path.rmdir)

async def samefile(
self, other_path: str | bytes | int | pathlib.Path | Path
) -> bool:
async def samefile(self, other_path: str | PathLike[str]) -> bool:
if isinstance(other_path, Path):
other_path = other_path._path

Expand Down

0 comments on commit 89795b9

Please sign in to comment.