Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cns/configuration/cns_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,13 @@
"TLSPort": "10091",
"TLSSubjectName": "",
"UseHTTPS": false,
"WireserverIP": "168.63.129.16"
"WireserverIP": "168.63.129.16",
"KeyVaultSettings": {
"URL": "",
"CertificateName": "",
"RefreshIntervalInHrs": 12
},
"MSISettings": {
"ResourceID": ""
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding these to keep this documentation up to date

}
20 changes: 20 additions & 0 deletions cns/configuration/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type CNSConfig struct {
TelemetrySettings TelemetrySettings
UseHTTPS bool
WireserverIP string
KeyVaultSettings KeyVaultSettings
MSISettings MSISettings
}

type TelemetrySettings struct {
Expand Down Expand Up @@ -67,6 +69,16 @@ type ManagedSettings struct {
NodeSyncIntervalInSeconds int
}

type MSISettings struct {
ResourceID string
}

type KeyVaultSettings struct {
URL string
CertificateName string
RefreshIntervalInHrs int
}

func getConfigFilePath(cmdLineConfigPath string) (string, error) {
// If config path is set from cmd line, return that
if cmdLineConfigPath != "" {
Expand Down Expand Up @@ -144,10 +156,18 @@ func setManagedSettingDefaults(managedSettings *ManagedSettings) {
}
}

func setKeyVaultSettingsDefaults(kvs *KeyVaultSettings) {
if kvs.RefreshIntervalInHrs == 0 {
kvs.RefreshIntervalInHrs = 12 //nolint:gomnd // default times
}
}

// SetCNSConfigDefaults set default values of CNS config if not specified
func SetCNSConfigDefaults(config *CNSConfig) {
setTelemetrySettingDefaults(&config.TelemetrySettings)
setManagedSettingDefaults(&config.ManagedSettings)
setKeyVaultSettingsDefaults(&config.KeyVaultSettings)

if config.ChannelMode == "" {
config.ChannelMode = cns.Direct
}
Expand Down
9 changes: 9 additions & 0 deletions cns/configuration/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ func TestSetCNSConfigDefaults(t *testing.T) {
RefreshIntervalInSecs: 15,
SnapshotIntervalInMins: 60,
},
KeyVaultSettings: KeyVaultSettings{
RefreshIntervalInHrs: 12,
},
},
},
{
Expand All @@ -220,6 +223,9 @@ func TestSetCNSConfigDefaults(t *testing.T) {
RefreshIntervalInSecs: 1,
SnapshotIntervalInMins: 6,
},
KeyVaultSettings: KeyVaultSettings{
RefreshIntervalInHrs: 3,
},
},
want: CNSConfig{
ChannelMode: "Other",
Expand All @@ -236,6 +242,9 @@ func TestSetCNSConfigDefaults(t *testing.T) {
RefreshIntervalInSecs: 1,
SnapshotIntervalInMins: 6,
},
KeyVaultSettings: KeyVaultSettings{
RefreshIntervalInHrs: 3,
},
},
},
}
Expand Down
103 changes: 100 additions & 3 deletions cns/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@
package cns

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"

"github.com/Azure/azure-container-networking/cns/common"
"github.com/Azure/azure-container-networking/cns/logger"
acn "github.com/Azure/azure-container-networking/common"
"github.com/Azure/azure-container-networking/keyvault"
"github.com/Azure/azure-container-networking/log"
localtls "github.com/Azure/azure-container-networking/server/tls"
"github.com/Azure/azure-container-networking/store"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/pkg/errors"
)

