Skip to content

Commit

Permalink
Merge a16b350 into a1c7538
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Aug 1, 2020
2 parents a1c7538 + a16b350 commit 577a64e
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 4 deletions.
36 changes: 36 additions & 0 deletions docs/threads.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,39 @@ If you need to call a coroutine function from a worker thread, you can do this::

.. note:: The worker thread must have been spawned using :func:`~anyio.run_sync_in_worker_thread`
for this to work.

Calling asynchronous code from an external thread
-------------------------------------------------

If you need to run async code from a thread that is not a worker thread spawned by the event loop,
you need a *blocking portal*. This needs to be obtained from within the event loop thread.

One way to do this is to start a new event loop with a portal, using
:func:`~anyio.start_blocking_portal` (which takes mostly the same arguments as :func:`~anyio.run`::

from anyio import start_blocking_portal


portal = start_blocking_portal(backend='trio')
portal.call(...)

# At the end of your application, stop the portal
portal.stop_from_external_thread()

Or, you can it as a context manager if that suits your use case::

with start_blocking_portal(backend='trio') as portal:
portal.call(...)

If you already have an event loop running and wish to grant access to external threads, you can
use :func:`~anyio.create_blocking_portal` directly::

from anyio import create_blocking_portal, run


async def main():
async with create_blocking_portal() as portal:
# ...hand off the portal to external threads...
await portal.sleep_until_stopped()

anyio.run(main)
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
async generators and cancels any leftover native tasks
- Fixed ``Condition.wait()`` not working on asyncio and curio (PR by Matt Westcott)
- Added the ``anyio.aclose_forcefully()`` to close asynchronous resources as quickly as possible
- Added support for "blocking portals" which allow running functions in the event loop thread from
external threads

**1.4.0**

Expand Down
33 changes: 32 additions & 1 deletion src/anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .abc import (
Lock, Condition, Event, Semaphore, CapacityLimiter, CancelScope, TaskGroup, IPAddressType,
SocketStream, UDPSocket, ConnectedUDPSocket, IPSockAddrType, Listener, SocketListener,
AsyncResource)
AsyncResource, BlockingPortal)
from .fileio import AsyncFile
from .streams.tls import TLSStream
from .streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand Down Expand Up @@ -280,6 +280,37 @@ def current_default_worker_thread_limiter() -> CapacityLimiter:
return _get_asynclib().current_default_thread_limiter()


def create_blocking_portal() -> BlockingPortal:
"""Create a portal for running functions in the event loop thread."""
return _get_asynclib().BlockingPortal()


def start_blocking_portal(
backend: str = BACKENDS[0],
backend_options: Optional[Dict[str, Any]] = None) -> BlockingPortal:
"""
Start a new event loop in a new thread and run a blocking portal in its main task.
:param backend:
:param backend_options:
:return: a blocking portal object
"""
async def run_portal():
nonlocal portal
async with create_blocking_portal() as portal:
event.set()
await portal.sleep_until_stopped()

portal: Optional[BlockingPortal]
event = threading.Event()
kwargs = {'func': run_portal, 'backend': backend, 'backend_options': backend_options}
thread = threading.Thread(target=run, kwargs=kwargs)
thread.start()
event.wait()
return typing.cast(BlockingPortal, portal)


#
# Async file I/O
#
Expand Down
13 changes: 13 additions & 0 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import socket
import sys
from collections import OrderedDict, deque
from concurrent.futures import Future
from functools import wraps
from inspect import isgenerator
from threading import Thread
Expand Down Expand Up @@ -526,6 +527,18 @@ def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *a
return f.result()


class BlockingPortal(abc.BlockingPortal):
__slots__ = '_loop'

def __init__(self):
super().__init__()
self._loop = get_running_loop()

def _spawn_task_from_thread(self, func: Callable, args: tuple, future: Future) -> None:
asyncio.run_coroutine_threadsafe(
self._task_group.spawn(self._call_func, func, args, future), self._loop)


#
# Sockets and networking
#
Expand Down
31 changes: 31 additions & 0 deletions src/anyio/_backends/_curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,37 @@ def run_async_from_thread(func: Callable[..., T_Retval], *args) -> T_Retval:
return future.result()


class BlockingPortal(abc.BlockingPortal):
__slots__ = '_queue'

def __init__(self):
super().__init__()
self._queue = curio.UniversalQueue()

async def _process_queue(self) -> None:
while self._event_loop_thread_id or not self._queue.empty():
func, args, future = await self._queue.get()
if func is not None:
await self._task_group.spawn(self._call_func, func, args, future)

async def __aenter__(self) -> 'BlockingPortal':
await super().__aenter__()
await self._task_group.spawn(self._process_queue)
return self

async def stop(self, cancel_remaining: bool = False) -> None:
if self._event_loop_thread_id is None:
return

await super().stop(cancel_remaining)

# Wake up from queue.get()
await self._queue.put((None, None, None))

def _spawn_task_from_thread(self, func: Callable, args: tuple, future: Future) -> None:
self._queue.put((func, args, future))


