Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/aleph/toolkit/shield.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import asyncio
from functools import wraps


def shielded(func):
"""
Protects a coroutine from cancellation.
"""
@wraps(func)
async def wrapped(*args, **kwargs):
return await asyncio.shield(func(*args, **kwargs))

return wrapped
3 changes: 3 additions & 0 deletions src/aleph/web/controllers/messages.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import List, Optional, Any, Dict, Iterable

Expand Down Expand Up @@ -31,6 +32,7 @@
format_message_dict,
PostMessage,
)
from aleph.toolkit.shield import shielded
from aleph.types.db_session import DbSessionFactory, DbSession
from aleph.types.message_status import MessageStatus
from aleph.types.sort_order import SortOrder, SortBy
Expand Down Expand Up @@ -338,6 +340,7 @@ async def _process_message(mq_message: aio_pika.abc.AbstractMessage):
return consumer_tag


@shielded
async def messages_ws(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
Expand Down
2 changes: 2 additions & 0 deletions src/aleph/web/controllers/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from aleph.schemas.pending_messages import parse_message, BasePendingMessage
from aleph.services.ipfs import IpfsService
from aleph.services.p2p.pubsub import publish as pub_p2p
from aleph.toolkit.shield import shielded
from aleph.types.message_status import (
InvalidMessageException,
MessageStatus,
Expand Down Expand Up @@ -183,6 +184,7 @@ class PubMessageResponse(BaseModel):
message_status: Optional[MessageStatus]


@shielded
async def pub_message(request: web.Request):
try:
request_data = PubMessageRequest.parse_obj(await request.json())
Expand Down