-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.go
166 lines (139 loc) · 3.09 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package main
import (
"context"
"errors"
"log"
"strings"
"sync"
"time"
"golang.org/x/sync/semaphore"
)
func producer(ctx context.Context, strings []string) (<-chan string, error) {
outChannel := make(chan string)
go func() {
defer close(outChannel)
for _, s := range strings {
select {
case <-ctx.Done():
return
case outChannel <- s:
}
}
}()
return outChannel, nil
}
func sink(ctx context.Context, cancelFunc context.CancelFunc, values <-chan string, errors <-chan error) {
for {
select {
case <-ctx.Done():
log.Print(ctx.Err().Error())
return
case err := <-errors:
if err != nil {
log.Println("error: ", err.Error())
cancelFunc()
}
case val, ok := <-values:
if ok {
log.Printf("sink: %s", val)
} else {
log.Print("done")
return
}
}
}
}
func step[In any, Out any](
ctx context.Context,
inputChannel <-chan In,
fn func(In) (Out, error),
) (chan Out, chan error) {
outputChannel := make(chan Out)
errorChannel := make(chan error)
limit := int64(2)
// Use all CPU cores to maximize efficiency. We'll set the limit to 2 so you
// can see the values being processed in batches of 2 at a time, in parallel
// limit := int64(runtime.NumCPU())
sem1 := semaphore.NewWeighted(limit)
go func() {
defer close(outputChannel)
defer close(errorChannel)
for {
select {
case <-ctx.Done():
break
case s, ok := <-inputChannel:
if ok {
if err := sem1.Acquire(ctx, 1); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
break
}
go func(s In) {
defer sem1.Release(1)
time.Sleep(time.Second * 3)
result, err := fn(s)
if err != nil {
errorChannel <- err
} else {
outputChannel <- result
}
}(s)
} else {
if err := sem1.Acquire(ctx, limit); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
}
return
}
}
}
}()
return outputChannel, errorChannel
}
func Merge[T any](ctx context.Context, cs ...<-chan T) <-chan T {
var wg sync.WaitGroup
out := make(chan T)
output := func(c <-chan T) {
defer wg.Done()
for n := range c {
select {
case out <- n:
case <-ctx.Done():
return
}
}
}
wg.Add(len(cs))
for _, c := range cs {
go output(c)
}
go func() {
wg.Wait()
close(out)
}()
return out
}
func transformA(s string) (string, error) {
log.Println("transformA input: ", s)
return strings.ToLower(s), nil
}
func transformB(s string) (string, error) {
log.Println("transformB input: ", s)
// Comment this out to see the pipeline finish successfully
if s == "foo" {
return "", errors.New("oh no")
}
return strings.Title(s), nil
}
func main() {
source := []string{"FOO", "BAR", "BAX"}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
readStream, err := producer(ctx, source)
if err != nil {
log.Fatal(err)
}
step1results, step1errors := step(ctx, readStream, transformA)
step2results, step2errors := step(ctx, step1results, transformB)
allErrors := Merge(ctx, step1errors, step2errors)
sink(ctx, cancel, step2results, allErrors)
}