Skip to content

Commit

Permalink
runtornado: Switch to asyncio event loop.
Browse files Browse the repository at this point in the history
Signed-off-by: Anders Kaseorg <anders@zulip.com>
  • Loading branch information
andersk authored and alexmv committed May 3, 2022
1 parent c263bfd commit 6fd1a55
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 89 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ no_implicit_reexport = false
module = [
"ahocorasick.*",
"aioapns.*",
"asgiref.*",
"bitfield.*",
"bmemcached.*",
"bson.*",
Expand Down
16 changes: 16 additions & 0 deletions zerver/lib/async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import asyncio


class NoAutoCreateEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore[misc,valid-type] # https://github.com/python/typeshed/issues/7452
"""
By default asyncio.get_event_loop() automatically creates an event
loop for the main thread if one isn't currently installed. Since
Django intentionally uninstalls the event loop within
sync_to_async, that autocreation proliferates confusing extra
event loops that will never be run. It is also deprecated in
Python 3.10. This policy disables it so we don't rely on it by
accident.
"""

def get_event_loop(self) -> asyncio.AbstractEventLoop: # nocoverage
return asyncio.get_running_loop()
3 changes: 3 additions & 0 deletions zerver/lib/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ def _on_connection_open_error(
def _on_connection_closed(
self, connection: pika.connection.Connection, reason: Exception
) -> None:
if self.connection is None:
return
self._connection_failure_count = 1
retry_secs = self.CONNECTION_RETRY_SECS
self.log.warning(
Expand Down Expand Up @@ -335,6 +337,7 @@ def _on_channel_open(self, channel: Channel) -> None:
def close(self) -> None:
if self.connection is not None:
self.connection.close()
self.connection = None

def ensure_queue(self, queue_name: str, callback: Callable[[Channel], object]) -> None:
def set_qos(frame: Any) -> None:
Expand Down
79 changes: 55 additions & 24 deletions zerver/management/commands/runtornado.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import asyncio
import logging
import sys
import signal
from contextlib import AsyncExitStack
from typing import Any
from urllib.parse import SplitResult

import __main__
from asgiref.sync import async_to_sync, sync_to_async
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError, CommandParser
from tornado import autoreload, ioloop
from tornado import autoreload
from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future

settings.RUNNING_INSIDE_TORNADO = True

from zerver.lib.async_utils import NoAutoCreateEventLoopPolicy
from zerver.lib.debug import interactive_debug_listen
from zerver.tornado.application import create_tornado_application, setup_tornado_rabbitmq
from zerver.tornado.event_queue import (
add_client_gc_hook,
dump_event_queues,
get_wrapped_process_notification,
missedmessage_hook,
setup_event_queue,
Expand All @@ -23,6 +29,8 @@
if settings.USING_RABBITMQ:
from zerver.lib.queue import TornadoQueueClient, set_queue_client

asyncio.set_event_loop_policy(NoAutoCreateEventLoopPolicy())


class Command(BaseCommand):
help = "Starts a Tornado Web server wrapping Django."
Expand Down Expand Up @@ -56,53 +64,76 @@ def handle(self, *args: Any, **options: Any) -> None:
level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s"
)

def inner_run() -> None:
async def inner_run() -> None:
from django.utils import translation

translation.activate(settings.LANGUAGE_CODE)
AsyncIOMainLoop().install()
loop = asyncio.get_running_loop()
stop_fut = loop.create_future()

def stop() -> None:
if not stop_fut.done():
stop_fut.set_result(None)

def add_signal_handlers() -> None:
loop.add_signal_handler(signal.SIGINT, stop),
loop.add_signal_handler(signal.SIGTERM, stop),

# We pass display_num_errors=False, since Django will
# likely display similar output anyway.
self.check(display_num_errors=False)
print(f"Tornado server (re)started on port {port}")
def remove_signal_handlers() -> None:
loop.remove_signal_handler(signal.SIGINT),
loop.remove_signal_handler(signal.SIGTERM),

if settings.USING_RABBITMQ:
queue_client = TornadoQueueClient()
set_queue_client(queue_client)
# Process notifications received via RabbitMQ
queue_name = notify_tornado_queue_name(port)
queue_client.start_json_consumer(
queue_name, get_wrapped_process_notification(queue_name)
async with AsyncExitStack() as stack:
stack.push_async_callback(
sync_to_async(remove_signal_handlers, thread_sensitive=True)
)
await sync_to_async(add_signal_handlers, thread_sensitive=True)()

translation.activate(settings.LANGUAGE_CODE)

# We pass display_num_errors=False, since Django will
# likely display similar output anyway.
self.check(display_num_errors=False)
print(f"Tornado server (re)started on port {port}")

if settings.USING_RABBITMQ:
queue_client = TornadoQueueClient()
set_queue_client(queue_client)
# Process notifications received via RabbitMQ
queue_name = notify_tornado_queue_name(port)
stack.callback(queue_client.close)
queue_client.start_json_consumer(
queue_name, get_wrapped_process_notification(queue_name)
)

try:
# Application is an instance of Django's standard wsgi handler.
application = create_tornado_application()

# start tornado web server in single-threaded mode
http_server = httpserver.HTTPServer(application, xheaders=True)
stack.push_async_callback(
lambda: to_asyncio_future(http_server.close_all_connections())
)
stack.callback(http_server.stop)
http_server.listen(port, address=addr)

from zerver.tornado.ioloop_logging import logging_data

logging_data["port"] = str(port)
setup_event_queue(http_server, port)
await setup_event_queue(http_server, port)
stack.callback(dump_event_queues, port)
add_client_gc_hook(missedmessage_hook)
if settings.USING_RABBITMQ:
setup_tornado_rabbitmq(queue_client)

instance = ioloop.IOLoop.instance()

if hasattr(__main__, "add_reload_hook"):
autoreload.start()

instance.start()
except KeyboardInterrupt:
await stop_fut

# Monkey patch tornado.autoreload to prevent it from continuing
# to watch for changes after catching our SystemExit. Otherwise
# the user needs to press Ctrl+C twice.
__main__.wait = lambda: None

sys.exit(0)

inner_run()
async_to_sync(inner_run, force_new_loop=True)()
87 changes: 60 additions & 27 deletions zerver/tests/test_tornado.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,95 @@
import asyncio
import urllib.parse
from typing import Any, Dict, Optional
from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar
from unittest import TestResult

import orjson
from asgiref.sync import async_to_sync, sync_to_async
from django.conf import settings
from django.core import signals
from django.db import close_old_connections
from django.test import override_settings
from tornado.httpclient import HTTPResponse
from tornado.testing import AsyncHTTPTestCase
from tornado.ioloop import IOLoop
from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase
from tornado.web import Application
from typing_extensions import ParamSpec

from zerver.lib.test_classes import ZulipTestCase
from zerver.tornado import event_queue
from zerver.tornado.application import create_tornado_application
from zerver.tornado.event_queue import process_event

P = ParamSpec("P")
T = TypeVar("T")


def async_to_sync_decorator(f: Callable[P, Awaitable[T]]) -> Callable[P, T]:
@wraps(f)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
return async_to_sync(f)(*args, **kwargs)

return wrapped


async def in_django_thread(f: Callable[[], T]) -> T:
return await asyncio.create_task(sync_to_async(f)())


class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase):
def setUp(self) -> None:
@async_to_sync_decorator
async def setUp(self) -> None:
super().setUp()
signals.request_started.disconnect(close_old_connections)
signals.request_finished.disconnect(close_old_connections)
self.session_cookie: Optional[Dict[str, str]] = None

def tearDown(self) -> None:
super().tearDown()
self.session_cookie = None
@async_to_sync_decorator
async def tearDown(self) -> None:
# Skip tornado.testing.AsyncTestCase.tearDown because it tries to kill
# the current task.
super(AsyncTestCase, self).tearDown()

def run(self, result: Optional[TestResult] = None) -> Optional[TestResult]:
return async_to_sync(
sync_to_async(super().run, thread_sensitive=False), force_new_loop=True
)(result)

def get_new_ioloop(self) -> IOLoop:
return AsyncIOMainLoop()

@override_settings(DEBUG=False)
def get_app(self) -> Application:
return create_tornado_application()

def tornado_client_get(self, path: str, **kwargs: Any) -> HTTPResponse:
async def tornado_client_get(self, path: str, **kwargs: Any) -> HTTPResponse:
self.add_session_cookie(kwargs)
kwargs["skip_user_agent"] = True
self.set_http_headers(kwargs)
if "HTTP_HOST" in kwargs:
kwargs["headers"]["Host"] = kwargs["HTTP_HOST"]
del kwargs["HTTP_HOST"]
return self.fetch(path, method="GET", **kwargs)
return await to_asyncio_future(
self.http_client.fetch(self.get_url(path), method="GET", **kwargs)
)

def fetch_async(self, method: str, path: str, **kwargs: Any) -> None:
async def fetch_async(self, method: str, path: str, **kwargs: Any) -> HTTPResponse:
self.add_session_cookie(kwargs)
kwargs["skip_user_agent"] = True
self.set_http_headers(kwargs)
if "HTTP_HOST" in kwargs:
kwargs["headers"]["Host"] = kwargs["HTTP_HOST"]
del kwargs["HTTP_HOST"]
self.http_client.fetch(
self.get_url(path),
self.stop,
method=method,
**kwargs,
return await to_asyncio_future(
self.http_client.fetch(self.get_url(path), method=method, **kwargs)
)

def client_get_async(self, path: str, **kwargs: Any) -> None:
async def client_get_async(self, path: str, **kwargs: Any) -> HTTPResponse:
kwargs["skip_user_agent"] = True
self.set_http_headers(kwargs)
self.fetch_async("GET", path, **kwargs)
return await self.fetch_async("GET", path, **kwargs)

def login_user(self, *args: Any, **kwargs: Any) -> None:
super().login_user(*args, **kwargs)
Expand All @@ -76,8 +108,8 @@ def add_session_cookie(self, kwargs: Dict[str, Any]) -> None:
headers.update(self.get_session_cookie())
kwargs["headers"] = headers

def create_queue(self, **kwargs: Any) -> str:
response = self.tornado_client_get(
async def create_queue(self, **kwargs: Any) -> str:
response = await self.tornado_client_get(
"/json/events?dont_block=true",
subdomain="zulip",
skip_user_agent=True,
Expand All @@ -90,22 +122,23 @@ def create_queue(self, **kwargs: Any) -> str:


class EventsTestCase(TornadoWebTestCase):
def test_create_queue(self) -> None:
self.login_user(self.example_user("hamlet"))
queue_id = self.create_queue()
@async_to_sync_decorator
async def test_create_queue(self) -> None:
await in_django_thread(lambda: self.login_user(self.example_user("hamlet")))
queue_id = await self.create_queue()
self.assertIn(queue_id, event_queue.clients)

def test_events_async(self) -> None:
user_profile = self.example_user("hamlet")
self.login_user(user_profile)
event_queue_id = self.create_queue()
@async_to_sync_decorator
async def test_events_async(self) -> None:
user_profile = await in_django_thread(lambda: self.example_user("hamlet"))
await in_django_thread(lambda: self.login_user(user_profile))
event_queue_id = await self.create_queue()
data = {
"queue_id": event_queue_id,
"last_event_id": -1,
}

path = f"/json/events?{urllib.parse.urlencode(data)}"
self.client_get_async(path)

def process_events() -> None:
users = [user_profile.id]
Expand All @@ -116,7 +149,7 @@ def process_events() -> None:
process_event(event, users)

self.io_loop.call_later(0.1, process_events)
response = self.wait()
response = await self.client_get_async(path)
self.assertEqual(response.headers["Vary"], "Accept-Language, Cookie")
data = orjson.loads(response.body)
self.assertEqual(
Expand Down
3 changes: 0 additions & 3 deletions zerver/tornado/application.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import atexit

import tornado.web
from django.conf import settings
from tornado import autoreload
Expand All @@ -10,7 +8,6 @@

def setup_tornado_rabbitmq(queue_client: TornadoQueueClient) -> None: # nocoverage
# When tornado is shut down, disconnect cleanly from RabbitMQ
atexit.register(lambda: queue_client.close())
autoreload.add_reload_hook(lambda: queue_client.close())


Expand Down

0 comments on commit 6fd1a55

Please sign in to comment.