From 4a23745badf5bf5ef7928f1e346e9986bd696d82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 7 Apr 2024 12:40:22 +0300 Subject: [PATCH] Added the BlockingPortalProvider class (#711) --- docs/api.rst | 1 + docs/threads.rst | 24 +++++++++++++ docs/versionhistory.rst | 3 ++ src/anyio/from_thread.py | 58 +++++++++++++++++++++++++++++++ tests/test_to_thread.py | 73 +++++++++++++++++++++++++++++++++++++++- 5 files changed, 158 insertions(+), 1 deletion(-) diff --git a/docs/api.rst b/docs/api.rst index 1bd57766..129ac457 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -65,6 +65,7 @@ Running asynchronous code from other threads .. autofunction:: anyio.from_thread.start_blocking_portal .. autoclass:: anyio.from_thread.BlockingPortal +.. autoclass:: anyio.from_thread.BlockingPortalProvider Async file I/O -------------- diff --git a/docs/threads.rst b/docs/threads.rst index 73de1c06..2479c0f6 100644 --- a/docs/threads.rst +++ b/docs/threads.rst @@ -225,3 +225,27 @@ and if it has, raise a cancellation exception. This can be done by simply callin async def foo(): with move_on_after(3): await to_thread.run_sync(sync_function) + + +Sharing a blocking portal on demand +----------------------------------- + +If you're building a synchronous API that needs to start a blocking portal on demand, +you might need a more efficient solution than just starting a blocking portal for each +call. To that end, you can use :class:`BlockingPortalProvider`:: + + from anyio.to_thread import BlockingPortalProvider + + class MyAPI: + def __init__(self, async_obj) -> None: + self._async_obj = async_obj + self._portal_provider = BlockingPortalProvider() + + def do_stuff(self) -> None: + with self._portal_provider as portal: + portal.call(async_obj.do_async_stuff) + +Now, no matter how many threads call the ``do_stuff()`` method on a ``MyAPI`` instance +at the same time, the same blocking portal will be used to handle the async calls +inside. It's easy to see that this is much more efficient than having each call spawn +its own blocking portal. diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index f0dcfdbd..0cf2d7ca 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -5,6 +5,9 @@ This library adheres to `Semantic Versioning 2.0 `_. **UNRELEASED** +- Added the ``BlockingPortalProvider`` class to aid with constructing synchronous + counterparts to asynchronous interfaces that would otherwise require multiple blocking + portals - Fixed erroneous ``RuntimeError: called 'started' twice on the same task status`` when cancelling a task in a TaskGroup created with the ``start()`` method before the first checkpoint is reached after calling ``task_status.started()`` diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 3e889f93..88a854bb 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -5,6 +5,7 @@ from collections.abc import Awaitable, Callable, Generator from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from contextlib import AbstractContextManager, contextmanager +from dataclasses import dataclass, field from inspect import isawaitable from types import TracebackType from typing import ( @@ -391,6 +392,63 @@ def wrap_async_context_manager( return _BlockingAsyncContextManager(cm, self) +@dataclass +class BlockingPortalProvider: + """ + A manager for a blocking portal. Used as a context manager. The first thread to + enter this context manager causes a blocking portal to be started with the specific + parameters, and the last thread to exit causes the portal to be shut down. Thus, + there will be exactly one blocking portal running in this context as long as at + least one thread has entered this context manager. + + The parameters are the same as for :func:`~anyio.run`. + + :param backend: name of the backend + :param backend_options: backend options + + .. versionadded:: 4.4 + """ + + backend: str = "asyncio" + backend_options: dict[str, Any] | None = None + _lock: threading.Lock = field(init=False, default_factory=threading.Lock) + _leases: int = field(init=False, default=0) + _portal: BlockingPortal = field(init=False) + _portal_cm: AbstractContextManager[BlockingPortal] | None = field( + init=False, default=None + ) + + def __enter__(self) -> BlockingPortal: + with self._lock: + if self._portal_cm is None: + self._portal_cm = start_blocking_portal( + self.backend, self.backend_options + ) + self._portal = self._portal_cm.__enter__() + + self._leases += 1 + return self._portal + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + portal_cm: AbstractContextManager[BlockingPortal] | None = None + with self._lock: + assert self._portal_cm + assert self._leases > 0 + self._leases -= 1 + if not self._leases: + portal_cm = self._portal_cm + self._portal_cm = None + del self._portal + + if portal_cm: + portal_cm.__exit__(None, None, None) + + @contextmanager def start_blocking_portal( backend: str = "asyncio", backend_options: dict[str, Any] | None = None diff --git a/tests/test_to_thread.py b/tests/test_to_thread.py index 6dc46ba7..9b80de2d 100644 --- a/tests/test_to_thread.py +++ b/tests/test_to_thread.py @@ -3,7 +3,7 @@ import asyncio import threading import time -from concurrent.futures import Future +from concurrent.futures import Future, ThreadPoolExecutor from contextvars import ContextVar from functools import partial from typing import Any, NoReturn @@ -21,6 +21,7 @@ to_thread, wait_all_tasks_blocked, ) +from anyio.from_thread import BlockingPortalProvider pytestmark = pytest.mark.anyio @@ -287,3 +288,73 @@ def raise_stopiteration() -> NoReturn: with pytest.raises(RuntimeError, match="coroutine raised StopIteration"): await to_thread.run_sync(raise_stopiteration) + + +class TestBlockingPortalProvider: + @pytest.fixture + def provider( + self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] + ) -> BlockingPortalProvider: + return BlockingPortalProvider( + backend=anyio_backend_name, backend_options=anyio_backend_options + ) + + def test_single_thread( + self, provider: BlockingPortalProvider, anyio_backend_name: str + ) -> None: + threads: set[threading.Thread] = set() + + async def check_thread() -> None: + assert sniffio.current_async_library() == anyio_backend_name + threads.add(threading.current_thread()) + + active_threads_before = threading.active_count() + for _ in range(3): + with provider as portal: + portal.call(check_thread) + + assert len(threads) == 3 + assert threading.active_count() == active_threads_before + + def test_single_thread_overlapping( + self, provider: BlockingPortalProvider, anyio_backend_name: str + ) -> None: + threads: set[threading.Thread] = set() + + async def check_thread() -> None: + assert sniffio.current_async_library() == anyio_backend_name + threads.add(threading.current_thread()) + + with provider as portal1: + with provider as portal2: + assert portal1 is portal2 + portal2.call(check_thread) + + portal1.call(check_thread) + + assert len(threads) == 1 + + def test_multiple_threads( + self, provider: BlockingPortalProvider, anyio_backend_name: str + ) -> None: + threads: set[threading.Thread] = set() + event = Event() + + async def check_thread() -> None: + assert sniffio.current_async_library() == anyio_backend_name + await event.wait() + threads.add(threading.current_thread()) + + def dummy() -> None: + with provider as portal: + portal.call(check_thread) + + with ThreadPoolExecutor(max_workers=3) as pool: + for _ in range(3): + pool.submit(dummy) + + with provider as portal: + portal.call(wait_all_tasks_blocked) + portal.call(event.set) + + assert len(threads) == 1