-
Notifications
You must be signed in to change notification settings - Fork 0
/
group.go
143 lines (123 loc) · 3.36 KB
/
group.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package runner
import (
"context"
"errors"
"os/signal"
"syscall"
"time"
)
type runner struct {
fn RunFunc
shutdownFn ShutdownFunc
}
type RunFunc func() error
type ShutdownFunc func(context.Context) error
type Group struct {
runners []runner
errors []error
shutdownTimeout time.Duration
}
func NewGroup(shutdownTimeout time.Duration) *Group {
if shutdownTimeout < 2*time.Second {
shutdownTimeout = 2 * time.Second
}
return &Group{
shutdownTimeout: shutdownTimeout,
}
}
func (g *Group) Register(fn RunFunc, shutdownFn ShutdownFunc) *Group {
g.runners = append(g.runners, runner{
fn: fn,
shutdownFn: shutdownFn,
})
return g
}
// Wait starts runners and returns if
// 1. The process receives SIGTERM or SIGINT.
// 2. One of the runner returns with non-nil error.
// 3. The input context is done.
//
// When this function returns, the process no longer masks signal handling.
func (g *Group) Wait(ctx context.Context) *Group {
ctx, stop := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGINT)
defer stop()
errCh := make(chan error, len(g.runners))
for _, executor := range g.runners {
exec := executor
go func() {
if err := exec.fn(); err != nil {
errCh <- err
}
}()
}
// wait for either signal marks the context as done, or one of the runner
// returns with error.
select {
case err := <-errCh:
g.errors = append(g.errors, err)
return g
case <-ctx.Done():
return g
}
}
func (g *Group) Errors() error {
if len(g.errors) != 0 {
return ErrGroup{Errs: g.errors}
}
return nil
}
// Shutdown runs shutdown handlers provided by each runner during registration,
// and exits if the shutdown timeout reached or all of the shutdown handler
// return. It returns non-nil error if one of the shutdown handlers returns
// error, or if the timeout is reached. The returned error type is ErrGroup
// (if it is non-nil).
func (g *Group) Shutdown() error {
shutDownCtx, cancel := context.WithTimeout(context.Background(), g.shutdownTimeout)
defer cancel()
var shutdownCount int
shutdownErrCh := make(chan error, len(g.runners))
// Setting inputShutdownCtx shorter than shutDownCtx so that runner's shutdownFunc can return earlier
// than the deadline of shutDownCtx if it respects the deadline of inputShutdownCtx.
inputShutdownCtx, cancel := context.WithTimeout(shutDownCtx, g.shutdownTimeout-time.Second)
defer cancel()
for _, executor := range g.runners {
exec := executor
if exec.shutdownFn == nil {
continue
}
shutdownCount += 1
go func() {
err := exec.shutdownFn(inputShutdownCtx)
shutdownErrCh <- err
}()
}
isTimeout, shutdownNilOrErrs := shutdown(shutDownCtx, shutdownErrCh, shutdownCount)
var groupedErr ErrGroup
if isTimeout {
groupedErr.Errs = append(groupedErr.Errs, errors.New("shutdown timeout exceeded"))
}
for _, err := range shutdownNilOrErrs {
if err == nil {
continue
}
groupedErr.Errs = append(groupedErr.Errs, err)
}
if len(groupedErr.Errs) == 0 {
return nil
}
return groupedErr
}
func shutdown(shutDownCtx context.Context, shutdownErrCh chan error, shutdownCount int) (bool, []error) {
var shutdownNilOrErrs []error
for {
select {
case err := <-shutdownErrCh:
shutdownNilOrErrs = append(shutdownNilOrErrs, err)
case <-shutDownCtx.Done():
return true, shutdownNilOrErrs
}
if len(shutdownNilOrErrs) == shutdownCount {
return false, shutdownNilOrErrs
}
}
}