/
certificate-manager.go
140 lines (120 loc) · 3.75 KB
/
certificate-manager.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package gcp
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"regexp"
"strconv"
"sync"
"time"
"go.uber.org/zap"
)
// CertificateManager represents a certificate manager
type CertificateManager struct {
// Certificates (kid:cert(*rsa.PublicKey))
Certificates map[string]*rsa.PublicKey
certificatesMutex sync.Mutex
client *http.Client
closeCh chan struct{}
}
// RespCertificates represents the data structure returned by the endpoint (kid:cert)
type RespCertificates map[string]string
var rgxMaxAge = regexp.MustCompile("max-age=([0-9]*)")
// GoogleOAuth2CertsURL is the default endpoint for retrieving public keys to verify the given JWT with
var GoogleOAuth2CertsURL = "https://www.googleapis.com/oauth2/v1/certs"
// NewCertificateManager returns a new certificate manager
func NewCertificateManager() *CertificateManager {
return &CertificateManager{
closeCh: make(chan struct{}),
Certificates: map[string]*rsa.PublicKey{},
client: &http.Client{Timeout: 5 * time.Second},
}
}
// GetPublicKeyByKeyID returns the associated public key to given kid
func (cM *CertificateManager) GetPublicKeyByKeyID(kid string) (*rsa.PublicKey, error) {
for i := 0; i < 10; i++ {
cM.certificatesMutex.Lock()
if len(cM.Certificates) > 0 {
defer cM.certificatesMutex.Unlock()
if cert, ok := cM.Certificates[kid]; ok {
return cert, nil
}
return nil, fmt.Errorf("could not find certificate for kid: %s", kid)
}
cM.certificatesMutex.Unlock()
time.Sleep(500 * time.Millisecond) // back-off
}
return nil, fmt.Errorf("timed out waiting for certificates to load")
}
// Run runs the certificate manager, this should most likely be executed as a go-routine
func (cM *CertificateManager) Run(wg *sync.WaitGroup) {
logger.Info("running certificate manager")
defer logger.Info("stopped certificate manager")
defer wg.Done()
for {
cM.certificatesMutex.Lock()
// #nosec G107
req, err := http.NewRequestWithContext(
context.Background(),
"GET",
GoogleOAuth2CertsURL,
nil,
)
if err != nil {
logger.Error("could not make Google OAuth2 certs request", zap.Error(err))
}
resp, err := cM.client.Do(req)
if err != nil {
logger.Error("could not retrieve Google OAuth2 certs", zap.Error(err))
}
defer resp.Body.Close()
var resCerts RespCertificates
err = json.NewDecoder(resp.Body).Decode(&resCerts)
if err != nil {
logger.Error("could not decode Google OAuth2 certs", zap.Error(err))
}
for kid, encodedCert := range resCerts {
block, _ := pem.Decode([]byte(encodedCert))
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
logger.Error("could not parse Google OAuth2 certs", zap.Error(err))
} else {
cM.Certificates[kid] = cert.PublicKey.(*rsa.PublicKey)
logger.Info("stored certificate", zap.String("kid", kid))
}
}
cM.certificatesMutex.Unlock()
expiry := GetMaxAgeFromHeader(resp.Header)
logger.Info("waiting to refresh certificates", zap.Duration("seconds", expiry))
select {
case <-cM.closeCh:
return
case <-time.After(expiry):
}
}
}
// Stop stops the certificate manager
func (cM *CertificateManager) Stop() {
close(cM.closeCh)
}
// GetMaxAgeFromHeader returns the max-age value from the cache-control header, defaulting to 1800.
func GetMaxAgeFromHeader(header http.Header) time.Duration {
maxAge := time.Duration(1800)
cacheControl := header.Get("cache-control")
if cacheControl != "" {
matches := rgxMaxAge.FindStringSubmatch(cacheControl)
if len(matches) > 1 {
foundMaxAge, err := strconv.ParseInt(matches[1], 10, 64)
if err == nil {
maxAge = time.Duration(foundMaxAge)
} else {
logger.Error("could not parse int", zap.String("in", matches[1]), zap.Error(err))
}
}
}
return maxAge * time.Second
}