diff --git a/src/aleph/api_entrypoint.py b/src/aleph/api_entrypoint.py index 106e98b83..6a9153e15 100644 --- a/src/aleph/api_entrypoint.py +++ b/src/aleph/api_entrypoint.py @@ -42,7 +42,9 @@ async def configure_aiohttp_app( session_factory = make_session_factory(engine) node_cache = NodeCache( - redis_host=config.redis.host.value, redis_port=config.redis.port.value + redis_host=config.redis.host.value, + redis_port=config.redis.port.value, + message_count_cache_ttl=config.perf.message_count_cache_ttl.value, ) # TODO: find a way to close the node cache when exiting the API process, not closing it causes # a warning. diff --git a/src/aleph/commands.py b/src/aleph/commands.py index a55d6c87a..612d20e25 100644 --- a/src/aleph/commands.py +++ b/src/aleph/commands.py @@ -66,7 +66,9 @@ def run_db_migrations(config: Config): async def init_node_cache(config: Config) -> NodeCache: node_cache = NodeCache( - redis_host=config.redis.host.value, redis_port=config.redis.port.value + redis_host=config.redis.host.value, + redis_port=config.redis.port.value, + message_count_cache_ttl=config.perf.message_count_cache_ttl.value, ) return node_cache diff --git a/src/aleph/config.py b/src/aleph/config.py index 0cac0159d..afafb9f28 100644 --- a/src/aleph/config.py +++ b/src/aleph/config.py @@ -248,6 +248,10 @@ def get_defaults(): # Sentry trace sample rate. "traces_sample_rate": None, }, + "perf": { + # TTL of the cache in front of DB count queries on the messages table. + "message_count_cache_ttl": 300, + }, } diff --git a/src/aleph/jobs/fetch_pending_messages.py b/src/aleph/jobs/fetch_pending_messages.py index ce85ab879..9afabcdd5 100644 --- a/src/aleph/jobs/fetch_pending_messages.py +++ b/src/aleph/jobs/fetch_pending_messages.py @@ -171,7 +171,9 @@ async def fetch_messages_task(config: Config): async with ( NodeCache( - redis_host=config.redis.host.value, redis_port=config.redis.port.value + redis_host=config.redis.host.value, + redis_port=config.redis.port.value, + message_count_cache_ttl=config.perf.message_count_cache_ttl.value, ) as node_cache, IpfsService.new(config) as ipfs_service, ): diff --git a/src/aleph/jobs/process_pending_messages.py b/src/aleph/jobs/process_pending_messages.py index c78a73c02..e7fd15612 100644 --- a/src/aleph/jobs/process_pending_messages.py +++ b/src/aleph/jobs/process_pending_messages.py @@ -159,7 +159,9 @@ async def fetch_and_process_messages_task(config: Config): async with ( NodeCache( - redis_host=config.redis.host.value, redis_port=config.redis.port.value + redis_host=config.redis.host.value, + redis_port=config.redis.port.value, + message_count_cache_ttl=config.perf.message_count_cache_ttl.value, ) as node_cache, IpfsService.new(config) as ipfs_service, ): diff --git a/src/aleph/jobs/process_pending_txs.py b/src/aleph/jobs/process_pending_txs.py index 3c4d3e774..56b7df5ae 100644 --- a/src/aleph/jobs/process_pending_txs.py +++ b/src/aleph/jobs/process_pending_txs.py @@ -133,7 +133,9 @@ async def handle_txs_task(config: Config): async with ( NodeCache( - redis_host=config.redis.host.value, redis_port=config.redis.port.value + redis_host=config.redis.host.value, + redis_port=config.redis.port.value, + message_count_cache_ttl=config.perf.message_count_cache_ttl.value, ) as node_cache, IpfsService.new(config) as ipfs_service, ): diff --git a/src/aleph/schemas/api/accounts.py b/src/aleph/schemas/api/accounts.py index 7fb26a9b4..088642fa8 100644 --- a/src/aleph/schemas/api/accounts.py +++ b/src/aleph/schemas/api/accounts.py @@ -5,9 +5,9 @@ from aleph_message.models import Chain from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, field_validator +from aleph.schemas.messages_query_params import DEFAULT_PAGE, LIST_FIELD_SEPARATOR from aleph.types.files import FileType from aleph.types.sort_order import SortOrder -from aleph.web.controllers.utils import DEFAULT_PAGE, LIST_FIELD_SEPARATOR class GetAccountQueryParams(BaseModel): diff --git a/src/aleph/schemas/messages_query_params.py b/src/aleph/schemas/messages_query_params.py new file mode 100644 index 000000000..4522038f4 --- /dev/null +++ b/src/aleph/schemas/messages_query_params.py @@ -0,0 +1,208 @@ +from typing import List, Optional + +from aleph_message.models import Chain, ItemHash, MessageType +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from aleph.types.message_status import MessageStatus +from aleph.types.sort_order import SortBy, SortOrder + +DEFAULT_WS_HISTORY = 10 +DEFAULT_MESSAGES_PER_PAGE = 20 +DEFAULT_PAGE = 1 +LIST_FIELD_SEPARATOR = "," + + +class BaseMessageQueryParams(BaseModel): + sort_by: SortBy = Field( + default=SortBy.TIME, + alias="sortBy", + description="Key to use to sort the messages. " + "'time' uses the message time field. " + "'tx-time' uses the first on-chain confirmation time.", + ) + sort_order: SortOrder = Field( + default=SortOrder.DESCENDING, + alias="sortOrder", + description="Order in which messages should be listed: " + "-1 means most recent messages first, 1 means older messages first.", + ) + message_type: Optional[MessageType] = Field( + default=None, + alias="msgType", + description="Message type. Deprecated: use msgTypes instead", + ) + message_types: Optional[List[MessageType]] = Field( + default=None, alias="msgTypes", description="Accepted message types." + ) + message_statuses: Optional[List[MessageStatus]] = Field( + default=[MessageStatus.PROCESSED, MessageStatus.REMOVING], + alias="msgStatuses", + description="Accepted values for the 'status' field.", + ) + addresses: Optional[List[str]] = Field( + default=None, description="Accepted values for the 'sender' field." + ) + refs: Optional[List[str]] = Field( + default=None, description="Accepted values for the 'content.ref' field." + ) + content_hashes: Optional[List[ItemHash]] = Field( + default=None, + alias="contentHashes", + description="Accepted values for the 'content.item_hash' field.", + ) + content_keys: Optional[List[ItemHash]] = Field( + default=None, + alias="contentKeys", + description="Accepted values for the 'content.keys' field.", + ) + content_types: Optional[List[str]] = Field( + default=None, + alias="contentTypes", + description="Accepted values for the 'content.type' field.", + ) + chains: Optional[List[Chain]] = Field( + default=None, description="Accepted values for the 'chain' field." + ) + channels: Optional[List[str]] = Field( + default=None, description="Accepted values for the 'channel' field." + ) + tags: Optional[List[str]] = Field( + default=None, description="Accepted values for the 'content.content.tag' field." + ) + hashes: Optional[List[ItemHash]] = Field( + default=None, description="Accepted values for the 'item_hash' field." + ) + + start_date: float = Field( + default=0, + ge=0, + alias="startDate", + description="Start date timestamp. If specified, only messages with " + "a time field greater or equal to this value will be returned.", + ) + end_date: float = Field( + default=0, + ge=0, + alias="endDate", + description="End date timestamp. If specified, only messages with " + "a time field lower than this value will be returned.", + ) + + start_block: int = Field( + default=0, + ge=0, + alias="startBlock", + description="Start block number. If specified, only messages with " + "a block number greater or equal to this value will be returned.", + ) + end_block: int = Field( + default=0, + ge=0, + alias="endBlock", + description="End block number. If specified, only messages with " + "a block number lower than this value will be returned.", + ) + + @model_validator(mode="after") + def validate_field_dependencies(self): + start_date = self.start_date + end_date = self.end_date + if start_date and end_date and (end_date < start_date): + raise ValueError("end date cannot be lower than start date.") + start_block = self.start_block + end_block = self.end_block + if start_block and end_block and (end_block < start_block): + raise ValueError("end block cannot be lower than start block.") + + return self + + @field_validator( + "hashes", + "addresses", + "refs", + "content_hashes", + "content_keys", + "content_types", + "chains", + "channels", + "message_types", + "message_statuses", + "tags", + mode="before", + ) + def split_str(cls, v): + if isinstance(v, str): + return v.split(LIST_FIELD_SEPARATOR) + return v + + model_config = ConfigDict(populate_by_name=True) + + +class MessageQueryParams(BaseMessageQueryParams): + pagination: int = Field( + default=DEFAULT_MESSAGES_PER_PAGE, + ge=0, + description="Maximum number of messages to return. Specifying 0 removes this limit.", + ) + page: int = Field( + default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1." + ) + + +class WsMessageQueryParams(BaseMessageQueryParams): + history: Optional[int] = Field( + DEFAULT_WS_HISTORY, + ge=0, + lt=200, + description="Historical elements to send through the websocket.", + ) + + +class MessageHashesQueryParams(BaseModel): + status: Optional[MessageStatus] = Field( + default=None, + description="Message status.", + ) + page: int = Field( + default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1." + ) + pagination: int = Field( + default=DEFAULT_MESSAGES_PER_PAGE, + ge=0, + description="Maximum number of messages to return. Specifying 0 removes this limit.", + ) + start_date: float = Field( + default=0, + ge=0, + alias="startDate", + description="Start date timestamp. If specified, only messages with " + "a time field greater or equal to this value will be returned.", + ) + end_date: float = Field( + default=0, + ge=0, + alias="endDate", + description="End date timestamp. If specified, only messages with " + "a time field lower than this value will be returned.", + ) + sort_order: SortOrder = Field( + default=SortOrder.DESCENDING, + alias="sortOrder", + description="Order in which messages should be listed: " + "-1 means most recent messages first, 1 means older messages first.", + ) + hash_only: bool = Field( + default=True, + description="By default, only hashes are returned. " + "Set this to false to include metadata alongside the hashes in the response.", + ) + + @model_validator(mode="after") + def validate_field_dependencies(self): + start_date = self.start_date + end_date = self.end_date + if start_date and end_date and (end_date < start_date): + raise ValueError("end date cannot be lower than start date.") + return self + + model_config = ConfigDict(populate_by_name=True) diff --git a/src/aleph/services/cache/node_cache.py b/src/aleph/services/cache/node_cache.py index dbe2637b3..c403b149d 100644 --- a/src/aleph/services/cache/node_cache.py +++ b/src/aleph/services/cache/node_cache.py @@ -1,7 +1,13 @@ -from typing import Any, List, Optional, Set +from hashlib import sha256 +from typing import Any, Dict, List, Optional, Set import redis.asyncio as redis_asyncio +import aleph.toolkit.json as aleph_json +from aleph.db.accessors.messages import count_matching_messages +from aleph.schemas.messages_query_params import MessageQueryParams +from aleph.types.db_session import DbSession + CacheKey = Any CacheValue = bytes @@ -10,9 +16,10 @@ class NodeCache: API_SERVERS_KEY = "api_servers" PUBLIC_ADDRESSES_KEY = "public_addresses" - def __init__(self, redis_host: str, redis_port: int): + def __init__(self, redis_host: str, redis_port: int, message_count_cache_ttl): self.redis_host = redis_host self.redis_port = redis_port + self.message_cache_count_ttl = message_count_cache_ttl self._redis_client: Optional[redis_asyncio.Redis] = None @@ -52,8 +59,8 @@ async def reset(self): async def get(self, key: CacheKey) -> Optional[CacheValue]: return await self.redis_client.get(key) - async def set(self, key: CacheKey, value: Any): - await self.redis_client.set(key, value) + async def set(self, key: CacheKey, value: Any, expiration: Optional[int] = None): + await self.redis_client.set(key, value, ex=expiration) async def incr(self, key: CacheKey): await self.redis_client.incr(key) @@ -82,3 +89,25 @@ async def add_public_address(self, public_address: str) -> None: async def get_public_addresses(self) -> List[str]: addresses = await self.redis_client.smembers(self.PUBLIC_ADDRESSES_KEY) return [addr.decode() for addr in addresses] + + @staticmethod + def _message_filter_id(filters: Dict[str, Any]): + filters_json = aleph_json.dumps(filters, sort_keys=True) + return sha256(filters_json).hexdigest() + + async def count_messages( + self, session: DbSession, query_params: MessageQueryParams + ) -> int: + filters = query_params.model_dump(exclude_none=True) + cache_key = f"message_count:{self._message_filter_id(filters)}" + + cached_result = await self.get(cache_key) + if cached_result is not None: + return int(cached_result.decode()) + + # Slow, can take a few seconds + n_matches = count_matching_messages(session, **filters) + + await self.set(cache_key, n_matches, expiration=self.message_cache_count_ttl) + + return n_matches diff --git a/src/aleph/toolkit/json.py b/src/aleph/toolkit/json.py index eeb89bf15..c16484b05 100644 --- a/src/aleph/toolkit/json.py +++ b/src/aleph/toolkit/json.py @@ -18,7 +18,6 @@ # serializer changes easier. SerializedJsonInput = Union[bytes, str] - # Note: JSONDecodeError is a subclass of ValueError, but the JSON module sometimes throws # raw value errors, including on NaN because of our custom parse_constant. DecodeError = orjson.JSONDecodeError @@ -55,8 +54,11 @@ def extended_json_encoder(obj: Any) -> Any: raise TypeError(f"Object of type {type(obj)} is not JSON serializable") -def dumps(obj: Any) -> bytes: +def dumps(obj: Any, sort_keys: bool = True) -> bytes: try: - return orjson.dumps(obj) + opts = orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS if sort_keys else 0 + return orjson.dumps(obj, option=opts) except TypeError: - return json.dumps(obj, default=extended_json_encoder).encode() + return json.dumps( + obj, default=extended_json_encoder, sort_keys=sort_keys + ).encode() diff --git a/src/aleph/web/controllers/aggregates.py b/src/aleph/web/controllers/aggregates.py index c9f61c3be..dd411ebf1 100644 --- a/src/aleph/web/controllers/aggregates.py +++ b/src/aleph/web/controllers/aggregates.py @@ -8,8 +8,7 @@ from aleph.db.accessors.aggregates import get_aggregates_by_owner, refresh_aggregate from aleph.db.models import AggregateDb - -from .utils import LIST_FIELD_SEPARATOR +from aleph.schemas.messages_query_params import LIST_FIELD_SEPARATOR LOGGER = logging.getLogger(__name__) diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index fa7a60cc4..2b5742de7 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -1,30 +1,23 @@ import json import logging -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List import aio_pika.abc import aiohttp.web_ws from aiohttp import WSMsgType, web -from aleph_message.models import Chain, ItemHash, MessageType -from pydantic import ( - BaseModel, - ConfigDict, - Field, - ValidationError, - field_validator, - model_validator, -) +from aleph_message.models import ItemHash, MessageType +from pydantic import ValidationError import aleph.toolkit.json as aleph_json from aleph.db.accessors.messages import ( count_matching_hashes, - count_matching_messages, get_forgotten_message, get_matching_hashes, get_matching_messages, get_message_by_item_hash, get_message_status, get_rejected_message, + make_matching_messages_query, ) from aleph.db.accessors.pending_messages import get_pending_messages from aleph.db.models import MessageDb, MessageStatusDb @@ -45,224 +38,25 @@ format_message, format_message_dict, ) +from aleph.schemas.messages_query_params import ( + MessageHashesQueryParams, + MessageQueryParams, + WsMessageQueryParams, +) from aleph.toolkit.shield import shielded from aleph.types.db_session import DbSession, DbSessionFactory from aleph.types.message_status import MessageStatus, RemovedMessageReason -from aleph.types.sort_order import SortBy, SortOrder from aleph.web.controllers.app_state_getters import ( get_config_from_request, get_mq_ws_channel_from_request, + get_node_cache_from_request, get_session_factory_from_request, ) -from aleph.web.controllers.utils import ( - DEFAULT_MESSAGES_PER_PAGE, - DEFAULT_PAGE, - LIST_FIELD_SEPARATOR, - mq_make_aleph_message_topic_queue, -) +from aleph.web.controllers.utils import mq_make_aleph_message_topic_queue LOGGER = logging.getLogger(__name__) -DEFAULT_WS_HISTORY = 10 - - -class BaseMessageQueryParams(BaseModel): - sort_by: SortBy = Field( - default=SortBy.TIME, - alias="sortBy", - description="Key to use to sort the messages. " - "'time' uses the message time field. " - "'tx-time' uses the first on-chain confirmation time.", - ) - sort_order: SortOrder = Field( - default=SortOrder.DESCENDING, - alias="sortOrder", - description="Order in which messages should be listed: " - "-1 means most recent messages first, 1 means older messages first.", - ) - message_type: Optional[MessageType] = Field( - default=None, - alias="msgType", - description="Message type. Deprecated: use msgTypes instead", - ) - message_types: Optional[List[MessageType]] = Field( - default=None, alias="msgTypes", description="Accepted message types." - ) - message_statuses: Optional[List[MessageStatus]] = Field( - default=[MessageStatus.PROCESSED, MessageStatus.REMOVING], - alias="msgStatuses", - description="Accepted values for the 'status' field.", - ) - addresses: Optional[List[str]] = Field( - default=None, description="Accepted values for the 'sender' field." - ) - refs: Optional[List[str]] = Field( - default=None, description="Accepted values for the 'content.ref' field." - ) - content_hashes: Optional[List[ItemHash]] = Field( - default=None, - alias="contentHashes", - description="Accepted values for the 'content.item_hash' field.", - ) - content_keys: Optional[List[ItemHash]] = Field( - default=None, - alias="contentKeys", - description="Accepted values for the 'content.keys' field.", - ) - content_types: Optional[List[str]] = Field( - default=None, - alias="contentTypes", - description="Accepted values for the 'content.type' field.", - ) - chains: Optional[List[Chain]] = Field( - default=None, description="Accepted values for the 'chain' field." - ) - channels: Optional[List[str]] = Field( - default=None, description="Accepted values for the 'channel' field." - ) - tags: Optional[List[str]] = Field( - default=None, description="Accepted values for the 'content.content.tag' field." - ) - hashes: Optional[List[ItemHash]] = Field( - default=None, description="Accepted values for the 'item_hash' field." - ) - - start_date: float = Field( - default=0, - ge=0, - alias="startDate", - description="Start date timestamp. If specified, only messages with " - "a time field greater or equal to this value will be returned.", - ) - end_date: float = Field( - default=0, - ge=0, - alias="endDate", - description="End date timestamp. If specified, only messages with " - "a time field lower than this value will be returned.", - ) - - start_block: int = Field( - default=0, - ge=0, - alias="startBlock", - description="Start block number. If specified, only messages with " - "a block number greater or equal to this value will be returned.", - ) - end_block: int = Field( - default=0, - ge=0, - alias="endBlock", - description="End block number. If specified, only messages with " - "a block number lower than this value will be returned.", - ) - - @model_validator(mode="after") - def validate_field_dependencies(self): - start_date = self.start_date - end_date = self.end_date - if start_date and end_date and (end_date < start_date): - raise ValueError("end date cannot be lower than start date.") - start_block = self.start_block - end_block = self.end_block - if start_block and end_block and (end_block < start_block): - raise ValueError("end block cannot be lower than start block.") - - return self - - @field_validator( - "hashes", - "addresses", - "refs", - "content_hashes", - "content_keys", - "content_types", - "chains", - "channels", - "message_types", - "message_statuses", - "tags", - mode="before", - ) - def split_str(cls, v): - if isinstance(v, str): - return v.split(LIST_FIELD_SEPARATOR) - return v - - model_config = ConfigDict(populate_by_name=True) - - -class MessageQueryParams(BaseMessageQueryParams): - pagination: int = Field( - default=DEFAULT_MESSAGES_PER_PAGE, - ge=0, - description="Maximum number of messages to return. Specifying 0 removes this limit.", - ) - page: int = Field( - default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1." - ) - - -class WsMessageQueryParams(BaseMessageQueryParams): - history: Optional[int] = Field( - DEFAULT_WS_HISTORY, - ge=0, - lt=200, - description="Historical elements to send through the websocket.", - ) - - -class MessageHashesQueryParams(BaseModel): - status: Optional[MessageStatus] = Field( - default=None, - description="Message status.", - ) - page: int = Field( - default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1." - ) - pagination: int = Field( - default=DEFAULT_MESSAGES_PER_PAGE, - ge=0, - description="Maximum number of messages to return. Specifying 0 removes this limit.", - ) - start_date: float = Field( - default=0, - ge=0, - alias="startDate", - description="Start date timestamp. If specified, only messages with " - "a time field greater or equal to this value will be returned.", - ) - end_date: float = Field( - default=0, - ge=0, - alias="endDate", - description="End date timestamp. If specified, only messages with " - "a time field lower than this value will be returned.", - ) - sort_order: SortOrder = Field( - default=SortOrder.DESCENDING, - alias="sortOrder", - description="Order in which messages should be listed: " - "-1 means most recent messages first, 1 means older messages first.", - ) - hash_only: bool = Field( - default=True, - description="By default, only hashes are returned. " - "Set this to false to include metadata alongside the hashes in the response.", - ) - - @model_validator(mode="after") - def validate_field_dependencies(self): - start_date = self.start_date - end_date = self.end_date - if start_date and end_date and (end_date < start_date): - raise ValueError("end date cannot be lower than start date.") - return self - - model_config = ConfigDict(populate_by_name=True) - - def message_to_dict(message: MessageDb) -> Dict[str, Any]: message_dict = message.to_dict(exclude={"content_type"}) message_dict["time"] = message.time.timestamp() @@ -325,11 +119,14 @@ async def view_messages_list(request: web.Request) -> web.Response: pagination_per_page = query_params.pagination session_factory = get_session_factory_from_request(request) + node_cache = get_node_cache_from_request(request) + with session_factory() as session: - messages = get_matching_messages( - session, include_confirmations=True, **find_filters + messages_query = make_matching_messages_query( + include_confirmations=True, **find_filters ) - total_msgs = count_matching_messages(session, **find_filters) + messages = (session.execute(messages_query)).scalars() + total_msgs = await node_cache.count_messages(session, query_params) return format_response( messages, diff --git a/src/aleph/web/controllers/posts.py b/src/aleph/web/controllers/posts.py index b2e4ab7c8..f186d5777 100644 --- a/src/aleph/web/controllers/posts.py +++ b/src/aleph/web/controllers/posts.py @@ -20,16 +20,14 @@ get_matching_posts_legacy, ) from aleph.db.models import ChainTxDb, message_confirmations -from aleph.types.db_session import DbSession, DbSessionFactory -from aleph.types.sort_order import SortBy, SortOrder -from aleph.web.controllers.utils import ( +from aleph.schemas.messages_query_params import ( DEFAULT_MESSAGES_PER_PAGE, DEFAULT_PAGE, LIST_FIELD_SEPARATOR, - Pagination, - cond_output, - get_path_page, ) +from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.sort_order import SortBy, SortOrder +from aleph.web.controllers.utils import Pagination, cond_output, get_path_page class PostQueryParams(BaseModel): diff --git a/src/aleph/web/controllers/utils.py b/src/aleph/web/controllers/utils.py index 7e0a4ac7b..b1ad90e21 100644 --- a/src/aleph/web/controllers/utils.py +++ b/src/aleph/web/controllers/utils.py @@ -17,6 +17,7 @@ import aleph.toolkit.json as aleph_json from aleph.db.accessors.files import insert_grace_period_file_pin +from aleph.schemas.messages_query_params import DEFAULT_MESSAGES_PER_PAGE from aleph.schemas.pending_messages import BasePendingMessage, parse_message from aleph.services.ipfs import IpfsService from aleph.services.p2p.pubsub import publish as pub_p2p @@ -36,10 +37,6 @@ get_p2p_client_from_request, ) -DEFAULT_MESSAGES_PER_PAGE = 20 -DEFAULT_PAGE = 1 -LIST_FIELD_SEPARATOR = "," - @overload def file_field_to_io(multi_dict: bytes) -> BytesIO: ... diff --git a/tests/conftest.py b/tests/conftest.py index b50351ac0..91a842bac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,6 +55,7 @@ from aleph.web import create_aiohttp_app from aleph.web.controllers.app_state_getters import ( APP_STATE_CONFIG, + APP_STATE_NODE_CACHE, APP_STATE_P2P_CLIENT, APP_STATE_SESSION_FACTORY, APP_STATE_STORAGE_SERVICE, @@ -133,7 +134,9 @@ def mock_config() -> Config: @pytest_asyncio.fixture async def node_cache(mock_config: Config): async with NodeCache( - redis_host=mock_config.redis.host.value, redis_port=mock_config.redis.port.value + redis_host=mock_config.redis.host.value, + redis_port=mock_config.redis.port.value, + message_count_cache_ttl=mock_config.perf.message_count_cache_ttl.value, ) as node_cache: yield node_cache @@ -159,13 +162,14 @@ async def test_storage_service(mock_config: Config, mocker) -> StorageService: @pytest.fixture -def ccn_test_aiohttp_app(mocker, mock_config, session_factory): +def ccn_test_aiohttp_app(mocker, mock_config, session_factory, node_cache: NodeCache): # Make aiohttp return the stack trace on 500 errors event_loop = asyncio.get_event_loop() event_loop.set_debug(True) app = create_aiohttp_app() app[APP_STATE_CONFIG] = mock_config + app[APP_STATE_NODE_CACHE] = node_cache app[APP_STATE_P2P_CLIENT] = mocker.AsyncMock() app[APP_STATE_STORAGE_SERVICE] = mocker.AsyncMock() app[APP_STATE_SESSION_FACTORY] = session_factory