Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom node hash in SimpleMerkleTree #39

Merged
merged 13 commits into from
Mar 4, 2024
28 changes: 13 additions & 15 deletions src/core.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import { keccak256 } from '@ethersproject/keccak256';
import { BytesLike, HexString, toHex, toBytes, concat, compare } from './bytes';
import { BytesLike, HexString, toHex, toBytes, compare } from './bytes';
import { NodeHash, standardNodeHash } from './hashes';
import { invariant, throwError, validateArgument } from './utils/errors';

const hashPair = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare)));

const leftChildIndex = (i: number) => 2 * i + 1;
const rightChildIndex = (i: number) => 2 * i + 2;
const parentIndex = (i: number) => (i > 0 ? Math.floor((i - 1) / 2) : throwError('Root has no parent'));
Expand All @@ -18,7 +16,7 @@ const checkLeafNode = (tree: unknown[], i: number) => void (isLeafNode(tree, i)
const checkValidMerkleNode = (node: BytesLike) =>
void (isValidMerkleNode(node) || throwError('Merkle tree nodes must be Uint8Array of length 32'));

export function makeMerkleTree(leaves: BytesLike[]): HexString[] {
export function makeMerkleTree(leaves: BytesLike[], nodeHash: NodeHash = standardNodeHash): HexString[] {
leaves.forEach(checkValidMerkleNode);

validateArgument(leaves.length !== 0, 'Expected non-zero number of leaves');
Expand All @@ -29,7 +27,7 @@ export function makeMerkleTree(leaves: BytesLike[]): HexString[] {
tree[tree.length - 1 - i] = toHex(leaf);
}
for (let i = tree.length - 1 - leaves.length; i >= 0; i--) {
tree[i] = hashPair(tree[leftChildIndex(i)]!, tree[rightChildIndex(i)]!);
tree[i] = nodeHash(tree[leftChildIndex(i)]!, tree[rightChildIndex(i)]!);
}

return tree;
Expand All @@ -46,11 +44,11 @@ export function getProof(tree: BytesLike[], index: number): HexString[] {
return proof;
}

export function processProof(leaf: BytesLike, proof: BytesLike[]): HexString {
export function processProof(leaf: BytesLike, proof: BytesLike[], nodeHash: NodeHash = standardNodeHash): HexString {
checkValidMerkleNode(leaf);
proof.forEach(checkValidMerkleNode);

return toHex(proof.reduce(hashPair, leaf));
return toHex(proof.reduce(nodeHash, leaf));
}

export interface MultiProof<T, L = T> {
Expand All @@ -68,7 +66,7 @@ export function getMultiProof(tree: BytesLike[], indices: number[]): MultiProof<
'Cannot prove duplicated index',
);

const stack = indices.concat(); // copy
const stack = Array.from(indices); // copy
const proof = [];
const proofFlags = [];

Expand Down Expand Up @@ -98,7 +96,7 @@ export function getMultiProof(tree: BytesLike[], indices: number[]): MultiProof<
};
}

export function processMultiProof(multiproof: MultiProof<BytesLike>): HexString {
export function processMultiProof(multiproof: MultiProof<BytesLike>, nodeHash: NodeHash = standardNodeHash): HexString {
multiproof.leaves.forEach(checkValidMerkleNode);
multiproof.proof.forEach(checkValidMerkleNode);

Expand All @@ -111,22 +109,22 @@ export function processMultiProof(multiproof: MultiProof<BytesLike>): HexString
'Provided leaves and multiproof are not compatible',
);

const stack = multiproof.leaves.concat(); // copy
const proof = multiproof.proof.concat(); // copy
const stack = Array.from(multiproof.leaves); // copy
const proof = Array.from(multiproof.proof); // copy

for (const flag of multiproof.proofFlags) {
const a = stack.shift();
const b = flag ? stack.shift() : proof.shift();
invariant(a !== undefined && b !== undefined);
stack.push(hashPair(a, b));
stack.push(nodeHash(a, b));
}

invariant(stack.length + proof.length === 1);

return toHex(stack.pop() ?? proof.shift()!);
}

export function isValidMerkleTree(tree: BytesLike[]): boolean {
export function isValidMerkleTree(tree: BytesLike[], nodeHash: NodeHash = standardNodeHash): boolean {
for (const [i, node] of tree.entries()) {
if (!isValidMerkleNode(node)) {
return false;
Expand All @@ -139,7 +137,7 @@ export function isValidMerkleTree(tree: BytesLike[]): boolean {
if (l < tree.length) {
return false;
}
} else if (compare(node, hashPair(tree[l]!, tree[r]!))) {
} else if (compare(node, nodeHash(tree[l]!, tree[r]!))) {
return false;
}
}
Expand Down
14 changes: 14 additions & 0 deletions src/hashes.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { defaultAbiCoder } from '@ethersproject/abi';
import { keccak256 } from '@ethersproject/keccak256';
import { BytesLike, HexString, concat, compare } from './bytes';

export type LeafHash<T> = (leaf: T) => HexString;
export type NodeHash = (left: BytesLike, right: BytesLike) => HexString;

export function standardLeafHash<T extends any[]>(types: string[], value: T): HexString {
return keccak256(keccak256(defaultAbiCoder.encode(types, value)));
}

export function standardNodeHash(a: BytesLike, b: BytesLike): HexString {
return keccak256(concat([a, b].sort(compare)));
}
18 changes: 12 additions & 6 deletions src/merkletree.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
} from './core';

import { MerkleTreeOptions, defaultOptions } from './options';
import { LeafHash, NodeHash } from './hashes';
import { validateArgument, invariant } from './utils/errors';

export interface MerkleTreeData<T> {
Expand Down Expand Up @@ -40,7 +41,8 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
protected constructor(
protected readonly tree: HexString[],
protected readonly values: MerkleTreeData<T>['values'],
public readonly leafHash: MerkleTree<T>['leafHash'],
public readonly leafHash: LeafHash<T>,
protected readonly nodeHash?: NodeHash,
Amxx marked this conversation as resolved.
Show resolved Hide resolved
) {
validateArgument(
values.every(({ value }) => typeof value != 'number'),
Expand All @@ -52,7 +54,8 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
protected static prepare<T>(
values: T[],
options: MerkleTreeOptions = {},
leafHash: MerkleTree<T>['leafHash'],
leafHash: LeafHash<T>,
nodeHash?: NodeHash,
): [tree: HexString[], indexedValues: MerkleTreeData<T>['values']] {
const sortLeaves = options.sortLeaves ?? defaultOptions.sortLeaves;
const hashedValues = values.map((value, valueIndex) => ({
Expand All @@ -65,7 +68,10 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
hashedValues.sort((a, b) => compare(a.hash, b.hash));
}

const tree = makeMerkleTree(hashedValues.map(v => v.hash));
const tree = makeMerkleTree(
hashedValues.map(v => v.hash),
nodeHash,
);

const indexedValues = values.map(value => ({ value, treeIndex: 0 }));
for (const [leafIndex, { valueIndex }] of hashedValues.entries()) {
Expand Down Expand Up @@ -93,7 +99,7 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {

validate(): void {
this.values.forEach((_, i) => this._validateValueAt(i));
invariant(isValidMerkleTree(this.tree), 'Merkle tree is invalid');
invariant(isValidMerkleTree(this.tree, this.nodeHash), 'Merkle tree is invalid');
}

leafLookup(leaf: T): number {
Expand Down Expand Up @@ -171,10 +177,10 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
}

private _verify(leafHash: BytesLike, proof: BytesLike[]): boolean {
return this.root === processProof(leafHash, proof);
return this.root === processProof(leafHash, proof, this.nodeHash);
}

private _verifyMultiProof(multiproof: MultiProof<BytesLike>): boolean {
return this.root === processMultiProof(multiproof);
return this.root === processMultiProof(multiproof, this.nodeHash);
}
}
137 changes: 91 additions & 46 deletions src/simple.test.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
import { test, testProp, fc } from '@fast-check/ava';
import { HashZero as zero } from '@ethersproject/constants';
import { keccak256 } from '@ethersproject/keccak256';
import { SimpleMerkleTree } from './simple';
import { BytesLike, HexString, concat, compare } from './bytes';

const reverseNodeHash = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare).reverse()));
const otherNodeHash = (a: BytesLike, b: BytesLike): HexString => keccak256(reverseNodeHash(a, b)); // double hash

import { toHex } from './bytes';
import { InvalidArgumentError, InvariantError } from './utils/errors';

const leaf = fc.uint8Array({ minLength: 32, maxLength: 32 }).map(toHex);
const leaves = fc.array(leaf, { minLength: 1 });
const options = fc.record({ sortLeaves: fc.oneof(fc.constant(undefined), fc.boolean()) });
const options = fc.record({
sortLeaves: fc.oneof(fc.constant(undefined), fc.boolean()),
nodeHash: fc.oneof(fc.constant(undefined), fc.constant(reverseNodeHash)),
});

const tree = fc.tuple(leaves, options).map(([leaves, options]) => SimpleMerkleTree.of(leaves, options));
const tree = fc
.tuple(leaves, options)
.chain(([leaves, options]) => fc.tuple(fc.constant(SimpleMerkleTree.of(leaves, options)), fc.constant(options)));
const treeAndLeaf = fc.tuple(leaves, options).chain(([leaves, options]) =>
fc.tuple(
fc.constant(SimpleMerkleTree.of(leaves, options)),
fc.constant(options),
fc.nat({ max: leaves.length - 1 }).map(index => ({ value: leaves[index]!, index })),
),
);
const treeAndLeaves = fc.tuple(leaves, options).chain(([leaves, options]) =>
fc.tuple(
fc.constant(SimpleMerkleTree.of(leaves, options)),
fc.constant(options),
fc
.uniqueArray(fc.nat({ max: leaves.length - 1 }))
.map(indices => indices.map(index => ({ value: leaves[index]!, index }))),
Expand All @@ -26,48 +39,64 @@ const treeAndLeaves = fc.tuple(leaves, options).chain(([leaves, options]) =>

fc.configureGlobal({ numRuns: process.env.CI ? 10000 : 100 });

testProp('generates a valid tree', [tree], (t, tree) => {
testProp('generates a valid tree', [tree], (t, [tree]) => {
t.notThrows(() => tree.validate());
});

testProp('generates valid single proofs for all leaves', [treeAndLeaf], (t, [tree, { value: leaf, index }]) => {
const proof1 = tree.getProof(index);
const proof2 = tree.getProof(leaf);

t.deepEqual(proof1, proof2);
t.true(tree.verify(index, proof1));
t.true(tree.verify(leaf, proof1));
t.true(SimpleMerkleTree.verify(tree.root, leaf, proof1));
});
testProp(
'generates valid single proofs for all leaves',
[treeAndLeaf],
(t, [tree, options, { value: leaf, index }]) => {
const proof1 = tree.getProof(index);
const proof2 = tree.getProof(leaf);

t.deepEqual(proof1, proof2);
t.true(tree.verify(index, proof1));
t.true(tree.verify(leaf, proof1));
t.true(SimpleMerkleTree.verify(tree.root, leaf, proof1, options.nodeHash));
},
);

testProp('rejects invalid proofs', [treeAndLeaf, tree], (t, [tree, { value: leaf }], otherTree) => {
const proof = tree.getProof(leaf);
t.false(otherTree.verify(leaf, proof));
t.false(SimpleMerkleTree.verify(otherTree.root, leaf, proof));
});
testProp(
'rejects invalid proofs',
[treeAndLeaf, tree],
(t, [tree, options, { value: leaf }], [otherTree, otherOptions]) => {
const proof = tree.getProof(leaf);
t.false(otherTree.verify(leaf, proof));
t.false(SimpleMerkleTree.verify(otherTree.root, leaf, proof, options.nodeHash));
t.false(SimpleMerkleTree.verify(otherTree.root, leaf, proof, otherOptions.nodeHash));
},
);

testProp('generates valid multiproofs', [treeAndLeaves], (t, [tree, indices]) => {
testProp('generates valid multiproofs', [treeAndLeaves], (t, [tree, options, indices]) => {
const proof1 = tree.getMultiProof(indices.map(e => e.index));
const proof2 = tree.getMultiProof(indices.map(e => e.value));

t.deepEqual(proof1, proof2);
t.true(tree.verifyMultiProof(proof1));
t.true(SimpleMerkleTree.verifyMultiProof(tree.root, proof1));
t.true(SimpleMerkleTree.verifyMultiProof(tree.root, proof1, options.nodeHash));
});

testProp('rejects invalid multiproofs', [treeAndLeaves, tree], (t, [tree, indices], otherTree) => {
const multiProof = tree.getMultiProof(indices.map(e => e.index));

t.false(otherTree.verifyMultiProof(multiProof));
t.false(SimpleMerkleTree.verifyMultiProof(otherTree.root, multiProof));
});
testProp(
'rejects invalid multiproofs',
[treeAndLeaves, tree],
(t, [tree, options, indices], [otherTree, otherOptions]) => {
const multiProof = tree.getMultiProof(indices.map(e => e.index));

t.false(otherTree.verifyMultiProof(multiProof));
t.false(SimpleMerkleTree.verifyMultiProof(otherTree.root, multiProof, options.nodeHash));
t.false(SimpleMerkleTree.verifyMultiProof(otherTree.root, multiProof, otherOptions.nodeHash));
},
);

testProp(
'renders tree representation',
[leaves],
(t, leaves) => {
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true }).render());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false }).render());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseNodeHash }).render());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseNodeHash }).render());
},
{ numRuns: 1, seed: 0 },
);
Expand All @@ -78,24 +107,34 @@ testProp(
(t, leaves) => {
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true }).dump());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false }).dump());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseNodeHash }).dump());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseNodeHash }).dump());
},
{ numRuns: 1, seed: 0 },
);

