diff --git a/packages/persistent-merkle-tree/src/hasher/as-sha256.ts b/packages/persistent-merkle-tree/src/hasher/as-sha256.ts index 560ae421..95405bd2 100644 --- a/packages/persistent-merkle-tree/src/hasher/as-sha256.ts +++ b/packages/persistent-merkle-tree/src/hasher/as-sha256.ts @@ -1,5 +1,6 @@ import {digest2Bytes32, digest64HashObjects, HashObject, batchHash4HashObjectInputs} from "@chainsafe/as-sha256"; import type {Hasher} from "./types"; +import {HashComputation} from "../node"; export const hasher: Hasher = { digest64: digest2Bytes32, @@ -15,8 +16,8 @@ export const hasher: Hasher = { const batch = Math.floor(inputs.length / 8); const outputs = new Array(); for (let i = 0; i < batch; i++) { - const [out0, out1, out2, out3] = batchHash4HashObjectInputs(inputs.slice(i * 8, i * 8 + 8)); - outputs.push(out0, out1, out2, out3); + const outs = batchHash4HashObjectInputs(inputs.slice(i * 8, i * 8 + 8)); + outputs.push(...outs); } for (let i = batch * 8; i < inputs.length; i += 2) { @@ -26,4 +27,43 @@ export const hasher: Hasher = { return outputs; }, + executeHashComputations: (hashComputations: Array) => { + for (let level = hashComputations.length - 1; level >= 0; level--) { + const hcArr = hashComputations[level]; + if (!hcArr) { + // should not happen + throw Error(`no hash computations for level ${level}`); + } + + // HashComputations of the same level are safe to batch + const batch = Math.floor(hcArr.length / 4); + for (let i = 0; i < batch; i++) { + const index = i * 4; + const outs = batchHash4HashObjectInputs([ + hcArr[index].src0, + hcArr[index].src1, + hcArr[index + 1].src0, + hcArr[index + 1].src1, + hcArr[index + 2].src0, + hcArr[index + 2].src1, + hcArr[index + 3].src0, + hcArr[index + 3].src1, + ]); + if (outs.length !== 4) { + throw Error(`batchHash4HashObjectInputs returned ${outs.length} outputs, expected 4`); + } + hcArr[index].dest.applyHash(outs[0]); + hcArr[index + 1].dest.applyHash(outs[1]); + hcArr[index + 2].dest.applyHash(outs[2]); + hcArr[index + 3].dest.applyHash(outs[3]); + } + + // remaining + for (let i = batch * 4; i < hcArr.length; i++) { + const {src0, src1, dest} = hcArr[i]; + const output = digest64HashObjects(src0, src1); + dest.applyHash(output); + } + } + }, }; diff --git a/packages/persistent-merkle-tree/src/hasher/noble.ts b/packages/persistent-merkle-tree/src/hasher/noble.ts index 3f3b082c..2af966a7 100644 --- a/packages/persistent-merkle-tree/src/hasher/noble.ts +++ b/packages/persistent-merkle-tree/src/hasher/noble.ts @@ -23,4 +23,17 @@ export const hasher: Hasher = { } return outputs; }, + executeHashComputations: (hashComputations) => { + for (let level = hashComputations.length - 1; level >= 0; level--) { + const hcArr = hashComputations[level]; + if (!hcArr) { + // should not happen + throw Error(`no hash computations for level ${level}`); + } + + for (const hc of hcArr) { + hc.dest.applyHash(digest64HashObjects(hc.src0, hc.src1)); + } + } + }, }; diff --git a/packages/persistent-merkle-tree/src/hasher/types.ts b/packages/persistent-merkle-tree/src/hasher/types.ts index 6d0d4219..2fdf136c 100644 --- a/packages/persistent-merkle-tree/src/hasher/types.ts +++ b/packages/persistent-merkle-tree/src/hasher/types.ts @@ -1,4 +1,5 @@ import type {HashObject} from "@chainsafe/as-sha256/lib/hashObject"; +import {HashComputation} from "../node"; export type Hasher = { /** @@ -13,4 +14,8 @@ export type Hasher = { * Batch hash 2 * n HashObjects, return n HashObjects output */ batchHashObjects(inputs: HashObject[]): HashObject[]; + /** + * Execute a batch of HashComputations + */ + executeHashComputations(hashComputations: Array): void; }; diff --git a/packages/persistent-merkle-tree/src/node.ts b/packages/persistent-merkle-tree/src/node.ts index 3b35a623..66a288d6 100644 --- a/packages/persistent-merkle-tree/src/node.ts +++ b/packages/persistent-merkle-tree/src/node.ts @@ -86,7 +86,7 @@ export class BranchNode extends Node { batchHash(): Uint8Array { const hashComputations: HashComputation[][] = []; getHashComputations(this, 0, hashComputations); - executeHashComputations(hashComputations); + hasher.executeHashComputations(hashComputations); if (this.h0 === null) { throw Error("Root is not computed by batch"); @@ -395,33 +395,6 @@ export function bitwiseOrNodeH(node: Node, hIndex: number, value: number): void else throw Error("hIndex > 7"); } -/** - * Given an array of HashComputation, execute them from the end - * The consumer has the root node so it should be able to get the final root from there - */ -export function executeHashComputations(hashComputations: Array): void { - for (let level = hashComputations.length - 1; level >= 0; level--) { - const hcArr = hashComputations[level]; - if (!hcArr) { - // should not happen - throw Error(`no hash computations for level ${level}`); - } - // HashComputations of the same level are safe to batch - const inputs: HashObject[] = []; - const dests: Node[] = []; - for (const {src0, src1, dest} of hcArr) { - inputs.push(src0, src1); - dests.push(dest); - } - const outputs = hasher.batchHashObjects(inputs); - if (outputs.length !== dests.length) { - throw Error(`${inputs.length} inputs produce ${outputs.length} outputs, expected ${dests.length} outputs`); - } - for (let i = 0; i < outputs.length; i++) { - dests[i].applyHash(outputs[i]); - } - } -} export function getHashComputations(node: Node, offset: number, hashCompsByLevel: Array): void { if (node.h0 === null) {