/
fixed_merkle_tree.go
139 lines (124 loc) · 3.18 KB
/
fixed_merkle_tree.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package fMerkleTree
import (
"fmt"
)
type MerkleTree struct {
*BaseTree
}
func NewMerkleTree(levels int, elements []Element, zeroElement Element, hashFn HashFunction) (*MerkleTree, error) {
base := &BaseTree{levels: levels}
if len(elements) > base.Capacity() {
return nil, fmt.Errorf("tree is full")
}
if hashFn == nil {
return nil, fmt.Errorf("hash function is nil")
}
base.hashFn = hashFn
base.zeroElement = zeroElement
base.layers = make([][]Element, levels+1)
base.layers[0] = elements
out := &MerkleTree{base}
out.buildZeros()
out.buildHashes()
return out, nil
}
func (mt *MerkleTree) buildHashes() {
for layerIndex := 1; layerIndex <= mt.levels; layerIndex++ {
nodes := mt.layers[layerIndex-1]
mt.layers[layerIndex] = mt.processNodes(nodes, layerIndex)
}
}
/**
* Insert multiple elements into the tree.
* @param elements Elements to insert
*/
func (mt *MerkleTree) BulkInsert(elements []Element) error {
if len(elements) == 0 {
return nil
}
for _, element := range elements {
if err := mt.Insert(element); err != nil {
return err
}
}
return nil
}
func (mt MerkleTree) IndexOf(element Element) int {
return IndexOfElement(mt.layers[0], element, 0, nil)
}
func (mt MerkleTree) Proof(element Element) (ProofPath, error) {
index := mt.IndexOf(element)
return mt.Path(index)
}
func (mt MerkleTree) getTreeEdge(edgeIndex int) (TreeEdge, error) {
if edgeIndex >= len(mt.layers[0]) {
return TreeEdge{}, fmt.Errorf("index out of range")
}
edgeElement := mt.layers[0][edgeIndex]
if edgeElement == nil {
return TreeEdge{}, fmt.Errorf("element not found")
}
edgePath, err := mt.Path(edgeIndex)
if err != nil {
return TreeEdge{}, err
}
return TreeEdge{
EdgePath: edgePath,
EdgeElement: edgeElement,
EdgeIndex: edgeIndex,
EdgeElementsCount: len(mt.layers[0])}, nil
}
func (mt MerkleTree) GetTreeSlices(count int) ([]TreeSlice, error) {
length := len(mt.layers[0])
size := length / count
if length%count != 0 {
size++
}
if size%2 != 0 {
size++
}
slices := []TreeSlice{}
for i := 0; i < length; i += size {
edgeLeft := i
edgeRight := i + size
edge, err := mt.getTreeEdge(edgeLeft)
if err != nil {
return nil, err
}
slices = append(slices, TreeSlice{Edge: edge, Elements: mt.layers[0][edgeLeft:edgeRight]})
}
return slices, nil
}
/**
* Serialize entire tree state including intermediate layers into a plain object
* Deserializing it back will not require to recompute any hashes
* Elements are not converted to a plain type, this is responsibility of the caller
*/
func (mt MerkleTree) Serialize() (SerializedTreeState, error) {
return NewSerializedTreeState(&mt)
}
func DeserializeMerkleTree(data SerializedTreeState, hashFn HashFunction) (*MerkleTree, error) {
layers, err := data.GetLayers()
if err != nil {
fmt.Println("failed to get layers")
return nil, err
}
zeros, err := data.GetZeros()
if err != nil {
return nil, err
}
out := &MerkleTree{
BaseTree: &BaseTree{
levels: data.GetLevels(),
layers: layers,
zeros: zeros,
hashFn: hashFn,
},
}
// check against root
if !out.Root().Cmp(data.GetRoot()) {
return nil, fmt.Errorf("root mismatch")
}
out.zeroElement = out.zeros[0]
return out, nil
}