Skip to content

Commit

Permalink
improve performance of total_mempool_fees() and `total_mempool_cost…
Browse files Browse the repository at this point in the history
…()` (#17107)

* improve performance of total_mempool_fees() and total_mempool_cost()

* log when updating the mempool using the slow-path

* extend test to ensure total-cost and total-fee book-keeping stays in sync with the DB

* lower mempool size from 50 blocks to 10 blocks
  • Loading branch information
arvidn committed Dec 22, 2023
1 parent 3856324 commit 55c064a
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 39 deletions.
4 changes: 2 additions & 2 deletions chia/consensus/default_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
"3d8765d3a597ec1d99663f6c9816d915b9f68613ac94009884c4addaefcce6af"
),
MAX_VDF_WITNESS_SIZE=64,
# Size of mempool = 50x the size of block
MEMPOOL_BLOCK_BUFFER=50,
# Size of mempool = 10x the size of block
MEMPOOL_BLOCK_BUFFER=10,
# Max coin amount, fits into 64 bits
MAX_COIN_AMOUNT=uint64((1 << 64) - 1),
# Max block cost in clvm cost units
Expand Down
80 changes: 46 additions & 34 deletions chia/full_node/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,16 @@ class Mempool:
_block_height: uint32
_timestamp: uint64

_total_fee: int
_total_cost: int

def __init__(self, mempool_info: MempoolInfo, fee_estimator: FeeEstimatorInterface):
self._db_conn = sqlite3.connect(":memory:")
self._items = {}
self._block_height = uint32(0)
self._timestamp = uint64(0)
self._total_fee = 0
self._total_cost = 0

