Skip to content

Commit

Permalink
Propagate PEP 567 context with @executor too
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed May 23, 2022
1 parent 0782774 commit 574e94d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 25 deletions.
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Expand Up @@ -3,6 +3,11 @@ Version history

This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

**UNRELEASED**

- Changed ``@executor`` to propagate the `PEP 567`_ context to the worker thread, just
like ``Context.call_in_executor()``

**4.9.1** (2022-05-22)

- Fixed type annotation for ``@context_teardown``
Expand Down
48 changes: 30 additions & 18 deletions src/asphalt/core/context.py
Expand Up @@ -25,6 +25,7 @@
import warnings
from asyncio import (
AbstractEventLoop,
Future,
current_task,
get_event_loop,
get_running_loop,
Expand Down Expand Up @@ -824,7 +825,7 @@ def threadpool(self, executor: Union[Executor, str, None] = None):
return asyncio_extras.threadpool(executor)


def executor(arg: Union[Executor, str, Callable] = None):
def executor(arg: Union[Executor, str, None, Callable[..., T_Retval]] = None):
"""
Decorate a function so that it runs in an :class:`~concurrent.futures.Executor`.
Expand All @@ -842,35 +843,46 @@ def should_run_in_executor():
def should_run_in_executor(ctx):
...
:param arg: a callable to decorate, an :class:`~concurrent.futures.Executor` instance, the
resource name of one or ``None`` to use the event loop's default executor
:param arg: a callable to decorate, an :class:`~concurrent.futures.Executor`
instance, the resource name of one or ``None`` to use the event loop's default
executor
:return: the wrapped function
"""
function: Callable[..., T_Retval]

def outer_wrapper(func: Callable):
@wraps(func)
def inner_wrapper(*args, **kwargs):
def inner_wrapper(*args, **kwargs) -> Future[T_Retval]:
executor: Executor | None
if isinstance(executor_arg, str):
try:
ctx = next(arg for arg in args[:2] if isinstance(arg, Context))
ctx = next(a for a in args[:2] if isinstance(a, Context))
except StopIteration:
raise RuntimeError(
"the first positional argument to {}() has to be a Context "
"instance".format(callable_name(func))
f"the first positional argument to {callable_name(function)}() has "
f"to be a Context instance"
) from None

executor = ctx.require_resource(Executor, resource_name)
return asyncio_extras.call_in_executor(
func, *args, executor=executor, **kwargs
)
executor = ctx.require_resource(Executor, executor_arg)
else:
executor = executor_arg

current_context()
callback: partial[T_Retval] = partial(
copy_context().run, function, *args, **kwargs
)
return get_running_loop().run_in_executor(executor, callback)

return inner_wrapper
def outer_wrapper(func: Callable[..., T_Retval]) -> Callable[..., Future[T_Retval]]:
nonlocal function
function = func
return wraps(func)(inner_wrapper)

if isinstance(arg, str):
resource_name = arg
if arg is None or isinstance(arg, (Executor, str)):
executor_arg = arg
return outer_wrapper

return asyncio_extras.threadpool(arg)
else:
executor_arg = None
return outer_wrapper(arg)


@overload
Expand Down
23 changes: 16 additions & 7 deletions tests/test_context.py
Expand Up @@ -553,7 +553,9 @@ async def test_threadpool(self, context: Context) -> None:
assert current_thread() is not event_loop_thread

@pytest.mark.asyncio
async def test_threadpool_named_executor(self, context, special_executor):
async def test_threadpool_named_executor(
self, context: Context, special_executor: Executor
) -> None:
special_executor_thread = special_executor.submit(current_thread).result()
async with context.threadpool("special"):
assert current_thread() is special_executor_thread
Expand All @@ -565,27 +567,34 @@ async def test_no_arguments(self, context: Context) -> None:
@executor
def runs_in_default_worker() -> None:
assert current_thread() is not event_loop_thread
current_context()

event_loop_thread = current_thread()
await runs_in_default_worker()
async with context:
await runs_in_default_worker()

@pytest.mark.asyncio
async def test_named_executor(self, context, special_executor):
async def test_named_executor(
self, context: Context, special_executor: Executor
) -> None:
@executor("special")
def runs_in_default_worker(ctx: Context) -> None:
assert current_thread() is special_executor_thread
assert current_context() is ctx

special_executor_thread = special_executor.submit(current_thread).result()
await runs_in_default_worker(context)
async with context:
await runs_in_default_worker(context)

@pytest.mark.asyncio
async def test_executor_missing_context(self, event_loop, context):
async def test_executor_missing_context(self, context: Context):
@executor("special")
def runs_in_default_worker() -> None:
pass
current_context()

with pytest.raises(RuntimeError) as exc:
await runs_in_default_worker()
async with context:
await runs_in_default_worker()

exc.match(
r"the first positional argument to %s\(\) has to be a Context instance"
Expand Down

0 comments on commit 574e94d

Please sign in to comment.