/
stream.go
187 lines (154 loc) · 4.5 KB
/
stream.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
package dispatch
import (
"context"
"sync"
"sync/atomic"
grpc "google.golang.org/grpc"
)
// Stream defines the interface generically matching a streaming dispatch response.
type Stream[T any] interface {
// Publish publishes the result to the stream.
Publish(T) error
// Context returns the context for the stream.
Context() context.Context
}
type grpcStream[T any] interface {
grpc.ServerStream
Send(T) error
}
// WrapGRPCStream wraps a gRPC result stream with a concurrent-safe dispatch stream. This is
// necessary because gRPC response streams are *not concurrent safe*.
// See: https://groups.google.com/g/grpc-io/c/aI6L6M4fzQ0?pli=1
func WrapGRPCStream[R any, S grpcStream[R]](grpcStream S) Stream[R] {
return &concurrentSafeStream[R]{
grpcStream: grpcStream,
mu: sync.Mutex{},
}
}
type concurrentSafeStream[T any] struct {
grpcStream grpcStream[T]
mu sync.Mutex
}
func (s *concurrentSafeStream[T]) Context() context.Context {
return s.grpcStream.Context()
}
func (s *concurrentSafeStream[T]) Publish(result T) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.grpcStream.Send(result)
}
// NewCollectingDispatchStream creates a new CollectingDispatchStream.
func NewCollectingDispatchStream[T any](ctx context.Context) *CollectingDispatchStream[T] {
return &CollectingDispatchStream[T]{
ctx: ctx,
results: nil,
mu: sync.Mutex{},
}
}
// CollectingDispatchStream is a dispatch stream that collects results in memory.
type CollectingDispatchStream[T any] struct {
ctx context.Context
results []T
mu sync.Mutex
}
func (s *CollectingDispatchStream[T]) Context() context.Context {
return s.ctx
}
func (s *CollectingDispatchStream[T]) Results() []T {
return s.results
}
func (s *CollectingDispatchStream[T]) Publish(result T) error {
s.mu.Lock()
defer s.mu.Unlock()
s.results = append(s.results, result)
return nil
}
// WrappedDispatchStream is a dispatch stream that wraps another dispatch stream, and performs
// an operation on each result before puppeting back up to the parent stream.
type WrappedDispatchStream[T any] struct {
Stream Stream[T]
Ctx context.Context
Processor func(result T) (T, bool, error)
}
func (s *WrappedDispatchStream[T]) Publish(result T) error {
if s.Processor == nil {
return s.Stream.Publish(result)
}
processed, ok, err := s.Processor(result)
if err != nil {
return err
}
if !ok {
return nil
}
return s.Stream.Publish(processed)
}
func (s *WrappedDispatchStream[T]) Context() context.Context {
return s.Ctx
}
// StreamWithContext returns the given dispatch stream, wrapped to return the given context.
func StreamWithContext[T any](context context.Context, stream Stream[T]) Stream[T] {
return &WrappedDispatchStream[T]{
Stream: stream,
Ctx: context,
Processor: nil,
}
}
// HandlingDispatchStream is a dispatch stream that executes a handler for each item published.
// It uses an internal mutex to ensure it is thread safe.
type HandlingDispatchStream[T any] struct {
ctx context.Context
processor func(result T) error
mu sync.Mutex
}
// NewHandlingDispatchStream returns a new handling dispatch stream.
func NewHandlingDispatchStream[T any](ctx context.Context, processor func(result T) error) Stream[T] {
return &HandlingDispatchStream[T]{
ctx: ctx,
processor: processor,
mu: sync.Mutex{},
}
}
func (s *HandlingDispatchStream[T]) Publish(result T) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.processor == nil {
return nil
}
return s.processor(result)
}
func (s *HandlingDispatchStream[T]) Context() context.Context {
return s.ctx
}
// CountingDispatchStream is a dispatch stream that counts the number of items published.
// It uses an internal atomic int to ensure it is thread safe.
type CountingDispatchStream[T any] struct {
Stream Stream[T]
count *atomic.Uint64
}
func NewCountingDispatchStream[T any](wrapped Stream[T]) *CountingDispatchStream[T] {
return &CountingDispatchStream[T]{
Stream: wrapped,
count: &atomic.Uint64{},
}
}
func (s *CountingDispatchStream[T]) PublishedCount() uint64 {
return s.count.Load()
}
func (s *CountingDispatchStream[T]) Publish(result T) error {
err := s.Stream.Publish(result)
if err != nil {
return err
}
s.count.Add(1)
return nil
}
func (s *CountingDispatchStream[T]) Context() context.Context {
return s.Stream.Context()
}
// Ensure the streams implement the interface.
var (
_ Stream[any] = &CollectingDispatchStream[any]{}
_ Stream[any] = &WrappedDispatchStream[any]{}
_ Stream[any] = &CountingDispatchStream[any]{}
)