forked from go-zookeeper/zk
/
batch_tree_walker.go
139 lines (119 loc) · 3.91 KB
/
batch_tree_walker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package zk
import (
"context"
gopath "path"
)
// BatchVisitorFunc is a function that is called for each batch of nodes visited.
type BatchVisitorFunc func(paths []string) error
// BatchVisitorCtxFunc is like BatchVisitorFunc, but it takes a context.
type BatchVisitorCtxFunc func(ctx context.Context, paths []string) error
// NewBatchTreeWalker returns a new BatchTreeWalker for the given connection, root path and batch size.
func NewBatchTreeWalker(conn *Conn, path string, batchSize int) *BatchTreeWalker {
if batchSize <= 0 {
batchSize = 1 // Must be at least 1.
}
return &BatchTreeWalker{
conn: conn,
path: path,
batchSize: batchSize,
}
}
// BatchTreeWalker provides traversal of a tree of nodes rooted at a specific path.
// It fetches children in batches to reduce the number of round trips.
// The batch size is configurable.
type BatchTreeWalker struct {
conn *Conn
path string
batchSize int
}
// Walk begins traversing the tree and calls the visitor function for each node visited.
func (w *BatchTreeWalker) Walk(visitor BatchVisitorFunc) error {
vc := func(ctx context.Context, paths []string) error {
return visitor(paths)
}
return w.WalkCtx(context.Background(), vc)
}
func (w *BatchTreeWalker) WalkCtx(ctx context.Context, visitor BatchVisitorCtxFunc) error {
return w.walkBatch(ctx, []string{w.path}, visitor)
}
// WalkChan begins traversing the tree and sends the results to the returned channel.
// The channel will be buffered with the given size.
// The channel is closed when the traversal is complete.
// If an error occurs, an error event will be sent to the channel before it is closed.
func (w *BatchTreeWalker) WalkChan(bufferSize int) <-chan VisitEvent {
return w.WalkChanCtx(context.Background(), bufferSize)
}
// WalkChanCtx is like WalkChan, but it takes a context that can be used to cancel the walk.
func (w *BatchTreeWalker) WalkChanCtx(ctx context.Context, bufferSize int) <-chan VisitEvent {
ch := make(chan VisitEvent, bufferSize)
visitor := func(ctx context.Context, paths []string) error {
for _, p := range paths {
select {
case <-ctx.Done():
return ctx.Err()
case ch <- VisitEvent{Path: p}:
}
}
return nil
}
go func() {
defer close(ch)
if err := w.WalkCtx(ctx, visitor); err != nil {
ch <- VisitEvent{Err: err}
}
}()
return ch
}
// walkBatch recursively walks the tree in batches.
// It calls the visitor function for each batch of nodes visited.
// It fetches children in batches to reduce the number of round trips.
func (w *BatchTreeWalker) walkBatch(ctx context.Context, paths []string, visitor BatchVisitorCtxFunc) error {
// Execute the visitor function on all paths.
if err := visitor(ctx, paths); err != nil {
return err
}
// Fetch children of all paths.
children, err := w.fetchChildrenBatch(ctx, paths)
if err != nil {
return err
}
var batch []string
for i, p := range paths {
for _, c := range children[i] {
batch = append(batch, gopath.Join(p, c))
if len(batch) >= w.batchSize {
// Recursively walk the batch.
if err = w.walkBatch(ctx, batch, visitor); err != nil {
return err
}
batch = nil
}
}
}
if len(batch) > 0 {
// Recursively walk the last batch.
if err = w.walkBatch(ctx, batch, visitor); err != nil {
return err
}
}
return nil
}
// fetchChildrenBatch fetches the children of all paths in a single batch.
func (w *BatchTreeWalker) fetchChildrenBatch(ctx context.Context, paths []string) ([][]string, error) {
requests := make([]any, len(paths))
for i, p := range paths {
requests[i] = &GetChildrenRequest{Path: p}
}
responses, err := w.conn.MultiReadCtx(ctx, requests...)
if err != nil && err != ErrNoNode { // Treat ErrNoNode as empty children.
return nil, err
}
children := make([][]string, len(responses))
for i, r := range responses {
if r.Error == ErrNoNode {
continue // Treat ErrNoNode as empty children.
}
children[i] = r.Children
}
return children, nil
}