From fdb5be3317eb25f92ef2e4343f56b4d042d347e5 Mon Sep 17 00:00:00 2001 From: arvidn Date: Fri, 22 Dec 2023 15:29:20 +0100 Subject: [PATCH 1/3] update type annotation for CoinStore.get_coin_records to support both List and Set --- chia/full_node/coin_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chia/full_node/coin_store.py b/chia/full_node/coin_store.py index 492331cefc50..dcfbef07a951 100644 --- a/chia/full_node/coin_store.py +++ b/chia/full_node/coin_store.py @@ -146,7 +146,7 @@ async def get_coin_record(self, coin_name: bytes32) -> Optional[CoinRecord]: return CoinRecord(coin, row[0], row[1], row[2], row[6]) return None - async def get_coin_records(self, names: List[bytes32]) -> List[CoinRecord]: + async def get_coin_records(self, names: Collection[bytes32]) -> List[CoinRecord]: if len(names) == 0: return [] From 14e483cbe63af33fb26535f039545bc63fcca0a7 Mon Sep 17 00:00:00 2001 From: arvidn Date: Thu, 21 Dec 2023 15:13:39 +0100 Subject: [PATCH 2/3] update the mempool to fetch multiple coin records per query --- benchmarks/mempool-long-lived.py | 11 +- benchmarks/mempool.py | 15 ++- chia/clvm/spend_sim.py | 2 +- chia/full_node/full_node.py | 2 +- chia/full_node/mempool_manager.py | 45 +++++--- .../mempool/test_mempool_fee_estimator.py | 2 +- tests/core/mempool/test_mempool_manager.py | 100 +++++++++++------- .../test_fee_estimation_integration.py | 6 +- 8 files changed, 120 insertions(+), 63 deletions(-) diff --git a/benchmarks/mempool-long-lived.py b/benchmarks/mempool-long-lived.py index 38db69ecdd30..7f78e594f164 100644 --- a/benchmarks/mempool-long-lived.py +++ b/benchmarks/mempool-long-lived.py @@ -3,7 +3,7 @@ import asyncio from dataclasses import dataclass from time import monotonic -from typing import Dict, Optional +from typing import Collection, Dict, List, Optional from chia_rs import G2Element from clvm.casts import int_to_bytes @@ -81,8 +81,13 @@ def fake_block_record(block_height: uint32, timestamp: uint64) -> BenchBlockReco async def run_mempool_benchmark() -> None: coin_records: Dict[bytes32, CoinRecord] = {} - async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]: - return coin_records.get(coin_id) + async def get_coin_record(coin_ids: Collection[bytes32]) -> List[CoinRecord]: + ret: List[CoinRecord] = [] + for name in coin_ids: + r = coin_records.get(name) + if r is not None: + ret.append(r) + return ret timestamp = uint64(1631794488) diff --git a/benchmarks/mempool.py b/benchmarks/mempool.py index 76ec63e64224..861944ae4f05 100644 --- a/benchmarks/mempool.py +++ b/benchmarks/mempool.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from subprocess import check_call from time import monotonic -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Collection, Dict, Iterator, List, Optional, Tuple from chia.consensus.coinbase import create_farmer_coin, create_pool_coin from chia.consensus.default_constants import DEFAULT_CONSTANTS @@ -78,8 +78,13 @@ def fake_block_record(block_height: uint32, timestamp: uint64) -> BenchBlockReco async def run_mempool_benchmark() -> None: all_coins: Dict[bytes32, CoinRecord] = {} - async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]: - return all_coins.get(coin_id) + async def get_coin_records(coin_ids: Collection[bytes32]) -> List[CoinRecord]: + ret: List[CoinRecord] = [] + for name in coin_ids: + r = all_coins.get(name) + if r is not None: + ret.append(r) + return ret wt = WalletTool(DEFAULT_CONSTANTS) @@ -156,7 +161,7 @@ async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]: else: print("\n== Multi-threaded") - mempool = MempoolManager(get_coin_record, DEFAULT_CONSTANTS, single_threaded=single_threaded) + mempool = MempoolManager(get_coin_records, DEFAULT_CONSTANTS, single_threaded=single_threaded) height = start_height rec = fake_block_record(height, timestamp) @@ -186,7 +191,7 @@ async def add_spend_bundles(spend_bundles: List[SpendBundle]) -> None: print(f" time: {stop - start:0.4f}s") print(f" per call: {(stop - start) / total_bundles * 1000:0.2f}ms") - mempool = MempoolManager(get_coin_record, DEFAULT_CONSTANTS, single_threaded=single_threaded) + mempool = MempoolManager(get_coin_records, DEFAULT_CONSTANTS, single_threaded=single_threaded) height = start_height rec = fake_block_record(height, timestamp) diff --git a/chia/clvm/spend_sim.py b/chia/clvm/spend_sim.py index c2d7835364e9..a6d296010046 100644 --- a/chia/clvm/spend_sim.py +++ b/chia/clvm/spend_sim.py @@ -156,7 +156,7 @@ async def create( self.db_wrapper = await DBWrapper2.create(database=uri, uri=True, reader_count=1, db_version=2) self.coin_store = await CoinStore.create(self.db_wrapper) - self.mempool_manager = MempoolManager(self.coin_store.get_coin_record, defaults) + self.mempool_manager = MempoolManager(self.coin_store.get_coin_records, defaults) self.defaults = defaults # Load the next data if there is any diff --git a/chia/full_node/full_node.py b/chia/full_node/full_node.py index e2e0cf65c326..b97ed827a6c7 100644 --- a/chia/full_node/full_node.py +++ b/chia/full_node/full_node.py @@ -276,7 +276,7 @@ async def manage(self) -> AsyncIterator[None]: ) self._mempool_manager = MempoolManager( - get_coin_record=self.coin_store.get_coin_record, + get_coin_records=self.coin_store.get_coin_records, consensus_constants=self.constants, multiprocessing_context=self.multiprocessing_context, single_threaded=single_threaded, diff --git a/chia/full_node/mempool_manager.py b/chia/full_node/mempool_manager.py index 4baaa7c0bc89..6a40acf7a068 100644 --- a/chia/full_node/mempool_manager.py +++ b/chia/full_node/mempool_manager.py @@ -7,7 +7,7 @@ from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass from multiprocessing.context import BaseContext -from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple, TypeVar +from typing import Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple, TypeVar from chia_rs import ELIGIBLE_FOR_DEDUP, GTElement from chiabip158 import PyBIP158 @@ -146,7 +146,7 @@ class MempoolManager: pool: Executor constants: ConsensusConstants seen_bundle_hashes: Dict[bytes32, bytes32] - get_coin_record: Callable[[bytes32], Awaitable[Optional[CoinRecord]]] + get_coin_records: Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]] nonzero_fee_minimum_fpc: int mempool_max_total_cost: int # a cache of MempoolItems that conflict with existing items in the pool @@ -159,7 +159,7 @@ class MempoolManager: def __init__( self, - get_coin_record: Callable[[bytes32], Awaitable[Optional[CoinRecord]]], + get_coin_records: Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]], consensus_constants: ConsensusConstants, multiprocessing_context: Optional[BaseContext] = None, *, @@ -170,7 +170,7 @@ def __init__( # Keep track of seen spend_bundles self.seen_bundle_hashes: Dict[bytes32, bytes32] = {} - self.get_coin_record = get_coin_record + self.get_coin_records = get_coin_records # The fee per cost must be above this amount to consider the fee "nonzero", and thus able to kick out other # transactions. This prevents spam. This is equivalent to 0.055 XCH per block, or about 0.00005 XCH for two @@ -303,7 +303,12 @@ async def pre_validate_spendbundle( return ret async def add_spend_bundle( - self, new_spend: SpendBundle, npc_result: NPCResult, spend_name: bytes32, first_added_height: uint32 + self, + new_spend: SpendBundle, + npc_result: NPCResult, + spend_name: bytes32, + first_added_height: uint32, + get_coin_records: Optional[Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]]] = None, ) -> Tuple[Optional[uint64], MempoolInclusionStatus, Optional[Err]]: """ Validates and adds to mempool a new_spend with the given NPCResult, and spend_name, and the current mempool. @@ -327,8 +332,14 @@ async def add_spend_bundle( if existing_item is not None: return existing_item.cost, MempoolInclusionStatus.SUCCESS, None + if get_coin_records is None: + get_coin_records = self.get_coin_records err, item, remove_items = await self.validate_spend_bundle( - new_spend, npc_result, spend_name, first_added_height + new_spend, + npc_result, + spend_name, + first_added_height, + get_coin_records, ) if err is None: # No error, immediately add to mempool, after removing conflicting TXs. @@ -358,6 +369,7 @@ async def validate_spend_bundle( npc_result: NPCResult, spend_name: bytes32, first_added_height: uint32, + get_coin_records: Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]], ) -> Tuple[Optional[Err], Optional[MempoolItem], List[bytes32]]: """ Validates new_spend with the given NPCResult, and spend_name, and the current mempool. The mempool should @@ -420,11 +432,14 @@ async def validate_spend_bundle( removal_record_dict: Dict[bytes32, CoinRecord] = {} removal_amount: int = 0 + removal_records = await get_coin_records(removal_names) + for record in removal_records: + removal_record_dict[record.coin.name()] = record + for name in removal_names: - removal_record = await self.get_coin_record(name) - if removal_record is None and name not in additions_dict: + if name not in removal_record_dict and name not in additions_dict: return Err.UNKNOWN_UNSPENT, None, [] - elif name in additions_dict: + if name in additions_dict: removal_coin = additions_dict[name] # The timestamp and block-height of this coin being spent needs # to be consistent with what we use to check time-lock @@ -440,10 +455,10 @@ async def validate_spend_bundle( False, self.peak.timestamp, ) - - assert removal_record is not None + removal_record_dict[name] = removal_record + else: + removal_record = removal_record_dict[name] removal_amount = removal_amount + removal_record.coin.amount - removal_record_dict[name] = removal_record fees = uint64(removal_amount - addition_amount) @@ -642,7 +657,11 @@ async def new_peak( txs_added = [] for item in potential_txs.values(): cost, status, error = await self.add_spend_bundle( - item.spend_bundle, item.npc_result, item.spend_bundle_name, item.height_added_to_mempool + item.spend_bundle, + item.npc_result, + item.spend_bundle_name, + item.height_added_to_mempool, + self.get_coin_records, ) if status == MempoolInclusionStatus.SUCCESS: txs_added.append((item.spend_bundle, item.npc_result, item.spend_bundle_name)) diff --git a/tests/core/mempool/test_mempool_fee_estimator.py b/tests/core/mempool/test_mempool_fee_estimator.py index 6dc2e7ce9946..86377d5b90e5 100644 --- a/tests/core/mempool/test_mempool_fee_estimator.py +++ b/tests/core/mempool/test_mempool_fee_estimator.py @@ -63,7 +63,7 @@ async def test_basics() -> None: async def test_fee_increase() -> None: async with DBConnection(db_version=2) as db_wrapper: coin_store = await CoinStore.create(db_wrapper) - mempool_manager = MempoolManager(coin_store.get_coin_record, test_constants) + mempool_manager = MempoolManager(coin_store.get_coin_records, test_constants) assert test_constants.MAX_BLOCK_COST_CLVM == mempool_manager.constants.MAX_BLOCK_COST_CLVM btc_fee_estimator: BitcoinFeeEstimator = mempool_manager.mempool.fee_estimator # type: ignore fee_tracker = btc_fee_estimator.get_tracker() diff --git a/tests/core/mempool/test_mempool_manager.py b/tests/core/mempool/test_mempool_manager.py index 7c64de390c27..3c7d355839f8 100644 --- a/tests/core/mempool/test_mempool_manager.py +++ b/tests/core/mempool/test_mempool_manager.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple import pytest from chia_rs import ELIGIBLE_FOR_DEDUP, G1Element, G2Element @@ -85,17 +85,24 @@ def is_transaction_block(self) -> bool: return self.timestamp is not None -async def zero_calls_get_coin_record(_: bytes32) -> Optional[CoinRecord]: - assert False +async def zero_calls_get_coin_records(coin_ids: Collection[bytes32]) -> List[CoinRecord]: + assert len(coin_ids) == 0 + return [] -async def get_coin_record_for_test_coins(coin_id: bytes32) -> Optional[CoinRecord]: +async def get_coin_records_for_test_coins(coin_ids: Collection[bytes32]) -> List[CoinRecord]: test_coin_records = { TEST_COIN_ID: TEST_COIN_RECORD, TEST_COIN_ID2: TEST_COIN_RECORD2, TEST_COIN_ID3: TEST_COIN_RECORD3, } - return test_coin_records.get(coin_id) + + ret: List[CoinRecord] = [] + for name in coin_ids: + r = test_coin_records.get(name) + if r is not None: + ret.append(r) + return ret def height_hash(height: int) -> bytes32: @@ -113,13 +120,13 @@ def create_test_block_record(*, height: uint32 = TEST_HEIGHT, timestamp: uint64 async def instantiate_mempool_manager( - get_coin_record: Callable[[bytes32], Awaitable[Optional[CoinRecord]]], + get_coin_records: Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]], *, block_height: uint32 = TEST_HEIGHT, block_timestamp: uint64 = TEST_TIMESTAMP, constants: ConsensusConstants = DEFAULT_CONSTANTS, ) -> MempoolManager: - mempool_manager = MempoolManager(get_coin_record, constants) + mempool_manager = MempoolManager(get_coin_records, constants) test_block_record = create_test_block_record(height=block_height, timestamp=block_timestamp) await mempool_manager.new_peak(test_block_record, None) invariant_check_mempool(mempool_manager.mempool) @@ -134,10 +141,15 @@ async def setup_mempool_with_coins(*, coin_amounts: List[int]) -> Tuple[MempoolM coins.append(coin) test_coin_records[coin.name()] = CoinRecord(coin, uint32(0), uint32(0), False, uint64(0)) - async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]: - return test_coin_records.get(coin_id) + async def get_coin_records(coin_ids: Collection[bytes32]) -> List[CoinRecord]: + ret: List[CoinRecord] = [] + for name in coin_ids: + r = test_coin_records.get(name) + if r is not None: + ret.append(r) + return ret - mempool_manager = await instantiate_mempool_manager(get_coin_record) + mempool_manager = await instantiate_mempool_manager(get_coin_records) return (mempool_manager, coins) @@ -393,7 +405,7 @@ def mempool_item_from_spendbundle(spend_bundle: SpendBundle) -> MempoolItem: @pytest.mark.anyio async def test_empty_spend_bundle() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) sb = SpendBundle([], G2Element()) with pytest.raises(ValidationError, match="INVALID_SPEND_BUNDLE"): await mempool_manager.pre_validate_spendbundle(sb, None, sb.name()) @@ -401,7 +413,7 @@ async def test_empty_spend_bundle() -> None: @pytest.mark.anyio async def test_negative_addition_amount() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, -1]] sb = spend_bundle_from_conditions(conditions) with pytest.raises(ValidationError, match="COIN_AMOUNT_NEGATIVE"): @@ -410,7 +422,7 @@ async def test_negative_addition_amount() -> None: @pytest.mark.anyio async def test_valid_addition_amount() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) max_amount = mempool_manager.constants.MAX_COIN_AMOUNT conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, max_amount]] coin = Coin(IDENTITY_PUZZLE_HASH, IDENTITY_PUZZLE_HASH, max_amount) @@ -421,7 +433,7 @@ async def test_valid_addition_amount() -> None: @pytest.mark.anyio async def test_too_big_addition_amount() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) max_amount = mempool_manager.constants.MAX_COIN_AMOUNT conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, max_amount + 1]] sb = spend_bundle_from_conditions(conditions) @@ -431,7 +443,7 @@ async def test_too_big_addition_amount() -> None: @pytest.mark.anyio async def test_duplicate_output() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) conditions = [ [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1], [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1], @@ -443,7 +455,7 @@ async def test_duplicate_output() -> None: @pytest.mark.anyio async def test_block_cost_exceeds_max() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) conditions = [] for i in range(2400): conditions.append([ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, i]) @@ -454,7 +466,7 @@ async def test_block_cost_exceeds_max() -> None: @pytest.mark.anyio async def test_double_spend_prevalidation() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1]] sb = spend_bundle_from_conditions(conditions) sb_twice: SpendBundle = SpendBundle.aggregate([sb, sb]) @@ -464,7 +476,7 @@ async def test_double_spend_prevalidation() -> None: @pytest.mark.anyio async def test_minting_coin() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, TEST_COIN_AMOUNT]] sb = spend_bundle_from_conditions(conditions) npc_result = await mempool_manager.pre_validate_spendbundle(sb, None, sb.name()) @@ -477,7 +489,7 @@ async def test_minting_coin() -> None: @pytest.mark.anyio async def test_reserve_fee_condition() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) conditions = [[ConditionOpcode.RESERVE_FEE, TEST_COIN_AMOUNT]] sb = spend_bundle_from_conditions(conditions) npc_result = await mempool_manager.pre_validate_spendbundle(sb, None, sb.name()) @@ -490,10 +502,10 @@ async def test_reserve_fee_condition() -> None: @pytest.mark.anyio async def test_unknown_unspent() -> None: - async def get_coin_record(_: bytes32) -> Optional[CoinRecord]: - return None + async def get_coin_records(_: Collection[bytes32]) -> List[CoinRecord]: + return [] - mempool_manager = await instantiate_mempool_manager(get_coin_record) + mempool_manager = await instantiate_mempool_manager(get_coin_records) conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1]] _, _, result = await generate_and_add_spendbundle(mempool_manager, conditions) assert result == (None, MempoolInclusionStatus.FAILED, Err.UNKNOWN_UNSPENT) @@ -501,7 +513,7 @@ async def get_coin_record(_: bytes32) -> Optional[CoinRecord]: @pytest.mark.anyio async def test_same_sb_twice_with_eligible_coin() -> None: - mempool_manager = await instantiate_mempool_manager(get_coin_record_for_test_coins) + mempool_manager = await instantiate_mempool_manager(get_coin_records_for_test_coins) sb1_conditions = [ [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1], [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 2], @@ -525,7 +537,7 @@ async def test_same_sb_twice_with_eligible_coin() -> None: @pytest.mark.anyio async def test_sb_twice_with_eligible_coin_and_different_spends_order() -> None: - mempool_manager = await instantiate_mempool_manager(get_coin_record_for_test_coins) + mempool_manager = await instantiate_mempool_manager(get_coin_records_for_test_coins) sb1_conditions = [ [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1], [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 2], @@ -622,7 +634,7 @@ async def test_ephemeral_timelock( expected_error: Optional[Err], ) -> None: mempool_manager = await instantiate_mempool_manager( - get_coin_record=get_coin_record_for_test_coins, + get_coin_records=get_coin_records_for_test_coins, block_height=uint32(5), block_timestamp=uint64(10050), constants=DEFAULT_CONSTANTS, @@ -635,7 +647,7 @@ async def test_ephemeral_timelock( # sb spends TEST_COIN and creates created_coin which gets spent too sb = SpendBundle.aggregate([sb1, sb2]) # We shouldn't have a record of this ephemeral coin - assert await get_coin_record_for_test_coins(created_coin.name()) is None + assert await get_coin_records_for_test_coins([created_coin.name()]) == [] try: _, status, error = await add_spendbundle(mempool_manager, sb, sb.name()) assert (status, error) == (expected_status, expected_error) @@ -833,7 +845,7 @@ def test_can_replace(existing_items: List[MempoolItem], new_item: MempoolItem, e @pytest.mark.anyio async def test_get_items_not_in_filter() -> None: - mempool_manager = await instantiate_mempool_manager(get_coin_record_for_test_coins) + mempool_manager = await instantiate_mempool_manager(get_coin_records_for_test_coins) conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1]] sb1, sb1_name, _ = await generate_and_add_spendbundle(mempool_manager, conditions) conditions2 = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 2]] @@ -880,10 +892,15 @@ async def test_get_items_not_in_filter() -> None: async def test_total_mempool_fees() -> None: coin_records: Dict[bytes32, CoinRecord] = {} - async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]: - return coin_records.get(coin_id) + async def get_coin_records(coin_ids: Collection[bytes32]) -> List[CoinRecord]: + ret: List[CoinRecord] = [] + for name in coin_ids: + r = coin_records.get(name) + if r is not None: + ret.append(r) + return ret - mempool_manager = await instantiate_mempool_manager(get_coin_record) + mempool_manager = await instantiate_mempool_manager(get_coin_records) conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1]] # the limit of total fees in the mempool is 2^63 @@ -998,11 +1015,17 @@ async def make_and_send_big_cost_sb(coin: Coin) -> None: async def test_assert_before_expiration( opcode: ConditionOpcode, arg: int, expect_eviction: bool, expect_limit: Optional[int] ) -> None: - async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]: - return {TEST_COIN.name(): CoinRecord(TEST_COIN, uint32(5), uint32(0), False, uint64(9900))}.get(coin_id) + async def get_coin_records(coin_ids: Collection[bytes32]) -> List[CoinRecord]: + all_coins = {TEST_COIN.name(): CoinRecord(TEST_COIN, uint32(5), uint32(0), False, uint64(9900))} + ret: List[CoinRecord] = [] + for name in coin_ids: + r = all_coins.get(name) + if r is not None: + ret.append(r) + return ret mempool_manager = await instantiate_mempool_manager( - get_coin_record, + get_coin_records, block_height=uint32(10), block_timestamp=uint64(10000), constants=DEFAULT_CONSTANTS, @@ -1357,10 +1380,15 @@ async def test_coin_spending_different_ways_then_finding_it_spent_in_new_peak(ne coin_id = coin.name() test_coin_records = {coin_id: CoinRecord(coin, uint32(0), uint32(0), False, uint64(0))} - async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]: - return test_coin_records.get(coin_id) + async def get_coin_records(coin_ids: Collection[bytes32]) -> List[CoinRecord]: + ret: List[CoinRecord] = [] + for name in coin_ids: + r = test_coin_records.get(name) + if r is not None: + ret.append(r) + return ret - mempool_manager = await instantiate_mempool_manager(get_coin_record) + mempool_manager = await instantiate_mempool_manager(get_coin_records) # Create a bunch of mempool items that spend the coin in different ways for i in range(3): _, _, result = await generate_and_add_spendbundle( diff --git a/tests/fee_estimation/test_fee_estimation_integration.py b/tests/fee_estimation/test_fee_estimation_integration.py index c5f843970870..64463ea05b3d 100644 --- a/tests/fee_estimation/test_fee_estimation_integration.py +++ b/tests/fee_estimation/test_fee_estimation_integration.py @@ -29,7 +29,7 @@ from tests.core.mempool.test_mempool_manager import ( create_test_block_record, instantiate_mempool_manager, - zero_calls_get_coin_record, + zero_calls_get_coin_records, ) @@ -217,7 +217,7 @@ def test_current_block_height_new_block_then_new_height() -> None: @pytest.mark.anyio async def test_mm_new_peak_changes_fee_estimator_block_height() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) block2 = create_test_block_record(height=uint32(2)) await mempool_manager.new_peak(block2, None) assert mempool_manager.mempool.fee_estimator.block_height == uint32(2) # type: ignore[attr-defined] @@ -225,7 +225,7 @@ async def test_mm_new_peak_changes_fee_estimator_block_height() -> None: @pytest.mark.anyio async def test_mm_calls_new_block_height() -> None: - mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record) + mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_records) new_block_height_called = False def test_new_block_height_called(self: FeeEstimatorInterface, height: uint32) -> None: From d08861ac9bdbb2d76f9df4042b0a33c0d36b15ce Mon Sep 17 00:00:00 2001 From: arvidn Date: Thu, 21 Dec 2023 18:32:06 +0100 Subject: [PATCH 3/3] optimize the slow-path of updating the mempool by fetching all coin records up-front, in a single sql query --- chia/full_node/mempool_manager.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/chia/full_node/mempool_manager.py b/chia/full_node/mempool_manager.py index 6a40acf7a068..04e34ed2f0ef 100644 --- a/chia/full_node/mempool_manager.py +++ b/chia/full_node/mempool_manager.py @@ -638,9 +638,35 @@ async def new_peak( old_pool = self.mempool self.mempool = Mempool(old_pool.mempool_info, old_pool.fee_estimator) self.seen_bundle_hashes = {} + + # in order to make this a bit quicker, we look-up all the spends in + # a single query, rather than one at a time. + coin_records: Dict[bytes32, CoinRecord] = {} + + removals: Set[bytes32] = set() + for item in old_pool.all_items(): + for s in item.spend_bundle.coin_spends: + removals.add(s.coin.name()) + + for record in await self.get_coin_records(removals): + name = record.coin.name() + coin_records[name] = record + + async def local_get_coin_records(names: Collection[bytes32]) -> List[CoinRecord]: + ret: List[CoinRecord] = [] + for name in names: + r = coin_records.get(name) + if r is not None: + ret.append(r) + return ret + for item in old_pool.all_items(): _, result, err = await self.add_spend_bundle( - item.spend_bundle, item.npc_result, item.spend_bundle_name, item.height_added_to_mempool + item.spend_bundle, + item.npc_result, + item.spend_bundle_name, + item.height_added_to_mempool, + local_get_coin_records, ) # Only add to `seen` if inclusion worked, so it can be resubmitted in case of a reorg if result == MempoolInclusionStatus.SUCCESS: