-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree.go
458 lines (391 loc) · 12.1 KB
/
tree.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
package pmmap
import (
"fmt"
)
// Construct a new persistent key-value map with the specified hasher.
func New[V, K any](hasher Hasher[K]) Tree[K, V] {
// Order of K and V is swapped because K can be inferred from the argument.
// TODO: Check if type inference is better in Go 1.19
// https://github.com/golang/go/issues/41176
return Tree[K, V]{hasher, nil}
}
// Tree represents a persistent hash map.
type Tree[K, V any] struct {
hasher Hasher[K]
root node[K, V]
}
// Lookup returns the value mapped to the provided key in the map.
// The semantics are equivalent to those of 2-valued lookup in regular Go maps.
func (tree Tree[K, V]) Lookup(key K) (V, bool) {
// Hashing can be expensive, so we hash the key once here and pass it on.
return lookup(tree.root, tree.hasher.Hash(key), key, tree.hasher)
}
// Inserts the given key-value pair into the map.
// Replaces previous value with the same key if it exists.
func (tree Tree[K, V]) Insert(key K, value V) Tree[K, V] {
return tree.InsertOrMerge(key, value, nil)
}
// Inserts the given key-value pair into the map. If a previous mapping
// (prevValue) exists for the key, the inserted value will be `f(value, prevValue)`.
func (tree Tree[K, V]) InsertOrMerge(key K, value V, f MergeFunc[V]) Tree[K, V] {
tree.root, _ = insert(tree.root, tree.hasher.Hash(key), key, value, tree.hasher, f)
return tree
}
// Remove a mapping for the given key if it exists.
func (tree Tree[K, V]) Remove(key K) Tree[K, V] {
// TODO: We can check if the key exists before erasing to prevent
// replacing parts of subtrees unnecessarily (to preserve pointer equality)
tree.root = remove(tree.root, tree.hasher.Hash(key), key, tree.hasher)
return tree
}
// Call the given function once for each key-value pair in the map.
func (tree Tree[K, V]) ForEach(f eachFunc[K, V]) {
if tree.root != nil {
tree.root.each(f)
}
}
// Merges two maps. If both maps contain a value for a key, the resulting map
// will map the key to the result of `f` on the two values.
//
// See the documentation for MergeFunc for conditions that `f` must satisfy.
// No guarantees are made about the order of arguments provided to `f`.
//
// This operation is made fast by skipping processing of shared subtrees.
// Merging a tree with itself after r updates takes linear time in r.
func (tree Tree[K, V]) Merge(other Tree[K, V], f MergeFunc[V]) Tree[K, V] {
tree.root, _ = merge(tree.root, other.root, tree.hasher, f)
return tree
}
// Returns whether two maps are equal. Values are compared with the provided
// function. This operation also skips processing of shared subtrees.
func (tree Tree[K, V]) Equal(other Tree[K, V], f func(V, V) bool) bool {
return equal(tree.root, other.root, tree.hasher, f)
}
// Returns the number of key-value pairs in the map.
// NOTE: Runs in linear time in the size of the map.
func (tree Tree[K, V]) Size() (res int) {
tree.ForEach(func(_ K, _ V) {
res++
})
return
}
func (tree Tree[K, V]) String() string {
buf := []string{}
tree.ForEach(func(k K, v V) {
buf = append(buf, fmt.Sprintf("%v ↦ %v", k, v))
})
return fmt.Sprintf("tree%s", buf)
}
// MergeFunc describes a binary operator, f, that merges two values.
// The operator must be commutative and idempotent. I.e.:
// f(a, b) = f(b, a)
// f(a, a) = a
// The second return value informs the caller whether a == b.
// This flag allows some optimizations in the implementation.
type MergeFunc[V any] func(a, b V) (V, bool)
// End of public interface
// The patricia tree implementation is based on:
// http://ittc.ku.edu/~andygill/papers/IntMap98.pdf
type eachFunc[K, V any] func(key K, value V)
type node[K, V any] interface {
each(eachFunc[K, V])
}
type keyt = uint64
type branch[K, V any] struct {
prefix keyt // Common prefix of all keys in the left and right subtrees
// A number with exactly one set bit. The position of the bit determines
// where the prefixes of the left and right subtrees diverge.
branchBit keyt
left, right node[K, V]
}
func (b *branch[K, V]) each(f eachFunc[K, V]) {
b.left.each(f)
b.right.each(f)
}
// Returns whether the key matches the prefix up until the branching bit.
// Intuitively: does the key belong in the branch's subtree?
func (b *branch[K, V]) match(key keyt) bool {
return (key & (b.branchBit - 1)) == b.prefix
}
type pair[K, V any] struct {
key K
value V
}
type leaf[K, V any] struct {
// The (shared) hash value of all keys in the leaf.
key keyt
// List of values to handle hash collisions.
// TODO: Since collisions should be rare it might be worth
// it to have a fast implementation when no collisions occur.
values []pair[K, V]
}
func (l *leaf[K, V]) copy() *leaf[K, V] {
return &leaf[K, V]{
l.key,
append([]pair[K, V](nil), l.values...),
}
}
func (l *leaf[K, V]) each(f eachFunc[K, V]) {
for _, pr := range l.values {
f(pr.key, pr.value)
}
}
// Smart branch constructor
func br[K, V any](prefix, branchBit keyt, left, right node[K, V]) node[K, V] {
if left == nil {
return right
} else if right == nil {
return left
}
return &branch[K, V]{prefix, branchBit, left, right}
}
// Recursive lookup on tree.
func lookup[K, V any](tree node[K, V], hash keyt, key K, hasher Hasher[K]) (ret V, found bool) {
if tree == nil {
return
}
switch tree := tree.(type) {
case *leaf[K, V]:
if tree.key == hash {
for _, pr := range tree.values {
if hasher.Equal(key, pr.key) {
return pr.value, true
}
}
}
return
case *branch[K, V]:
rec := tree.right
if !tree.match(hash) {
return
} else if zeroBit(hash, tree.branchBit) {
rec = tree.left
}
return lookup(rec, hash, key, hasher)
default:
panic("???")
}
}
// Joins two trees t0 and t1 which have prefixes p0 and p1 respectively.
// The prefixes must not be equal!
func join[K, V any](p0, p1 keyt, t0, t1 node[K, V]) node[K, V] {
bbit := branchingBit(p0, p1)
prefix := p0 & (bbit - 1)
if zeroBit(p0, bbit) {
return &branch[K, V]{prefix, bbit, t0, t1}
} else {
return &branch[K, V]{prefix, bbit, t1, t0}
}
}
// If `f` is nil the old value is always replaced with the argument value, otherwise
// the old value is replaced with `f(value, prevValue)`.
// If the returned flag is false, the returned node is (reference-)equal to the input node.
func insert[K, V any](tree node[K, V], hash keyt, key K, value V, hasher Hasher[K], f MergeFunc[V]) (node[K, V], bool) {
if tree == nil {
return &leaf[K, V]{key: hash, values: []pair[K, V]{{key, value}}}, true
}
var prefix keyt
switch tree := tree.(type) {
case *leaf[K, V]:
if tree.key == hash {
for i, pr := range tree.values {
// If key matches previous key, replace value
if hasher.Equal(key, pr.key) {
newValue := value
if f != nil {
var equal bool
newValue, equal = f(value, pr.value)
if equal {
return tree, false
}
}
lf := tree.copy()
lf.values[i].value = newValue
return lf, true
}
}
// Hash collision - append to list of values in leaf
lf := tree.copy()
lf.values = append(lf.values, pair[K, V]{key, value})
return lf, true
}
prefix = tree.key
case *branch[K, V]:
if tree.match(hash) {
l, r := tree.left, tree.right
var changed bool
if zeroBit(hash, tree.branchBit) {
l, changed = insert(l, hash, key, value, hasher, f)
} else {
r, changed = insert(r, hash, key, value, hasher, f)
}
if !changed {
return tree, false
}
return &branch[K, V]{tree.prefix, tree.branchBit, l, r}, true
}
prefix = tree.prefix
default:
panic("???")
}
newLeaf, _ := insert(nil, hash, key, value, nil, nil)
return join(hash, prefix, newLeaf, tree), true
}
func remove[K, V any](tree node[K, V], hash keyt, key K, hasher Hasher[K]) node[K, V] {
if tree == nil {
return tree
}
switch tree := tree.(type) {
case *leaf[K, V]:
if tree.key == hash {
newLeaf := &leaf[K, V]{tree.key, nil}
// Copy all pairs that do not match the key
for _, pr := range tree.values {
if !hasher.Equal(key, pr.key) {
newLeaf.values = append(newLeaf.values, pr)
}
}
if len(newLeaf.values) == 0 {
return nil
}
return newLeaf
}
case *branch[K, V]:
if tree.match(hash) {
left, right := tree.left, tree.right
if zeroBit(hash, tree.branchBit) {
left = remove(left, hash, key, hasher)
} else {
right = remove(right, hash, key, hasher)
}
return br(tree.prefix, tree.branchBit, left, right)
}
default:
panic("???")
}
return tree
}
// If the returned flag is true, a and b represent equal trees
func merge[K, V any](a, b node[K, V], hasher Hasher[K], f MergeFunc[V]) (node[K, V], bool) {
// Cheap pointer-equality
if a == b {
return a, true
} else if a == nil {
return b, false
} else if b == nil {
return a, false
}
// Check if either a or b is a leaf
lf, isLeaf := a.(*leaf[K, V])
other := b
if !isLeaf {
lf, isLeaf = b.(*leaf[K, V])
other = a
}
if isLeaf {
originalOther := other
for _, pr := range lf.values {
other, _ = insert(other, lf.key, pr.key, pr.value, hasher, f)
}
if oLf, oIsLeaf := other.(*leaf[K, V]); oIsLeaf &&
other == originalOther &&
len(lf.values) == len(oLf.values) {
// Since the other tree is also a leaf, and it did not change as a
// result of inserting our values, and we did not start out with a
// fewer number of key-value pairs than the other leaf, the two
// leaves were (and are still) equal.
return a, true
}
return other, false
}
// Both a and b are branches
s, t := a.(*branch[K, V]), b.(*branch[K, V])
if s.branchBit == t.branchBit && s.prefix == t.prefix {
l, leq := merge(s.left, t.left, hasher, f)
r, req := merge(s.right, t.right, hasher, f)
if leq && req {
return s, true
} else if l == s.left && r == s.right {
return s, false
} else if l == t.left && r == t.right {
return t, false
}
return &branch[K, V]{s.prefix, s.branchBit, l, r}, false
}
if s.branchBit > t.branchBit {
s, t = t, s
}
if s.branchBit < t.branchBit && s.match(t.prefix) {
// s contains t
l, r := s.left, s.right
if zeroBit(t.prefix, s.branchBit) {
l, _ = merge(l, node[K, V](t), hasher, f)
if l == s.left {
return s, false
}
} else {
r, _ = merge(r, node[K, V](t), hasher, f)
if r == s.right {
return s, false
}
}
return &branch[K, V]{s.prefix, s.branchBit, l, r}, false
} else {
// prefixes disagree
return join(s.prefix, t.prefix, node[K, V](s), node[K, V](t)), false
}
// NOTE: The implementation of this function is complex because it is
// performance critical, and since the performance does not rely only on
// the implementation within this function. Using shared subtrees speeds
// up future merge/equal operations on the result, which is important.
// The implementation does not (yet) produce a result that shares maximally
// with one of the input trees. Consider `merge(s, t) = t'`:
// s t t'
// / \ / \ / \
// 0 a c b c a
// / \ / \ / \ / \ / \
// 1 3 0 2 1 3 0 2 1 3
// The merge of the leaf `0` and `c` returns `c` because it is a superset of
// the leaf. However, the merge of `a` and `b` returns `a` because we prefer
// the left subtree over the right (both `a` and `b` are valid return values
// as the subtrees are equal). Since `t` is not the branch `(c, a)`, we
// return a new branch `t'` when we could have just returned `t`.
// Note also that `merge(t, s) = t`.
}
func equal[K, V any](a, b node[K, V], hasher Hasher[K], f func(V, V) bool) bool {
if a == b {
return true
} else if a == nil || b == nil {
return false
}
switch a := a.(type) {
case *leaf[K, V]:
b, ok := b.(*leaf[K, V])
if !ok || len(a.values) != len(b.values) {
return false
}
FOUND:
for _, apr := range a.values {
for _, bpr := range b.values {
if hasher.Equal(apr.key, bpr.key) {
if !f(apr.value, bpr.value) {
return false
}
continue FOUND
}
}
// a contained a key that b did not
return false
}
return true
case *branch[K, V]:
b, ok := b.(*branch[K, V])
if !ok {
return false
}
return a.prefix == b.prefix && a.branchBit == b.branchBit &&
equal(a.left, b.left, hasher, f) && equal(a.right, b.right, hasher, f)
default:
panic("???")
}
}