Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor workers #3471

Merged
merged 4 commits into from
Dec 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/3471.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Use the same task for app initialization and web server handling in gunicorn workers.
It allows to use Python3.7 context vars smoothly.
42 changes: 23 additions & 19 deletions aiohttp/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import signal
import sys
from types import FrameType
from typing import Any, Optional # noqa
from typing import Any, Awaitable, Callable, Optional, Union # noqa

from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
from gunicorn.workers import base

from aiohttp import web

from .helpers import set_result
from .web_app import Application
from .web_log import AccessLogger

try:
Expand All @@ -37,7 +38,6 @@ class GunicornWebWorker(base.Worker):
def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover
super().__init__(*args, **kw)

self._runner = None # type: Optional[web.AppRunner]
self._task = None # type: Optional[asyncio.Task[None]]
self.exit_code = 0
self._notify_waiter = None # type: Optional[asyncio.Future[bool]]
Expand All @@ -52,35 +52,39 @@ def init_process(self) -> None:
super().init_process()

def run(self) -> None:
access_log = self.log.access_log if self.cfg.accesslog else None
params = dict(
logger=self.log,
keepalive_timeout=self.cfg.keepalive,
access_log=access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
if asyncio.iscoroutinefunction(self.wsgi): # type: ignore
self.wsgi = self.loop.run_until_complete(
self.wsgi()) # type: ignore
self._runner = web.AppRunner(self.wsgi, **params)
self.loop.run_until_complete(self._runner.setup())
self._task = self.loop.create_task(self._run())

try: # ignore all finalization problems
self.loop.run_until_complete(self._task)
except Exception as error:
self.log.exception(error)
except Exception:
self.log.exception("Exception in gunicorn worker")
if sys.version_info >= (3, 6):
if hasattr(self.loop, 'shutdown_asyncgens'):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()

sys.exit(self.exit_code)

async def _run(self) -> None:
if isinstance(self.wsgi, Application):
app = self.wsgi
elif asyncio.iscoroutinefunction(self.wsgi):
app = await self.wsgi()
else:
raise RuntimeError("wsgi app should be either Application or "
"async function returning Application, got {}"
.format(self.wsgi))
access_log = self.log.access_log if self.cfg.accesslog else None
runner = web.AppRunner(app,
logger=self.log,
keepalive_timeout=self.cfg.keepalive,
access_log=access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
await runner.setup()

ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None

runner = self._runner
runner = runner
assert runner is not None
server = runner.server
assert server is not None
Expand Down
70 changes: 25 additions & 45 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest

from aiohttp import web
from aiohttp.test_utils import make_mocked_coro

base_worker = pytest.importorskip('aiohttp.worker')

Expand Down Expand Up @@ -42,13 +41,15 @@ def __init__(self):
self.wsgi = web.Application()


class AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker): # type: ignore # noqa
class AsyncioWorker(BaseTestWorker, # type: ignore
base_worker.GunicornWebWorker):
pass


PARAMS = [AsyncioWorker]
if uvloop is not None:
class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker): # type: ignore # noqa
class UvloopWorker(BaseTestWorker, # type: ignore
base_worker.GunicornUVLoopWebWorker):
pass

PARAMS.append(UvloopWorker)
Expand Down Expand Up @@ -78,30 +79,47 @@ def test_run(worker, loop) -> None:
worker.log = mock.Mock()
worker.cfg = mock.Mock()
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
worker.cfg.is_ssl = False
worker.sockets = []

worker.loop = loop
worker._run = make_mocked_coro(None)
with pytest.raises(SystemExit):
worker.run()
assert worker._run.called
worker.log.exception.assert_not_called()
assert loop.is_closed()


def test_run_async_factory(worker, loop) -> None:
worker.log = mock.Mock()
worker.cfg = mock.Mock()
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
worker.cfg.is_ssl = False
worker.sockets = []
app = worker.wsgi

async def make_app():
return app
worker.wsgi = make_app

worker.loop = loop
worker._run = make_mocked_coro(None)
worker.alive = False
with pytest.raises(SystemExit):
worker.run()
worker.log.exception.assert_not_called()
assert loop.is_closed()


def test_run_not_app(worker, loop) -> None:
worker.log = mock.Mock()
worker.cfg = mock.Mock()
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT

worker.loop = loop
worker.wsgi = "not-app"
worker.alive = False
with pytest.raises(SystemExit):
worker.run()
assert worker._run.called
worker.log.exception.assert_called_with('Exception in gunicorn worker')
assert loop.is_closed()


Expand Down Expand Up @@ -197,15 +215,11 @@ async def test__run_ok_parent_changed(worker, loop,
worker.cfg.max_requests = 0
worker.cfg.is_ssl = False

worker._runner = web.AppRunner(worker.wsgi)
await worker._runner.setup()

await worker._run()

worker.notify.assert_called_with()
worker.log.info.assert_called_with("Parent changed, shutting down: %s",
worker)
assert worker._runner.server is None


async def test__run_exc(worker, loop, aiohttp_unused_port) -> None:
Expand All @@ -223,9 +237,6 @@ async def test__run_exc(worker, loop, aiohttp_unused_port) -> None:
worker.cfg.max_requests = 0
worker.cfg.is_ssl = False

worker._runner = web.AppRunner(worker.wsgi)
await worker._runner.setup()

def raiser():
waiter = worker._notify_waiter
worker.alive = False
Expand All @@ -235,37 +246,6 @@ def raiser():
await worker._run()

worker.notify.assert_called_with()
assert worker._runner.server is None


async def test__run_ok_max_requests_exceeded(worker, loop,
aiohttp_unused_port):
skip_if_no_dict(loop)

worker.ppid = os.getppid()
worker.alive = True
worker.servers = {}
sock = socket.socket()
addr = ('localhost', aiohttp_unused_port())
sock.bind(addr)
worker.sockets = [sock]
worker.log = mock.Mock()
worker.loop = loop
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
worker.cfg.max_requests = 10
worker.cfg.is_ssl = False

worker._runner = web.AppRunner(worker.wsgi)
await worker._runner.setup()
worker._runner.server.requests_count = 30

await worker._run()

worker.notify.assert_called_with()
worker.log.info.assert_called_with("Max requests, shutting down: %s",
worker)

assert worker._runner.server is None


def test__create_ssl_context_without_certs_and_ciphers(worker) -> None:
Expand Down