Skip to content

Commit

Permalink
feat: compute HashComputations when creating validator tree
Browse files Browse the repository at this point in the history
  • Loading branch information
twoeths committed Jun 12, 2024
1 parent 36fa2e8 commit 321a0e5
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 70 deletions.
74 changes: 64 additions & 10 deletions packages/persistent-merkle-tree/src/subtree.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {BranchNode, Node} from "./node";
import {BranchNode, HashComputationGroup, Node, arrayAtIndex, getHashComputations} from "./node";
import {zeroNode} from "./zeroNode";

export function subtreeFillToDepth(bottom: Node, depth: number): Node {
Expand Down Expand Up @@ -37,9 +37,16 @@ export function subtreeFillToLength(bottom: Node, depth: number, length: number)

/**
* WARNING: Mutates the provided nodes array.
* @param hashCompRootNode is a hacky way from ssz to set `dest` of HashComputation for BranchNodeStruct
* TODO: Don't mutate the nodes array.
* TODO - batch: check consumers of this function, can we compute HashComputationGroup when deserializing ViewDU from Uint8Array?
*/
export function subtreeFillToContents(nodes: Node[], depth: number): Node {
export function subtreeFillToContents(
nodes: Node[],
depth: number,
hashComps: HashComputationGroup | null = null,
hashCompRootNode: Node | null = null
): Node {
const maxLength = 2 ** depth;
if (nodes.length > maxLength) {
throw new Error(`nodes.length ${nodes.length} over maxIndex at depth ${depth}`);
Expand All @@ -50,30 +57,77 @@ export function subtreeFillToContents(nodes: Node[], depth: number): Node {
}

if (depth === 0) {
return nodes[0];
const node = nodes[0];
if (hashComps !== null) {
// only use hashCompRootNode for >=1 nodes where we have a rebind
getHashComputations(node, hashComps.offset, hashComps.byLevel);
}
return node;
}

if (depth === 1) {
return nodes.length > 1
? // All nodes at depth 1 available
new BranchNode(nodes[0], nodes[1])
: // Pad with zero node
new BranchNode(nodes[0], zeroNode(0));
// All nodes at depth 1 available
// If there is only one node, pad with zero node
const leftNode = nodes[0];
const rightNode = nodes.length > 1 ? nodes[1] : zeroNode(0);
const rootNode = new BranchNode(leftNode, rightNode);

if (hashComps !== null) {
const offset = hashComps.offset;
getHashComputations(leftNode, offset + 1, hashComps.byLevel);
getHashComputations(rightNode, offset + 1, hashComps.byLevel);
arrayAtIndex(hashComps.byLevel, offset).push({
src0: leftNode,
src1: rightNode,
dest: hashCompRootNode ?? rootNode,
});
}

return rootNode;
}

let count = nodes.length;

for (let d = depth; d > 0; d--) {
const countRemainder = count % 2;
const countEven = count - countRemainder;
const offset = hashComps ? hashComps.offset + d - 1 : null;

// For each depth level compute the new BranchNodes and overwrite the nodes array
for (let i = 0; i < countEven; i += 2) {
nodes[i / 2] = new BranchNode(nodes[i], nodes[i + 1]);
const left = nodes[i];
const right = nodes[i + 1];
const node = new BranchNode(left, right);
nodes[i / 2] = node;
if (offset !== null && hashComps !== null) {
arrayAtIndex(hashComps.byLevel, offset).push({
src0: left,
src1: right,
// d = 1 means we are at root node, use hashCompRootNode if possible
dest: d === 1 ? hashCompRootNode ?? node : node,
});
if (d === depth) {
// bottom up strategy so we don't need to go down the tree except for the last level
getHashComputations(left, offset + 1, hashComps.byLevel);
getHashComputations(right, offset + 1, hashComps.byLevel);
}
}
}

if (countRemainder > 0) {
nodes[countEven / 2] = new BranchNode(nodes[countEven], zeroNode(depth - d));
const left = nodes[countEven];
const right = zeroNode(depth - d);
const node = new BranchNode(left, right);
nodes[countEven / 2] = node;
if (offset !== null && hashComps !== null) {
if (d === depth) {
// only go down on the last level
getHashComputations(left, offset + 1, hashComps.byLevel);
}
// no need to getHashComputations for zero node
// no need to set hashCompRootNode here
arrayAtIndex(hashComps.byLevel, offset).push({src0: left, src1: right, dest: node});
}
}

// If there was remainer, 2 nodes are added to the count
Expand Down
65 changes: 62 additions & 3 deletions packages/persistent-merkle-tree/test/unit/subtree.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import {subtreeFillToContents, LeafNode, getNodesAtDepth} from "../../src";
import { expect } from "chai";
import {subtreeFillToContents, LeafNode, getNodesAtDepth, executeHashComputations, BranchNode, Node} from "../../src";

describe("subtreeFillToContents", function () {
// the hash computation takes time
this.timeout(5000);

describe("subtreeFillToContents", () => {
it("Simple case", () => {
function nodeNum(num: number): LeafNode {
return LeafNode.fromUint32(num);
Expand Down Expand Up @@ -35,7 +39,12 @@ describe("subtreeFillToContents", () => {
expectedNodes[i] = node;
}

const node = subtreeFillToContents(nodes, depth);
const hashComps = {
offset: 0,
byLevel: [],
};

const node = subtreeFillToContents(nodes, depth, hashComps);
const retrievedNodes = getNodesAtDepth(node, depth, 0, count);

// Assert correct
Expand All @@ -44,7 +53,57 @@ describe("subtreeFillToContents", () => {
throw Error(`Wrong node at index ${i}`);
}
}
executeHashComputations(hashComps.byLevel);
if (node.h0 === null) {
throw Error("Root node h0 is null");
}
});
}
}
});

describe("subtreeFillToContents - validator nodes", function () {
/**
* 0 root
* / \
* 1 10 11
* / \ / \
* 2 20 21 22 23
* / \ / \ / \ / \
* 3 pub with eff sla act act exit with
* / \
* 4 pub0 pub1
**/
it("should compute HashComputations for validator nodes", () => {
const numNodes = 8;
const nodesArr: Array<Node[]> = [];
for (let count = 0; count < 2; count++) {
const nodes = new Array<Node>(numNodes);
for (let i = 1; i < numNodes; i++) {
const node = LeafNode.fromUint32(i);
nodes[i] = node;
}
nodes[0] = new BranchNode(LeafNode.fromUint32(0), LeafNode.fromUint32(1));
nodesArr.push(nodes);
}

// maxChunksToDepth in ssz returns 3 for 8 nodes
const depth = 3;
const root0 = subtreeFillToContents(nodesArr[0], depth);
const hashComps = {
offset: 0,
byLevel: new Array<[]>(),
};
const node = subtreeFillToContents(nodesArr[1], depth, hashComps);
expect(hashComps.byLevel.length).to.equal(4);
expect(hashComps.byLevel[0].length).to.equal(1);
expect(hashComps.byLevel[1].length).to.equal(2);
expect(hashComps.byLevel[2].length).to.equal(4);
expect(hashComps.byLevel[3].length).to.equal(1);
executeHashComputations(hashComps.byLevel);
if (node.h0 === null) {
throw Error("Root node h0 is null");
}
expect(node.root).to.deep.equal(root0.root);
});
});
32 changes: 28 additions & 4 deletions packages/ssz/src/branchNodeStruct.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import {HashObject} from "@chainsafe/as-sha256/lib/hashObject";
import {hashObjectToUint8Array, Node} from "@chainsafe/persistent-merkle-tree";
import {
hashObjectToUint8Array,
Node,
getHashComputations,
HashComputationGroup,
} from "@chainsafe/persistent-merkle-tree";

export type ValueToNodeFn<T> = (
value: T,
hashComps: HashComputationGroup | null,
hashCompRootNode: Node | null
) => Node;

/**
* BranchNode whose children's data is represented as a struct, the backed tree is lazily computed from the struct.
Expand All @@ -13,14 +24,13 @@ export class BranchNodeStruct<T> extends Node {
* this represents the backed tree which is lazily computed from value
*/
private _rootNode: Node | null = null;
constructor(private readonly valueToNode: (value: T) => Node, readonly value: T) {
constructor(private readonly valueToNode: ValueToNodeFn<T>, readonly value: T) {
// First null value is to save an extra variable to check if a node has a root or not
super(null as unknown as number, 0, 0, 0, 0, 0, 0, 0);
this._rootNode = null;
}

get rootHashObject(): HashObject {
// return this.rootNode.rootHashObject;
if (this.h0 === null) {
super.applyHash(this.rootNode.rootHashObject);
}
Expand All @@ -43,13 +53,27 @@ export class BranchNodeStruct<T> extends Node {
return this.rootNode.right;
}

getHashComputations(hashComps: HashComputationGroup): void {
if (this.h0 !== null) {
return;
}

if (this._rootNode === null) {
// set dest of HashComputation to this node
this._rootNode = this.valueToNode(this.value, hashComps, this);
} else {
// not likely to hit this path if called from ViewDU, handle just in case
getHashComputations(this, hashComps.offset, hashComps.byLevel);
}
}

/**
* Singleton implementation to make sure there is single backed tree for this node.
* This is important for batching HashComputations
*/
private get rootNode(): Node {
if (this._rootNode === null) {
this._rootNode = this.valueToNode(this.value);
this._rootNode = this.valueToNode(this.value, null, null);
}
return this._rootNode;
}
Expand Down
12 changes: 8 additions & 4 deletions packages/ssz/src/type/containerNodeStruct.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {Node, subtreeFillToContents} from "@chainsafe/persistent-merkle-tree";
import {HashComputationGroup, Node, subtreeFillToContents} from "@chainsafe/persistent-merkle-tree";
import {Type, ByteViews} from "./abstract";
import {isCompositeType} from "./composite";
import {ContainerType, ContainerOptions, renderContainerTypeName} from "./container";
Expand Down Expand Up @@ -106,9 +106,13 @@ export class ContainerNodeStructType<Fields extends Record<string, Type<unknown>
return new BranchNodeStruct(this.valueToTree.bind(this), value);
}

private valueToTree(value: ValueOfFields<Fields>): Node {
// TODO - batch get hash computations while creating tree
private valueToTree(
value: ValueOfFields<Fields>,
hashComps: HashComputationGroup | null = null,
hashCompRootNode: Node | null = null
): Node {
const nodes = this.fieldsEntries.map(({fieldName, fieldType}) => fieldType.value_toTree(value[fieldName]));
return subtreeFillToContents(nodes, this.depth);
const rootNode = subtreeFillToContents(nodes, this.depth, hashComps, hashCompRootNode);
return rootNode;
}
}
6 changes: 3 additions & 3 deletions packages/ssz/src/viewDU/containerNodeStruct.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {HashComputationGroup, Node, getHashComputations} from "@chainsafe/persistent-merkle-tree";
import {HashComputationGroup, Node} from "@chainsafe/persistent-merkle-tree";
import {Type, ValueOf} from "../type/abstract";
import {isCompositeType} from "../type/composite";
import {BranchNodeStruct} from "../branchNodeStruct";
Expand Down Expand Up @@ -34,8 +34,8 @@ class ContainerTreeViewDU<Fields extends Record<string, Type<unknown>>> extends
this._rootNode = this.type.value_toTree(value) as BranchNodeStruct<ValueOfFields<Fields>>;
}

if (hashComps !== null && this._rootNode.h0 === null) {
getHashComputations(this._rootNode, hashComps.offset, hashComps.byLevel);
if (hashComps !== null) {
this._rootNode.getHashComputations(hashComps);
}
}

Expand Down
Loading

0 comments on commit 321a0e5

Please sign in to comment.