Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/aleph/jobs/process_pending_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ async def process_messages(
async def publish_to_mq(
self, message_iterator: AsyncIterator[Sequence[MessageProcessingResult]]
) -> AsyncIterator[Sequence[MessageProcessingResult]]:

async for processing_results in message_iterator:
for result in processing_results:
body = {"item_hash": result.item_hash}
mq_message = aio_pika.Message(body=aleph_json.dumps(body))
mq_message = aio_pika.Message(body=aleph_json.dumps(result.to_dict()))
await self.mq_message_exchange.publish(
mq_message,
routing_key=f"{result.status.value}.{result.item_hash}",
Expand Down
50 changes: 36 additions & 14 deletions src/aleph/schemas/api/messages.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import datetime as dt
from typing import Optional, Generic, TypeVar, Literal, List, Any, Union, Dict, Mapping
from typing import (
Optional,
Generic,
TypeVar,
Literal,
List,
Any,
Union,
Dict,
Mapping,
Annotated,
)

from aleph_message.models import (
AggregateContent,
Expand All @@ -9,14 +20,14 @@
PostContent,
ProgramContent,
StoreContent,
AlephMessage,
InstanceContent,
)
from aleph_message.models import MessageType, ItemType
from pydantic import BaseModel
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel

import aleph.toolkit.json as aleph_json
from aleph.db.models import MessageDb
from aleph.types.message_status import MessageStatus, ErrorCode

MType = TypeVar("MType", bound=MessageType)
Expand Down Expand Up @@ -97,19 +108,30 @@ class StoreMessage(
}


def format_message(message: Any) -> AlephMessage:
message_cls = MESSAGE_CLS_DICT[message.type]
return message_cls.from_orm(message)
AlephMessage = Annotated[
Union[
AggregateMessage,
ForgetMessage,
InstanceMessage,
PostMessage,
ProgramMessage,
StoreMessage,
],
Field(discriminator="type"),
]


AlephMessage = Union[
AggregateMessage,
ForgetMessage,
InstanceMessage,
PostMessage,
ProgramMessage,
StoreMessage,
]
def format_message(message: MessageDb) -> AlephMessage:
message_type = message.type

message_cls = MESSAGE_CLS_DICT[message_type]
return message_cls.from_orm(message) # type: ignore[return-value]


def format_message_dict(message: Dict[str, Any]) -> AlephMessage:
message_type = message.get("type")
message_cls = MESSAGE_CLS_DICT[message_type]
return message_cls.parse_obj(message) # type: ignore[return-value]


class BaseMessageStatus(BaseModel):
Expand Down
17 changes: 14 additions & 3 deletions src/aleph/types/message_processing_result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Protocol
from typing import Any, Dict, Protocol

from aleph.db.models import PendingMessageDb, MessageDb
from aleph.schemas.api.messages import format_message
from aleph.types.message_status import (
ErrorCode,
MessageProcessingStatus,
Expand All @@ -14,6 +15,9 @@ class MessageProcessingResult(Protocol):
def item_hash(self) -> str:
pass

def to_dict(self) -> Dict[str, Any]:
pass


class ProcessedMessage(MessageProcessingResult):
def __init__(self, message: MessageDb, is_confirmation: bool = False):
Expand All @@ -28,10 +32,11 @@ def __init__(self, message: MessageDb, is_confirmation: bool = False):
def item_hash(self) -> str:
return self.message.item_hash

def to_dict(self) -> Dict[str, Any]:
return {"status": self.status.value, "message": format_message(self.message).dict()}

class FailedMessage(MessageProcessingResult):
status = MessageProcessingStatus.FAILED_WILL_RETRY

class FailedMessage(MessageProcessingResult):
def __init__(
self, pending_message: PendingMessageDb, error_code: ErrorCode, will_retry: bool
):
Expand All @@ -48,6 +53,12 @@ def __init__(
def item_hash(self) -> str:
return self.pending_message.item_hash

def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status.value,
"item_hash": self.item_hash,
}


class WillRetryMessage(FailedMessage):
def __init__(self, pending_message: PendingMessageDb, error_code: ErrorCode):
Expand Down
95 changes: 69 additions & 26 deletions src/aleph/web/controllers/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@
ForgottenMessage,
RejectedMessageStatus,
PendingMessage,
AlephMessage,
format_message_dict,
PostMessage,
)
from aleph.types.db_session import DbSessionFactory, DbSession
from aleph.types.message_status import MessageStatus
from aleph.types.sort_order import SortOrder, SortBy
from aleph.web.controllers.app_state_getters import (
get_session_factory_from_request,
get_mq_channel_from_request,
get_config_from_request, get_mq_ws_channel_from_request,
get_config_from_request,
get_mq_ws_channel_from_request,
)
from aleph.web.controllers.utils import (
DEFAULT_MESSAGES_PER_PAGE,
Expand Down Expand Up @@ -161,7 +164,7 @@ class WsMessageQueryParams(BaseMessageQueryParams):
)


def format_message_dict(message: MessageDb) -> Dict[str, Any]:
def message_to_dict(message: MessageDb) -> Dict[str, Any]:
message_dict = message.to_dict()
message_dict["time"] = message.time.timestamp()
confirmations = [
Expand All @@ -188,7 +191,7 @@ def format_response_dict(
def format_response(
messages: Iterable[MessageDb], pagination: int, page: int, total_messages: int
) -> web.Response:
formatted_messages = [format_message_dict(message) for message in messages]
formatted_messages = [message_to_dict(message) for message in messages]

response = format_response_dict(
messages=formatted_messages,
Expand Down Expand Up @@ -237,50 +240,90 @@ async def _send_history_to_ws(
ws: aiohttp.web_ws.WebSocketResponse,
session_factory: DbSessionFactory,
history: int,
message_filters: Dict[str, Any],
query_params: WsMessageQueryParams,
) -> None:

with session_factory() as session:
messages = get_matching_messages(
session=session,
pagination=history,
include_confirmations=True,
**message_filters,
**query_params.dict(exclude_none=True),
)
for message in messages:
await ws.send_str(format_message(message).json())


def message_matches_filters(
message: AlephMessage, query_params: WsMessageQueryParams
) -> bool:
if message_type := query_params.message_type:
if message.type != message_type:
return False

# For simple filters, this reduces the amount of boilerplate
filters_by_message_field = {
"sender": "addresses",
"type": "message_type",
"item_hash": "hashes",
"ref": "refs",
"chain": "chains",
"channel": "channels",
}

for message_field, query_field in filters_by_message_field.items():
if user_filters := getattr(query_params, query_field):
if not isinstance(user_filters, list):
user_filters = [user_filters]
if not getattr(message, message_field) in user_filters:
return False

# Process filters on content and content.content
message_content = message.content
if content_types := query_params.content_types:
content_type = getattr(message_content, "type", None)
if content_type not in content_types:
return False

if content_hashes := query_params.content_hashes:
content_hash = getattr(message_content, "item_hash", None)
if content_hash not in content_hashes:
return False

# For tags, we only need to match one filter
if query_tags := query_params.tags:
nested_content = getattr(message.content, "content")
if not nested_content:
return False

content_tags = set(getattr(nested_content, "tags", []))
if (content_tags & set(query_tags)) == set():
return False

return True


async def _start_mq_consumer(
ws: aiohttp.web_ws.WebSocketResponse,
mq_queue: aio_pika.abc.AbstractQueue,
session_factory: DbSessionFactory,
message_filters: Dict[str, Any],
query_params: WsMessageQueryParams,
) -> aio_pika.abc.ConsumerTag:
"""
Starts the consumer task responsible for forwarding new aleph.im messages from
the processing pipeline to a websocket.

:param ws: Websocket.
:param mq_queue: Message queue object.
:param session_factory: DB session factory.
:param message_filters: Filters to apply to select outgoing messages.
:param query_params: Message filters specified by the caller.
"""

async def _process_message(mq_message: aio_pika.abc.AbstractMessage):
item_hash = aleph_json.loads(mq_message.body)["item_hash"]
# A bastardized way to apply the filters on the message as well.
# TODO: put the filter key/values in the RabbitMQ message?
with session_factory() as session:
matching_messages = get_matching_messages(
session=session,
hashes=[item_hash],
include_confirmations=True,
**message_filters,
)
payload_bytes = mq_message.body
payload_dict = aleph_json.loads(payload_bytes)
message = format_message_dict(payload_dict["message"])

if message_matches_filters(message=message, query_params=query_params):
try:
for message in matching_messages:
await ws.send_str(format_message(message).json())
await ws.send_str(message.json())
except ConnectionResetError:
# We can detect the WS closing in this task in addition to the main one.
# The main task will also detect the close event.
Expand All @@ -307,6 +350,7 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse:
query_params = WsMessageQueryParams.parse_obj(request.query)
except ValidationError as e:
raise web.HTTPUnprocessableEntity(body=e.json(indent=4))

message_filters = query_params.dict(exclude_none=True)
history = query_params.history

Expand All @@ -316,7 +360,7 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse:
ws=ws,
session_factory=session_factory,
history=history,
message_filters=message_filters,
query_params=query_params,
)
except ConnectionResetError:
LOGGER.info("Could not send history, aborting message websocket")
Expand All @@ -332,8 +376,7 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse:
consumer_tag = await _start_mq_consumer(
ws=ws,
mq_queue=mq_queue,
session_factory=session_factory,
message_filters=message_filters,
query_params=query_params,
)
LOGGER.debug(
"Started consuming mq %s for websocket. Consumer tag: %s",
Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_get_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ async def test_get_processed_message_status(
response = await ccn_api_client.get(
MESSAGE_URI.format(processed_message.item_hash)
)
assert response.status == 200
assert response.status == 200, await response.text()
response_json = await response.json()
parsed_response = ProcessedMessageStatus.parse_obj(response_json)
assert parsed_response.status == MessageStatus.PROCESSED
Expand Down