From 6ae5e0fbdac144a08680cd8bae0c1242b715ed80 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Wed, 14 Jun 2023 11:42:48 +0200 Subject: [PATCH] Fix: prevent cancellation of the message websocket Problem: resources are not cleaned up properly when the message websocket handler is cancelled. This is because aiohttp appears to raise a CancellationError. Solution: use `asyncio.shield()` on the whole handler. Added a decorator to avoid splitting the implementation of shielded functions. Protected the POST /messages endpoint as well as it also uses a message queue. fix --- src/aleph/toolkit/shield.py | 13 +++++++++++++ src/aleph/web/controllers/messages.py | 3 +++ src/aleph/web/controllers/p2p.py | 2 ++ 3 files changed, 18 insertions(+) create mode 100644 src/aleph/toolkit/shield.py diff --git a/src/aleph/toolkit/shield.py b/src/aleph/toolkit/shield.py new file mode 100644 index 000000000..7c5806224 --- /dev/null +++ b/src/aleph/toolkit/shield.py @@ -0,0 +1,13 @@ +import asyncio +from functools import wraps + + +def shielded(func): + """ + Protects a coroutine from cancellation. + """ + @wraps(func) + async def wrapped(*args, **kwargs): + return await asyncio.shield(func(*args, **kwargs)) + + return wrapped diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 7838aef78..61ea0f75f 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import List, Optional, Any, Dict, Iterable @@ -31,6 +32,7 @@ format_message_dict, PostMessage, ) +from aleph.toolkit.shield import shielded from aleph.types.db_session import DbSessionFactory, DbSession from aleph.types.message_status import MessageStatus from aleph.types.sort_order import SortOrder, SortBy @@ -338,6 +340,7 @@ async def _process_message(mq_message: aio_pika.abc.AbstractMessage): return consumer_tag +@shielded async def messages_ws(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) diff --git a/src/aleph/web/controllers/p2p.py b/src/aleph/web/controllers/p2p.py index 3abecb8bb..2a0406480 100644 --- a/src/aleph/web/controllers/p2p.py +++ b/src/aleph/web/controllers/p2p.py @@ -13,6 +13,7 @@ from aleph.schemas.pending_messages import parse_message, BasePendingMessage from aleph.services.ipfs import IpfsService from aleph.services.p2p.pubsub import publish as pub_p2p +from aleph.toolkit.shield import shielded from aleph.types.message_status import ( InvalidMessageException, MessageStatus, @@ -183,6 +184,7 @@ class PubMessageResponse(BaseModel): message_status: Optional[MessageStatus] +@shielded async def pub_message(request: web.Request): try: request_data = PubMessageRequest.parse_obj(await request.json())