forked from hashicorp/nomad
/
sids_hook.go
278 lines (238 loc) · 7.93 KB
/
sids_hook.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
package taskrunner
import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
"github.com/hashicorp/nomad/client/consul"
"github.com/hashicorp/nomad/nomad/structs"
)
const (
// the name of this hook, used in logs
sidsHookName = "consul_si_token"
// sidsBackoffBaseline is the baseline time for exponential backoff when
// attempting to retrieve a Consul SI token
sidsBackoffBaseline = 5 * time.Second
// sidsBackoffLimit is the limit of the exponential backoff when attempting
// to retrieve a Consul SI token
sidsBackoffLimit = 3 * time.Minute
// sidsDerivationTimeout limits the amount of time we may spend trying to
// derive a SI token. If the hook does not get a token within this amount of
// time, the result is a failure.
sidsDerivationTimeout = 5 * time.Minute
// sidsTokenFile is the name of the file holding the Consul SI token inside
// the task's secret directory
sidsTokenFile = "si_token"
// sidsTokenFilePerms is the level of file permissions granted on the file
// in the secrets directory for the task
sidsTokenFilePerms = 0440
)
type sidsHookConfig struct {
alloc *structs.Allocation
task *structs.Task
sidsClient consul.ServiceIdentityAPI
lifecycle ti.TaskLifecycle
logger hclog.Logger
}
// Service Identities hook for managing SI tokens of connect enabled tasks.
type sidsHook struct {
// alloc is the allocation
alloc *structs.Allocation
// taskName is the name of the task
task *structs.Task
// sidsClient is the Consul client [proxy] for requesting SI tokens
sidsClient consul.ServiceIdentityAPI
// lifecycle is used to signal, restart, and kill a task
lifecycle ti.TaskLifecycle
// derivationTimeout is the amount of time we may wait for Consul to successfully
// provide a SI token. Making this configurable for testing, otherwise
// default to sidsDerivationTimeout
derivationTimeout time.Duration
// logger is used to log
logger hclog.Logger
// lock variables that can be manipulated after hook creation
lock sync.Mutex
// firstRun keeps track of whether the hook is being called for the first
// time (for this task) during the lifespan of the Nomad Client process.
firstRun bool
}
func newSIDSHook(c sidsHookConfig) *sidsHook {
return &sidsHook{
alloc: c.alloc,
task: c.task,
sidsClient: c.sidsClient,
lifecycle: c.lifecycle,
derivationTimeout: sidsDerivationTimeout,
logger: c.logger.Named(sidsHookName),
firstRun: true,
}
}
func (h *sidsHook) Name() string {
return sidsHookName
}
func (h *sidsHook) Prestart(
ctx context.Context,
req *interfaces.TaskPrestartRequest,
resp *interfaces.TaskPrestartResponse) error {
h.lock.Lock()
defer h.lock.Unlock()
// do nothing if we have already done things
if h.earlyExit() {
resp.Done = true
return nil
}
// optimistically try to recover token from disk
token, err := h.recoverToken(req.TaskDir.SecretsDir)
if err != nil {
return err
}
// need to ask for a new SI token & persist it to disk
if token == "" {
if token, err = h.deriveSIToken(ctx); err != nil {
return err
}
if err := h.writeToken(req.TaskDir.SecretsDir, token); err != nil {
return err
}
}
h.logger.Info("derived SI token", "task", h.task.Name, "si_task", h.task.Kind.Value())
resp.Done = true
return nil
}
// earlyExit returns true if the Prestart hook has already been executed during
// the instantiation of this task runner.
//
// assumes h is locked
func (h *sidsHook) earlyExit() bool {
if h.firstRun {
h.firstRun = false
return false
}
return true
}
// writeToken writes token into the secrets directory for the task.
func (h *sidsHook) writeToken(dir string, token string) error {
tokenPath := filepath.Join(dir, sidsTokenFile)
if err := ioutil.WriteFile(tokenPath, []byte(token), sidsTokenFilePerms); err != nil {
return fmt.Errorf("failed to write SI token: %w", err)
}
return nil
}
// recoverToken returns the token saved to disk in the secrets directory for the
// task if it exists, or the empty string if the file does not exist. an error
// is returned only for some other (e.g. disk IO) error.
func (h *sidsHook) recoverToken(dir string) (string, error) {
tokenPath := filepath.Join(dir, sidsTokenFile)
token, err := ioutil.ReadFile(tokenPath)
if err != nil {
if !os.IsNotExist(err) {
h.logger.Error("failed to recover SI token", "error", err)
return "", fmt.Errorf("failed to recover SI token: %w", err)
}
h.logger.Trace("no pre-existing SI token to recover", "task", h.task.Name)
return "", nil // token file does not exist yet
}
h.logger.Trace("recovered pre-existing SI token", "task", h.task.Name)
return string(token), nil
}
// siDerivationResult is used to pass along the result of attempting to derive
// an SI token between the goroutine doing the derivation and its caller
type siDerivationResult struct {
token string
err error
}
// deriveSIToken spawns and waits on a goroutine which will make attempts to
// derive an SI token until a token is successfully created, or ctx is signaled
// done.
func (h *sidsHook) deriveSIToken(ctx context.Context) (string, error) {
ctx, cancel := context.WithTimeout(ctx, h.derivationTimeout)
defer cancel()
resultCh := make(chan siDerivationResult)
// keep trying to get the token in the background
go h.tryDerive(ctx, resultCh)
// wait until we get a token, or we get a signal to quit
for {
select {
case result := <-resultCh:
if result.err != nil {
h.logger.Error("failed to derive SI token", "error", result.err)
h.kill(ctx, fmt.Errorf("failed to derive SI token: %w", result.err))
return "", result.err
}
return result.token, nil
case <-ctx.Done():
return "", ctx.Err()
}
}
}
func (h *sidsHook) kill(ctx context.Context, reason error) {
if err := h.lifecycle.Kill(ctx,
structs.NewTaskEvent(structs.TaskKilling).
SetFailsTask().
SetDisplayMessage(reason.Error()),
); err != nil {
h.logger.Error("failed to kill task", "kill_reason", reason, "error", err)
}
}
// tryDerive loops forever until a token is created, or ctx is done.
func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- siDerivationResult) {
for attempt := 0; backoff(ctx, attempt); attempt++ {
tokens, err := h.sidsClient.DeriveSITokens(h.alloc, []string{h.task.Name})
switch {
case err == nil:
token, exists := tokens[h.task.Name]
if !exists {
err := errors.New("response does not include token for task")
h.logger.Error("derive SI token is missing token for task", "error", err, "task", h.task.Name)
ch <- siDerivationResult{token: "", err: err}
return
}
ch <- siDerivationResult{token: token, err: nil}
return
case structs.IsServerSide(err):
// the error is known to be a server problem, just die
h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "server_side", true)
ch <- siDerivationResult{token: "", err: err}
return
case !structs.IsRecoverable(err):
// the error is known not to be recoverable, just die
h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "recoverable", false)
ch <- siDerivationResult{token: "", err: err}
return
default:
// the error is marked recoverable, retry after some backoff
h.logger.Error("failed attempt to derive SI token", "error", err, "recoverable", true)
}
}
}
func backoff(ctx context.Context, attempt int) bool {
next := computeBackoff(attempt)
select {
case <-ctx.Done():
return false
case <-time.After(next):
return true
}
}
func computeBackoff(attempt int) time.Duration {
switch attempt {
case 0:
return 0
case 1:
// go fast on first retry, because a unit test should be fast
return 100 * time.Millisecond
default:
wait := time.Duration(attempt) * sidsBackoffBaseline
if wait > sidsBackoffLimit {
wait = sidsBackoffLimit
}
return wait
}
}