Skip to content

Commit

Permalink
Added the BlockingPortalProvider class (#711)
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Apr 7, 2024
1 parent 234e434 commit 4a23745
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Running asynchronous code from other threads
.. autofunction:: anyio.from_thread.start_blocking_portal

.. autoclass:: anyio.from_thread.BlockingPortal
.. autoclass:: anyio.from_thread.BlockingPortalProvider

Async file I/O
--------------
Expand Down
24 changes: 24 additions & 0 deletions docs/threads.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,27 @@ and if it has, raise a cancellation exception. This can be done by simply callin
async def foo():
with move_on_after(3):
await to_thread.run_sync(sync_function)


Sharing a blocking portal on demand
-----------------------------------

If you're building a synchronous API that needs to start a blocking portal on demand,
you might need a more efficient solution than just starting a blocking portal for each
call. To that end, you can use :class:`BlockingPortalProvider`::

from anyio.to_thread import BlockingPortalProvider

class MyAPI:
def __init__(self, async_obj) -> None:
self._async_obj = async_obj
self._portal_provider = BlockingPortalProvider()

def do_stuff(self) -> None:
with self._portal_provider as portal:
portal.call(async_obj.do_async_stuff)

Now, no matter how many threads call the ``do_stuff()`` method on a ``MyAPI`` instance
at the same time, the same blocking portal will be used to handle the async calls
inside. It's easy to see that this is much more efficient than having each call spawn
its own blocking portal.
3 changes: 3 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

**UNRELEASED**

- Added the ``BlockingPortalProvider`` class to aid with constructing synchronous
counterparts to asynchronous interfaces that would otherwise require multiple blocking
portals
- Fixed erroneous ``RuntimeError: called 'started' twice on the same task status``
when cancelling a task in a TaskGroup created with the ``start()`` method before
the first checkpoint is reached after calling ``task_status.started()``
Expand Down
58 changes: 58 additions & 0 deletions src/anyio/from_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Awaitable, Callable, Generator
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass, field
from inspect import isawaitable
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -391,6 +392,63 @@ def wrap_async_context_manager(
return _BlockingAsyncContextManager(cm, self)


@dataclass
class BlockingPortalProvider:
"""
A manager for a blocking portal. Used as a context manager. The first thread to
enter this context manager causes a blocking portal to be started with the specific
parameters, and the last thread to exit causes the portal to be shut down. Thus,
there will be exactly one blocking portal running in this context as long as at
least one thread has entered this context manager.
The parameters are the same as for :func:`~anyio.run`.
:param backend: name of the backend
:param backend_options: backend options
.. versionadded:: 4.4
"""

backend: str = "asyncio"
backend_options: dict[str, Any] | None = None
_lock: threading.Lock = field(init=False, default_factory=threading.Lock)
_leases: int = field(init=False, default=0)
_portal: BlockingPortal = field(init=False)
_portal_cm: AbstractContextManager[BlockingPortal] | None = field(
init=False, default=None
)

def __enter__(self) -> BlockingPortal:
with self._lock:
if self._portal_cm is None:
self._portal_cm = start_blocking_portal(
self.backend, self.backend_options
)
self._portal = self._portal_cm.__enter__()

self._leases += 1
return self._portal

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
portal_cm: AbstractContextManager[BlockingPortal] | None = None
with self._lock:
assert self._portal_cm
assert self._leases > 0
self._leases -= 1
if not self._leases:
portal_cm = self._portal_cm
self._portal_cm = None
del self._portal

if portal_cm:
portal_cm.__exit__(None, None, None)


@contextmanager
def start_blocking_portal(
backend: str = "asyncio", backend_options: dict[str, Any] | None = None
Expand Down
73 changes: 72 additions & 1 deletion tests/test_to_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import threading
import time
from concurrent.futures import Future
from concurrent.futures import Future, ThreadPoolExecutor
from contextvars import ContextVar
from functools import partial
from typing import Any, NoReturn
Expand All @@ -21,6 +21,7 @@
to_thread,
wait_all_tasks_blocked,
)
from anyio.from_thread import BlockingPortalProvider

pytestmark = pytest.mark.anyio

Expand Down Expand Up @@ -287,3 +288,73 @@ def raise_stopiteration() -> NoReturn:

with pytest.raises(RuntimeError, match="coroutine raised StopIteration"):
await to_thread.run_sync(raise_stopiteration)


class TestBlockingPortalProvider:
@pytest.fixture
def provider(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> BlockingPortalProvider:
return BlockingPortalProvider(
backend=anyio_backend_name, backend_options=anyio_backend_options
)

def test_single_thread(
self, provider: BlockingPortalProvider, anyio_backend_name: str
) -> None:
threads: set[threading.Thread] = set()

async def check_thread() -> None:
assert sniffio.current_async_library() == anyio_backend_name
threads.add(threading.current_thread())

active_threads_before = threading.active_count()
for _ in range(3):
with provider as portal:
portal.call(check_thread)

assert len(threads) == 3
assert threading.active_count() == active_threads_before

def test_single_thread_overlapping(
self, provider: BlockingPortalProvider, anyio_backend_name: str
) -> None:
threads: set[threading.Thread] = set()

async def check_thread() -> None:
assert sniffio.current_async_library() == anyio_backend_name
threads.add(threading.current_thread())

with provider as portal1:
with provider as portal2:
assert portal1 is portal2
portal2.call(check_thread)

portal1.call(check_thread)

assert len(threads) == 1

def test_multiple_threads(
self, provider: BlockingPortalProvider, anyio_backend_name: str
) -> None:
threads: set[threading.Thread] = set()
event = Event()

async def check_thread() -> None:
assert sniffio.current_async_library() == anyio_backend_name
await event.wait()
threads.add(threading.current_thread())

def dummy() -> None:
with provider as portal:
portal.call(check_thread)

with ThreadPoolExecutor(max_workers=3) as pool:
for _ in range(3):
pool.submit(dummy)

with provider as portal:
portal.call(wait_all_tasks_blocked)
portal.call(event.set)

assert len(threads) == 1

0 comments on commit 4a23745

Please sign in to comment.