testProp('dump and load', [tree], (t, tree) => {
const recoveredTree = SimpleMerkleTree.load(tree.dump());
recoveredTree.validate();
testProp('dump and load', [tree], (t, [tree, options]) => {
const dump = tree.dump();
const recoveredTree = SimpleMerkleTree.load(dump, options.nodeHash);
recoveredTree.validate(); // already done in load

t.is(dump.hash, options.nodeHash ? 'custom' : undefined);
t.is(tree.root, recoveredTree.root);
t.is(tree.render(), recoveredTree.render());
t.deepEqual(tree.entries(), recoveredTree.entries());
t.deepEqual(tree.dump(), recoveredTree.dump());
});

testProp('reject out of bounds value index', [tree], (t, tree) => {
testProp('reject out of bounds value index', [tree], (t, [tree]) => {
t.throws(() => tree.getProof(-1), new InvalidArgumentError('Index out of bounds'));
});

// We need at least 2 leaves for internal node hashing to come into play
testProp('reject loading dump with wrong node hash', [fc.array(leaf, { minLength: 2 })], (t, leaves) => {
const dump = SimpleMerkleTree.of(leaves, { nodeHash: reverseNodeHash }).dump();
t.throws(() => SimpleMerkleTree.load(dump, otherNodeHash), new InvariantError('Merkle tree is invalid'));
});

test('reject invalid leaf size', t => {
const invalidLeaf = '0x000000000000000000000000000000000000000000000000000000000000000000';
t.throws(() => SimpleMerkleTree.of([invalidLeaf]), {
Expand All @@ -116,22 +155,28 @@ test('reject unrecognized tree dump', t => {
});

test('reject malformed tree dump', t => {
const loadedTree1 = SimpleMerkleTree.load({
format: 'simple-v1',
tree: [zero],
values: [
{
value: '0x0000000000000000000000000000000000000000000000000000000000000001',
treeIndex: 0,
},
],
});
t.throws(() => loadedTree1.getProof(0), new InvariantError('Merkle tree does not contain the expected value'));
t.throws(
() =>
SimpleMerkleTree.load({
format: 'simple-v1',
tree: [zero],
values: [
{
value: '0x0000000000000000000000000000000000000000000000000000000000000001',
treeIndex: 0,
},
],
}),
new InvariantError('Merkle tree does not contain the expected value'),
);

const loadedTree2 = SimpleMerkleTree.load({
format: 'simple-v1',
tree: [zero, zero, zero],
values: [{ value: zero, treeIndex: 2 }],
});
t.throws(() => loadedTree2.getProof(0), new InvariantError('Unable to prove value'));
t.throws(
() =>
SimpleMerkleTree.load({
format: 'simple-v1',
tree: [zero, zero, zero],
values: [{ value: zero, treeIndex: 2 }],
}),
new InvariantError('Merkle tree is invalid'),
);
});
Loading
Loading