From 6e9264d859447c5fe653882bb2bad85eacbbfc88 Mon Sep 17 00:00:00 2001 From: Ramiro Gamarra Date: Thu, 2 Jun 2022 18:23:29 -0700 Subject: [PATCH] adding a cert refresher construct --- cns/configuration/cns_config.json | 10 +- cns/configuration/configuration.go | 20 ++++ cns/configuration/configuration_test.go | 9 ++ cns/service.go | 103 ++++++++++++++++- cns/service/main.go | 10 +- common/listener.go | 134 ++++++++-------------- go.mod | 12 +- go.sum | 13 ++- keyvault/certrefresher.go | 130 +++++++++++++++++++++ keyvault/certrefresher_test.go | 145 ++++++++++++++++++++++++ server/tls/tlscertificate_retriever.go | 15 ++- 11 files changed, 494 insertions(+), 107 deletions(-) create mode 100644 keyvault/certrefresher.go create mode 100644 keyvault/certrefresher_test.go diff --git a/cns/configuration/cns_config.json b/cns/configuration/cns_config.json index 560d00c7c8..fdd5fddfe6 100644 --- a/cns/configuration/cns_config.json +++ b/cns/configuration/cns_config.json @@ -20,5 +20,13 @@ "TLSPort": "10091", "TLSSubjectName": "", "UseHTTPS": false, - "WireserverIP": "168.63.129.16" + "WireserverIP": "168.63.129.16", + "KeyVaultSettings": { + "URL": "", + "CertificateName": "", + "RefreshIntervalInHrs": 12 + }, + "MSISettings": { + "ResourceID": "" + } } diff --git a/cns/configuration/configuration.go b/cns/configuration/configuration.go index 9a05d8b950..d9ea331e09 100644 --- a/cns/configuration/configuration.go +++ b/cns/configuration/configuration.go @@ -33,6 +33,8 @@ type CNSConfig struct { TelemetrySettings TelemetrySettings UseHTTPS bool WireserverIP string + KeyVaultSettings KeyVaultSettings + MSISettings MSISettings } type TelemetrySettings struct { @@ -67,6 +69,16 @@ type ManagedSettings struct { NodeSyncIntervalInSeconds int } +type MSISettings struct { + ResourceID string +} + +type KeyVaultSettings struct { + URL string + CertificateName string + RefreshIntervalInHrs int +} + func getConfigFilePath(cmdLineConfigPath string) (string, error) { // If config path is set from cmd line, return that if cmdLineConfigPath != "" { @@ -144,10 +156,18 @@ func setManagedSettingDefaults(managedSettings *ManagedSettings) { } } +func setKeyVaultSettingsDefaults(kvs *KeyVaultSettings) { + if kvs.RefreshIntervalInHrs == 0 { + kvs.RefreshIntervalInHrs = 12 //nolint:gomnd // default times + } +} + // SetCNSConfigDefaults set default values of CNS config if not specified func SetCNSConfigDefaults(config *CNSConfig) { setTelemetrySettingDefaults(&config.TelemetrySettings) setManagedSettingDefaults(&config.ManagedSettings) + setKeyVaultSettingsDefaults(&config.KeyVaultSettings) + if config.ChannelMode == "" { config.ChannelMode = cns.Direct } diff --git a/cns/configuration/configuration_test.go b/cns/configuration/configuration_test.go index eee4c61b99..3b2d1d4485 100644 --- a/cns/configuration/configuration_test.go +++ b/cns/configuration/configuration_test.go @@ -201,6 +201,9 @@ func TestSetCNSConfigDefaults(t *testing.T) { RefreshIntervalInSecs: 15, SnapshotIntervalInMins: 60, }, + KeyVaultSettings: KeyVaultSettings{ + RefreshIntervalInHrs: 12, + }, }, }, { @@ -220,6 +223,9 @@ func TestSetCNSConfigDefaults(t *testing.T) { RefreshIntervalInSecs: 1, SnapshotIntervalInMins: 6, }, + KeyVaultSettings: KeyVaultSettings{ + RefreshIntervalInHrs: 3, + }, }, want: CNSConfig{ ChannelMode: "Other", @@ -236,6 +242,9 @@ func TestSetCNSConfigDefaults(t *testing.T) { RefreshIntervalInSecs: 1, SnapshotIntervalInMins: 6, }, + KeyVaultSettings: KeyVaultSettings{ + RefreshIntervalInHrs: 3, + }, }, }, } diff --git a/cns/service.go b/cns/service.go index b9e1555d49..4f9e398eca 100644 --- a/cns/service.go +++ b/cns/service.go @@ -4,7 +4,10 @@ package cns import ( + "context" + "crypto/tls" "fmt" + "net" "net/http" "net/url" "strings" @@ -12,8 +15,11 @@ import ( "github.com/Azure/azure-container-networking/cns/common" "github.com/Azure/azure-container-networking/cns/logger" acn "github.com/Azure/azure-container-networking/common" + "github.com/Azure/azure-container-networking/keyvault" "github.com/Azure/azure-container-networking/log" + localtls "github.com/Azure/azure-container-networking/server/tls" "github.com/Azure/azure-container-networking/store" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/pkg/errors" ) @@ -68,19 +74,27 @@ func (service *Service) Initialize(config *common.ServiceConfig) error { if err != nil { return err } - // Create the listener. + listener, err := acn.NewListener(u) if err != nil { return err } + if config.TlsSettings.TLSPort != "" { // listener.URL.Host will always be hostname:port, passed in to CNS via CNS command // else it will default to localhost // extract hostname and override tls port. hostParts := strings.Split(listener.URL.Host, ":") - config.TlsSettings.TLSEndpoint = hostParts[0] + ":" + config.TlsSettings.TLSPort + tlsAddress := net.JoinHostPort(hostParts[0], config.TlsSettings.TLSPort) + // Start the listener and HTTP and HTTPS server. - if err = listener.StartTLS(config.ErrChan, config.TlsSettings); err != nil { + tlsConfig, err := getTLSConfig(config.TlsSettings, config.ErrChan) + if err != nil { + log.Printf("Failed to compose Tls Configuration with error: %+v", err) + return errors.Wrap(err, "could not get tls config") + } + + if err := listener.StartTLS(config.ErrChan, tlsConfig, tlsAddress); err != nil { return err } } @@ -95,6 +109,89 @@ func (service *Service) Initialize(config *common.ServiceConfig) error { return nil } +func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.Config, error) { + if tlsSettings.TLSCertificatePath != "" { + return getTLSConfigFromFile(tlsSettings) + } + + if tlsSettings.KeyVaultURL != "" { + return getTLSConfigFromKeyVault(tlsSettings, errChan) + } + + return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings) +} + +func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) { + tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings) + if err != nil { + return nil, errors.Wrap(err, "failed to get certificate retriever") + } + + leafCertificate, err := tlsCertRetriever.GetCertificate() + if err != nil { + return nil, errors.Wrap(err, "failed to get certificate") + } + + if leafCertificate == nil { + return nil, errors.New("certificate retrieval returned empty") + } + + privateKey, err := tlsCertRetriever.GetPrivateKey() + if err != nil { + return nil, errors.Wrap(err, "failed to get certificate private key") + } + + tlsCert := tls.Certificate{ + Certificate: [][]byte{leafCertificate.Raw}, + PrivateKey: privateKey, + Leaf: leafCertificate, + } + + tlsConfig := &tls.Config{ + MaxVersion: tls.VersionTLS13, + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{ + tlsCert, + }, + } + + return tlsConfig, nil +} + +func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.Config, error) { + credOpts := azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ResourceID(tlsSettings.MSIResourceID)} + cred, err := azidentity.NewManagedIdentityCredential(&credOpts) + if err != nil { + return nil, errors.Wrap(err, "could not create managed identity credential") + } + + kvs, err := keyvault.NewShim(tlsSettings.KeyVaultURL, cred) + if err != nil { + return nil, errors.Wrap(err, "could not create new keyvault shim") + } + + ctx := context.TODO() + + cr, err := keyvault.NewCertRefresher(ctx, kvs, logger.Log, tlsSettings.KeyVaultCertificateName) + if err != nil { + return nil, errors.Wrap(err, "could not create new cert refresher") + } + + go func() { + errChan <- cr.Refresh(ctx, tlsSettings.KeyVaultCertificateRefreshInterval) + }() + + tlsConfig := tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS13, + GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + return cr.GetCertificate(), nil + }, + } + + return &tlsConfig, nil +} + func (service *Service) StartListener(config *common.ServiceConfig) error { log.Debugf("[Azure CNS] Going to start listener: %+v", config) diff --git a/cns/service/main.go b/cns/service/main.go index ff02d4f270..89d9a30563 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -533,9 +533,13 @@ func main() { if httpRestService != nil { if cnsconfig.UseHTTPS { config.TlsSettings = localtls.TlsSettings{ - TLSSubjectName: cnsconfig.TLSSubjectName, - TLSCertificatePath: cnsconfig.TLSCertificatePath, - TLSPort: cnsconfig.TLSPort, + TLSSubjectName: cnsconfig.TLSSubjectName, + TLSCertificatePath: cnsconfig.TLSCertificatePath, + TLSPort: cnsconfig.TLSPort, + KeyVaultURL: cnsconfig.KeyVaultSettings.URL, + KeyVaultCertificateName: cnsconfig.KeyVaultSettings.CertificateName, + MSIResourceID: cnsconfig.MSISettings.ResourceID, + KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour, } } diff --git a/common/listener.go b/common/listener.go index 6b79aa5313..06645741f9 100644 --- a/common/listener.go +++ b/common/listener.go @@ -6,26 +6,25 @@ package common import ( "crypto/tls" "encoding/json" - "fmt" "net" "net/http" "net/url" "os" "github.com/Azure/azure-container-networking/log" - localtls "github.com/Azure/azure-container-networking/server/tls" + "github.com/pkg/errors" ) // Listener represents an HTTP listener. type Listener struct { - URL *url.URL - protocol string - localAddress string - endpoints []string - active bool - l net.Listener - securelistener net.Listener - mux *http.ServeMux + URL *url.URL + protocol string + localAddress string + endpoints []string + active bool + listener net.Listener + tlsListener net.Listener + mux *http.ServeMux } // NewListener creates a new Listener. @@ -41,142 +40,103 @@ func NewListener(u *url.URL) (*Listener, error) { return &listener, nil } -func GetTlsConfig(tlsSettings localtls.TlsSettings) (*tls.Config, error) { - tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings) - if err != nil { - return nil, fmt.Errorf("Failed to get certificate retriever %+v", err) - } - leafCertificate, err := tlsCertRetriever.GetCertificate() - if err != nil { - return nil, fmt.Errorf("Failed to get certificate %+v", err) - } - if leafCertificate == nil { - return nil, fmt.Errorf("Certificate retrival returned empty %+v", err) - } - privateKey, err := tlsCertRetriever.GetPrivateKey() - if err != nil { - return nil, fmt.Errorf("Failed to get certificate private key %+v", err) - } - tlsCert := tls.Certificate{ - Certificate: [][]byte{leafCertificate.Raw}, - PrivateKey: privateKey, - Leaf: leafCertificate, - } - tlsConfig := &tls.Config{ - MaxVersion: tls.VersionTLS12, - MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{ - tlsCert, - }, - } - return tlsConfig, nil -} - -// Start creates the listener socket and starts the HTTPS server. -func (listener *Listener) StartTLS(errChan chan<- error, tlsSettings localtls.TlsSettings) error { - tlsConfig, err := GetTlsConfig(tlsSettings) - if err != nil { - log.Printf("[Listener] Failed to compose Tls Configuration with errror: %+v", err) - return err - } +// StartTLS creates the listener socket and starts the HTTPS server. +func (l *Listener) StartTLS(errChan chan<- error, tlsConfig *tls.Config, address string) error { server := http.Server{ TLSConfig: tlsConfig, - Handler: listener.mux, + Handler: l.mux, } - // listen on a seperate endpoint for secure tls connections - listener.securelistener, err = net.Listen(listener.protocol, tlsSettings.TLSEndpoint) + // listen on a separate endpoint for secure tls connections + list, err := net.Listen(l.protocol, address) if err != nil { log.Printf("[Listener] Failed to listen on TlsEndpoint: %+v", err) return err } - log.Printf("[Listener] Started listening on tls endpoint %s.", tlsSettings.TLSEndpoint) + + l.tlsListener = list + log.Printf("[Listener] Started listening on tls endpoint %s.", address) // Launch goroutine for servicing https requests go func() { - errChan <- server.ServeTLS(listener.securelistener, "", "") + errChan <- server.ServeTLS(l.tlsListener, "", "") }() - listener.active = true + l.active = true return nil } // Start creates the listener socket and starts the HTTP server. -func (listener *Listener) Start(errChan chan<- error) error { - var err error - - // Succeed early if no socket was requested. - if listener.localAddress == "null" { - return nil - } - - listener.l, err = net.Listen(listener.protocol, listener.localAddress) +func (l *Listener) Start(errChan chan<- error) error { + list, err := net.Listen(l.protocol, l.localAddress) if err != nil { log.Printf("[Listener] Failed to listen: %+v", err) return err } - log.Printf("[Listener] Started listening on %s.", listener.localAddress) + l.listener = list + log.Printf("[Listener] Started listening on %s.", l.localAddress) // Launch goroutine for servicing requests. go func() { - errChan <- http.Serve(listener.l, listener.mux) + errChan <- http.Serve(l.listener, l.mux) }() - listener.active = true + l.active = true return nil } // Stop stops listening for requests. -func (listener *Listener) Stop() { +func (l *Listener) Stop() { // Ignore if not active. - if !listener.active { + if !l.active { return } - listener.active = false + l.active = false // Stop servicing requests. - listener.l.Close() + _ = l.listener.Close() - if listener.securelistener != nil { + if l.tlsListener != nil { // Stop servicing requests on secure listener - listener.securelistener.Close() + _ = l.tlsListener.Close() } // Delete the unix socket. - if listener.protocol == "unix" { - os.Remove(listener.localAddress) + if l.protocol == "unix" { + _ = os.Remove(l.localAddress) } - log.Printf("[Listener] Stopped listening on %s", listener.localAddress) + log.Printf("[Listener] Stopped listening on %s", l.localAddress) } // GetMux returns the HTTP mux for the listener. -func (listener *Listener) GetMux() *http.ServeMux { - return listener.mux +func (l *Listener) GetMux() *http.ServeMux { + return l.mux } // GetEndpoints returns the list of registered protocol endpoints. -func (listener *Listener) GetEndpoints() []string { - return listener.endpoints +func (l *Listener) GetEndpoints() []string { + return l.endpoints } // AddEndpoint registers a protocol endpoint. -func (listener *Listener) AddEndpoint(endpoint string) { - listener.endpoints = append(listener.endpoints, endpoint) +func (l *Listener) AddEndpoint(endpoint string) { + l.endpoints = append(l.endpoints, endpoint) } // AddHandler registers a protocol handler. -func (listener *Listener) AddHandler(path string, handler http.HandlerFunc) { - listener.mux.HandleFunc(path, handler) +func (l *Listener) AddHandler(path string, handler http.HandlerFunc) { + l.mux.HandleFunc(path, handler) } +// todo: Decode and Encode below should not be methods, just functions. They make no use of Listener fields. + // Decode receives and decodes JSON payload to a request. -func (listener *Listener) Decode(w http.ResponseWriter, r *http.Request, request interface{}) error { +func (l *Listener) Decode(w http.ResponseWriter, r *http.Request, request interface{}) error { var err error - if r.Body == nil { - err = fmt.Errorf("Request body is empty") + err = errors.New("request body is empty") } else { err = json.NewDecoder(r.Body).Decode(request) } @@ -189,7 +149,7 @@ func (listener *Listener) Decode(w http.ResponseWriter, r *http.Request, request } // Encode encodes and sends a response as JSON payload. -func (listener *Listener) Encode(w http.ResponseWriter, response interface{}) error { +func (l *Listener) Encode(w http.ResponseWriter, response interface{}) error { // Set the content type as application json w.Header().Set("Content-Type", "application/json; charset=UTF-8") err := json.NewEncoder(w).Encode(response) diff --git a/go.mod b/go.mod index 6d69a407f4..d1373c3def 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,8 @@ module github.com/Azure/azure-container-networking go 1.18 require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.0.0 github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.7.1 github.com/Masterminds/semver v1.5.0 github.com/Microsoft/go-winio v0.4.17 @@ -45,12 +46,10 @@ require ( ) require ( + code.cloudfoundry.org/clock v1.0.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.5.0 // indirect -) - -require ( - code.cloudfoundry.org/clock v1.0.0 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v0.4.0 // indirect github.com/PuerkitoBio/purell v1.1.1 // indirect github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -68,6 +67,7 @@ require ( github.com/go-openapi/swag v0.19.14 // indirect github.com/gofrs/uuid v3.3.0+incompatible // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/gnostic v0.5.7-v3refs // indirect github.com/google/gofuzz v1.2.0 // indirect @@ -78,6 +78,7 @@ require ( github.com/ishidawataru/sctp v0.0.0-20210226210310-f2269e66cdee // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/labstack/echo/v4 v4.7.2 github.com/labstack/gommon v0.3.1 // indirect github.com/magiconair/properties v1.8.6 // indirect @@ -92,6 +93,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.0.1 // indirect + github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/procfs v0.7.3 // indirect diff --git a/go.sum b/go.sum index 45c429cbfe..f7bff8ddff 100644 --- a/go.sum +++ b/go.sum @@ -44,11 +44,11 @@ code.cloudfoundry.org/clock v0.0.0-20180518195852-02e53af36e6c/go.mod h1:QD9Lzhd code.cloudfoundry.org/clock v1.0.0 h1:kFXWQM4bxYvdBw2X8BbBeXwQNgfoWv1vqAk2ZZyBN2o= code.cloudfoundry.org/clock v1.0.0/go.mod h1:QD9Lzhd/ux6eNQVUDVRJX/RKTigpewimNYBi7ivZKY8= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/Azure/azure-sdk-for-go v16.2.1+incompatible h1:KnPIugL51v3N3WwvaSmZbxukD1WuWXOiE9fRdu32f2I= github.com/Azure/azure-sdk-for-go v16.2.1+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.0 h1:Ut0ZGdOwJDw0npYEg+TLlPls3Pq6JiZaP2/aGKir7Zw= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.0/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0 h1:sVPhtT2qjO86rTUaWMr4WoES4TkjGnzcioXcnHV9s5k= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.0.0 h1:Yoicul8bnVdQrhDMTHxdEckRGX01XvwXDHUT9zYZ3k0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.0.0/go.mod h1:+6sju8gk8FRmSajX3Oz4G5Gm7P+mbqE9FVaXXFYTkCM= github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0 h1:jp0dGvZ7ZK0mgqnTSClMxa5xuRL7NZgHameVYF6BurY= github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.7.1 h1:X7FHRMKr0u5YiPnD6L/nqG64XBOcK0IYavhAHBQEmms= @@ -71,6 +71,7 @@ github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZ github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/AzureAD/microsoft-authentication-library-for-go v0.4.0 h1:WVsrXCnHlDDX8ls+tootqRE87/hL9S/g4ewig9RsD/c= +github.com/AzureAD/microsoft-authentication-library-for-go v0.4.0/go.mod h1:Vt9sXTKwMyGcOxSmLDMnGPgqsUg7m8pe215qMLrDXw4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= @@ -376,7 +377,10 @@ github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXP github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -564,6 +568,7 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/labstack/echo/v4 v4.7.2 h1:Kv2/p8OaQ+M6Ex4eGimg9b9e6icoxA42JSlOR3msKtI= github.com/labstack/echo/v4 v4.7.2/go.mod h1:xkCDAdFCIf8jsFQ5NnbK7oqaF/yU1A1X20Ltm0OvSks= github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= @@ -629,6 +634,7 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/montanaflynn/stats v0.6.6/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ= github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= @@ -678,6 +684,7 @@ github.com/pelletier/go-toml/v2 v2.0.1 h1:8e3L2cCQzLFi2CR4g7vGFuFxX7Jl1kKX8gW+iV github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI= +github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/keyvault/certrefresher.go b/keyvault/certrefresher.go new file mode 100644 index 0000000000..11942f5f22 --- /dev/null +++ b/keyvault/certrefresher.go @@ -0,0 +1,130 @@ +package keyvault + +import ( + "context" + //nolint:gosec // sha1 only used to display cert thumbprint in logs for cross-verification with keyvault. + "crypto/sha1" + "crypto/tls" + "fmt" + "sync" + "time" + + "github.com/avast/retry-go/v3" + "github.com/pkg/errors" +) + +type EventualExpirationErr struct { + time.Time +} + +func (e *EventualExpirationErr) Error() string { + return fmt.Sprintf("could not refresh before expiration on %s", e.Time.String()) +} + +type tlsCertFetcher interface { + GetLatestTLSCertificate(ctx context.Context, certName string) (tls.Certificate, error) +} + +type logger interface { + Printf(format string, args ...any) + Errorf(format string, args ...any) +} + +// CertRefresher offers a mechanism to present the latest version of a tls.Certificate from KeyVault, refreshed at an interval. +type CertRefresher struct { + certName string + kvc tlsCertFetcher + logger logger + + m sync.RWMutex + cert *tls.Certificate +} + +// NewCertRefresher returns a CertRefresher. When there's no error, the CertRefresher's GetCertificate method is ready +// for use, returning a valid tls.Certificate fetched from KeyVault during construction. +func NewCertRefresher(ctx context.Context, kvc tlsCertFetcher, l logger, certName string) (*CertRefresher, error) { + cf := CertRefresher{ + certName: certName, + kvc: kvc, + logger: l, + } + + cert, err := cf.kvc.GetLatestTLSCertificate(ctx, cf.certName) + if err != nil { + return nil, errors.Wrap(err, "could not fetch initial cert") + } + + cf.cert = &cert + cf.logger.Printf("initial certificate fetched: %s", &cf) + return &cf, nil +} + +func (c *CertRefresher) String() string { + return fmt.Sprintf("cert name: %s, sha1 thumbprint: %s, expiration: %s", c.certName, sha1String(c.cert.Leaf.Raw), c.cert.Leaf.NotAfter.String()) +} + +// GetCertificate returns the latest certificate fetched from KeyVault. +func (c *CertRefresher) GetCertificate() *tls.Certificate { + c.m.RLock() + defer c.m.RUnlock() + return c.cert +} + +// Refresh starts refreshing the certificate at the interval provided. +// It blocks until context is done or refreshing fails. +func (c *CertRefresher) Refresh(ctx context.Context, interval time.Duration) error { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return errors.Wrap(ctx.Err(), "refresh canceled") + case <-ticker.C: + if err := c.refresh(ctx); err != nil { + c.logger.Errorf("could not refresh before certificate expiration on %s: %v", c.cert.Leaf.NotAfter, err) + return &EventualExpirationErr{c.cert.Leaf.NotAfter} + } + } + } +} + +// refresh will attempt to fetch the latest version of a certificate, up until the current one expires. +func (c *CertRefresher) refresh(ctx context.Context) error { + certExpires := c.cert.Leaf.NotAfter + ctx, cancel := context.WithDeadline(ctx, certExpires) + defer cancel() + + var latestCert tls.Certificate + retryFn := func() (err error) { + latestCert, err = c.kvc.GetLatestTLSCertificate(ctx, c.certName) + if err != nil { + c.logger.Errorf("could not fetch latest tls certificate: %v. retrying...", err) + return errors.Wrap(err, "could not fetch latest tls certificate") + } + return nil + } + + if err := retry.Do(retryFn, retry.Context(ctx), retry.Delay(time.Second), retry.DelayType(retry.FixedDelay)); err != nil { + return errors.Wrap(err, "could not refresh cert") + } + + c.m.Lock() + defer c.m.Unlock() + + if latestCert.Leaf.Equal(c.cert.Leaf) { + c.logger.Printf("certificate unchanged. certificate %s", c) + return nil + } + + oldThumbprint := sha1String(c.cert.Leaf.Raw) + c.cert = &latestCert + c.logger.Printf("certificate refreshed. old sha1 thumbprint: %s, certificate: %s", oldThumbprint, c) + + return nil +} + +func sha1String(bs []byte) string { + //nolint:gosec // sha1 only used to display cert thumbprint in logs for cross-verification with keyvault. + return fmt.Sprintf("%X", sha1.Sum(bs)) +} diff --git a/keyvault/certrefresher_test.go b/keyvault/certrefresher_test.go new file mode 100644 index 0000000000..ed16bd4b77 --- /dev/null +++ b/keyvault/certrefresher_test.go @@ -0,0 +1,145 @@ +package keyvault + +import ( + "context" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "sync" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCertRefresher(t *testing.T) { + ctx, cancel := testContext(t) + defer cancel() + + // returns a different cert on every invocation, until context is done + tlsFn := tlsFunc(func() (tls.Certificate, error) { + if err := ctx.Err(); err != nil { + return tls.Certificate{}, errors.Wrap(err, "context done") + } + + bs := make([]byte, 100) + _, _ = rand.Read(bs) + leaf := x509.Certificate{Raw: bs, NotAfter: time.Now().Add(time.Minute)} + + return tls.Certificate{Leaf: &leaf}, nil + }) + + cf, err := NewCertRefresher(ctx, tlsFn, testLogger{t}, "dummy") + require.NoError(t, err) + + // a new cert should be loaded roughly every second + go func() { _ = cf.Refresh(ctx, time.Second) }() + + thumbprintSet := stringSet{ts: make(map[string]struct{})} + + // spin multiple concurrent readers, collecting unique thumbprints for eventual assertion + for i := 0; i < 10; i++ { + go readAndCollect(ctx, cf, &thumbprintSet, time.Millisecond*300) + } + + waitFor := time.Second * 10 + // at least this many unique certs should eventually be seen + condFn := func() bool { return thumbprintSet.len() > 5 } + checkEvery := time.Second + + assert.Eventually(t, condFn, waitFor, checkEvery) +} + +func TestCertRefresher_RetryUntilExpiration(t *testing.T) { + ctx, cancel := testContext(t) + defer cancel() + + called := false + // returns a cert with short expiration once, then consistently errors + tlsFn := tlsFunc(func() (tls.Certificate, error) { + if called { + return tls.Certificate{}, errors.New("some error") + } + called = true + leaf := x509.Certificate{Raw: []byte{0}, NotAfter: time.Now().Add(time.Second * 5)} + return tls.Certificate{Leaf: &leaf}, nil + }) + + cf, err := NewCertRefresher(ctx, tlsFn, testLogger{t}, "dummy") + require.NoError(t, err) + + errCh := make(chan error, 1) + go func() { errCh <- cf.Refresh(ctx, time.Second) }() + + waitFor := time.Second * 10 + condFn := func() bool { + select { + case err := <-errCh: + var expErr *EventualExpirationErr + return errors.As(err, &expErr) + default: + } + return false + } + checkEvery := time.Second + + assert.Eventually(t, condFn, waitFor, checkEvery) +} + +type tlsFunc func() (tls.Certificate, error) + +func (t tlsFunc) GetLatestTLSCertificate(_ context.Context, _ string) (tls.Certificate, error) { + return t() +} + +type stringSet struct { + sync.RWMutex + ts map[string]struct{} +} + +func (s *stringSet) add(val string) { + s.Lock() + s.ts[val] = struct{}{} + s.Unlock() +} + +func (s *stringSet) len() int { + s.RLock() + defer s.RUnlock() + return len(s.ts) +} + +func readAndCollect(ctx context.Context, cf *CertRefresher, thumbprintSet *stringSet, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + cert := cf.GetCertificate() + thumbprintSet.add(sha1String(cert.Leaf.Raw)) + } + } +} + +type testLogger struct{ *testing.T } + +func (t testLogger) Printf(format string, args ...any) { + t.Logf(format, args...) +} + +func (t testLogger) Errorf(format string, args ...any) { + t.Logf(format, args...) +} + +// todo: move to a better package for reuse +func testContext(t *testing.T) (context.Context, context.CancelFunc) { + ctx := context.Background() + if deadline, ok := t.Deadline(); ok { + return context.WithDeadline(ctx, deadline) + } + return context.WithCancel(ctx) +} diff --git a/server/tls/tlscertificate_retriever.go b/server/tls/tlscertificate_retriever.go index 9e9cf8e845..28d7f6e952 100644 --- a/server/tls/tlscertificate_retriever.go +++ b/server/tls/tlscertificate_retriever.go @@ -2,12 +2,17 @@ package tls -// TlsCertificateSettins - Details related to the TLS certificate. +import "time" + +// TlsSettings - Details related to the TLS certificate. type TlsSettings struct { - TLSSubjectName string - TLSCertificatePath string - TLSEndpoint string - TLSPort string + TLSSubjectName string + TLSCertificatePath string + TLSPort string + KeyVaultURL string + KeyVaultCertificateName string + MSIResourceID string + KeyVaultCertificateRefreshInterval time.Duration } func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {