Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce zero-padding and cache zero-hashes in MerkleTree #2415

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 55 additions & 8 deletions console/collections/src/merkle_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
// Compute the empty hash.
let empty_hash = path_hasher.hash_empty()?;

// Calculate the size of the tree which excludes leafless nodes.
// The minimum tree size is either a single root node or the calculated number of nodes plus
// the supplied leaves; if the number of leaves is odd, an empty hash is added for padding.
let minimum_tree_size =
std::cmp::max(1, num_nodes + leaves.len() + if leaves.len() > 1 { leaves.len() % 2 } else { 0 });

// Initialize the Merkle tree.
let mut tree = vec![empty_hash; tree_size];
let mut tree = vec![empty_hash; minimum_tree_size];

// Compute and store each leaf hash.
tree[num_nodes..num_nodes + leaves.len()].copy_from_slice(&leaf_hasher.hash_leaves(leaves)?);
Expand All @@ -90,10 +96,22 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
while let Some(start) = parent(start_index) {
// Compute the end index of the current level.
let end = left_child(start);
// Construct the children for each node in the current level.
let tuples = (start..end).map(|i| (tree[left_child(i)], tree[right_child(i)])).collect::<Vec<_>>();
// Construct the children for each node in the current level; the leaves are padded, which means
// that there either are 2 children, or there are none, at which point we may stop iterating.
let tuples = (start..end)
.take_while(|&i| tree.get(left_child(i)).is_some())
.map(|i| (tree[left_child(i)], tree[right_child(i)]))
.collect::<Vec<_>>();
// Compute and store the hashes for each node in the current level.
tree[start..end].copy_from_slice(&path_hasher.hash_all_children(&tuples)?);
let num_full_nodes = tuples.len();
tree[start..][..num_full_nodes].copy_from_slice(&path_hasher.hash_all_children(&tuples)?);
// Use the precomputed empty node hash for every empty node, if there are any.
if start + num_full_nodes < end {
let empty_node_hash = path_hasher.hash_children(&empty_hash, &empty_hash)?;
for node in tree.iter_mut().take(end).skip(start + num_full_nodes) {
*node = empty_node_hash;
}
}
// Update the start index for the next level.
start_index = start;
}
Expand Down Expand Up @@ -144,8 +162,16 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
tree.extend(self.leaf_hashes()?);
// Extend the new Merkle tree with the new leaf hashes.
tree.extend(&self.leaf_hasher.hash_leaves(new_leaves)?);

// Calculate the size of the tree which excludes leafless nodes.
let new_number_of_leaves = self.number_of_leaves + new_leaves.len();
let minimum_tree_size = std::cmp::max(
1,
num_nodes + new_number_of_leaves + if new_number_of_leaves > 1 { new_number_of_leaves % 2 } else { 0 },
);

// Resize the new Merkle tree with empty hashes to pad up to `tree_size`.
tree.resize(tree_size, self.empty_hash);
tree.resize(minimum_tree_size, self.empty_hash);
lap!(timer, "Hashed {} new leaves", new_leaves.len());

// Initialize a start index to track the starting index of the current level.
Expand Down Expand Up @@ -453,12 +479,20 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
// Compute the number of padded levels.
let padding_depth = DEPTH - tree_depth;

// Calculate the size of the tree which excludes leafless nodes.
let minimum_tree_size = std::cmp::max(
1,
num_nodes
+ updated_number_of_leaves
+ if updated_number_of_leaves > 1 { updated_number_of_leaves % 2 } else { 0 },
);

// Initialize the Merkle tree.
let mut tree = vec![self.empty_hash; num_nodes];
// Extend the new Merkle tree with the existing leaf hashes, excluding the last 'n' leaves.
tree.extend(&self.leaf_hashes()?[..updated_number_of_leaves]);
// Resize the new Merkle tree with empty hashes to pad up to `tree_size`.
tree.resize(tree_size, self.empty_hash);
tree.resize(minimum_tree_size, self.empty_hash);
lap!(timer, "Resizing to {} leaves", updated_number_of_leaves);

// Initialize a start index to track the starting index of the current level.
Expand Down Expand Up @@ -627,6 +661,7 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
let timer = timer!("MerkleTree::compute_updated_tree");

// Compute and store the hashes for each level, iterating from the penultimate level to the root level.
let empty_hash = self.path_hasher.hash_empty()?;
while let (Some(start), Some(middle)) = (parent(start_index), parent(middle_index)) {
// Compute the end index of the current level.
let end = left_child(start);
Expand All @@ -651,7 +686,12 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
if let Some(middle_precompute) = parent(middle_precompute) {
// Construct the children for the new indices in the current level.
let tuples = (middle..middle_precompute)
.map(|i| (tree[left_child(i)], tree[right_child(i)]))
.map(|i| {
(
tree.get(left_child(i)).copied().unwrap_or(empty_hash),
tree.get(right_child(i)).copied().unwrap_or(empty_hash),
)
})
.collect::<Vec<_>>();
// Process the indices that need to be computed for the current level.
// If any level requires computing more than 100 nodes, borrow the tree for performance.
Expand Down Expand Up @@ -687,7 +727,14 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
}
} else {
// Construct the children for the new indices in the current level.
let tuples = (middle..end).map(|i| (tree[left_child(i)], tree[right_child(i)])).collect::<Vec<_>>();
let tuples = (middle..end)
.map(|i| {
(
tree.get(left_child(i)).copied().unwrap_or(empty_hash),
tree.get(right_child(i)).copied().unwrap_or(empty_hash),
)
})
.collect::<Vec<_>>();
// Process the indices that need to be computed for the current level.
// If any level requires computing more than 100 nodes, borrow the tree for performance.
match tuples.len() >= 100 {
Expand Down
14 changes: 4 additions & 10 deletions console/collections/src/merkle_tree/tests/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ fn check_merkle_tree_depth_3_padded<E: Environment, LH: LeafHash<Hash = PH::Hash

// Rebuild the Merkle tree with the additional leaf.
merkle_tree.append(additional_leaves)?;
assert_eq!(15, merkle_tree.tree.len());
assert_eq!(13, merkle_tree.tree.len());
// assert_eq!(0, merkle_tree.padding_tree.len());
assert_eq!(5, merkle_tree.number_of_leaves);

Expand All @@ -173,8 +173,6 @@ fn check_merkle_tree_depth_3_padded<E: Environment, LH: LeafHash<Hash = PH::Hash
assert_eq!(expected_leaf3, merkle_tree.tree[10]);
assert_eq!(expected_leaf4, merkle_tree.tree[11]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[12]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[13]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[14]);

// Depth 2.
let expected_left0 = PathHash::hash_children(path_hasher, &expected_leaf0, &expected_leaf1)?;
Expand Down Expand Up @@ -258,7 +256,7 @@ fn check_merkle_tree_depth_4_padded<E: Environment, LH: LeafHash<Hash = PH::Hash

// Rebuild the Merkle tree with the additional leaf.
merkle_tree.append(&[additional_leaves[0].clone()])?;
assert_eq!(15, merkle_tree.tree.len());
assert_eq!(13, merkle_tree.tree.len());
// assert_eq!(0, merkle_tree.padding_tree.len());
assert_eq!(5, merkle_tree.number_of_leaves);

Expand All @@ -274,8 +272,6 @@ fn check_merkle_tree_depth_4_padded<E: Environment, LH: LeafHash<Hash = PH::Hash
assert_eq!(expected_leaf3, merkle_tree.tree[10]);
assert_eq!(expected_leaf4, merkle_tree.tree[11]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[12]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[13]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[14]);

// Depth 3.
let expected_left0 = PathHash::hash_children(path_hasher, &expected_leaf0, &expected_leaf1)?;
Expand Down Expand Up @@ -308,13 +304,13 @@ fn check_merkle_tree_depth_4_padded<E: Environment, LH: LeafHash<Hash = PH::Hash
// ------------------------------------------------------------------------------------------ //

// Ensure we're starting where we left off from the previous rebuild.
assert_eq!(15, merkle_tree.tree.len());
assert_eq!(13, merkle_tree.tree.len());
// assert_eq!(0, merkle_tree.padding_tree.len());
assert_eq!(5, merkle_tree.number_of_leaves);

// Rebuild the Merkle tree with the additional leaf.
merkle_tree.append(&[additional_leaves[1].clone()])?;
assert_eq!(15, merkle_tree.tree.len());
assert_eq!(13, merkle_tree.tree.len());
// assert_eq!(0, merkle_tree.padding_tree.len());
assert_eq!(6, merkle_tree.number_of_leaves);

Expand All @@ -331,8 +327,6 @@ fn check_merkle_tree_depth_4_padded<E: Environment, LH: LeafHash<Hash = PH::Hash
assert_eq!(expected_leaf3, merkle_tree.tree[10]);
assert_eq!(expected_leaf4, merkle_tree.tree[11]);
assert_eq!(expected_leaf5, merkle_tree.tree[12]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[13]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[14]);

// Depth 3.
let expected_left0 = PathHash::hash_children(path_hasher, &expected_leaf0, &expected_leaf1)?;
Expand Down