Skip to content

Commit b7e8fe7

Browse files
authored
[CHIA-387] DL batch upsert optimization. (#17999)
* DL batch upsert optimization. * Lint * Lint * Fix test. * Convert delete/insert to upserts. * Update data_store.py * Improve coverage. * Whitespace. * Change test to use upsert too. * Clarify test usage.
1 parent d868d9c commit b7e8fe7

File tree

2 files changed

+113
-18
lines changed

2 files changed

+113
-18
lines changed

chia/_tests/core/data_layer/test_data_store.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pytest
1818

1919
from chia._tests.core.data_layer.util import Example, add_0123_example, add_01234567_example
20-
from chia._tests.util.misc import BenchmarkRunner, Marks, datacases
20+
from chia._tests.util.misc import BenchmarkRunner, Marks, boolean_datacases, datacases
2121
from chia.data_layer.data_layer_errors import KeyNotFoundError, NodeHashError, TreeGenerationIncrementingError
2222
from chia.data_layer.data_layer_util import (
2323
DiffData,
@@ -1991,20 +1991,41 @@ async def test_insert_key_already_present(data_store: DataStore, store_id: bytes
19911991

19921992

19931993
@pytest.mark.anyio
1994-
async def test_update_keys(data_store: DataStore, store_id: bytes32) -> None:
1994+
@boolean_datacases(name="use_batch_autoinsert", false="not optimized batch insert", true="optimized batch insert")
1995+
async def test_batch_insert_key_already_present(
1996+
data_store: DataStore,
1997+
store_id: bytes32,
1998+
use_batch_autoinsert: bool,
1999+
) -> None:
2000+
key = b"foo"
2001+
value = b"bar"
2002+
changelist = [{"action": "insert", "key": key, "value": value}]
2003+
await data_store.insert_batch(store_id, changelist, Status.COMMITTED, use_batch_autoinsert)
2004+
with pytest.raises(Exception, match=f"Key already present: {key.hex()}"):
2005+
await data_store.insert_batch(store_id, changelist, Status.COMMITTED, use_batch_autoinsert)
2006+
2007+
2008+
@pytest.mark.anyio
2009+
@boolean_datacases(name="use_upsert", false="update with delete and insert", true="update with upsert")
2010+
async def test_update_keys(data_store: DataStore, store_id: bytes32, use_upsert: bool) -> None:
19952011
num_keys = 10
19962012
missing_keys = 50
19972013
num_values = 10
19982014
new_keys = 10
19992015
for value in range(num_values):
20002016
changelist: List[Dict[str, Any]] = []
20012017
bytes_value = value.to_bytes(4, byteorder="big")
2002-
for key in range(num_keys + missing_keys):
2003-
bytes_key = key.to_bytes(4, byteorder="big")
2004-
changelist.append({"action": "delete", "key": bytes_key})
2005-
for key in range(num_keys):
2006-
bytes_key = key.to_bytes(4, byteorder="big")
2007-
changelist.append({"action": "insert", "key": bytes_key, "value": bytes_value})
2018+
if use_upsert:
2019+
for key in range(num_keys):
2020+
bytes_key = key.to_bytes(4, byteorder="big")
2021+
changelist.append({"action": "upsert", "key": bytes_key, "value": bytes_value})
2022+
else:
2023+
for key in range(num_keys + missing_keys):
2024+
bytes_key = key.to_bytes(4, byteorder="big")
2025+
changelist.append({"action": "delete", "key": bytes_key})
2026+
for key in range(num_keys):
2027+
bytes_key = key.to_bytes(4, byteorder="big")
2028+
changelist.append({"action": "insert", "key": bytes_key, "value": bytes_value})
20082029

20092030
await data_store.insert_batch(
20102031
store_id=store_id,

chia/data_layer/data_store.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,9 @@ async def clean_node_table(self, writer: Optional[aiosqlite.Connection] = None)
13691369
else:
13701370
await writer.execute(query, params)
13711371

1372-
async def get_leaf_at_minimum_height(self, root_hash: bytes32) -> TerminalNode:
1372+
async def get_leaf_at_minimum_height(
1373+
self, root_hash: bytes32, hash_to_parent: Dict[bytes32, InternalNode]
1374+
) -> TerminalNode:
13731375
root_node = await self.get_node(root_hash)
13741376
queue: List[Node] = [root_node]
13751377
while True:
@@ -1378,11 +1380,29 @@ async def get_leaf_at_minimum_height(self, root_hash: bytes32) -> TerminalNode:
13781380
if isinstance(node, InternalNode):
13791381
left_node = await self.get_node(node.left_hash)
13801382
right_node = await self.get_node(node.right_hash)
1383+
hash_to_parent[left_node.hash] = node
1384+
hash_to_parent[right_node.hash] = node
13811385
queue.append(left_node)
13821386
queue.append(right_node)
13831387
elif isinstance(node, TerminalNode):
13841388
return node
13851389

1390+
async def batch_upsert(
1391+
self,
1392+
tree_id: bytes32,
1393+
hash: bytes32,
1394+
to_update_hashes: Set[bytes32],
1395+
pending_upsert_new_hashes: Dict[bytes32, bytes32],
1396+
) -> bytes32:
1397+
if hash not in to_update_hashes:
1398+
return hash
1399+
node = await self.get_node(hash)
1400+
if isinstance(node, TerminalNode):
1401+
return pending_upsert_new_hashes[hash]
1402+
new_left_hash = await self.batch_upsert(tree_id, node.left_hash, to_update_hashes, pending_upsert_new_hashes)
1403+
new_right_hash = await self.batch_upsert(tree_id, node.right_hash, to_update_hashes, pending_upsert_new_hashes)
1404+
return await self._insert_internal_node(new_left_hash, new_right_hash)
1405+
13861406
async def insert_batch(
13871407
self,
13881408
store_id: bytes32,
@@ -1410,14 +1430,19 @@ async def insert_batch(
14101430

14111431
key_hash_frequency: Dict[bytes32, int] = {}
14121432
first_action: Dict[bytes32, str] = {}
1433+
last_action: Dict[bytes32, str] = {}
1434+
14131435
for change in changelist:
14141436
key = change["key"]
14151437
hash = key_hash(key)
14161438
key_hash_frequency[hash] = key_hash_frequency.get(hash, 0) + 1
14171439
if hash not in first_action:
14181440
first_action[hash] = change["action"]
1441+
last_action[hash] = change["action"]
14191442

14201443
pending_autoinsert_hashes: List[bytes32] = []
1444+
pending_upsert_new_hashes: Dict[bytes32, bytes32] = {}
1445+
14211446
for change in changelist:
14221447
if change["action"] == "insert":
14231448
key = change["key"]
@@ -1435,8 +1460,16 @@ async def insert_batch(
14351460
if key_hash_frequency[hash] == 1 or (
14361461
key_hash_frequency[hash] == 2 and first_action[hash] == "delete"
14371462
):
1463+
old_node = await self.maybe_get_node_by_key(key, store_id)
14381464
terminal_node_hash = await self._insert_terminal_node(key, value)
1439-
pending_autoinsert_hashes.append(terminal_node_hash)
1465+
1466+
if old_node is None:
1467+
pending_autoinsert_hashes.append(terminal_node_hash)
1468+
else:
1469+
if key_hash_frequency[hash] == 1:
1470+
raise Exception(f"Key already present: {key.hex()}")
1471+
else:
1472+
pending_upsert_new_hashes[old_node.hash] = terminal_node_hash
14401473
continue
14411474
insert_result = await self.autoinsert(
14421475
key, value, store_id, True, Status.COMMITTED, root=latest_local_root
@@ -1458,17 +1491,50 @@ async def insert_batch(
14581491
latest_local_root = insert_result.root
14591492
elif change["action"] == "delete":
14601493
key = change["key"]
1494+
hash = key_hash(key)
1495+
if key_hash_frequency[hash] == 2 and last_action[hash] == "insert" and enable_batch_autoinsert:
1496+
continue
14611497
latest_local_root = await self.delete(key, store_id, True, Status.COMMITTED, root=latest_local_root)
14621498
elif change["action"] == "upsert":
14631499
key = change["key"]
14641500
new_value = change["value"]
1501+
hash = key_hash(key)
1502+
if key_hash_frequency[hash] == 1 and enable_batch_autoinsert:
1503+
terminal_node_hash = await self._insert_terminal_node(key, new_value)
1504+
old_node = await self.maybe_get_node_by_key(key, store_id)
1505+
if old_node is not None:
1506+
pending_upsert_new_hashes[old_node.hash] = terminal_node_hash
1507+
else:
1508+
pending_autoinsert_hashes.append(terminal_node_hash)
1509+
continue
14651510
insert_result = await self.upsert(
14661511
key, new_value, store_id, True, Status.COMMITTED, root=latest_local_root
14671512
)
14681513
latest_local_root = insert_result.root
14691514
else:
14701515
raise Exception(f"Operation in batch is not insert or delete: {change}")
14711516

1517+
if len(pending_upsert_new_hashes) > 0:
1518+
to_update_hashes: Set[bytes32] = set()
1519+
for hash in pending_upsert_new_hashes.keys():
1520+
while True:
1521+
if hash in to_update_hashes:
1522+
break
1523+
to_update_hashes.add(hash)
1524+
node = await self._get_one_ancestor(hash, store_id)
1525+
if node is None:
1526+
break
1527+
hash = node.hash
1528+
assert latest_local_root is not None
1529+
assert latest_local_root.node_hash is not None
1530+
new_root_hash = await self.batch_upsert(
1531+
store_id,
1532+
latest_local_root.node_hash,
1533+
to_update_hashes,
1534+
pending_upsert_new_hashes,
1535+
)
1536+
latest_local_root = await self._insert_root(store_id, new_root_hash, Status.COMMITTED)
1537+
14721538
# Start with the leaf nodes and pair them to form new nodes at the next level up, repeating this process
14731539
# in a bottom-up fashion until a single root node remains. This constructs a balanced tree from the leaves.
14741540
while len(pending_autoinsert_hashes) > 1:
@@ -1488,14 +1554,15 @@ async def insert_batch(
14881554
if latest_local_root is None or latest_local_root.node_hash is None:
14891555
await self._insert_root(store_id=store_id, node_hash=subtree_hash, status=Status.COMMITTED)
14901556
else:
1491-
min_height_leaf = await self.get_leaf_at_minimum_height(latest_local_root.node_hash)
1492-
ancestors = await self.get_ancestors_common(
1493-
node_hash=min_height_leaf.hash,
1494-
store_id=store_id,
1495-
root_hash=latest_local_root.node_hash,
1496-
generation=latest_local_root.generation,
1497-
use_optimized=True,
1498-
)
1557+
hash_to_parent: Dict[bytes32, InternalNode] = {}
1558+
min_height_leaf = await self.get_leaf_at_minimum_height(latest_local_root.node_hash, hash_to_parent)
1559+
ancestors: List[InternalNode] = []
1560+
hash = min_height_leaf.hash
1561+
while hash in hash_to_parent:
1562+
node = hash_to_parent[hash]
1563+
ancestors.append(node)
1564+
hash = node.hash
1565+
14991566
await self.update_ancestor_hashes_on_insert(
15001567
store_id=store_id,
15011568
left=min_height_leaf.hash,
@@ -1631,6 +1698,13 @@ async def get_node_by_key_latest_generation(self, key: bytes, store_id: bytes32)
16311698
assert isinstance(node, TerminalNode)
16321699
return node
16331700

1701+
async def maybe_get_node_by_key(self, key: bytes, tree_id: bytes32) -> Optional[TerminalNode]:
1702+
try:
1703+
node = await self.get_node_by_key_latest_generation(key, tree_id)
1704+
return node
1705+
except KeyNotFoundError:
1706+
return None
1707+
16341708
async def get_node_by_key(
16351709
self,
16361710
key: bytes,

0 commit comments

Comments
 (0)