with self._db_conn:
# name means SpendBundle hash
Expand All @@ -75,8 +80,6 @@ def __init__(self, mempool_info: MempoolInfo, fee_estimator: FeeEstimatorInterfa
"""
)
self._db_conn.execute("CREATE INDEX name_idx ON tx(name)")
self._db_conn.execute("CREATE INDEX fee_sum ON tx(fee)")
self._db_conn.execute("CREATE INDEX cost_sum ON tx(cost)")
self._db_conn.execute("CREATE INDEX feerate ON tx(fee_per_cost)")
self._db_conn.execute(
"CREATE INDEX assert_before ON tx(assert_before_height, assert_before_seconds) "
Expand Down Expand Up @@ -121,16 +124,10 @@ def _row_to_item(self, row: sqlite3.Row) -> MempoolItem:
)

def total_mempool_fees(self) -> int:
with self._db_conn:
cursor = self._db_conn.execute("SELECT SUM(fee) FROM tx")
val = cursor.fetchone()[0]
return uint64(0) if val is None else uint64(val)
return self._total_fee

def total_mempool_cost(self) -> CLVMCost:
with self._db_conn:
cursor = self._db_conn.execute("SELECT SUM(cost) FROM tx")
val = cursor.fetchone()[0]
return CLVMCost(uint64(0) if val is None else uint64(val))
return CLVMCost(uint64(self._total_cost))

def all_items(self) -> Iterator[MempoolItem]:
with self._db_conn:
Expand Down Expand Up @@ -189,28 +186,28 @@ def get_min_fee_rate(self, cost: int) -> float:
Gets the minimum fpc rate that a transaction with specified cost will need in order to get included.
"""

if self.at_full_capacity(cost):
# TODO: make MempoolItem.cost be CLVMCost
current_cost = int(self.total_mempool_cost())

# Iterates through all spends in increasing fee per cost
with self._db_conn:
cursor = self._db_conn.execute("SELECT cost,fee_per_cost FROM tx ORDER BY fee_per_cost ASC, seq DESC")

item_cost: int
fee_per_cost: float
for item_cost, fee_per_cost in cursor:
current_cost -= item_cost
# Removing one at a time, until our transaction of size cost fits
if current_cost + cost <= self.mempool_info.max_size_in_cost:
return fee_per_cost

raise ValueError(
f"Transaction with cost {cost} does not fit in mempool of max cost {self.mempool_info.max_size_in_cost}"
)
else:
if not self.at_full_capacity(cost):
return 0

# TODO: make MempoolItem.cost be CLVMCost
current_cost = self._total_cost

# Iterates through all spends in increasing fee per cost
with self._db_conn:
cursor = self._db_conn.execute("SELECT cost,fee_per_cost FROM tx ORDER BY fee_per_cost ASC, seq DESC")

item_cost: int
fee_per_cost: float
for item_cost, fee_per_cost in cursor:
current_cost -= item_cost
# Removing one at a time, until our transaction of size cost fits
if current_cost + cost <= self.mempool_info.max_size_in_cost:
return fee_per_cost

raise ValueError(
f"Transaction with cost {cost} does not fit in mempool of max cost {self.mempool_info.max_size_in_cost}"
)

def new_tx_block(self, block_height: uint32, timestamp: uint64) -> None:
"""
Remove all items that became invalid because of this new height and
Expand Down Expand Up @@ -255,9 +252,19 @@ def remove_from_pool(self, items: List[bytes32], reason: MempoolRemoveReason) ->
for batch in to_batches(items, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(batch.entries))
with self._db_conn:
cursor = self._db_conn.execute(
f"SELECT SUM(cost), SUM(fee) FROM tx WHERE name in ({args})", batch.entries
)
cost_to_remove, fee_to_remove = cursor.fetchone()

self._db_conn.execute(f"DELETE FROM tx WHERE name in ({args})", batch.entries)
self._db_conn.execute(f"DELETE FROM spends WHERE tx in ({args})", batch.entries)

self._total_cost -= cost_to_remove
self._total_fee -= fee_to_remove
assert self._total_cost >= 0
assert self._total_fee >= 0

if reason != MempoolRemoveReason.BLOCK_INCLUSION:
info = FeeMempoolInfo(
self.mempool_info, self.total_mempool_cost(), self.total_mempool_fees(), datetime.now()
Expand Down Expand Up @@ -310,14 +317,15 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]:
if fee_per_cost > item.fee_per_cost:
return Err.INVALID_FEE_LOW_FEE
to_remove.append(name)

self.remove_from_pool(to_remove, MempoolRemoveReason.EXPIRED)

# if we don't find any entries, it's OK to add this entry

total_cost = int(self.total_mempool_cost())
if total_cost + item.cost > self.mempool_info.max_size_in_cost:
if self._total_cost + item.cost > self.mempool_info.max_size_in_cost:
# pick the items with the lowest fee per cost to remove
cursor = self._db_conn.execute(
"""SELECT name FROM tx
"""SELECT name, cost, fee FROM tx

This comment has been minimized.

Copy link
@arvidn

arvidn Dec 22, 2023

Author Contributor

this was left in by mistake

WHERE name NOT IN (
SELECT name FROM (
SELECT name,
Expand All @@ -328,6 +336,7 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]:
(self.mempool_info.max_size_in_cost - item.cost,),
)
to_remove = [bytes32(row[0]) for row in cursor]

self.remove_from_pool(to_remove, MempoolRemoveReason.POOL_FULL)

# TODO: In the future, for the "fee_per_cost" field, opt for
Expand All @@ -354,6 +363,9 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]:
item.spend_bundle, item.npc_result, item.height_added_to_mempool, item.bundle_coin_spends
)

self._total_cost += item.cost
self._total_fee += item.fee

info = FeeMempoolInfo(self.mempool_info, self.total_mempool_cost(), self.total_mempool_fees(), datetime.now())
self.fee_estimator.add_mempool_item(info, MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool))
return None
Expand All @@ -363,7 +375,7 @@ def at_full_capacity(self, cost: int) -> bool:
Checks whether the mempool is at full capacity and cannot accept a transaction with size cost.
"""

return self.total_mempool_cost() + cost > self.mempool_info.max_size_in_cost
return self._total_cost + cost > self.mempool_info.max_size_in_cost

def create_bundle_from_mempool_items(
self, item_inclusion_filter: Callable[[bytes32], bool]
Expand Down
6 changes: 6 additions & 0 deletions chia/full_node/mempool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,12 @@ async def new_peak(
spendbundle_ids_to_remove.add(item.name)
self.mempool.remove_from_pool(list(spendbundle_ids_to_remove), MempoolRemoveReason.BLOCK_INCLUSION)
else:
log.warning(
"updating the mempool using the slow-path. "
f"peak: {self.peak.header_hash} "
f"new-peak-prev: {new_peak.prev_transaction_block_hash} "
f"coins: {'not set' if spent_coins is None else 'set'}"
)
old_pool = self.mempool
self.mempool = Mempool(old_pool.mempool_info, old_pool.fee_estimator)
self.seen_bundle_hashes = {}
Expand Down
28 changes: 26 additions & 2 deletions tests/core/mempool/test_mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
spend_bundle_from_conditions,
)
from tests.core.node_height import node_height_at_least
from tests.util.misc import BenchmarkRunner
from tests.util.misc import BenchmarkRunner, invariant_check_mempool
from tests.util.time_out_assert import time_out_assert

BURN_PUZZLE_HASH = bytes32(b"0" * 32)
Expand Down Expand Up @@ -336,7 +336,9 @@ async def respond_transaction(
self.full_node.full_node_store.pending_tx_request.pop(spend_name)
if spend_name in self.full_node.full_node_store.peers_with_tx:
self.full_node.full_node_store.peers_with_tx.pop(spend_name)
return await self.full_node.add_transaction(tx.transaction, spend_name, peer, test)
ret = await self.full_node.add_transaction(tx.transaction, spend_name, peer, test)
invariant_check_mempool(self.full_node.mempool_manager.mempool)
return ret


async def next_block(full_node_1, wallet_a, bt) -> Coin:
Expand Down Expand Up @@ -580,6 +582,7 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
)
peer = await connect_and_get_peer(server_1, server_2, self_hostname)

invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)
for block in blocks:
await full_node_1.full_node.add_block(block)
await time_out_assert(60, node_height_at_least, True, full_node_1, start_height + 3)
Expand All @@ -595,12 +598,14 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
# Fee increase is insufficient, the old spendbundle must stay
self.assert_sb_in_pool(full_node_1, sb1_1)
self.assert_sb_not_in_pool(full_node_1, sb1_2)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb1_3 = await self.gen_and_send_sb(full_node_1, peer, wallet_a, coin1, fee=MEMPOOL_MIN_FEE_INCREASE)

# Fee increase is sufficiently high, sb1_1 gets replaced with sb1_3
self.assert_sb_not_in_pool(full_node_1, sb1_1)
self.assert_sb_in_pool(full_node_1, sb1_3)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb2 = generate_test_spend_bundle(wallet_a, coin2, fee=MEMPOOL_MIN_FEE_INCREASE)
sb12 = SpendBundle.aggregate((sb2, sb1_3))
Expand All @@ -610,6 +615,7 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
# of coins spent in sb1_3
self.assert_sb_in_pool(full_node_1, sb12)
self.assert_sb_not_in_pool(full_node_1, sb1_3)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb3 = generate_test_spend_bundle(wallet_a, coin3, fee=uint64(MEMPOOL_MIN_FEE_INCREASE * 2))
sb23 = SpendBundle.aggregate((sb2, sb3))
Expand All @@ -619,16 +625,19 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
# coins that are spent in the latter (specifically, coin1)
self.assert_sb_in_pool(full_node_1, sb12)
self.assert_sb_not_in_pool(full_node_1, sb23)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

await self.send_sb(full_node_1, sb3)
# Adding non-conflicting sb3 should succeed
self.assert_sb_in_pool(full_node_1, sb3)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb4_1 = generate_test_spend_bundle(wallet_a, coin4, fee=MEMPOOL_MIN_FEE_INCREASE)
sb1234_1 = SpendBundle.aggregate((sb12, sb3, sb4_1))
await self.send_sb(full_node_1, sb1234_1)
# sb1234_1 should not be in pool as it decreases total fees per cost
self.assert_sb_not_in_pool(full_node_1, sb1234_1)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb4_2 = generate_test_spend_bundle(wallet_a, coin4, fee=uint64(MEMPOOL_MIN_FEE_INCREASE * 2))
sb1234_2 = SpendBundle.aggregate((sb12, sb3, sb4_2))
Expand All @@ -638,6 +647,7 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
self.assert_sb_in_pool(full_node_1, sb1234_2)
self.assert_sb_not_in_pool(full_node_1, sb12)
self.assert_sb_not_in_pool(full_node_1, sb3)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

@pytest.mark.anyio
async def test_invalid_signature(self, one_node_one_block, wallet_a):
Expand Down Expand Up @@ -669,6 +679,7 @@ async def test_invalid_signature(self, one_node_one_block, wallet_a):
ack: TransactionAck = TransactionAck.from_bytes(res.data)
assert ack.status == MempoolInclusionStatus.FAILED.value
assert ack.error == Err.BAD_AGGREGATE_SIGNATURE.name
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

async def condition_tester(
self,
Expand Down Expand Up @@ -2764,13 +2775,16 @@ def test_full_mempool(items: List[int], add: int, expected: List[int]) -> None:
CLVMCost(uint64(100)),
)
mempool = Mempool(mempool_info, fee_estimator)
invariant_check_mempool(mempool)
fee_rate: float = 3.0
for i in items:
mempool.add_to_pool(item_cost(i, fee_rate))
fee_rate -= 0.1
invariant_check_mempool(mempool)

# now, add the item we're testing
mempool.add_to_pool(item_cost(add, 3.1))
invariant_check_mempool(mempool)

ordered_items = list(mempool.items_by_feerate())

Expand Down Expand Up @@ -2809,12 +2823,14 @@ def test_limit_expiring_transactions(height: bool, items: List[int], expected: L
)
mempool = Mempool(mempool_info, fee_estimator)
mempool.new_tx_block(uint32(10), uint64(100000))
invariant_check_mempool(mempool)

# fill the mempool with regular transactions (without expiration)
fee_rate: float = 3.0
for i in range(1, 20):
mempool.add_to_pool(item_cost(i, fee_rate))
fee_rate -= 0.1
invariant_check_mempool(mempool)

# now add the expiring transactions from the test case
fee_rate = 2.7
Expand All @@ -2826,6 +2842,7 @@ def test_limit_expiring_transactions(height: bool, items: List[int], expected: L
ret = mempool.add_to_pool(mk_item([coin], cost=cost, fee=int(cost * fee_rate), assert_before_height=15))
else:
ret = mempool.add_to_pool(mk_item([coin], cost=cost, fee=int(cost * fee_rate), assert_before_seconds=10400))
invariant_check_mempool(mempool)
if increase_fee:
fee_rate += 0.1
assert ret is None
Expand All @@ -2849,6 +2866,7 @@ def test_limit_expiring_transactions(height: bool, items: List[int], expected: L
print(f"- cost: {item.cost} TTL: {ttl}")

assert mempool.total_mempool_cost() > 90
invariant_check_mempool(mempool)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -2884,6 +2902,7 @@ def test_get_items_by_coin_ids(items: List[MempoolItem], coin_ids: List[bytes32]
mempool = Mempool(mempool_info, fee_estimator)
for i in items:
mempool.add_to_pool(i)
invariant_check_mempool(mempool)
result = mempool.get_items_by_coin_ids(coin_ids)
assert set(result) == set(expected)

Expand All @@ -2906,6 +2925,7 @@ def agg_and_add_sb_returning_cost_info(mempool: Mempool, spend_bundles: List[Spe
sb = SpendBundle.aggregate(spend_bundles)
mi = mempool_item_from_spendbundle(sb)
mempool.add_to_pool(mi)
invariant_check_mempool(mempool)
saved_cost = run_for_cost(
sb.coin_spends[0].puzzle_reveal, sb.coin_spends[0].solution, len(mi.additions), mi.cost
)
Expand All @@ -2926,9 +2946,11 @@ def agg_and_add_sb_returning_cost_info(mempool: Mempool, spend_bundles: List[Spe
highest_fee = 58282830
sb_high_rate = make_test_spendbundle(coins[1], fee=highest_fee)
agg_and_add_sb_returning_cost_info(mempool, [sb_A, sb_high_rate])
invariant_check_mempool(mempool)
# Create a ~2 FPC item that spends the eligible coin using the same solution A
sb_low_rate = make_test_spendbundle(coins[2], fee=highest_fee // 5)
saved_cost_on_solution_A = agg_and_add_sb_returning_cost_info(mempool, [sb_A, sb_low_rate])
invariant_check_mempool(mempool)
result = mempool.create_bundle_from_mempool_items(always)
assert result is not None
agg, _ = result
Expand All @@ -2942,6 +2964,7 @@ def agg_and_add_sb_returning_cost_info(mempool: Mempool, spend_bundles: List[Spe
# (which has ~10 FPC) but before sb_A2 (which has ~2 FPC)
sb_mid_rate = make_test_spendbundle(coins[i], fee=38004852 - i)
saved_cost_on_solution_B = agg_and_add_sb_returning_cost_info(mempool, [sb_B, sb_mid_rate])
invariant_check_mempool(mempool)
# We'd save more cost if we went with solution B instead of A
assert saved_cost_on_solution_B > saved_cost_on_solution_A
# If we process everything now, the 3 x ~3 FPC items get skipped because
Expand All @@ -2954,6 +2977,7 @@ def agg_and_add_sb_returning_cost_info(mempool: Mempool, spend_bundles: List[Spe
# We ran with solution A and missed bigger savings on solution B
assert mempool.size() == 5
assert [c.coin for c in agg.coin_spends] == [coins[0], coins[1], coins[2]]
invariant_check_mempool(mempool)


def test_get_puzzle_and_solution_for_coin_failure():
Expand Down
Loading

0 comments on commit 55c064a

Please sign in to comment.