Skip to content

Commit

Permalink
Propagate context to and from worker threads
Browse files Browse the repository at this point in the history
Fixes #387.
  • Loading branch information
agronholm committed Nov 17, 2021
1 parent c50242b commit 5de4ec1
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 5 deletions.
14 changes: 14 additions & 0 deletions docs/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,17 @@ the task group context manager? The answer is "both". In practice this means tha
exception, :exc:`~ExceptionGroup` is raised which contains both exception objects.
Unfortunately this complicates any code that wishes to catch a specific exception because it could
be wrapped in an :exc:`~ExceptionGroup`.

Context propagation
-------------------

Whenever a new task is spawned, `context`_ will be copied to the new task. It is important to note
*which* content will be copied to the newly spawned task. It is not the context of the task group's
host task that will be copied, but the context of the task that calls
:meth:`TaskGroup.start() <.abc.TaskGroup.start>` or
:meth:`TaskGroup.start_soon() <.abc.TaskGroup.start_soon>`.

.. note:: Context propagation **does not work** on asyncio when using Python 3.6, as asyncio
support for this only landed in v3.7.

.. _context: https://docs.python.org/3/library/contextvars.html
12 changes: 12 additions & 0 deletions docs/threads.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,15 @@ managers as a synchronous one::

.. note:: You cannot use wrapped async context managers in synchronous callbacks inside the event
loop thread.

Context propagation
-------------------

When running functions in worker threads, the current context is copied to the worker thread.
Therefore any context variables available on the task will also be available to the code running
on the thread. As always with context variables, any changes made to them will not propagate back
to the calling asynchronous task.

When calling asynchronous code from worker threads, context is again copied to the task that calls
the target function in the event loop thread. Note, however, that this **does not work** on asyncio
when running on Python 3.6.
6 changes: 4 additions & 2 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Fixed race condition in ``Lock`` and ``Semaphore`` classes when a task waiting on ``acquire()``
is cancelled while another task is waiting to acquire the same primitive
(`#387 <https://github.com/agronholm/anyio/issues/387>`_)
- Fixed context variables not propagating to worker threads in ``to_thread.run_sync()``
(partially fixes `#363 <https://github.com/agronholm/anyio/issues/363>`_)
- Fixed context variables not propagating to/from worker threads in ``to_thread.run_sync()``,
``from_thread.run()`` and ``from_thread.run_sync()``
(`#363 <https://github.com/agronholm/anyio/issues/363>`_; does **not** work from threads to async
on Python 3.6 + asyncio!)

**3.3.4**

Expand Down
7 changes: 6 additions & 1 deletion src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from asyncio.base_events import _run_until_complete_cb # type: ignore
from collections import OrderedDict, deque
from concurrent.futures import Future
from contextvars import copy_context
from dataclasses import dataclass
from functools import partial, wraps
from inspect import (
Expand Down Expand Up @@ -818,7 +819,11 @@ def wrapper() -> None:

f: concurrent.futures.Future[T_Retval] = Future()
loop = loop or threadlocals.loop
loop.call_soon_threadsafe(wrapper)
if sys.version_info < (3, 7):
loop.call_soon_threadsafe(copy_context().run, wrapper)
else:
loop.call_soon_threadsafe(wrapper)

return f.result()


Expand Down
25 changes: 23 additions & 2 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import socket
from concurrent.futures import Future
from contextvars import copy_context
from dataclasses import dataclass
from functools import partial
from io import IOBase
Expand Down Expand Up @@ -169,8 +170,28 @@ def wrapper() -> T_Retval:

return await run_sync(wrapper, cancellable=cancellable, limiter=limiter)

run_async_from_thread = trio.from_thread.run
run_sync_from_thread = trio.from_thread.run_sync

def run_async_from_thread(fn: Callable[..., Awaitable[T_Retval]], *args: Any) -> T_Retval:
async def wrapper() -> Optional[T_Retval]:
retval: T_Retval

async def inner() -> None:
nonlocal retval
__tracebackhide__ = True
retval = await fn(*args)

async with trio.open_nursery() as n:
ctx.run(n.start_soon, inner)

__tracebackhide__ = True
return retval

ctx = copy_context()
return trio.from_thread.run(wrapper)


def run_sync_from_thread(fn: Callable[..., T_Retval], *args: Any) -> T_Retval:
return trio.from_thread.run_sync(copy_context().run, fn, *args)


class BlockingPortal(abc.BlockingPortal):
Expand Down
25 changes: 25 additions & 0 deletions tests/test_from_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from concurrent.futures import CancelledError
from contextlib import suppress
from contextvars import ContextVar
from typing import Any, Dict, List, NoReturn, Optional

import pytest
Expand Down Expand Up @@ -117,6 +118,21 @@ async def foo() -> None:
exc = pytest.raises(RuntimeError, from_thread.run, foo)
exc.match('This function can only be run from an AnyIO worker thread')

async def test_contextvar_propagation(self, anyio_backend_name: str) -> None:
if anyio_backend_name == 'asyncio' and sys.version_info < (3, 7):
pytest.skip('Asyncio does not propagate context before Python 3.7')

var = ContextVar('var', default=1)

async def async_func() -> int:
return var.get()

def worker() -> int:
var.set(6)
return from_thread.run(async_func)

assert await to_thread.run_sync(worker) == 6


class TestRunSyncFromThread:
def test_run_sync_from_unclaimed_thread(self) -> None:
Expand All @@ -126,6 +142,15 @@ def foo() -> None:
exc = pytest.raises(RuntimeError, from_thread.run_sync, foo)
exc.match('This function can only be run from an AnyIO worker thread')

async def test_contextvar_propagation(self) -> None:
var = ContextVar('var', default=1)

def worker() -> int:
var.set(6)
return from_thread.run_sync(var.get)

assert await to_thread.run_sync(worker) == 6


class TestBlockingPortal:
class AsyncCM:
Expand Down

0 comments on commit 5de4ec1

Please sign in to comment.