Skip to content

Commit

Permalink
Add MempoolRemoveReason, which represents why we are removing a Mempo…
Browse files Browse the repository at this point in the history
…olItem (#14263)

* Add MempoolRemoveReason, which represents why we are removing a MempoolItem

* Add integration tests for remove_from_pool. Generalize call count tracking in mock Fee Estimator.
  • Loading branch information
aqk committed Jan 3, 2023
1 parent 559e58a commit 8364166
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 14 deletions.
14 changes: 11 additions & 3 deletions chia/full_node/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional

from sortedcontainers import SortedDict
Expand All @@ -15,6 +16,12 @@
from chia.util.ints import uint64


class MempoolRemoveReason(Enum):
CONFLICT = 1
BLOCK_INCLUSION = 2
POOL_FULL = 3


class Mempool:
def __init__(self, mempool_info: MempoolInfo, fee_estimator: FeeEstimatorInterface):
self.log: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,7 +54,7 @@ def get_min_fee_rate(self, cost: int) -> float:
else:
return 0

def remove_from_pool(self, items: List[bytes32]) -> None:
def remove_from_pool(self, items: List[bytes32], reason: MempoolRemoveReason) -> None:
"""
Removes an item from the mempool.
"""
Expand All @@ -70,7 +77,8 @@ def remove_from_pool(self, items: List[bytes32]) -> None:
self.total_mempool_cost = CLVMCost(uint64(self.total_mempool_cost - item.cost))
assert self.total_mempool_cost >= 0
info = FeeMempoolInfo(self.mempool_info, self.total_mempool_cost, datetime.now())
self.fee_estimator.remove_mempool_item(info, item)
if reason != MempoolRemoveReason.BLOCK_INCLUSION:
self.fee_estimator.remove_mempool_item(info, item)

def add_to_pool(self, item: MempoolItem) -> None:
"""
Expand All @@ -81,7 +89,7 @@ def add_to_pool(self, item: MempoolItem) -> None:
# Val is Dict[hash, MempoolItem]
fee_per_cost, val = self.sorted_spends.peekitem(index=0)
to_remove: MempoolItem = list(val.values())[0]
self.remove_from_pool([to_remove.name])
self.remove_from_pool([to_remove.name], MempoolRemoveReason.POOL_FULL)

self.spends[item.name] = item

Expand Down
6 changes: 3 additions & 3 deletions chia/full_node/mempool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from chia.full_node.bundle_tools import simple_solution_generator
from chia.full_node.fee_estimation import FeeBlockInfo, MempoolInfo
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.full_node.mempool import Mempool
from chia.full_node.mempool import Mempool, MempoolRemoveReason
from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions, mempool_check_time_locks
from chia.full_node.pending_tx_cache import PendingTxCache
from chia.types.blockchain_format.coin import Coin
Expand Down Expand Up @@ -350,7 +350,7 @@ async def add_spend_bundle(
# No error, immediately add to mempool, after removing conflicting TXs.
assert item is not None
self.mempool.add_to_pool(item)
self.mempool.remove_from_pool(remove_items)
self.mempool.remove_from_pool(remove_items, MempoolRemoveReason.CONFLICT)
return item.cost, MempoolInclusionStatus.SUCCESS, None
elif item is not None:
# There is an error, but we still returned a mempool item, this means we should add to the pending pool.
Expand Down Expand Up @@ -586,7 +586,7 @@ async def new_peak(
spendbundle_ids: List[bytes32] = self.mempool.removal_coin_id_to_spendbundle_ids[
bytes32(spend.coin_id)
]
self.mempool.remove_from_pool(spendbundle_ids)
self.mempool.remove_from_pool(spendbundle_ids, MempoolRemoveReason.BLOCK_INCLUSION)
for spendbundle_id in spendbundle_ids:
self.remove_seen(spendbundle_id)
else:
Expand Down
36 changes: 29 additions & 7 deletions tests/fee_estimation/test_fee_estimation_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Dict

from chia_rs import Coin

from chia.consensus.cost_calculator import NPCResult
Expand All @@ -12,7 +14,7 @@
MempoolInfo,
)
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.full_node.mempool import Mempool
from chia.full_node.mempool import Mempool, MempoolRemoveReason
from chia.simulator.block_tools import test_constants
from chia.simulator.wallet_tools import WalletTool
from chia.types.clvm_cost import CLVMCost
Expand Down Expand Up @@ -43,19 +45,20 @@ def make_mempoolitem() -> MempoolItem:


class FeeEstimatorInterfaceIntegrationVerificationObject(FeeEstimatorInterface):
add_mempool_item_called = False
add_mempool_item_called_count: int = 0
remove_mempool_item_called_count: int = 0

def new_block(self, block_info: FeeBlockInfo) -> None:
"""A new block has been added to the blockchain"""
pass

def add_mempool_item(self, mempool_item_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
"""A MempoolItem (transaction and associated info) has been added to the mempool"""
self.add_mempool_item_called = True
self.add_mempool_item_called_count += 1

def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
"""A MempoolItem (transaction and associated info) has been removed from the mempool"""
pass
self.remove_mempool_item_called_count += 1

def estimate_fee_rate(self, *, time_offset_seconds: int) -> FeeRate:
"""time_offset_seconds: number of seconds into the future for which to estimate fee"""
Expand Down Expand Up @@ -93,12 +96,31 @@ def test_mempool_fee_estimator_add_item() -> None:
mempool = Mempool(test_mempool_info, fee_estimator)
item = make_mempoolitem()
mempool.add_to_pool(item)
assert mempool.fee_estimator.add_mempool_item_called # type: ignore[attr-defined]
assert mempool.fee_estimator.add_mempool_item_called_count == 1 # type: ignore[attr-defined]


def test_item_not_removed_if_not_added() -> None:
for reason in MempoolRemoveReason:
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
item = make_mempoolitem()
mempool.remove_from_pool([item.name], reason)
assert mempool.fee_estimator.remove_mempool_item_called_count == 0 # type: ignore[attr-defined]


def test_mempool_fee_estimator_remove_item() -> None:
# reasons: Dict[MempoolRemoveReason, bool] = {}
pass
should_call_fee_estimator_remove: Dict[MempoolRemoveReason, int] = {
MempoolRemoveReason.BLOCK_INCLUSION: 0,
MempoolRemoveReason.CONFLICT: 1,
MempoolRemoveReason.POOL_FULL: 1,
}
for reason, call_count in should_call_fee_estimator_remove.items():
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
item = make_mempoolitem()
mempool.add_to_pool(item)
mempool.remove_from_pool([item.name], reason)
assert mempool.fee_estimator.remove_mempool_item_called_count == call_count # type: ignore[attr-defined]


def test_mempool_manager_fee_estimator_new_block() -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/wallet/test_wallet_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from chia.full_node.full_node_api import FullNodeAPI
from chia.full_node.mempool import MempoolRemoveReason
from chia.simulator.block_tools import BlockTools
from chia.simulator.full_node_simulator import FullNodeSimulator
from chia.simulator.time_out_assert import time_out_assert, time_out_assert_custom_interval
Expand All @@ -29,7 +30,7 @@ def assert_sb_not_in_pool(node: FullNodeAPI, sb: SpendBundle) -> None:

def evict_from_pool(node: FullNodeAPI, sb: SpendBundle) -> None:
mempool_item = node.full_node.mempool_manager.mempool.spends[sb.name()]
node.full_node.mempool_manager.mempool.remove_from_pool([mempool_item.name])
node.full_node.mempool_manager.mempool.remove_from_pool([mempool_item.name], MempoolRemoveReason.CONFLICT)
node.full_node.mempool_manager.remove_seen(sb.name())


Expand Down

0 comments on commit 8364166

Please sign in to comment.