Skip to content

Commit

Permalink
Contextvars support (#3446)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Dec 15, 2018
1 parent b27fa73 commit dd30b2a
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 42 deletions.
10 changes: 8 additions & 2 deletions aiohttp/helpers.py
Expand Up @@ -20,8 +20,8 @@
from pathlib import Path
from types import TracebackType
from typing import (Any, Callable, Dict, Iterable, Iterator, List, # noqa
Mapping, Optional, Pattern, Tuple, Type, TypeVar, Union,
cast)
Mapping, Optional, Pattern, Set, Tuple, Type, TypeVar,
Union, cast)
from urllib.parse import quote
from urllib.request import getproxies

Expand Down Expand Up @@ -50,6 +50,12 @@
from typing_extensions import ContextManager


all_tasks = asyncio.Task.all_tasks

if PY_37:
all_tasks = getattr(asyncio, 'all_tasks') # use the trick to cheat mypy


_T = TypeVar('_T')


Expand Down
139 changes: 99 additions & 40 deletions aiohttp/web.py
Expand Up @@ -11,6 +11,7 @@
web_protocol, web_request, web_response, web_routedef,
web_runner, web_server, web_urldispatcher, web_ws)
from .abc import AbstractAccessLogger
from .helpers import all_tasks
from .log import access_logger
from .web_app import * # noqa
from .web_app import Application
Expand Down Expand Up @@ -52,42 +53,33 @@
SSLContext = Any # type: ignore


def run_app(app: Union[Application, Awaitable[Application]], *,
host: Optional[str]=None,
port: Optional[int]=None,
path: Optional[str]=None,
sock: Optional[socket.socket]=None,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
print: Callable[..., None]=print,
backlog: int=128,
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
access_log_format: str=AccessLogger.LOG_FORMAT,
access_log: logging.Logger=access_logger,
handle_signals: bool=True,
reuse_address: Optional[bool]=None,
reuse_port: Optional[bool]=None) -> None:
"""Run an app locally"""
loop = asyncio.get_event_loop()

async def _run_app(app: Union[Application, Awaitable[Application]], *,
host: Optional[str]=None,
port: Optional[int]=None,
path: Optional[str]=None,
sock: Optional[socket.socket]=None,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
print: Callable[..., None]=print,
backlog: int=128,
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
access_log_format: str=AccessLogger.LOG_FORMAT,
access_log: logging.Logger=access_logger,
handle_signals: bool=True,
reuse_address: Optional[bool]=None,
reuse_port: Optional[bool]=None) -> None:
# A internal functio to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = loop.run_until_complete(app) # type: ignore
app = await app # type: ignore

app = cast(Application, app)

# Configure if and only if in debugging mode and using the default logger
if loop.get_debug() and access_log.name == 'aiohttp.access':
if access_log.level == logging.NOTSET:
access_log.setLevel(logging.DEBUG)
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())

runner = AppRunner(app, handle_signals=handle_signals,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log)

loop.run_until_complete(runner.setup())
await runner.setup()

sites = [] # type: List[BaseSite]

Expand Down Expand Up @@ -141,21 +133,88 @@ def run_app(app: Union[Application, Awaitable[Application]], *,
ssl_context=ssl_context,
backlog=backlog))
for site in sites:
loop.run_until_complete(site.start())
try:
if print: # pragma: no branch
names = sorted(str(s.name) for s in runner.sites)
print("======== Running on {} ========\n"
"(Press CTRL+C to quit)".format(', '.join(names)))
loop.run_forever()
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
await site.start()

if print: # pragma: no branch
names = sorted(str(s.name) for s in runner.sites)
print("======== Running on {} ========\n"
"(Press CTRL+C to quit)".format(', '.join(names)))
while True:
await asyncio.sleep(3600) # sleep forever by 1 hour intervals
finally:
await runner.cleanup()


def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(
asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})


def run_app(app: Union[Application, Awaitable[Application]], *,
host: Optional[str]=None,
port: Optional[int]=None,
path: Optional[str]=None,
sock: Optional[socket.socket]=None,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
print: Callable[..., None]=print,
backlog: int=128,
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
access_log_format: str=AccessLogger.LOG_FORMAT,
access_log: logging.Logger=access_logger,
handle_signals: bool=True,
reuse_address: Optional[bool]=None,
reuse_port: Optional[bool]=None) -> None:
"""Run an app locally"""
loop = asyncio.get_event_loop()

# Configure if and only if in debugging mode and using the default logger
if loop.get_debug() and access_log.name == 'aiohttp.access':
if access_log.level == logging.NOTSET:
access_log.setLevel(logging.DEBUG)
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())