#
# Sockets and networking
#
Expand Down
13 changes: 13 additions & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import socket
import sys
from concurrent.futures import Future
from types import TracebackType
from typing import Callable, Optional, List, Type, Union, Tuple, TYPE_CHECKING, TypeVar, Generic

Expand Down Expand Up @@ -176,6 +177,18 @@ def wrapper():
run_async_from_thread = trio.from_thread.run


class BlockingPortal(abc.BlockingPortal):
__slots__ = '_token'

def __init__(self):
super().__init__()
self._token = trio.lowlevel.current_trio_token()

def _spawn_task_from_thread(self, func: Callable, args: tuple, future: Future) -> None:
return trio.from_thread.run(self._task_group.spawn, self._call_func, func, args, future,
trio_token=self._token)


#
# Sockets and networking
#
Expand Down
5 changes: 3 additions & 2 deletions src/anyio/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
__all__ = ('IPAddressType', 'IPSockAddrType', 'SockAddrType', 'UDPPacketType', 'SocketStream',
__all__ = ('IPAddressType', 'IPSockAddrType', 'SockAddrType', 'UDPPacketType', 'SocketStream',
'SocketListener', 'UDPSocket', 'ConnectedUDPSocket', 'AsyncResource',
'UnreliableObjectReceiveStream', 'UnreliableObjectSendStream', 'UnreliableObjectStream',
'ObjectReceiveStream', 'ObjectSendStream', 'ObjectStream', 'ByteReceiveStream',
'ByteSendStream', 'ByteStream', 'AnyUnreliableByteReceiveStream',
'AnyUnreliableByteSendStream', 'AnyUnreliableByteStream', 'AnyByteReceiveStream',
'AnyByteSendStream', 'AnyByteStream', 'Listener', 'Event', 'Lock', 'Condition',
'Semaphore', 'CapacityLimiter', 'CancelScope', 'TaskGroup')
'Semaphore', 'CapacityLimiter', 'CancelScope', 'TaskGroup', 'BlockingPortal')

from .sockets import (
IPAddressType, IPSockAddrType, SockAddrType, UDPPacketType, SocketStream, SocketListener,
Expand All @@ -18,3 +18,4 @@
AnyUnreliableByteStream, AnyByteReceiveStream, AnyByteSendStream, AnyByteStream, Listener)
from .synchronization import Event, Lock, Condition, Semaphore, CapacityLimiter
from .tasks import CancelScope, TaskGroup
from .threads import BlockingPortal
96 changes: 96 additions & 0 deletions src/anyio/abc/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import threading
from abc import ABCMeta, abstractmethod
from concurrent.futures import Future
from inspect import iscoroutine
from typing import TypeVar, Callable

T_Retval = TypeVar('T_Retval')


class BlockingPortal(metaclass=ABCMeta):
"""An object tied that lets external threads run code in an asynchronous event loop."""

__slots__ = '_task_group', '_event_loop_thread_id', '_stop_event', '_cancelled_exc_class'

def __init__(self):
from .. import create_event, create_task_group, get_cancelled_exc_class

self._event_loop_thread_id = threading.get_ident()
self._stop_event = create_event()
self._task_group = create_task_group()
self._cancelled_exc_class = get_cancelled_exc_class()

async def __aenter__(self) -> 'BlockingPortal':
await self._task_group.__aenter__()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.call(self.stop, exc_type is not None)

async def sleep_until_stopped(self) -> None:
"""Sleep until :meth:`stop` is called."""
await self._stop_event.wait()

async def stop(self, cancel_remaining: bool = False) -> None:
"""
Signal the portal to shut down.
This marks the portal as no longer accepting new calls and exits from
:meth:`sleep_until_stopped`.
:param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` to let them
finish before returning
"""
self._event_loop_thread_id = None
await self._stop_event.set()
if cancel_remaining:
await self._task_group.cancel_scope.cancel()

def stop_from_external_thread(self, cancel_remaining: bool = False) -> None:
thread = self.call(threading.current_thread)
self.call(self.stop, cancel_remaining)
thread.join()

async def _call_func(self, func: Callable, args: tuple, future: Future) -> None:
try:
retval = func(*args)
if iscoroutine(retval):
future.set_result(await retval)
else:
future.set_result(retval)
except self._cancelled_exc_class:
future.cancel()
except BaseException as exc:
future.set_exception(exc)

@abstractmethod
def _spawn_task_from_thread(self, func: Callable, args: tuple, future: Future) -> None:
pass

def call(self, func: Callable[..., T_Retval], *args) -> T_Retval:
"""
Call the given function in the event loop thread.
If the callable returns a coroutine object, it is awaited on.
:param func: any callable
:raises RuntimeError: if this method is called from within the event loop thread
"""
if self._event_loop_thread_id is None:
raise RuntimeError('This portal is not running')
if self._event_loop_thread_id == threading.get_ident():
raise RuntimeError('This method cannot be called from the event loop thread')

future: Future = Future()
self._spawn_task_from_thread(func, args, future)
return future.result()
Loading

0 comments on commit 577a64e

Please sign in to comment.