forked from cloudflare/cfssl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.go
370 lines (319 loc) · 9.53 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
package transport
import (
"crypto/tls"
"net"
"os"
"time"
"github.com/cloudflare/backoff"
"github.com/cloudflare/cfssl/csr"
"github.com/cloudflare/cfssl/errors"
"github.com/cloudflare/cfssl/log"
"github.com/cloudflare/cfssl/revoke"
"github.com/cloudflare/cfssl/transport/ca"
"github.com/cloudflare/cfssl/transport/core"
"github.com/cloudflare/cfssl/transport/kp"
"github.com/cloudflare/cfssl/transport/roots"
)
func envOrDefault(key, def string) string {
val := os.Getenv(key)
if val == "" {
return def
}
return val
}
var (
// NewKeyProvider is the function used to build key providers
// from some identity.
NewKeyProvider = func(id *core.Identity) (kp.KeyProvider, error) {
return kp.NewStandardProvider(id)
}
// NewCA is used to load a configuration for a certificate
// authority.
NewCA = func(id *core.Identity) (ca.CertificateAuthority, error) {
return ca.NewCFSSLProvider(id, nil)
}
)
// A Transport is capable of providing transport-layer security using
// TLS.
type Transport struct {
// Before defines how long before the certificate expires the
// transport should start attempting to refresh the
// certificate. For example, if this is 24h, then 24 hours
// before the certificate expires the Transport will start
// attempting to replace it.
Before time.Duration
// Provider contains a key management provider.
Provider kp.KeyProvider
// CA contains a mechanism for obtaining signed certificates.
CA ca.CertificateAuthority
// TrustStore contains the certificates trusted by this
// transport.
TrustStore *roots.TrustStore
// ClientTrustStore contains the certificate authorities to
// use in verifying client authentication certificates.
ClientTrustStore *roots.TrustStore
// Identity contains information about the entity that will be
// used to construct certificates.
Identity *core.Identity
// Backoff is used to control the behaviour of a Transport
// when it is attempting to automatically update a certificate
// as part of AutoUpdate.
Backoff *backoff.Backoff
// RevokeSoftFail, if true, will cause a failure to check
// revocation (such that the revocation status of a
// certificate cannot be checked) to not be treated as an
// error.
RevokeSoftFail bool
}
// TLSClientAuthClientConfig returns a new client authentication TLS
// configuration that can be used for a client using client auth
// connecting to the named host.
func (tr *Transport) TLSClientAuthClientConfig(host string) (*tls.Config, error) {
cert, err := tr.getCertificate()
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: tr.TrustStore.Pool(),
ServerName: host,
CipherSuites: core.CipherSuites,
MinVersion: tls.VersionTLS12,
ClientAuth: tls.RequireAndVerifyClientCert,
}, nil
}
// TLSClientAuthServerConfig returns a new client authentication TLS
// configuration for servers expecting mutually authenticated
// clients. The clientAuth parameter should contain the root pool used
// to authenticate clients.
func (tr *Transport) TLSClientAuthServerConfig() (*tls.Config, error) {
cert, err := tr.getCertificate()
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: tr.TrustStore.Pool(),
ClientCAs: tr.ClientTrustStore.Pool(),
ClientAuth: tls.RequireAndVerifyClientCert,
CipherSuites: core.CipherSuites,
MinVersion: tls.VersionTLS12,
}, nil
}
// TLSServerConfig is a general server configuration that should be
// used for non-client authentication purposes, such as HTTPS.
func (tr *Transport) TLSServerConfig() (*tls.Config, error) {
cert, err := tr.getCertificate()
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
CipherSuites: core.CipherSuites,
MinVersion: tls.VersionTLS12,
}, nil
}
// New builds a new transport from an identity and a before time. The
// before time tells the transport how long before the certificate
// expires to start attempting to update when auto-updating. If before
// is longer than the certificate's lifetime, every update check will
// trigger a new certificate to be generated.
func New(before time.Duration, identity *core.Identity) (*Transport, error) {
var tr = &Transport{
Before: before,
Identity: identity,
Backoff: &backoff.Backoff{},
}
store, err := roots.New(identity.Roots)
if err != nil {
return nil, err
}
tr.TrustStore = store
if len(identity.ClientRoots) > 0 {
store, err = roots.New(identity.ClientRoots)
if err != nil {
return nil, err
}
tr.ClientTrustStore = store
}
tr.Provider, err = NewKeyProvider(identity)
if err != nil {
return nil, err
}
tr.CA, err = NewCA(identity)
if err != nil {
return nil, err
}
return tr, nil
}
// Lifespan returns how much time is left before the transport's
// certificate expires, or 0 if the certificate is not present or
// expired.
func (tr *Transport) Lifespan() time.Duration {
cert := tr.Provider.Certificate()
if cert == nil {
return 0
}
now := time.Now()
if now.After(cert.NotAfter) {
return 0
}
now = now.Add(tr.Before)
ls := cert.NotAfter.Sub(now)
log.Debugf(" LIFESPAN:\t%s", ls)
if ls < 0 {
return 0
}
return ls
}
// RefreshKeys will make sure the Transport has loaded keys and has a
// valid certificate. It will handle any persistence, check that the
// certificate is valid (i.e. that its expiry date is within the
// Before date), and handle certificate reissuance as needed.
func (tr *Transport) RefreshKeys() (err error) {
if !tr.Provider.Ready() {
log.Debug("key and certificate aren't ready, loading")
err = tr.Provider.Load()
if err != nil && err != kp.ErrCertificateUnavailable {
log.Debugf("failed to load keypair: %v", err)
kr := tr.Identity.Request.KeyRequest
if kr == nil {
kr = csr.NewBasicKeyRequest()
}
err = tr.Provider.Generate(kr.Algo(), kr.Size())
if err != nil {
log.Debugf("failed to generate key: %v", err)
return err
}
}
}
lifespan := tr.Lifespan()
if lifespan < tr.Before {
log.Debugf("transport's certificate is out of date (lifespan %s)", lifespan)
req, err := tr.Provider.CertificateRequest(tr.Identity.Request)
if err != nil {
log.Debugf("couldn't get a CSR: %v", err)
if tr.Provider.SignalFailure(err) {
return tr.RefreshKeys()
}
return err
}
log.Debug("requesting certificate from CA")
cert, err := tr.CA.SignCSR(req)
if err != nil {
if tr.Provider.SignalFailure(err) {
return tr.RefreshKeys()
}
log.Debugf("failed to get the certificate signed: %v", err)
return err
}
log.Debug("giving the certificate to the provider")
err = tr.Provider.SetCertificatePEM(cert)
if err != nil {
log.Debugf("failed to set the provider's certificate: %v", err)
if tr.Provider.SignalFailure(err) {
return tr.RefreshKeys()
}
return err
}
if tr.Provider.Persistent() {
log.Debug("storing the certificate")
err = tr.Provider.Store()
if err != nil {
log.Debugf("the provider failed to store the certificate: %v", err)
if tr.Provider.SignalFailure(err) {
return tr.RefreshKeys()
}
return err
}
}
}
return nil
}
func (tr *Transport) getCertificate() (cert tls.Certificate, err error) {
if !tr.Provider.Ready() {
log.Debug("transport isn't ready; attempting to refresh keypair")
err = tr.RefreshKeys()
if err != nil {
log.Debugf("transport couldn't get a certificate: %v", err)
return
}
}
cert, err = tr.Provider.X509KeyPair()
if err != nil {
log.Debugf("couldn't generate an X.509 keypair: %v", err)
}
return
}
// Dial initiates a TLS connection to an outbound server. It returns a
// TLS connection to the server.
func Dial(address string, tr *Transport) (*tls.Conn, error) {
host, _, err := net.SplitHostPort(address)
if err != nil {
// Assume address is a hostname, and that it should
// use the HTTPS port number.
host = address
address = net.JoinHostPort(address, "443")
}
cfg, err := tr.TLSClientAuthClientConfig(host)
if err != nil {
return nil, err
}
conn, err := tls.Dial("tcp", address, cfg)
if err != nil {
return nil, err
}
state := conn.ConnectionState()
if len(state.VerifiedChains) == 0 {
return nil, errors.New(errors.CertificateError, errors.VerifyFailed)
}
for _, chain := range state.VerifiedChains {
for _, cert := range chain {
revoked, ok := revoke.VerifyCertificate(cert)
if (!tr.RevokeSoftFail && !ok) || revoked {
return nil, errors.New(errors.CertificateError, errors.VerifyFailed)
}
}
}
return conn, nil
}
// AutoUpdate will automatically update the listener. If a non-nil
// certUpdates chan is provided, it will receive timestamps for
// reissued certificates. If errChan is non-nil, any errors that occur
// in the updater will be passed along.
func (tr *Transport) AutoUpdate(certUpdates chan<- time.Time, errChan chan<- error) {
defer func() {
if r := recover(); r != nil {
log.Criticalf("AutoUpdate panicked: %v", r)
}
}()
for {
// Wait until it's time to update the certificate.
target := time.Now().Add(tr.Lifespan())
if PollInterval == 0 {
<-time.After(tr.Lifespan())
} else {
pollWait(target)
}
// Keep trying to update the certificate until it's
// ready.
for {
log.Debugf("attempting to refresh keypair")
err := tr.RefreshKeys()
if err == nil {
break
}
delay := tr.Backoff.Duration()
log.Debugf("failed to update certificate, will try again in %s", delay)
if errChan != nil {
errChan <- err
}
<-time.After(delay)
}
log.Debugf("certificate updated")
if certUpdates != nil {
certUpdates <- time.Now()
}
tr.Backoff.Reset()
}
}