try:
loop.run_until_complete(_run_app(app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port))
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
finally:
loop.run_until_complete(runner.cleanup())
if sys.version_info >= (3, 6): # don't use PY_36 to pass mypy
if hasattr(loop, 'shutdown_asyncgens'):
_cancel_all_tasks(loop)
if sys.version_info >= (3, 6): # don't use PY_36 to pass mypy
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
loop.close()


def main(argv: List[str]) -> None:
Expand Down
Binary file added docs/old-logo.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 76 additions & 0 deletions docs/web_advanced.rst
Expand Up @@ -393,6 +393,82 @@ and safe to use as the key.
Otherwise, something based on your company name/url would be satisfactory (i.e.
``org.company.app``).


.. _aiohttp-web-contextvars:


ContextVars support
-------------------

Starting from Python 3.7 asyncio has :mod:`Context Variables <contextvars>` as a
context-local storage (a generalization of thread-local concept that works with asyncio
tasks also).


*aiohttp* server supports it in the following way:

* A server inherits the current task's context used when creating it.
:func:`aiohttp.web.run_app()` runs a task for handling all underlying jobs running
the app, but alternatively :ref:`aiohttp-web-app-runners` can be used.

* Application initialization / finalization events (:attr:`Application.cleanup_ctx`,
:attr:`Application.on_startup` and :attr:`Application.on_shutdown`,
:attr:`Application.on_cleanup`) are executed inside the same context.

E.g. all context modifications made on application startup a visible on teardown.

* On every request handling *aiohttp* creates a context copy. :term:`web-handler` has
all variables installed on initialization stage. But the context modification made by
a handler or middleware is invisible to another HTTP request handling call.

An example of context vars usage::

from contextvars import ContextVar

from aiohttp import web

VAR = ContextVar('VAR', default='default')


async def coro():
return VAR.get()


async def handler(request):
var = VAR.get()
VAR.set('handler')
ret = await coro()
return web.Response(text='\n'.join([var,
ret]))


async def on_startup(app):
print('on_startup', VAR.get())
VAR.set('on_startup')


async def on_cleanup(app):
print('on_cleanup', VAR.get())
VAR.set('on_cleanup')


async def init():
print('init', VAR.get())
VAR.set('init')
app = web.Application()
app.router.add_get('/', handler)

app.on_startup.append(on_startup)
app.on_cleanup.append(on_cleanup)
return app


web.run_app(init())
print('done', VAR.get())

.. versionadded:: 3.5


.. _aiohttp-web-middlewares:

Middlewares
Expand Down
101 changes: 101 additions & 0 deletions tests/test_run_app.py
Expand Up @@ -14,6 +14,7 @@
import pytest

from aiohttp import web
from aiohttp.helpers import PY_37
from aiohttp.test_utils import make_mocked_coro


Expand Down Expand Up @@ -621,3 +622,103 @@ def test_run_app_default_logger_setup_only_if_unconfigured(patched_loop):
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_called_with()
mock_logger.addHandler.assert_not_called()


def test_run_app_cancels_all_pending_tasks(patched_loop):
app = web.Application()
task = None

async def on_startup(app):
nonlocal task
loop = asyncio.get_event_loop()
task = loop.create_task(asyncio.sleep(1000))

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
assert task.cancelled()


def test_run_app_cancels_done_tasks(patched_loop):
app = web.Application()
task = None

async def coro():
return 123

async def on_startup(app):
nonlocal task
loop = asyncio.get_event_loop()
task = loop.create_task(coro())

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
assert task.done()


def test_run_app_cancels_failed_tasks(patched_loop):
app = web.Application()
task = None

exc = RuntimeError("FAIL")

async def fail():
try:
await asyncio.sleep(1000)
except asyncio.CancelledError:
raise exc

async def on_startup(app):
nonlocal task
loop = asyncio.get_event_loop()
task = loop.create_task(fail())
await asyncio.sleep(0.01)

app.on_startup.append(on_startup)

exc_handler = mock.Mock()
patched_loop.set_exception_handler(exc_handler)
web.run_app(app, print=stopper(patched_loop))
assert task.done()

msg = {
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': exc,
'task': task,
}
exc_handler.assert_called_with(patched_loop, msg)


@pytest.mark.skipif(not PY_37,
reason="contextvars support is required")
def test_run_app_context_vars(patched_loop):
from contextvars import ContextVar

count = 0
VAR = ContextVar('VAR', default='default')

async def on_startup(app):
nonlocal count
assert 'init' == VAR.get()
VAR.set('on_startup')
count += 1

async def on_cleanup(app):
nonlocal count
assert 'on_startup' == VAR.get()
count += 1

async def init():
nonlocal count
assert 'default' == VAR.get()
VAR.set('init')
app = web.Application()

app.on_startup.append(on_startup)
app.on_cleanup.append(on_cleanup)
count += 1
return app

web.run_app(init(), print=stopper(patched_loop))
assert count == 3

0 comments on commit dd30b2a

Please sign in to comment.