From e7f33ba48cdb1dd8d234b41527876fdf2d6d2e3a Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Thu, 1 Jun 2023 21:29:48 +0200 Subject: [PATCH 1/5] Fix: no DB calls in message websocket Problem: the websocket uses the DB to apply filters to processed messages to determine whether they must be sent on the websocket. This leads to an exhaustion of DB connections when many websockets are open with the same filters. Solution: send the full message on the RabbitMQ topic and apply filters in Python instead of using a DB query. --- src/aleph/jobs/process_pending_messages.py | 4 +- src/aleph/schemas/api/messages.py | 10 +-- src/aleph/types/message_processing_result.py | 17 +++- src/aleph/web/controllers/messages.py | 84 ++++++++++++++------ 4 files changed, 81 insertions(+), 34 deletions(-) diff --git a/src/aleph/jobs/process_pending_messages.py b/src/aleph/jobs/process_pending_messages.py index 5a0c22d95..cf0292636 100644 --- a/src/aleph/jobs/process_pending_messages.py +++ b/src/aleph/jobs/process_pending_messages.py @@ -119,10 +119,10 @@ async def process_messages( async def publish_to_mq( self, message_iterator: AsyncIterator[Sequence[MessageProcessingResult]] ) -> AsyncIterator[Sequence[MessageProcessingResult]]: + async for processing_results in message_iterator: for result in processing_results: - body = {"item_hash": result.item_hash} - mq_message = aio_pika.Message(body=aleph_json.dumps(body)) + mq_message = aio_pika.Message(body=aleph_json.dumps(result.to_dict())) await self.mq_message_exchange.publish( mq_message, routing_key=f"{result.status.value}.{result.item_hash}", diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index e62b3fda3..e38524fe5 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -97,11 +97,6 @@ class StoreMessage( } -def format_message(message: Any) -> AlephMessage: - message_cls = MESSAGE_CLS_DICT[message.type] - return message_cls.from_orm(message) - - AlephMessage = Union[ AggregateMessage, ForgetMessage, @@ -112,6 +107,11 @@ def format_message(message: Any) -> AlephMessage: ] +def format_message(message: Any) -> AlephMessage: + message_cls = MESSAGE_CLS_DICT[message.type] + return message_cls.from_orm(message) + + class BaseMessageStatus(BaseModel): status: MessageStatus item_hash: str diff --git a/src/aleph/types/message_processing_result.py b/src/aleph/types/message_processing_result.py index 4a9de0f0d..3e71d1c04 100644 --- a/src/aleph/types/message_processing_result.py +++ b/src/aleph/types/message_processing_result.py @@ -1,6 +1,7 @@ -from typing import Protocol +from typing import Any, Dict, Protocol from aleph.db.models import PendingMessageDb, MessageDb +from aleph.schemas.api.messages import format_message from aleph.types.message_status import ( ErrorCode, MessageProcessingStatus, @@ -14,6 +15,9 @@ class MessageProcessingResult(Protocol): def item_hash(self) -> str: pass + def to_dict(self) -> Dict[str, Any]: + pass + class ProcessedMessage(MessageProcessingResult): def __init__(self, message: MessageDb, is_confirmation: bool = False): @@ -28,10 +32,11 @@ def __init__(self, message: MessageDb, is_confirmation: bool = False): def item_hash(self) -> str: return self.message.item_hash + def to_dict(self) -> Dict[str, Any]: + return {"status": self.status.value, "message": format_message(self.message)} -class FailedMessage(MessageProcessingResult): - status = MessageProcessingStatus.FAILED_WILL_RETRY +class FailedMessage(MessageProcessingResult): def __init__( self, pending_message: PendingMessageDb, error_code: ErrorCode, will_retry: bool ): @@ -48,6 +53,12 @@ def __init__( def item_hash(self) -> str: return self.pending_message.item_hash + def to_dict(self) -> Dict[str, Any]: + return { + "status": self.status.value, + "item_hash": self.item_hash, + } + class WillRetryMessage(FailedMessage): def __init__(self, pending_message: PendingMessageDb, error_code: ErrorCode): diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index bf5cf4883..6f062b6c5 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -27,14 +27,15 @@ ForgottenMessage, RejectedMessageStatus, PendingMessage, + AlephMessage, ) from aleph.types.db_session import DbSessionFactory, DbSession from aleph.types.message_status import MessageStatus from aleph.types.sort_order import SortOrder, SortBy from aleph.web.controllers.app_state_getters import ( get_session_factory_from_request, - get_mq_channel_from_request, - get_config_from_request, get_mq_ws_channel_from_request, + get_config_from_request, + get_mq_ws_channel_from_request, ) from aleph.web.controllers.utils import ( DEFAULT_MESSAGES_PER_PAGE, @@ -237,25 +238,68 @@ async def _send_history_to_ws( ws: aiohttp.web_ws.WebSocketResponse, session_factory: DbSessionFactory, history: int, - message_filters: Dict[str, Any], + query_params: WsMessageQueryParams, ) -> None: - with session_factory() as session: messages = get_matching_messages( session=session, pagination=history, include_confirmations=True, - **message_filters, + **query_params.dict(exclude_none=True), ) for message in messages: await ws.send_str(format_message(message).json()) +def message_matches_filters( + message: AlephMessage, query_params: WsMessageQueryParams +) -> bool: + if message_type := query_params.message_type: + if message.type != message_type: + return False + + # For simple filters, this reduces the amount of boilerplate + filters_by_message_field = { + "sender": "addresses", + "type": "message_type", + "item_hash": "hashes", + "ref": "refs", + "chain": "chains", + "channel": "channels", + } + + for message_field, query_field in filters_by_message_field.items(): + if user_filters := getattr(query_params, query_field): + if not isinstance(user_filters, list): + user_filters = [user_filters] + if not getattr(message, message_field) in user_filters: + return False + + # Process filters on content.content + content = getattr(message.content, "content") + if content_types := query_params.content_types: + content_type = getattr(content, "type") + if content_type not in content_types: + return False + + if content_hashes := query_params.content_hashes: + content_hash = getattr(content, "item_hash") + if content_hash not in content_hashes: + return False + + # For tags, we only need to match one filter + if query_tags := query_params.tags: + content_tags = set(getattr(content, "tags")) + if (content_tags & set(query_tags)) == set(): + return False + + return True + + async def _start_mq_consumer( ws: aiohttp.web_ws.WebSocketResponse, mq_queue: aio_pika.abc.AbstractQueue, - session_factory: DbSessionFactory, - message_filters: Dict[str, Any], + query_params: WsMessageQueryParams, ) -> aio_pika.abc.ConsumerTag: """ Starts the consumer task responsible for forwarding new aleph.im messages from @@ -263,24 +307,16 @@ async def _start_mq_consumer( :param ws: Websocket. :param mq_queue: Message queue object. - :param session_factory: DB session factory. - :param message_filters: Filters to apply to select outgoing messages. + :param query_params: Message filters specified by the caller. """ async def _process_message(mq_message: aio_pika.abc.AbstractMessage): - item_hash = aleph_json.loads(mq_message.body)["item_hash"] - # A bastardized way to apply the filters on the message as well. - # TODO: put the filter key/values in the RabbitMQ message? - with session_factory() as session: - matching_messages = get_matching_messages( - session=session, - hashes=[item_hash], - include_confirmations=True, - **message_filters, - ) + message_bytes = mq_message.body + message_dict = aleph_json.loads(message_bytes) + message = format_message(message_dict) + if message_matches_filters(message=message, query_params=query_params): try: - for message in matching_messages: - await ws.send_str(format_message(message).json()) + await ws.send_str(message_bytes.decode()) except ConnectionResetError: # We can detect the WS closing in this task in addition to the main one. # The main task will also detect the close event. @@ -307,6 +343,7 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: query_params = WsMessageQueryParams.parse_obj(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(body=e.json(indent=4)) + message_filters = query_params.dict(exclude_none=True) history = query_params.history @@ -316,7 +353,7 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: ws=ws, session_factory=session_factory, history=history, - message_filters=message_filters, + query_params=query_params, ) except ConnectionResetError: LOGGER.info("Could not send history, aborting message websocket") @@ -332,8 +369,7 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: consumer_tag = await _start_mq_consumer( ws=ws, mq_queue=mq_queue, - session_factory=session_factory, - message_filters=message_filters, + query_params=query_params, ) LOGGER.debug( "Started consuming mq %s for websocket. Consumer tag: %s", From e67f21cadc576430734576eb31cf98751d224bcf Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Mon, 12 Jun 2023 11:30:44 +0200 Subject: [PATCH 2/5] fix mypy --- src/aleph/schemas/api/messages.py | 5 ++--- src/aleph/web/controllers/messages.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index e38524fe5..3e35ddbd1 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -9,7 +9,6 @@ PostContent, ProgramContent, StoreContent, - AlephMessage, InstanceContent, ) from aleph_message.models import MessageType, ItemType @@ -107,7 +106,7 @@ class StoreMessage( ] -def format_message(message: Any) -> AlephMessage: +def format_message(message: Any) -> BaseMessage: message_cls = MESSAGE_CLS_DICT[message.type] return message_cls.from_orm(message) @@ -150,7 +149,7 @@ class Config: orm_mode = True status: MessageStatus = MessageStatus.PROCESSED - message: AlephMessage + message: BaseMessage class ForgottenMessage(BaseModel): diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 6f062b6c5..d0e79b27d 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -27,7 +27,7 @@ ForgottenMessage, RejectedMessageStatus, PendingMessage, - AlephMessage, + AlephMessage, BaseMessage, ) from aleph.types.db_session import DbSessionFactory, DbSession from aleph.types.message_status import MessageStatus @@ -252,7 +252,7 @@ async def _send_history_to_ws( def message_matches_filters( - message: AlephMessage, query_params: WsMessageQueryParams + message: BaseMessage, query_params: WsMessageQueryParams ) -> bool: if message_type := query_params.message_type: if message.type != message_type: From bb4818f91b6fabfef652a97be4243f49788d9be6 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Mon, 12 Jun 2023 16:31:46 +0200 Subject: [PATCH 3/5] fix tests --- src/aleph/schemas/api/messages.py | 38 +++++++++++++------- src/aleph/types/message_processing_result.py | 2 +- src/aleph/web/controllers/messages.py | 4 +-- tests/api/test_get_message.py | 2 +- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index 3e35ddbd1..9f00a51a1 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -1,5 +1,16 @@ import datetime as dt -from typing import Optional, Generic, TypeVar, Literal, List, Any, Union, Dict, Mapping +from typing import ( + Optional, + Generic, + TypeVar, + Literal, + List, + Any, + Union, + Dict, + Mapping, + Annotated, +) from aleph_message.models import ( AggregateContent, @@ -12,7 +23,7 @@ InstanceContent, ) from aleph_message.models import MessageType, ItemType -from pydantic import BaseModel +from pydantic import BaseModel, Field from pydantic.generics import GenericModel import aleph.toolkit.json as aleph_json @@ -96,19 +107,22 @@ class StoreMessage( } -AlephMessage = Union[ - AggregateMessage, - ForgetMessage, - InstanceMessage, - PostMessage, - ProgramMessage, - StoreMessage, +AlephMessage = Annotated[ + Union[ + AggregateMessage, + ForgetMessage, + InstanceMessage, + PostMessage, + ProgramMessage, + StoreMessage, + ], + Field(discriminator="type"), ] -def format_message(message: Any) -> BaseMessage: +def format_message(message: Any) -> AlephMessage: message_cls = MESSAGE_CLS_DICT[message.type] - return message_cls.from_orm(message) + return message_cls.from_orm(message) # type: ignore[return-value] class BaseMessageStatus(BaseModel): @@ -149,7 +163,7 @@ class Config: orm_mode = True status: MessageStatus = MessageStatus.PROCESSED - message: BaseMessage + message: AlephMessage class ForgottenMessage(BaseModel): diff --git a/src/aleph/types/message_processing_result.py b/src/aleph/types/message_processing_result.py index 3e71d1c04..9a386f143 100644 --- a/src/aleph/types/message_processing_result.py +++ b/src/aleph/types/message_processing_result.py @@ -33,7 +33,7 @@ def item_hash(self) -> str: return self.message.item_hash def to_dict(self) -> Dict[str, Any]: - return {"status": self.status.value, "message": format_message(self.message)} + return {"status": self.status.value, "message": format_message(self.message).dict()} class FailedMessage(MessageProcessingResult): diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index d0e79b27d..6f062b6c5 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -27,7 +27,7 @@ ForgottenMessage, RejectedMessageStatus, PendingMessage, - AlephMessage, BaseMessage, + AlephMessage, ) from aleph.types.db_session import DbSessionFactory, DbSession from aleph.types.message_status import MessageStatus @@ -252,7 +252,7 @@ async def _send_history_to_ws( def message_matches_filters( - message: BaseMessage, query_params: WsMessageQueryParams + message: AlephMessage, query_params: WsMessageQueryParams ) -> bool: if message_type := query_params.message_type: if message.type != message_type: diff --git a/tests/api/test_get_message.py b/tests/api/test_get_message.py index 447662062..c6df7b760 100644 --- a/tests/api/test_get_message.py +++ b/tests/api/test_get_message.py @@ -175,7 +175,7 @@ async def test_get_processed_message_status( response = await ccn_api_client.get( MESSAGE_URI.format(processed_message.item_hash) ) - assert response.status == 200 + assert response.status == 200, await response.text() response_json = await response.json() parsed_response = ProcessedMessageStatus.parse_obj(response_json) assert parsed_response.status == MessageStatus.PROCESSED From 9da795458f4d8e79f650a5f969f7b8c21f26e686 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Mon, 12 Jun 2023 17:54:33 +0200 Subject: [PATCH 4/5] more fixes --- src/aleph/schemas/api/messages.py | 15 +++++++++++--- src/aleph/web/controllers/messages.py | 29 +++++++++++++++++---------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index 9f00a51a1..ec16da209 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -27,6 +27,7 @@ from pydantic.generics import GenericModel import aleph.toolkit.json as aleph_json +from aleph.db.models import MessageDb from aleph.types.message_status import MessageStatus, ErrorCode MType = TypeVar("MType", bound=MessageType) @@ -120,9 +121,17 @@ class StoreMessage( ] -def format_message(message: Any) -> AlephMessage: - message_cls = MESSAGE_CLS_DICT[message.type] - return message_cls.from_orm(message) # type: ignore[return-value] +def format_message(message: MessageDb) -> AlephMessage: + message_type = message.type + + message_cls = MESSAGE_CLS_DICT[message_type] + return message_cls.from_orm(message) # type: ignore[return-value] + + +def format_message_dict(message: Dict[str, Any]) -> AlephMessage: + message_type = message.get("type") + message_cls = MESSAGE_CLS_DICT[message_type] + return message_cls.parse_obj(message) class BaseMessageStatus(BaseModel): diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 6f062b6c5..7838aef78 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -28,6 +28,8 @@ RejectedMessageStatus, PendingMessage, AlephMessage, + format_message_dict, + PostMessage, ) from aleph.types.db_session import DbSessionFactory, DbSession from aleph.types.message_status import MessageStatus @@ -162,7 +164,7 @@ class WsMessageQueryParams(BaseMessageQueryParams): ) -def format_message_dict(message: MessageDb) -> Dict[str, Any]: +def message_to_dict(message: MessageDb) -> Dict[str, Any]: message_dict = message.to_dict() message_dict["time"] = message.time.timestamp() confirmations = [ @@ -189,7 +191,7 @@ def format_response_dict( def format_response( messages: Iterable[MessageDb], pagination: int, page: int, total_messages: int ) -> web.Response: - formatted_messages = [format_message_dict(message) for message in messages] + formatted_messages = [message_to_dict(message) for message in messages] response = format_response_dict( messages=formatted_messages, @@ -275,21 +277,25 @@ def message_matches_filters( if not getattr(message, message_field) in user_filters: return False - # Process filters on content.content - content = getattr(message.content, "content") + # Process filters on content and content.content + message_content = message.content if content_types := query_params.content_types: - content_type = getattr(content, "type") + content_type = getattr(message_content, "type", None) if content_type not in content_types: return False if content_hashes := query_params.content_hashes: - content_hash = getattr(content, "item_hash") + content_hash = getattr(message_content, "item_hash", None) if content_hash not in content_hashes: return False # For tags, we only need to match one filter if query_tags := query_params.tags: - content_tags = set(getattr(content, "tags")) + nested_content = getattr(message.content, "content") + if not nested_content: + return False + + content_tags = set(getattr(nested_content, "tags", [])) if (content_tags & set(query_tags)) == set(): return False @@ -311,12 +317,13 @@ async def _start_mq_consumer( """ async def _process_message(mq_message: aio_pika.abc.AbstractMessage): - message_bytes = mq_message.body - message_dict = aleph_json.loads(message_bytes) - message = format_message(message_dict) + payload_bytes = mq_message.body + payload_dict = aleph_json.loads(payload_bytes) + message = format_message_dict(payload_dict["message"]) + if message_matches_filters(message=message, query_params=query_params): try: - await ws.send_str(message_bytes.decode()) + await ws.send_str(message.json()) except ConnectionResetError: # We can detect the WS closing in this task in addition to the main one. # The main task will also detect the close event. From faffe4e8faf9347daab2b979c54b674ee48aabb6 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Mon, 12 Jun 2023 18:07:21 +0200 Subject: [PATCH 5/5] fix mypy --- src/aleph/schemas/api/messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index ec16da209..68a999fcf 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -131,7 +131,7 @@ def format_message(message: MessageDb) -> AlephMessage: def format_message_dict(message: Dict[str, Any]) -> AlephMessage: message_type = message.get("type") message_cls = MESSAGE_CLS_DICT[message_type] - return message_cls.parse_obj(message) + return message_cls.parse_obj(message) # type: ignore[return-value] class BaseMessageStatus(BaseModel):