From efd1b76fff8c9ec2c3a45e3a63ac3d724498dfbf Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Mon, 11 Apr 2022 18:12:05 +0200 Subject: [PATCH 1/3] [Jobs] Simplify pending TXs job Simplified the pending TXs code. The task to handle a single TX is now stateless and returns a list of DB operations. The processing of DB operations is now the same for pending TXs and messages. Added a new `async_batch` method to group TXs by batch. --- src/aleph/chains/common.py | 2 +- src/aleph/jobs/job_utils.py | 48 +++++++- src/aleph/jobs/process_pending_messages.py | 34 ++---- src/aleph/jobs/process_pending_txs.py | 114 +++++++++--------- src/aleph/toolkit/__init__.py | 0 src/aleph/toolkit/batch.py | 18 +++ tests/message_processing/__init__.py | 0 tests/message_processing/conftest.py | 8 ++ tests/message_processing/load_fixtures.py | 11 ++ .../test_perform_db_operations.py | 102 ++++++++++++++++ .../test_process_pending_txs.py | 56 ++++----- tests/utils/test_batch.py | 26 ++++ 12 files changed, 305 insertions(+), 114 deletions(-) create mode 100644 src/aleph/toolkit/__init__.py create mode 100644 src/aleph/toolkit/batch.py create mode 100644 tests/message_processing/__init__.py create mode 100644 tests/message_processing/conftest.py create mode 100644 tests/message_processing/load_fixtures.py create mode 100644 tests/message_processing/test_perform_db_operations.py create mode 100644 tests/utils/test_batch.py diff --git a/src/aleph/chains/common.py b/src/aleph/chains/common.py index 1a5e54869..de8a50b0f 100644 --- a/src/aleph/chains/common.py +++ b/src/aleph/chains/common.py @@ -395,7 +395,7 @@ async def get_chaindata_messages( if config.ipfs.enabled.value: # wait for 4 seconds to try to pin that try: - LOGGER.info(f"chaindatax {chaindata}") + LOGGER.info(f"chaindata {chaindata}") await PermanentPin.register( multihash=chaindata["content"], reason={ diff --git a/src/aleph/jobs/job_utils.py b/src/aleph/jobs/job_utils.py index 010d79c8b..faa520bdb 100644 --- a/src/aleph/jobs/job_utils.py +++ b/src/aleph/jobs/job_utils.py @@ -1,12 +1,15 @@ import asyncio from typing import Dict -from typing import Tuple +from typing import Iterable, Tuple import aleph.config from aleph.model import init_db_globals from aleph.services.ipfs.common import init_ipfs_globals from aleph.services.p2p import init_p2p_client from configmanager import Config +from typing import Awaitable, Callable, List +from aleph.model.db_bulk_operation import DbBulkOperation +from itertools import groupby def prepare_loop(config_values: Dict) -> Tuple[asyncio.AbstractEventLoop, Config]: @@ -28,3 +31,46 @@ def prepare_loop(config_values: Dict) -> Tuple[asyncio.AbstractEventLoop, Config init_ipfs_globals(config) _ = init_p2p_client(config) return loop, config + + +async def perform_db_operations(db_operations: Iterable[DbBulkOperation]) -> None: + # Sort the operations by collection name before grouping and executing them. + sorted_operations = sorted( + db_operations, + key=lambda op: op.collection.__name__, + ) + + for collection, operations in groupby(sorted_operations, lambda op: op.collection): + mongo_ops = [op.operation for op in operations] + await collection.collection.bulk_write(mongo_ops) + + +async def gather_and_perform_db_operations( + tasks: List[Awaitable[List[DbBulkOperation]]], + on_error: Callable[[BaseException], None], +) -> None: + """ + Processes the result of the pending TX/message tasks. + + Gathers the results of the tasks passed in input, handles exceptions + and performs DB operations. + + :param tasks: Job tasks. Each of these tasks must return a list of + DbBulkOperation objects. + :param on_error: Error callback function. This function will be called + on each error from one of the tasks. + """ + task_results = await asyncio.gather(*tasks, return_exceptions=True) + + errors = [op for op in task_results if isinstance(op, BaseException)] + for error in errors: + on_error(error) + + db_operations = ( + op + for operations in task_results + if not isinstance(operations, BaseException) + for op in operations + ) + + await perform_db_operations(db_operations) diff --git a/src/aleph/jobs/process_pending_messages.py b/src/aleph/jobs/process_pending_messages.py index 9028431a3..cc7881c59 100644 --- a/src/aleph/jobs/process_pending_messages.py +++ b/src/aleph/jobs/process_pending_messages.py @@ -3,22 +3,21 @@ """ import asyncio -from itertools import groupby from logging import getLogger from typing import List, Dict, Tuple import sentry_sdk +from aleph_message.models import MessageType +from pymongo import DeleteOne, DeleteMany, ASCENDING +from setproctitle import setproctitle + from aleph.chains.common import incoming, IncomingStatus from aleph.logging import setup_logging from aleph.model.db_bulk_operation import DbBulkOperation from aleph.model.pending import PendingMessage from aleph.services.p2p import singleton from aleph.types import ItemType -from pymongo import DeleteOne, DeleteMany, ASCENDING -from setproctitle import setproctitle -from aleph_message.models import MessageType - -from .job_utils import prepare_loop +from .job_utils import prepare_loop, gather_and_perform_db_operations LOGGER = getLogger("jobs.pending_messages") @@ -47,27 +46,10 @@ async def handle_pending_message( async def join_pending_message_tasks(tasks): - db_operations = await asyncio.gather(*tasks, return_exceptions=True) - - errors = [op for op in db_operations if isinstance(op, BaseException)] - for error in errors: - LOGGER.error("Error while processing message: %s", error) - - # Sort the operations by collection name before grouping and executing them. - db_operations = sorted( - ( - op - for operations in db_operations - if not isinstance(operations, BaseException) - for op in operations - ), - key=lambda op: op.collection.__name__, + await gather_and_perform_db_operations( + tasks, + on_error=lambda e: LOGGER.error("Error while processing message: %s", e), ) - - for collection, operations in groupby(db_operations, lambda op: op.collection): - mongo_ops = [op.operation for op in operations] - await collection.collection.bulk_write(mongo_ops) - tasks.clear() diff --git a/src/aleph/jobs/process_pending_txs.py b/src/aleph/jobs/process_pending_txs.py index fa4a15b02..aa14f137d 100644 --- a/src/aleph/jobs/process_pending_txs.py +++ b/src/aleph/jobs/process_pending_txs.py @@ -18,22 +18,25 @@ from aleph.model.pending import PendingMessage, PendingTX from aleph.network import check_message from aleph.services.p2p import singleton -from .job_utils import prepare_loop +from .job_utils import prepare_loop, gather_and_perform_db_operations +from aleph.model.db_bulk_operation import DbBulkOperation +from aleph.toolkit.batch import async_batch LOGGER = logging.getLogger("jobs.pending_txs") async def handle_pending_tx( - pending, actions_list: List, seen_ids: Optional[List] = None -): - tx_context = TxContext(**pending["context"]) + pending_tx, seen_ids: Optional[List] = None +) -> List[DbBulkOperation]: + + db_operations = [] + tx_context = TxContext(**pending_tx["context"]) LOGGER.info("%s Handling TX in block %s", tx_context.chain_name, tx_context.height) messages = await get_chaindata_messages( - pending["content"], tx_context, seen_ids=seen_ids + pending_tx["content"], tx_context, seen_ids=seen_ids ) if messages: - message_actions = list() for i, message in enumerate(messages): message["time"] = tx_context.time + (i / 1000) # force order @@ -46,76 +49,71 @@ async def handle_pending_tx( continue # we add it to the message queue... bad idea? should we process it asap? - message_actions.append( - InsertOne( - { - "message": message, - "source": dict( - chain_name=tx_context.chain_name, - tx_hash=tx_context.tx_hash, - height=tx_context.height, - check_message=True, # should we store this? - ), - } + db_operations.append( + DbBulkOperation( + collection=PendingMessage, + operation=InsertOne( + { + "message": message, + "source": dict( + chain_name=tx_context.chain_name, + tx_hash=tx_context.tx_hash, + height=tx_context.height, + check_message=True, # should we store this? + ), + } + ), ) ) await asyncio.sleep(0) - if message_actions: - await PendingMessage.collection.bulk_write(message_actions) else: LOGGER.debug("TX contains no message") if messages is not None: # bogus or handled, we remove it. - actions_list.append(DeleteOne({"_id": pending["_id"]})) + db_operations.append( + DbBulkOperation( + collection=PendingTX, operation=DeleteOne({"_id": pending_tx["_id"]}) + ) + ) + return db_operations -async def join_pending_txs_tasks(tasks, actions_list): - results = await asyncio.gather(*tasks, return_exceptions=True) - for result in results: - if isinstance(result, BaseException): - LOGGER.exception( - "error in incoming txs task", - exc_info=(type(result), result, result.__traceback__), - ) +async def join_pending_txs_tasks(tasks): + await gather_and_perform_db_operations( + tasks, + on_error=lambda e: LOGGER.exception( + "error in incoming txs task", + exc_info=(type(e), e, e.__traceback__), + ), + ) - tasks.clear() - if len(actions_list): - await PendingTX.collection.bulk_write(actions_list) - actions_list.clear() +async def process_pending_txs(): + """ + Process chain transactions in the Pending TX queue. + """ + batch_size = 200 -async def process_pending_txs(): - """Each few minutes, try to handle message that were added to the - pending queue (Unavailable messages).""" - if not await PendingTX.collection.count_documents({}): - await asyncio.sleep(5) - return - - actions = [] - tasks = [] - seen_offchain_hashes = [] + seen_offchain_hashes = set() seen_ids = [] - i = 0 LOGGER.info("handling TXs") - async for pending in PendingTX.collection.find().sort([("context.time", 1)]): - if pending["content"]["protocol"] == "aleph-offchain": - if pending["content"]["content"] not in seen_offchain_hashes: - seen_offchain_hashes.append(pending["content"]["content"]) - else: - continue - - i += 1 - tasks.append(handle_pending_tx(pending, actions, seen_ids=seen_ids)) - - if i > 200: - await join_pending_txs_tasks(tasks, actions) - i = 0 - - await join_pending_txs_tasks(tasks, actions) + async for pending_tx_batch in async_batch( + PendingTX.collection.find().sort([("context.time", 1)]), batch_size + ): + tasks = [] + for pending_tx in pending_tx_batch: + if pending_tx["content"]["protocol"] == "aleph-offchain": + if pending_tx["content"]["content"] in seen_offchain_hashes: + continue + + seen_offchain_hashes.add(pending_tx["content"]["content"]) + tasks.append(handle_pending_tx(pending_tx, seen_ids=seen_ids)) + + await join_pending_txs_tasks(tasks) async def handle_txs_task(): diff --git a/src/aleph/toolkit/__init__.py b/src/aleph/toolkit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/aleph/toolkit/batch.py b/src/aleph/toolkit/batch.py new file mode 100644 index 000000000..db912ce98 --- /dev/null +++ b/src/aleph/toolkit/batch.py @@ -0,0 +1,18 @@ +from typing import AsyncIterator, List, TypeVar + +T = TypeVar("T") + + +async def async_batch( + async_iterable: AsyncIterator["T"], n: int +) -> AsyncIterator[List[T]]: + batch = [] + async for item in async_iterable: + batch.append(item) + if len(batch) == n: + yield batch + batch = [] + + # Yield the last batch + if batch: + yield batch diff --git a/tests/message_processing/__init__.py b/tests/message_processing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/message_processing/conftest.py b/tests/message_processing/conftest.py new file mode 100644 index 000000000..1b6b8e9fc --- /dev/null +++ b/tests/message_processing/conftest.py @@ -0,0 +1,8 @@ +import pytest + +from .load_fixtures import load_fixture_messages + + +@pytest.fixture +def fixture_messages(): + return load_fixture_messages("test-data-pending-tx-messages.json") diff --git a/tests/message_processing/load_fixtures.py b/tests/message_processing/load_fixtures.py new file mode 100644 index 000000000..08f58a1a8 --- /dev/null +++ b/tests/message_processing/load_fixtures.py @@ -0,0 +1,11 @@ +import json +import os +from typing import Dict, List +from pathlib import Path + + +def load_fixture_messages(fixture: str) -> List[Dict]: + fixture_path = Path(__file__).parent / "fixtures" / fixture + + with open(fixture_path) as f: + return json.load(f)["content"]["messages"] diff --git a/tests/message_processing/test_perform_db_operations.py b/tests/message_processing/test_perform_db_operations.py new file mode 100644 index 000000000..413e4d211 --- /dev/null +++ b/tests/message_processing/test_perform_db_operations.py @@ -0,0 +1,102 @@ +import pytest +from pymongo import DeleteOne +from pymongo import InsertOne + +from aleph.jobs.job_utils import perform_db_operations +from aleph.model.db_bulk_operation import DbBulkOperation +from aleph.model.pending import PendingMessage, PendingTX + +PENDING_TX = { + "content": { + "protocol": "aleph-offchain", + "version": 1, + "content": "test-data-pending-tx-messages", + }, + "context": { + "chain_name": "ETH", + "tx_hash": "0xf49cb176c1ce4f6eb7b9721303994b05074f8fadc37b5f41ac6f78bdf4b14b6c", + "time": 1632835747, + "height": 13314512, + "publisher": "0x23eC28598DCeB2f7082Cc3a9D670592DfEd6e0dC", + }, +} + + +@pytest.mark.asyncio +async def test_db_operations_insert_one(test_db): + start_count = await PendingTX.count({}) + + db_operations = [ + DbBulkOperation(collection=PendingTX, operation=InsertOne(PENDING_TX)) + ] + await perform_db_operations(db_operations) + + end_count = await PendingTX.count({}) + stored_pending_tx = await PendingTX.collection.find_one( + filter={"context.tx_hash": PENDING_TX["context"]["tx_hash"]} + ) + + assert stored_pending_tx["content"] == PENDING_TX["content"] + assert stored_pending_tx["context"] == PENDING_TX["context"] + assert end_count - start_count == 1 + + +@pytest.mark.asyncio +async def test_db_operations_delete_one(test_db): + await PendingTX.collection.insert_one(PENDING_TX) + start_count = await PendingTX.count({}) + + db_operations = [ + DbBulkOperation( + collection=PendingTX, + operation=DeleteOne( + filter={"context.tx_hash": PENDING_TX["context"]["tx_hash"]} + ), + ) + ] + await perform_db_operations(db_operations) + + end_count = await PendingTX.count({}) + assert end_count - start_count == -1 + + +@pytest.mark.asyncio +async def test_db_operations_insert_and_delete(test_db, fixture_messages): + """ + Test a typical case where we insert several messages and delete a pending TX. + """ + + await PendingTX.collection.insert_one(PENDING_TX) + tx_start_count = await PendingTX.count({}) + msg_start_count = await PendingMessage.count({}) + + db_operations = [ + DbBulkOperation(collection=PendingMessage, operation=InsertOne(msg)) + for msg in fixture_messages + ] + + db_operations.append( + DbBulkOperation( + collection=PendingTX, + operation=DeleteOne( + filter={"context.tx_hash": PENDING_TX["context"]["tx_hash"]} + ), + ) + ) + + await perform_db_operations(db_operations) + + tx_end_count = await PendingTX.count({}) + msg_end_count = await PendingMessage.count({}) + assert tx_end_count - tx_start_count == -1 + assert msg_end_count - msg_start_count == len(fixture_messages) + + # Check each message + fixture_messages_by_hash = {msg["item_hash"]: msg for msg in fixture_messages} + + async for pending_msg in PendingMessage.collection.find( + {"message.item_hash": {"$in": [msg["item_hash"] for msg in fixture_messages]}} + ): + pending_message = pending_msg["message"] + expected_message = fixture_messages_by_hash[pending_message["item_hash"]] + assert set(expected_message.items()).issubset(set(pending_message.items())) diff --git a/tests/message_processing/test_process_pending_txs.py b/tests/message_processing/test_process_pending_txs.py index 70acf8fb0..d3a7f7167 100644 --- a/tests/message_processing/test_process_pending_txs.py +++ b/tests/message_processing/test_process_pending_txs.py @@ -1,23 +1,18 @@ -import json -from pathlib import Path +from collections import defaultdict +from collections import defaultdict from typing import Dict, List import pytest from bson.objectid import ObjectId -from pymongo import DeleteOne +from pymongo import DeleteOne, InsertOne from aleph.jobs.process_pending_txs import handle_pending_tx -from aleph.model.pending import PendingMessage - - -def load_fixture_messages(fixture: str) -> List[Dict]: - fixture_path = Path(__file__).parent / "fixtures" / fixture - with open(fixture_path) as f: - return json.load(f)["content"]["messages"] +from aleph.model.pending import PendingMessage, PendingTX +from .load_fixtures import load_fixture_messages # TODO: try to replace this fixture by a get_json fixture. Currently, the pinning -# of the message content gets in the way in the real get_chaindata_messages function. +# of the message content gets in the way in the real get_chaindata_messages function. async def get_fixture_chaindata_messages( pending_tx_content, pending_tx_context, seen_ids: List[str] ) -> List[Dict]: @@ -47,27 +42,32 @@ async def test_process_pending_tx(mocker, test_db): }, } - actions_list = [] seen_ids = [] - await handle_pending_tx( - pending=pending_tx, actions_list=actions_list, seen_ids=seen_ids - ) + db_operations = await handle_pending_tx(pending_tx=pending_tx, seen_ids=seen_ids) - assert len(actions_list) == 1 - action = actions_list[0] - assert isinstance(action, DeleteOne) - assert action._filter == {"_id": pending_tx["_id"]} + db_operations_by_collection = defaultdict(list) + for op in db_operations: + db_operations_by_collection[op.collection].append(op) - fixture_messages = load_fixture_messages(f"{pending_tx['content']['content']}.json") - pending_messages = [m async for m in PendingMessage.collection.find()] + assert set(db_operations_by_collection.keys()) == {PendingMessage, PendingTX} + + pending_tx_ops = db_operations_by_collection[PendingTX] + assert len(pending_tx_ops) == 1 + assert isinstance(pending_tx_ops[0].operation, DeleteOne) + assert pending_tx_ops[0].operation._filter == {"_id": pending_tx["_id"]} - assert len(pending_messages) == len(fixture_messages) - fixture_messages_by_hash = {m["item_hash"]: m for m in fixture_messages} + pending_msg_ops = db_operations_by_collection[PendingMessage] + fixture_messages = load_fixture_messages(f"{pending_tx['content']['content']}.json") - for pending in pending_messages: - pending_message = pending["message"] - expected_message = fixture_messages_by_hash[pending_message["item_hash"]] + assert len(pending_msg_ops) == len(fixture_messages) + fixture_messages_by_hash = {msg["item_hash"]: msg for msg in fixture_messages} + for pending_msg_op in pending_msg_ops: + assert isinstance(pending_msg_op.operation, InsertOne) + pending_message = pending_msg_op.operation._doc["message"] + expected_message = fixture_messages_by_hash[ + pending_msg_op.operation._doc["message"]["item_hash"] + ] # TODO: currently, the pending TX job modifies the time of the message. - del expected_message["time"] - assert set(expected_message.items()).issubset(set(pending_message.items())) + del pending_message["time"] + assert set(pending_message.items()).issubset(set(expected_message.items())) diff --git a/tests/utils/test_batch.py b/tests/utils/test_batch.py new file mode 100644 index 000000000..8e7c0fa73 --- /dev/null +++ b/tests/utils/test_batch.py @@ -0,0 +1,26 @@ +import pytest +from aleph.toolkit.batch import async_batch + + +async def async_range(*args): + for i in range(*args): + yield i + + +@pytest.mark.asyncio +async def test_async_batch(): + # batch with a remainder + batches = [b async for b in async_batch(async_range(0, 10), 3)] + assert batches == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + + # iterable divisible by n + batches = [b async for b in async_batch(async_range(0, 4), 2)] + assert batches == [[0, 1], [2, 3]] + + # n = 1 + batches = [b async for b in async_batch(async_range(0, 5), 1)] + assert batches == [[0], [1], [2], [3], [4]] + + # n = len(iterable) + batches = [b async for b in async_batch(async_range(0, 7), 7)] + assert batches == [[0, 1, 2, 3, 4, 5, 6]] From b181ae0980b14bcb160b3de7d0b0d2ccbb063b2e Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Tue, 3 May 2022 16:47:34 +0200 Subject: [PATCH 2/3] fixes for review --- src/aleph/jobs/process_pending_txs.py | 2 +- src/aleph/toolkit/batch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aleph/jobs/process_pending_txs.py b/src/aleph/jobs/process_pending_txs.py index aa14f137d..79a6ff089 100644 --- a/src/aleph/jobs/process_pending_txs.py +++ b/src/aleph/jobs/process_pending_txs.py @@ -29,7 +29,7 @@ async def handle_pending_tx( pending_tx, seen_ids: Optional[List] = None ) -> List[DbBulkOperation]: - db_operations = [] + db_operations: List[DbBulkOperation] = [] tx_context = TxContext(**pending_tx["context"]) LOGGER.info("%s Handling TX in block %s", tx_context.chain_name, tx_context.height) diff --git a/src/aleph/toolkit/batch.py b/src/aleph/toolkit/batch.py index db912ce98..ca0910270 100644 --- a/src/aleph/toolkit/batch.py +++ b/src/aleph/toolkit/batch.py @@ -4,7 +4,7 @@ async def async_batch( - async_iterable: AsyncIterator["T"], n: int + async_iterable: AsyncIterator[T], n: int ) -> AsyncIterator[List[T]]: batch = [] async for item in async_iterable: From 22232520e0b51776d04482ede3ba4cb7a1af5d67 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Tue, 3 May 2022 23:47:28 +0200 Subject: [PATCH 3/3] use partial instead of lambda --- src/aleph/jobs/process_pending_messages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/aleph/jobs/process_pending_messages.py b/src/aleph/jobs/process_pending_messages.py index cc7881c59..8eaaf804e 100644 --- a/src/aleph/jobs/process_pending_messages.py +++ b/src/aleph/jobs/process_pending_messages.py @@ -3,6 +3,7 @@ """ import asyncio +from functools import partial from logging import getLogger from typing import List, Dict, Tuple @@ -48,7 +49,7 @@ async def handle_pending_message( async def join_pending_message_tasks(tasks): await gather_and_perform_db_operations( tasks, - on_error=lambda e: LOGGER.error("Error while processing message: %s", e), + on_error=partial(LOGGER.error, "Error while processing message: %s"), ) tasks.clear()