diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 13a83ff5f..34a6a3478 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import List, Optional, Any, Dict, Iterable @@ -234,62 +233,66 @@ async def view_messages_list(request: web.Request) -> web.Response: ) -async def _message_ws_read_from_queue( +async def _send_history_to_ws( ws: aiohttp.web_ws.WebSocketResponse, - mq_queue: aio_pika.abc.AbstractQueue, - request: web.Request, + session_factory: DbSessionFactory, + history: int, + message_filters: Dict[str, Any], ) -> None: + + with session_factory() as session: + messages = get_matching_messages( + session=session, + pagination=history, + include_confirmations=True, + **message_filters, + ) + for message in messages: + await ws.send_str(format_message(message).json()) + + +async def _start_mq_consumer( + ws: aiohttp.web_ws.WebSocketResponse, + mq_queue: aio_pika.abc.AbstractQueue, + session_factory: DbSessionFactory, + message_filters: Dict[str, Any], +) -> aio_pika.abc.ConsumerTag: """ - Task receiving new aleph.im messages from the processing pipeline to a websocket. + Starts the consumer task responsible for forwarding new aleph.im messages from + the processing pipeline to a websocket. :param ws: Websocket. :param mq_queue: Message queue object. - :param request: Websocket HTTP request object. + :param session_factory: DB session factory. + :param message_filters: Filters to apply to select outgoing messages. """ - query_params = WsMessageQueryParams.parse_obj(request.query) - session_factory = get_session_factory_from_request(request) - - find_filters = query_params.dict(exclude_none=True) - history = query_params.history - - if history: + async def _process_message(mq_message: aio_pika.abc.AbstractMessage): + item_hash = aleph_json.loads(mq_message.body)["item_hash"] + # A bastardized way to apply the filters on the message as well. + # TODO: put the filter key/values in the RabbitMQ message? with session_factory() as session: - messages = get_matching_messages( + matching_messages = get_matching_messages( session=session, - pagination=history, + hashes=[item_hash], include_confirmations=True, - **find_filters, + **message_filters, ) - for message in messages: - await ws.send_str(format_message(message).json()) - - try: - async with mq_queue.iterator(no_ack=True) as queue_iter: - async for mq_message in queue_iter: - item_hash = aleph_json.loads(mq_message.body)["item_hash"] - # A bastardized way to apply the filters on the message as well. - # TODO: put the filter key/values in the RabbitMQ message? - with session_factory() as session: - matching_messages = get_matching_messages( - session=session, - hashes=[item_hash], - include_confirmations=True, - **find_filters, - ) - for message in matching_messages: - await ws.send_str(format_message(message).json()) - - except ConnectionResetError: - # We can detect the WS closing in this task in addition to the main one. - # warning. The main task will also detect the close event. - # We ignore this exception to avoid the "task exception was never retrieved" - LOGGER.info("Cannot send messages because the websocket is closed") - pass - - except asyncio.CancelledError: - LOGGER.info("MQ -> WS task cancelled") - raise + try: + for message in matching_messages: + await ws.send_str(format_message(message).json()) + except ConnectionResetError: + # We can detect the WS closing in this task in addition to the main one. + # The main task will also detect the close event. + # We just ignore this exception to avoid the "task exception was never retrieved" + # warning. + LOGGER.info("Cannot send messages because the websocket is closed") + + # Note that we use the consume pattern here instead of using the `queue.iterator()` + # pattern because cancelling the iterator attempts to close the queue and channel. + # See discussion here: https://github.com/mosquito/aio-pika/issues/358 + consumer_tag = await mq_queue.consume(_process_message, no_ack=True) + return consumer_tag async def messages_ws(request: web.Request) -> web.WebSocketResponse: @@ -297,39 +300,68 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: await ws.prepare(request) config = get_config_from_request(request) + session_factory = get_session_factory_from_request(request) mq_channel = get_mq_channel_from_request(request) + try: + query_params = WsMessageQueryParams.parse_obj(request.query) + except ValidationError as e: + raise web.HTTPUnprocessableEntity(body=e.json(indent=4)) + message_filters = query_params.dict(exclude_none=True) + history = query_params.history + + if history: + try: + await _send_history_to_ws( + ws=ws, + session_factory=session_factory, + history=history, + message_filters=message_filters, + ) + except ConnectionResetError: + LOGGER.info("Could not send history, aborting message websocket") + return ws + mq_queue = await mq_make_aleph_message_topic_queue( channel=mq_channel, config=config, routing_key="processed.*" ) + consumer_tag = None - # Start a task to handle outgoing traffic to the websocket. - queue_task = asyncio.create_task( - _message_ws_read_from_queue( + try: + # Start a task to handle outgoing traffic to the websocket. + consumer_tag = await _start_mq_consumer( ws=ws, - request=request, mq_queue=mq_queue, + session_factory=session_factory, + message_filters=message_filters, + ) + LOGGER.debug( + "Started consuming mq %s for websocket. Consumer tag: %s", + mq_queue.name, + consumer_tag, ) - ) - # Wait for the websocket to close. - try: + # Wait for the websocket to close. while not ws.closed: # Users can potentially send anything to the websocket. Ignore these messages # and only handle "close" messages. ws_msg = await ws.receive() - LOGGER.info("rx ws msg: %s", str(ws_msg)) + LOGGER.debug("rx ws msg: %s", str(ws_msg)) if ws_msg.type == WSMsgType.CLOSE: LOGGER.debug("ws close received") break finally: - # Cancel the MQ -> ws task - queue_task.cancel() - await asyncio.wait([queue_task]) - - # Always delete the queue, auto-delete queues are only deleted once the channel is closed - # and that's not meant to happen for the API. + # In theory, we should cancel the consumer with `mq_queue.cancel()` before deleting the queue. + # In practice, this sometimes leads to an RPC timeout that closes the channel. + # To avoid this situation, we just delete the queue directly. + # Note that even if the queue is in auto-delete mode, it will only be deleted automatically + # once the channel closes. We delete it manually to avoid keeping queues around. + if consumer_tag: + LOGGER.info("Deleting consumer %s (queue: %s)", consumer_tag, mq_queue.name) + await mq_queue.cancel(consumer_tag=consumer_tag) + + LOGGER.info("Deleting queue: %s", mq_queue.name) await mq_queue.delete(if_unused=False, if_empty=False) return ws diff --git a/src/aleph/web/controllers/p2p.py b/src/aleph/web/controllers/p2p.py index f44753eab..0e4c274f4 100644 --- a/src/aleph/web/controllers/p2p.py +++ b/src/aleph/web/controllers/p2p.py @@ -142,21 +142,25 @@ async def pub_json(request: web.Request): async def _mq_read_one_message( - queue: aio_pika.abc.AbstractQueue, timeout: float + mq_queue: aio_pika.abc.AbstractQueue, timeout: float ) -> Optional[aio_pika.abc.AbstractIncomingMessage]: """ - Believe it or not, this is the only way I found to - :return: + Consume one element from a message queue and then return. """ - try: - async with queue.iterator(timeout=timeout, no_ack=True) as queue_iter: - async for message in queue_iter: - return message - except asyncio.TimeoutError: - pass + queue: asyncio.Queue = asyncio.Queue() + + async def _process_message(message: aio_pika.abc.AbstractMessage): + await queue.put(message) - return None + consumer_tag = await mq_queue.consume(_process_message, no_ack=True) + + try: + return await asyncio.wait_for(queue.get(), timeout) + except asyncio.TimeoutError: + return None + finally: + await mq_queue.cancel(consumer_tag) def _processing_status_to_http_status(status: MessageProcessingStatus) -> int: diff --git a/src/aleph/web/controllers/utils.py b/src/aleph/web/controllers/utils.py index aec2bb9eb..b8267cb3b 100644 --- a/src/aleph/web/controllers/utils.py +++ b/src/aleph/web/controllers/utils.py @@ -152,6 +152,11 @@ async def mq_make_aleph_message_topic_queue( type=aio_pika.ExchangeType.TOPIC, auto_delete=False, ) - mq_queue = await channel.declare_queue(auto_delete=True) + mq_queue = await channel.declare_queue( + auto_delete=True, exclusive=True, + # Auto-delete the queue after 30 seconds. This guarantees that queues are deleted even + # if a bug makes the consumer crash before cleanup. + arguments={"x-expires": 30000} + ) await mq_queue.bind(mq_message_exchange, routing_key=routing_key) return mq_queue