Expand Down Expand Up @@ -68,19 +74,27 @@ func (service *Service) Initialize(config *common.ServiceConfig) error {
if err != nil {
return err
}
// Create the listener.

listener, err := acn.NewListener(u)
if err != nil {
return err
}

if config.TlsSettings.TLSPort != "" {
// listener.URL.Host will always be hostname:port, passed in to CNS via CNS command
// else it will default to localhost
// extract hostname and override tls port.
hostParts := strings.Split(listener.URL.Host, ":")
config.TlsSettings.TLSEndpoint = hostParts[0] + ":" + config.TlsSettings.TLSPort
tlsAddress := net.JoinHostPort(hostParts[0], config.TlsSettings.TLSPort)

// Start the listener and HTTP and HTTPS server.
if err = listener.StartTLS(config.ErrChan, config.TlsSettings); err != nil {
tlsConfig, err := getTLSConfig(config.TlsSettings, config.ErrChan)
if err != nil {
log.Printf("Failed to compose Tls Configuration with error: %+v", err)
return errors.Wrap(err, "could not get tls config")
}

if err := listener.StartTLS(config.ErrChan, tlsConfig, tlsAddress); err != nil {
return err
}
}
Expand All @@ -95,6 +109,89 @@ func (service *Service) Initialize(config *common.ServiceConfig) error {
return nil
}

func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.Config, error) {
if tlsSettings.TLSCertificatePath != "" {
return getTLSConfigFromFile(tlsSettings)
}

if tlsSettings.KeyVaultURL != "" {
return getTLSConfigFromKeyVault(tlsSettings, errChan)
}

return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings)
}

func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) {
tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings)
if err != nil {
return nil, errors.Wrap(err, "failed to get certificate retriever")
}

leafCertificate, err := tlsCertRetriever.GetCertificate()
if err != nil {
return nil, errors.Wrap(err, "failed to get certificate")
}

if leafCertificate == nil {
return nil, errors.New("certificate retrieval returned empty")
}

privateKey, err := tlsCertRetriever.GetPrivateKey()
if err != nil {
return nil, errors.Wrap(err, "failed to get certificate private key")
}

tlsCert := tls.Certificate{
Certificate: [][]byte{leafCertificate.Raw},
PrivateKey: privateKey,
Leaf: leafCertificate,
}

tlsConfig := &tls.Config{
MaxVersion: tls.VersionTLS13,
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{
tlsCert,
},
}

return tlsConfig, nil
}

func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.Config, error) {
credOpts := azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ResourceID(tlsSettings.MSIResourceID)}
cred, err := azidentity.NewManagedIdentityCredential(&credOpts)
if err != nil {
return nil, errors.Wrap(err, "could not create managed identity credential")
}

kvs, err := keyvault.NewShim(tlsSettings.KeyVaultURL, cred)
if err != nil {
return nil, errors.Wrap(err, "could not create new keyvault shim")
}

ctx := context.TODO()

cr, err := keyvault.NewCertRefresher(ctx, kvs, logger.Log, tlsSettings.KeyVaultCertificateName)
if err != nil {
return nil, errors.Wrap(err, "could not create new cert refresher")
}

go func() {
errChan <- cr.Refresh(ctx, tlsSettings.KeyVaultCertificateRefreshInterval)
}()

tlsConfig := tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
return cr.GetCertificate(), nil
},
}

return &tlsConfig, nil
}

func (service *Service) StartListener(config *common.ServiceConfig) error {
log.Debugf("[Azure CNS] Going to start listener: %+v", config)

Expand Down
10 changes: 7 additions & 3 deletions cns/service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,13 @@ func main() {
if httpRestService != nil {
if cnsconfig.UseHTTPS {
config.TlsSettings = localtls.TlsSettings{
TLSSubjectName: cnsconfig.TLSSubjectName,
TLSCertificatePath: cnsconfig.TLSCertificatePath,
TLSPort: cnsconfig.TLSPort,
TLSSubjectName: cnsconfig.TLSSubjectName,
TLSCertificatePath: cnsconfig.TLSCertificatePath,
TLSPort: cnsconfig.TLSPort,
KeyVaultURL: cnsconfig.KeyVaultSettings.URL,
KeyVaultCertificateName: cnsconfig.KeyVaultSettings.CertificateName,
MSIResourceID: cnsconfig.MSISettings.ResourceID,
KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour,
}
}

Expand Down
Loading