diff --git a/go.mod b/go.mod index c5113c19e..606494105 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,17 @@ module github.com/Azure/secrets-store-csi-driver-provider-azure go 1.19 require ( - github.com/Azure/azure-sdk-for-go v68.0.0+incompatible github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 + github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates v0.8.0 + github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.9.0 + github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.11.0 github.com/Azure/go-autorest/autorest v0.11.28 github.com/Azure/go-autorest/autorest/adal v0.9.22 github.com/Azure/go-autorest/autorest/date v0.3.0 + github.com/Azure/go-autorest/autorest/to v0.4.0 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.8 - github.com/jongio/azidext/go/azidext v0.4.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.1 go.opentelemetry.io/otel v0.20.0 @@ -28,9 +30,8 @@ require ( require ( github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.0 // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect - github.com/Azure/go-autorest/autorest/to v0.4.0 // indirect - github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect github.com/Azure/go-autorest/logger v0.2.1 // indirect github.com/Azure/go-autorest/tracing v0.6.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1 // indirect diff --git a/go.sum b/go.sum index 8e6d9b086..8222b9b2a 100644 --- a/go.sum +++ b/go.sum @@ -31,14 +31,20 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= -github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1 h1:gVXuXcWd1i4C2Ruxe321aU+IKGaStvGB/S90PUPB/W8= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1/go.mod h1:DffdKW9RFqa5VgmsjUOsS7UE7eiA5iAvYUs63bhKQ0M= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 h1:T8quHYlUGyb/oqtSTwqlCr1ilJHrDv+ZtpSfo+hm1BU= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1/go.mod h1:gLa1CL2RNE4s7M3yopJ/p0iq5DdY6Yv5ZUt9MTRZOQM= github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8= github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates v0.8.0 h1:edn/e2qs1fEkPHlZqbESJWhFai9Pk/UA5eiwFUA1nwI= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates v0.8.0/go.mod h1:8eUJPoEz7doIqSwW2pAvLGhEy3mDC9o/ToCa8OZy7go= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.9.0 h1:TOFrNxfjslms5nLLIMjW7N0+zSALX4KiGsptmpb16AA= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.9.0/go.mod h1:EAyXOW1F6BTJPiK2pDvmnvxOHPxoTYWoqBeIlql+QhI= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.11.0 h1:82w8tzLcOwDP/Q35j/wEBPt0n0kVC3cjtPdD62G8UAk= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.11.0/go.mod h1:S78i9yTr4o/nXlH76bKjGUye9Z2wSxO5Tz7GoDr4vfI= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.0 h1:Lg6BW0VPmCwcMlvOviL3ruHFO+H9tZNqscK0AeuFjGM= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.0/go.mod h1:9V2j0jn9jDEkCkv8w/bKTNppX/d0FVA1ud77xCIP4KA= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= github.com/Azure/go-autorest/autorest v0.11.28 h1:ndAExarwr5Y+GaHE6VCaY1kyS/HwwGGyuimVhWsHOEM= @@ -53,8 +59,6 @@ github.com/Azure/go-autorest/autorest/mocks v0.4.2 h1:PGN4EDXnuQbojHbU0UWoNvmu9A github.com/Azure/go-autorest/autorest/mocks v0.4.2/go.mod h1:Vy7OitM9Kei0i1Oj+LvyAWMXJHeKH1MVlzFugfVrmyU= github.com/Azure/go-autorest/autorest/to v0.4.0 h1:oXVqrxakqqV1UZdSazDOPOLvOIz+XA683u8EctwboHk= github.com/Azure/go-autorest/autorest/to v0.4.0/go.mod h1:fE8iZBn7LQR7zH/9XU2NcPR4o9jEImooCeWJcYV/zLE= -github.com/Azure/go-autorest/autorest/validation v0.3.1 h1:AgyqjAd94fwNAoTjl/WQXg4VvFeRFpO+UhNyRXqF1ac= -github.com/Azure/go-autorest/autorest/validation v0.3.1/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E= github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg= github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= @@ -266,10 +270,7 @@ github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7P github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= -github.com/jongio/azidext/go/azidext v0.4.0 h1:TOYyVFMeWGgXNhURSgrEtUCu7JAAKgsy+5C4+AEfYlw= -github.com/jongio/azidext/go/azidext v0.4.0/go.mod h1:VrlpGde5B+pPbTUxnThE5UIQQkcebdr3jrC2MmlMVSI= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index c5844d260..eb14dc993 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -15,9 +15,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/adal" - "github.com/jongio/azidext/go/azidext" "github.com/pkg/errors" "k8s.io/klog/v2" ) @@ -137,30 +135,21 @@ func NewConfig( return config, nil } -// GetAuthorizer returns an Azure authorizer based on the provided azure identity -func (c Config) GetAuthorizer(podName, podNamespace, resource, aadEndpoint, tenantID, nmiPort string) (autorest.Authorizer, error) { - var cred azcore.TokenCredential - var err error - +// GetCredential returns the azure credential to use based on the auth config +func (c Config) GetCredential(podName, podNamespace, resource, aadEndpoint, tenantID, nmiPort string) (azcore.TokenCredential, error) { // use switch case to ensure only one of the identity modes is enabled switch { case c.UsePodIdentity: - cred, err = getAuthorizerForPodIdentity(podName, podNamespace, resource, tenantID, nmiPort) + return getPodIdentityTokenCredential(podName, podNamespace, resource, tenantID, nmiPort) case c.UseVMManagedIdentity: - cred, err = getAuthorizerForManagedIdentity(c.UserAssignedIdentityID) + return getManagedIdentityTokenCredential(c.UserAssignedIdentityID) case len(c.AADClientSecret) > 0 && len(c.AADClientID) > 0: - cred, err = getAuthorizerForServicePrincipal(c.AADClientID, c.AADClientSecret, aadEndpoint, tenantID) + return getServicePrincipalTokenCredential(c.AADClientID, c.AADClientSecret, aadEndpoint, tenantID) case len(c.WorkloadIdentityClientID) > 0 && len(c.WorkloadIdentityToken) > 0: - cred, err = getAuthorizerForWorkloadIdentity(c.WorkloadIdentityClientID, c.WorkloadIdentityToken, aadEndpoint, tenantID) + return getWorkloadIdentityTokenCredential(c.WorkloadIdentityClientID, c.WorkloadIdentityToken, aadEndpoint, tenantID) default: return nil, fmt.Errorf("no identity mode is enabled") } - - if err != nil { - return nil, err - } - - return azidext.NewTokenCredentialAdapter(cred, []string{getScope(resource)}), nil } func newWorkloadIdentityCredential(tenantID, clientID, assertion string, options *workloadIdentityCredentialOptions) (azcore.TokenCredential, error) { @@ -181,7 +170,7 @@ func (w *workloadIdentityCredential) getAssertion(context.Context) (string, erro return w.assertion, nil } -func getAuthorizerForWorkloadIdentity(clientID, signedAssertion, aadEndpoint, tenantID string) (azcore.TokenCredential, error) { +func getWorkloadIdentityTokenCredential(clientID, signedAssertion, aadEndpoint, tenantID string) (azcore.TokenCredential, error) { opts := &workloadIdentityCredentialOptions{ ClientOptions: azcore.ClientOptions{ Cloud: cloud.Configuration{ @@ -192,7 +181,7 @@ func getAuthorizerForWorkloadIdentity(clientID, signedAssertion, aadEndpoint, te return newWorkloadIdentityCredential(tenantID, clientID, signedAssertion, opts) } -func getAuthorizerForServicePrincipal(clientID, secret, aadEndpoint, tenantID string) (azcore.TokenCredential, error) { +func getServicePrincipalTokenCredential(clientID, secret, aadEndpoint, tenantID string) (azcore.TokenCredential, error) { opts := &azidentity.ClientSecretCredentialOptions{ ClientOptions: azcore.ClientOptions{ Cloud: cloud.Configuration{ @@ -203,7 +192,7 @@ func getAuthorizerForServicePrincipal(clientID, secret, aadEndpoint, tenantID st return azidentity.NewClientSecretCredential(tenantID, clientID, secret, opts) } -func getAuthorizerForManagedIdentity(identityClientID string) (azcore.TokenCredential, error) { +func getManagedIdentityTokenCredential(identityClientID string) (azcore.TokenCredential, error) { opts := &azidentity.ManagedIdentityCredentialOptions{ ID: azidentity.ClientID(identityClientID), } @@ -258,7 +247,7 @@ func (c *podIdentityCredential) GetToken(ctx context.Context, _ policy.TokenRequ }, nil } -func getAuthorizerForPodIdentity(podName, podNamespace, resource, tenantID, nmiPort string) (azcore.TokenCredential, error) { +func getPodIdentityTokenCredential(podName, podNamespace, resource, tenantID, nmiPort string) (azcore.TokenCredential, error) { if len(podName) == 0 || len(podNamespace) == 0 { return nil, fmt.Errorf("pod information is not available. deploy a CSIDriver object to set podInfoOnMount: true") } diff --git a/pkg/provider/keyvault.go b/pkg/provider/keyvault.go new file mode 100644 index 000000000..f9ba8a57a --- /dev/null +++ b/pkg/provider/keyvault.go @@ -0,0 +1,174 @@ +package provider + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates" + "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys" + "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" + "github.com/Azure/go-autorest/autorest/date" + + "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/provider/types" +) + +type KeyVault interface { + GetSecret(ctx context.Context, name, version string) (*azsecrets.SecretBundle, error) + GetSecretVersions(ctx context.Context, name string) ([]types.KeyVaultObjectVersion, error) + GetKey(ctx context.Context, name, version string) (*azkeys.KeyBundle, error) + GetKeyVersions(ctx context.Context, name string) ([]types.KeyVaultObjectVersion, error) + GetCertificate(ctx context.Context, name, version string) (*azcertificates.CertificateBundle, error) + GetCertificateVersions(ctx context.Context, name string) ([]types.KeyVaultObjectVersion, error) +} + +// TODO(aramase): add user agent +type client struct { + secrets *azsecrets.Client + keys *azkeys.Client + certs *azcertificates.Client +} + +// NewClient creates a new KeyVault client +func NewClient(cred azcore.TokenCredential, vaultURI string) (KeyVault, error) { + secrets, err := azsecrets.NewClient(vaultURI, cred, nil) + if err != nil { + return nil, err + } + keys, err := azkeys.NewClient(vaultURI, cred, nil) + if err != nil { + return nil, err + } + certs, err := azcertificates.NewClient(vaultURI, cred, nil) + if err != nil { + return nil, err + } + + return &client{ + secrets: secrets, + keys: keys, + certs: certs, + }, nil +} + +func (c *client) GetSecret(ctx context.Context, name, version string) (*azsecrets.SecretBundle, error) { + resp, err := c.secrets.GetSecret(ctx, name, version, &azsecrets.GetSecretOptions{}) + if err != nil { + return nil, err + } + return &resp.SecretBundle, nil +} + +func (c *client) GetKey(ctx context.Context, name, version string) (*azkeys.KeyBundle, error) { + resp, err := c.keys.GetKey(ctx, name, version, &azkeys.GetKeyOptions{}) + if err != nil { + return nil, err + } + return &resp.KeyBundle, nil +} + +func (c *client) GetCertificate(ctx context.Context, name, version string) (*azcertificates.CertificateBundle, error) { + resp, err := c.certs.GetCertificate(ctx, name, version, &azcertificates.GetCertificateOptions{}) + if err != nil { + return nil, err + } + return &resp.CertificateBundle, nil +} + +func (c *client) GetSecretVersions(ctx context.Context, name string) ([]types.KeyVaultObjectVersion, error) { + pager := c.secrets.NewListSecretVersionsPager(name, &azsecrets.ListSecretVersionsOptions{}) + var versions []types.KeyVaultObjectVersion + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + for _, secret := range page.SecretListResult.Value { + if secret.Attributes == nil { + continue + } + if secret.Attributes.Enabled != nil && !*secret.Attributes.Enabled { + continue + } + + id := *secret.ID + created := date.UnixEpoch() + if secret.Attributes.Created != nil { + created = *secret.Attributes.Created + } + + versions = append(versions, types.KeyVaultObjectVersion{ + Version: id.Version(), + Created: created, + }) + } + } + + return versions, nil +} + +func (c *client) GetKeyVersions(ctx context.Context, name string) ([]types.KeyVaultObjectVersion, error) { + pager := c.keys.NewListKeyVersionsPager(name, &azkeys.ListKeyVersionsOptions{}) + var versions []types.KeyVaultObjectVersion + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + for _, key := range page.KeyListResult.Value { + if key.Attributes == nil { + continue + } + if key.Attributes.Enabled != nil && !*key.Attributes.Enabled { + continue + } + + id := *key.KID + created := date.UnixEpoch() + if key.Attributes.Created != nil { + created = *key.Attributes.Created + } + + versions = append(versions, types.KeyVaultObjectVersion{ + Version: id.Version(), + Created: created, + }) + } + } + + return versions, nil +} + +func (c *client) GetCertificateVersions(ctx context.Context, name string) ([]types.KeyVaultObjectVersion, error) { + pager := c.certs.NewListCertificateVersionsPager(name, &azcertificates.ListCertificateVersionsOptions{}) + var versions []types.KeyVaultObjectVersion + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + for _, cert := range page.CertificateListResult.Value { + if cert.Attributes == nil { + continue + } + if cert.Attributes.Enabled != nil && !*cert.Attributes.Enabled { + continue + } + + id := *cert.ID + created := date.UnixEpoch() + if cert.Attributes.Created != nil { + created = *cert.Attributes.Created + } + + versions = append(versions, types.KeyVaultObjectVersion{ + Version: id.Version(), + Created: created, + }) + } + } + + return versions, nil +} diff --git a/pkg/provider/mock_keyvault/doc.go b/pkg/provider/mock_keyvault/doc.go new file mode 100644 index 000000000..eac039694 --- /dev/null +++ b/pkg/provider/mock_keyvault/doc.go @@ -0,0 +1,4 @@ +// Run go generate to regenerate this mock. +// +//go:generate ../../../.tools/mockgen -destination keyvault_mock.go -package mock_keyvault -source ../keyvault.go +package mock_keyvault //nolint diff --git a/pkg/provider/mock_keyvault/keyvault_mock.go b/pkg/provider/mock_keyvault/keyvault_mock.go new file mode 100644 index 000000000..1c1235e7c --- /dev/null +++ b/pkg/provider/mock_keyvault/keyvault_mock.go @@ -0,0 +1,129 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../keyvault.go + +// Package mock_keyvault is a generated GoMock package. +package mock_keyvault + +import ( + context "context" + reflect "reflect" + + azcertificates "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates" + azkeys "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys" + azsecrets "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" + types "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/provider/types" + gomock "github.com/golang/mock/gomock" +) + +// MockKeyVault is a mock of KeyVault interface. +type MockKeyVault struct { + ctrl *gomock.Controller + recorder *MockKeyVaultMockRecorder +} + +// MockKeyVaultMockRecorder is the mock recorder for MockKeyVault. +type MockKeyVaultMockRecorder struct { + mock *MockKeyVault +} + +// NewMockKeyVault creates a new mock instance. +func NewMockKeyVault(ctrl *gomock.Controller) *MockKeyVault { + mock := &MockKeyVault{ctrl: ctrl} + mock.recorder = &MockKeyVaultMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockKeyVault) EXPECT() *MockKeyVaultMockRecorder { + return m.recorder +} + +// GetCertificate mocks base method. +func (m *MockKeyVault) GetCertificate(ctx context.Context, name, version string) (*azcertificates.CertificateBundle, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCertificate", ctx, name, version) + ret0, _ := ret[0].(*azcertificates.CertificateBundle) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCertificate indicates an expected call of GetCertificate. +func (mr *MockKeyVaultMockRecorder) GetCertificate(ctx, name, version interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCertificate", reflect.TypeOf((*MockKeyVault)(nil).GetCertificate), ctx, name, version) +} + +// GetCertificateVersions mocks base method. +func (m *MockKeyVault) GetCertificateVersions(ctx context.Context, name string) ([]types.KeyVaultObjectVersion, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCertificateVersions", ctx, name) + ret0, _ := ret[0].([]types.KeyVaultObjectVersion) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCertificateVersions indicates an expected call of GetCertificateVersions. +func (mr *MockKeyVaultMockRecorder) GetCertificateVersions(ctx, name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCertificateVersions", reflect.TypeOf((*MockKeyVault)(nil).GetCertificateVersions), ctx, name) +} + +// GetKey mocks base method. +func (m *MockKeyVault) GetKey(ctx context.Context, name, version string) (*azkeys.KeyBundle, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKey", ctx, name, version) + ret0, _ := ret[0].(*azkeys.KeyBundle) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKey indicates an expected call of GetKey. +func (mr *MockKeyVaultMockRecorder) GetKey(ctx, name, version interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*MockKeyVault)(nil).GetKey), ctx, name, version) +} + +// GetKeyVersions mocks base method. +func (m *MockKeyVault) GetKeyVersions(ctx context.Context, name string) ([]types.KeyVaultObjectVersion, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKeyVersions", ctx, name) + ret0, _ := ret[0].([]types.KeyVaultObjectVersion) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKeyVersions indicates an expected call of GetKeyVersions. +func (mr *MockKeyVaultMockRecorder) GetKeyVersions(ctx, name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeyVersions", reflect.TypeOf((*MockKeyVault)(nil).GetKeyVersions), ctx, name) +} + +// GetSecret mocks base method. +func (m *MockKeyVault) GetSecret(ctx context.Context, name, version string) (*azsecrets.SecretBundle, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSecret", ctx, name, version) + ret0, _ := ret[0].(*azsecrets.SecretBundle) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSecret indicates an expected call of GetSecret. +func (mr *MockKeyVaultMockRecorder) GetSecret(ctx, name, version interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSecret", reflect.TypeOf((*MockKeyVault)(nil).GetSecret), ctx, name, version) +} + +// GetSecretVersions mocks base method. +func (m *MockKeyVault) GetSecretVersions(ctx context.Context, name string) ([]types.KeyVaultObjectVersion, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSecretVersions", ctx, name) + ret0, _ := ret[0].([]types.KeyVaultObjectVersion) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSecretVersions indicates an expected call of GetSecretVersions. +func (mr *MockKeyVaultMockRecorder) GetSecretVersions(ctx, name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSecretVersions", reflect.TypeOf((*MockKeyVault)(nil).GetSecretVersions), ctx, name) +} diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index c2014ed7a..8a24b4a05 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -22,12 +22,9 @@ import ( "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/auth" "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/metrics" "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/provider/types" - "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/version" - kv "github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault" - "github.com/Azure/go-autorest/autorest" + "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys" "github.com/Azure/go-autorest/autorest/azure" - "github.com/Azure/go-autorest/autorest/date" "github.com/pkg/errors" "golang.org/x/crypto/pkcs12" "golang.org/x/net/context" @@ -78,8 +75,8 @@ func NewProvider(constructPEMChain, writeCertAndKeyInSeparateFiles bool) Interfa } } -// ParseAzureEnvironment returns azure environment by name -func ParseAzureEnvironment(cloudName string) (*azure.Environment, error) { +// parseAzureEnvironment returns azure environment by name +func parseAzureEnvironment(cloudName string) (*azure.Environment, error) { var env azure.Environment var err error if cloudName == "" { @@ -90,20 +87,14 @@ func ParseAzureEnvironment(cloudName string) (*azure.Environment, error) { return &env, err } -func (mc *mountConfig) initializeKvClient() (*kv.BaseClient, error) { - kvClient := kv.New() +func (mc *mountConfig) initializeKvClient(vaultURI string) (KeyVault, error) { kvEndpoint := strings.TrimSuffix(mc.azureCloudEnvironment.KeyVaultEndpoint, "/") - err := kvClient.AddToUserAgent(version.GetUserAgent()) + cred, err := mc.authConfig.GetCredential(mc.podName, mc.podNamespace, kvEndpoint, mc.azureCloudEnvironment.ActiveDirectoryEndpoint, mc.tenantID, types.PodIdentityNMIPort) if err != nil { - return nil, errors.Wrapf(err, "failed to add user agent to keyvault client") - } - - kvClient.Authorizer, err = mc.GetAuthorizer(kvEndpoint) - if err != nil { - return nil, errors.Wrapf(err, "failed to get authorizer for keyvault client") + return nil, err } - return &kvClient, nil + return NewClient(cred, vaultURI) } func (mc *mountConfig) getVaultURL() (vaultURL *string, err error) { @@ -122,11 +113,6 @@ func (mc *mountConfig) getVaultURL() (vaultURL *string, err error) { return &vaultURI, nil } -// GetAuthorizer returns an Azure authorizer based on the provided azure identity -func (mc *mountConfig) GetAuthorizer(resource string) (autorest.Authorizer, error) { - return mc.authConfig.GetAuthorizer(mc.podName, mc.podNamespace, resource, mc.azureCloudEnvironment.ActiveDirectoryEndpoint, mc.tenantID, types.PodIdentityNMIPort) -} - // GetSecretsStoreObjectContent gets the objects (secret, key, certificate) from keyvault and returns the content // to the CSI driver. The driver will write the content to the file system. func (p *provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, secrets map[string]string, defaultFilePermission os.FileMode) ([]types.SecretFile, error) { @@ -162,7 +148,7 @@ func (p *provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec if err != nil { return nil, fmt.Errorf("failed to set AZURE_ENVIRONMENT_FILEPATH env to %s, error %w", cloudEnvFileName, err) } - azureCloudEnv, err := ParseAzureEnvironment(cloudName) + azureCloudEnv, err := parseAzureEnvironment(cloudName) if err != nil { return nil, fmt.Errorf("cloudName %s is not valid, error: %w", cloudName, err) } @@ -231,7 +217,7 @@ func (p *provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec klog.V(2).InfoS("vault url", "vaultName", mc.keyvaultName, "vaultURL", *vaultURL, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName}) // the keyvault name is per SPC and we don't need to recreate the client for every single keyvault object defined - kvClient, err := mc.initializeKvClient() + kvClient, err := mc.initializeKvClient(*vaultURL) if err != nil { return nil, errors.Wrap(err, "failed to get keyvault client") } @@ -240,14 +226,14 @@ func (p *provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec for _, keyVaultObject := range keyVaultObjects { klog.V(5).InfoS("fetching object from key vault", "objectName", keyVaultObject.ObjectName, "objectType", keyVaultObject.ObjectType, "keyvault", mc.keyvaultName, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName}) - resolvedKvObjects, err := p.resolveObjectVersions(ctx, kvClient, keyVaultObject, *vaultURL) + resolvedKvObjects, err := p.resolveObjectVersions(ctx, kvClient, keyVaultObject) if err != nil { return nil, err } for _, resolvedKvObject := range resolvedKvObjects { // fetch the object from Key Vault - result, err := p.getKeyVaultObjectContent(ctx, kvClient, resolvedKvObject, *vaultURL) + result, err := p.getKeyVaultObjectContent(ctx, kvClient, resolvedKvObject) if err != nil { return nil, err } @@ -280,14 +266,14 @@ func (p *provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec return files, nil } -func (p *provider) resolveObjectVersions(ctx context.Context, kvClient *kv.BaseClient, kvObject types.KeyVaultObject, vaultURL string) (versions []types.KeyVaultObject, err error) { +func (p *provider) resolveObjectVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) (versions []types.KeyVaultObject, err error) { if kvObject.IsSyncingSingleVersion() { // version history less than or equal to 1 means only sync the latest and // don't add anything to the file name return []types.KeyVaultObject{kvObject}, nil } - kvObjectVersions, err := p.getKeyVaultObjectVersions(ctx, kvClient, kvObject, vaultURL) + kvObjectVersions, err := p.getKeyVaultObjectVersions(ctx, kvClient, kvObject) if err != nil { return nil, err } @@ -333,7 +319,7 @@ func getLatestNKeyVaultObjects(kvObject types.KeyVaultObject, kvObjectVersions t return objects } -func (p *provider) getKeyVaultObjectVersions(ctx context.Context, kvClient *kv.BaseClient, kvObject types.KeyVaultObject, vaultURL string) (versions types.KeyVaultObjectVersionList, err error) { +func (p *provider) getKeyVaultObjectVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) (versions types.KeyVaultObjectVersionList, err error) { start := time.Now() defer func() { var errMsg string @@ -345,127 +331,31 @@ func (p *provider) getKeyVaultObjectVersions(ctx context.Context, kvClient *kv.B switch kvObject.ObjectType { case types.VaultObjectTypeSecret: - return getSecretVersions(ctx, kvClient, vaultURL, kvObject) + return getSecretVersions(ctx, kvClient, kvObject) case types.VaultObjectTypeKey: - return getKeyVersions(ctx, kvClient, vaultURL, kvObject) + return getKeyVersions(ctx, kvClient, kvObject) case types.VaultObjectTypeCertificate: - return getCertificateVersions(ctx, kvClient, vaultURL, kvObject) + return getCertificateVersions(ctx, kvClient, kvObject) default: err := errors.Errorf("Invalid vaultObjectTypes. Should be secret, key, or cert") return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } } -func getSecretVersions(ctx context.Context, kvClient *kv.BaseClient, vaultURL string, kvObject types.KeyVaultObject) ([]types.KeyVaultObjectVersion, error) { - kvVersionsList, err := kvClient.GetSecretVersions(ctx, vaultURL, kvObject.ObjectName, nil) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } - - secretVersions := types.KeyVaultObjectVersionList{} - - for notDone := true; notDone; notDone = kvVersionsList.NotDone() { - for _, secret := range kvVersionsList.Values() { - if secret.Attributes != nil { - objectVersion := getObjectVersion(*secret.ID) - created := date.UnixEpoch() - - if secret.Attributes.Created != nil { - created = time.Time(*secret.Attributes.Created) - } - - if secret.Attributes.Enabled != nil && *secret.Attributes.Enabled { - secretVersions = append(secretVersions, types.KeyVaultObjectVersion{ - Version: objectVersion, - Created: created, - }) - } - } - } - - err = kvVersionsList.NextWithContext(ctx) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } - } - - return secretVersions, nil +func getSecretVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]types.KeyVaultObjectVersion, error) { + return kvClient.GetSecretVersions(ctx, kvObject.ObjectName) } -func getKeyVersions(ctx context.Context, kvClient *kv.BaseClient, vaultURL string, kvObject types.KeyVaultObject) ([]types.KeyVaultObjectVersion, error) { - kvVersionsList, err := kvClient.GetKeyVersions(ctx, vaultURL, kvObject.ObjectName, nil) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } - - keyVersions := types.KeyVaultObjectVersionList{} - - for notDone := true; notDone; notDone = kvVersionsList.NotDone() { - for _, key := range kvVersionsList.Values() { - if key.Attributes != nil { - objectVersion := getObjectVersion(*key.Kid) - created := date.UnixEpoch() - - if key.Attributes.Created != nil { - created = time.Time(*key.Attributes.Created) - } - - if key.Attributes.Enabled != nil && *key.Attributes.Enabled { - keyVersions = append(keyVersions, types.KeyVaultObjectVersion{ - Version: objectVersion, - Created: created, - }) - } - } - } - - err = kvVersionsList.NextWithContext(ctx) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } - } - - return keyVersions, nil +func getKeyVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]types.KeyVaultObjectVersion, error) { + return kvClient.GetKeyVersions(ctx, kvObject.ObjectName) } -func getCertificateVersions(ctx context.Context, kvClient *kv.BaseClient, vaultURL string, kvObject types.KeyVaultObject) ([]types.KeyVaultObjectVersion, error) { - kvVersionsList, err := kvClient.GetCertificateVersions(ctx, vaultURL, kvObject.ObjectName, nil) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } - - certVersions := types.KeyVaultObjectVersionList{} - - for notDone := true; notDone; notDone = kvVersionsList.NotDone() { - for _, cert := range kvVersionsList.Values() { - if cert.Attributes != nil { - objectVersion := getObjectVersion(*cert.ID) - created := date.UnixEpoch() - - if cert.Attributes.Created != nil { - created = time.Time(*cert.Attributes.Created) - } - - if cert.Attributes.Enabled != nil && *cert.Attributes.Enabled { - certVersions = append(certVersions, types.KeyVaultObjectVersion{ - Version: objectVersion, - Created: created, - }) - } - } - } - - err = kvVersionsList.NextWithContext(ctx) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } - } - - return certVersions, nil +func getCertificateVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]types.KeyVaultObjectVersion, error) { + return kvClient.GetCertificateVersions(ctx, kvObject.ObjectName) } // getKeyVaultObjectContent gets content of the keyvault object -func (p *provider) getKeyVaultObjectContent(ctx context.Context, kvClient *kv.BaseClient, kvObject types.KeyVaultObject, vaultURL string) (result []keyvaultObject, err error) { +func (p *provider) getKeyVaultObjectContent(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) (result []keyvaultObject, err error) { start := time.Now() defer func() { var errMsg string @@ -477,11 +367,11 @@ func (p *provider) getKeyVaultObjectContent(ctx context.Context, kvClient *kv.Ba switch kvObject.ObjectType { case types.VaultObjectTypeSecret: - return p.getSecret(ctx, kvClient, vaultURL, kvObject) + return p.getSecret(ctx, kvClient, kvObject) case types.VaultObjectTypeKey: - return p.getKey(ctx, kvClient, vaultURL, kvObject) + return p.getKey(ctx, kvClient, kvObject) case types.VaultObjectTypeCertificate: - return p.getCertificate(ctx, kvClient, vaultURL, kvObject) + return p.getCertificate(ctx, kvClient, kvObject) default: err := errors.Errorf("Invalid vaultObjectTypes. Should be secret, key, or cert") return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) @@ -489,8 +379,8 @@ func (p *provider) getKeyVaultObjectContent(ctx context.Context, kvClient *kv.Ba } // getSecret retrieves the secret from the vault -func (p *provider) getSecret(ctx context.Context, kvClient *kv.BaseClient, vaultURL string, kvObject types.KeyVaultObject) ([]keyvaultObject, error) { - secret, err := kvClient.GetSecret(ctx, vaultURL, kvObject.ObjectName, kvObject.ObjectVersion) +func (p *provider) getSecret(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]keyvaultObject, error) { + secret, err := kvClient.GetSecret(ctx, kvObject.ObjectName, kvObject.ObjectVersion) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } @@ -501,7 +391,8 @@ func (p *provider) getSecret(ctx context.Context, kvClient *kv.BaseClient, vault return nil, errors.Errorf("secret id is nil") } content := *secret.Value - version := getObjectVersion(*secret.ID) + id := *secret.ID + version := id.Version() result := []keyvaultObject{} // if the secret is part of a certificate, then we need to convert the certificate and key to PEM format if secret.Kid != nil && len(*secret.Kid) > 0 { @@ -538,31 +429,26 @@ func (p *provider) getSecret(ctx context.Context, kvClient *kv.BaseClient, vault } // getKey retrieves the key from the vault -func (p *provider) getKey(ctx context.Context, kvClient *kv.BaseClient, vaultURL string, kvObject types.KeyVaultObject) ([]keyvaultObject, error) { - keybundle, err := kvClient.GetKey(ctx, vaultURL, kvObject.ObjectName, kvObject.ObjectVersion) +func (p *provider) getKey(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]keyvaultObject, error) { + keybundle, err := kvClient.GetKey(ctx, kvObject.ObjectName, kvObject.ObjectVersion) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } if keybundle.Key == nil { return nil, errors.Errorf("key value is nil") } - if keybundle.Key.Kid == nil { + if keybundle.Key.KID == nil { return nil, errors.Errorf("key id is nil") } - version := getObjectVersion(*keybundle.Key.Kid) + + id := *keybundle.Key.KID + version := id.Version() // for object type "key" the public key is written to the file in PEM format - switch keybundle.Key.Kty { - case kv.RSA, kv.RSAHSM: - // decode the base64 bytes for n - nb, err := base64.RawURLEncoding.DecodeString(*keybundle.Key.N) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } - // decode the base64 bytes for e - eb, err := base64.RawURLEncoding.DecodeString(*keybundle.Key.E) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } + switch *keybundle.Key.Kty { + case azkeys.JSONWebKeyTypeRSA, azkeys.JSONWebKeyTypeRSAHSM: + nb := keybundle.Key.N + eb := keybundle.Key.E + e := new(big.Int).SetBytes(eb).Int64() pKey := &rsa.PublicKey{ N: new(big.Int).SetBytes(nb), @@ -579,18 +465,11 @@ func (p *provider) getKey(ctx context.Context, kvClient *kv.BaseClient, vaultURL var pemData []byte pemData = append(pemData, pem.EncodeToMemory(pubKeyBlock)...) return []keyvaultObject{{content: string(pemData), version: version}}, nil - case kv.EC, kv.ECHSM: - // decode the base64 bytes for x - xb, err := base64.RawURLEncoding.DecodeString(*keybundle.Key.X) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } - // decode the base64 bytes for y - yb, err := base64.RawURLEncoding.DecodeString(*keybundle.Key.Y) - if err != nil { - return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) - } - crv, err := getCurve(keybundle.Key.Crv) + case azkeys.JSONWebKeyTypeEC, azkeys.JSONWebKeyTypeECHSM: + xb := keybundle.Key.X + yb := keybundle.Key.Y + + crv, err := getCurve(*keybundle.Key.Crv) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } @@ -611,29 +490,31 @@ func (p *provider) getKey(ctx context.Context, kvClient *kv.BaseClient, vaultURL pemData = append(pemData, pem.EncodeToMemory(pubKeyBlock)...) return []keyvaultObject{{content: string(pemData), version: version}}, nil default: - err := errors.Errorf("failed to get key. key type '%s' currently not supported", keybundle.Key.Kty) + err := errors.Errorf("failed to get key. key type '%s' currently not supported", *keybundle.Key.Kty) return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } } // getCertificate retrieves the certificate from the vault -func (p *provider) getCertificate(ctx context.Context, kvClient *kv.BaseClient, vaultURL string, kvObject types.KeyVaultObject) ([]keyvaultObject, error) { +func (p *provider) getCertificate(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]keyvaultObject, error) { // for object type "cert" the certificate is written to the file in PEM format - certbundle, err := kvClient.GetCertificate(ctx, vaultURL, kvObject.ObjectName, kvObject.ObjectVersion) + certbundle, err := kvClient.GetCertificate(ctx, kvObject.ObjectName, kvObject.ObjectVersion) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } - if certbundle.Cer == nil { + if certbundle.CER == nil { return nil, errors.Errorf("certificate value is nil") } if certbundle.ID == nil { return nil, errors.Errorf("certificate id is nil") } - version := getObjectVersion(*certbundle.ID) + + id := *certbundle.ID + version := id.Version() certBlock := &pem.Block{ Type: types.CertificateType, - Bytes: *certbundle.Cer, + Bytes: certbundle.CER, } var pemData []byte pemData = append(pemData, pem.EncodeToMemory(certBlock)...) @@ -700,13 +581,13 @@ func (p *provider) decodePKCS12(value string) (content string, err error) { return string(pemData), nil } -func getCurve(crv kv.JSONWebKeyCurveName) (elliptic.Curve, error) { +func getCurve(crv azkeys.JSONWebKeyCurveName) (elliptic.Curve, error) { switch crv { - case kv.P256: + case azkeys.JSONWebKeyCurveNameP256: return elliptic.P256(), nil - case kv.P384: + case azkeys.JSONWebKeyCurveNameP384: return elliptic.P384(), nil - case kv.P521: + case azkeys.JSONWebKeyCurveNameP521: return elliptic.P521(), nil default: return nil, fmt.Errorf("curve %s is not supported", crv) @@ -736,16 +617,6 @@ func setAzureEnvironmentFilePath(envFileName string) error { return os.Setenv(azure.EnvironmentFilepathName, envFileName) } -// getObjectVersion parses the id to retrieve the version -// of object fetched -// example id format - https://kindkv.vault.azure.net/secrets/actual/1f304204f3624873aab40231241243eb -// TODO (aramase) follow up on https://github.com/Azure/azure-rest-api-specs/issues/10825 to provide -// a native way to obtain the version -func getObjectVersion(id string) string { - splitID := strings.Split(id, "/") - return splitID[len(splitID)-1] -} - // getContentBytes takes the given content string and returns the bytes to write to disk // If an encoding is specified it will decode the string first func getContentBytes(content, objectType, objectEncoding string) ([]byte, error) { diff --git a/pkg/provider/provider_test.go b/pkg/provider/provider_test.go index 4522cc837..2873c035a 100644 --- a/pkg/provider/provider_test.go +++ b/pkg/provider/provider_test.go @@ -7,6 +7,7 @@ import ( "crypto/elliptic" "crypto/rsa" "encoding/pem" + "errors" "fmt" "io" "os" @@ -16,15 +17,18 @@ import ( "testing" "time" - kv "github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault" + "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates" + "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys" + "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/to" + "github.com/golang/mock/gomock" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "k8s.io/klog/v2" - "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/auth" + "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/provider/mock_keyvault" "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/provider/types" - "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/version" ) func TestGetVaultURL(t *testing.T) { @@ -65,7 +69,7 @@ func TestGetVaultURL(t *testing.T) { } for idx := range testEnvs { - azCloudEnv, err := ParseAzureEnvironment(testEnvs[idx]) + azCloudEnv, err := parseAzureEnvironment(testEnvs[idx]) if err != nil { t.Fatalf("Error parsing cloud environment %v", err) } @@ -85,7 +89,7 @@ func TestGetVaultURL(t *testing.T) { func TestParseAzureEnvironment(t *testing.T) { envNamesArray := []string{"AZURECHINACLOUD", "AZUREGERMANCLOUD", "AZUREPUBLICCLOUD", "AZUREUSGOVERNMENTCLOUD", ""} for _, envName := range envNamesArray { - azureEnv, err := ParseAzureEnvironment(envName) + azureEnv, err := parseAzureEnvironment(envName) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -97,7 +101,7 @@ func TestParseAzureEnvironment(t *testing.T) { } wrongEnvName := "AZUREWRONGCLOUD" - _, err := ParseAzureEnvironment(wrongEnvName) + _, err := parseAzureEnvironment(wrongEnvName) if err == nil { t.Fatalf("expected error for wrong azure environment name") } @@ -232,7 +236,7 @@ func TestParseAzureEnvironmentAzureStackCloud(t *testing.T) { if err != nil { t.Fatalf("expected error to be nil, got: %+v", err) } - _, err = ParseAzureEnvironment(azureStackCloudEnvName) + _, err = parseAzureEnvironment(azureStackCloudEnvName) if err == nil { t.Fatalf("expected error to be not nil as AZURE_ENVIRONMENT_FILEPATH is not set") } @@ -242,7 +246,7 @@ func TestParseAzureEnvironmentAzureStackCloud(t *testing.T) { if err != nil { t.Fatalf("expected error to be nil, got: %+v", err) } - env, err := ParseAzureEnvironment(azureStackCloudEnvName) + env, err := parseAzureEnvironment(azureStackCloudEnvName) if err != nil { t.Fatalf("expected error to be nil, got: %+v", err) } @@ -855,37 +859,6 @@ kzqEt441cQasPp5ohL5U4cJN6lAuwA== } } -func TestInitializeKVClient(t *testing.T) { - testEnvs := []azure.Environment{ - azure.PublicCloud, - azure.GermanCloud, - azure.ChinaCloud, - azure.USGovernmentCloud, - } - for i := range testEnvs { - authConfig, err := auth.NewConfig(false, false, "", "", "", map[string]string{"clientid": "id", "clientsecret": "secret"}) - assert.NoError(t, err) - - mc := &mountConfig{ - azureCloudEnvironment: &testEnvs[i], - authConfig: authConfig, - podName: "pod", - podNamespace: "default", - tenantID: "tenant", - } - - version.BuildVersion = "version" - version.BuildDate = "Now" - version.Vcs = "hash" - - kvBaseClient, err := mc.initializeKvClient() - assert.NoError(t, err) - assert.NotNil(t, kvBaseClient) - assert.NotNil(t, kvBaseClient.Authorizer) - assert.Contains(t, kvBaseClient.UserAgent, "csi-secrets-store") - } -} - func TestGetSecretsStoreObjectContent(t *testing.T) { cases := []struct { desc string @@ -1033,7 +1006,7 @@ func TestGetSecretsStoreObjectContent(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { p := NewProvider(false, false) - _, err := p.GetSecretsStoreObjectContent(context.TODO(), tc.parameters, tc.secrets, 0420) + _, err := p.GetSecretsStoreObjectContent(testContext(t), tc.parameters, tc.secrets, 0420) if tc.expectedErr { assert.NotNil(t, err) } else { @@ -1045,29 +1018,29 @@ func TestGetSecretsStoreObjectContent(t *testing.T) { func TestGetCurve(t *testing.T) { cases := []struct { - crv kv.JSONWebKeyCurveName + crv azkeys.JSONWebKeyCurveName expectedCurve elliptic.Curve expectedErr error }{ { - crv: kv.P256, + crv: azkeys.JSONWebKeyCurveNameP256, expectedCurve: elliptic.P256(), expectedErr: nil, }, { - crv: kv.P384, + crv: azkeys.JSONWebKeyCurveNameP384, expectedCurve: elliptic.P384(), expectedErr: nil, }, { - crv: kv.P521, + crv: azkeys.JSONWebKeyCurveNameP521, expectedCurve: elliptic.P521(), expectedErr: nil, }, { - crv: kv.SECP256K1, + crv: azkeys.JSONWebKeyCurveNameP256K, expectedCurve: nil, - expectedErr: fmt.Errorf("curve SECP256K1 is not supported"), + expectedErr: fmt.Errorf("curve P-256K is not supported"), }, } @@ -1181,13 +1154,6 @@ PxrUsXyXty7ERMp5QNyxjMWS+0w93FrAIw== } } -func TestGetObjectVersion(t *testing.T) { - id := "https://kindkv.vault.azure.net/secrets/secret1/c55925c29c6743dcb9bb4bf091be03b0" - expectedVersion := "c55925c29c6743dcb9bb4bf091be03b0" - actual := getObjectVersion(id) - assert.Equal(t, expectedVersion, actual) -} - func TestSplitCertAndKey(t *testing.T) { rootCACert := `-----BEGIN CERTIFICATE----- MIIBeTCCAR6gAwIBAgIRAM3RAPH7k1Q+bICMC0mzKhkwCgYIKoZIzj0EAwIwGjEY @@ -1290,3 +1256,355 @@ SIVZww73PTGisLmXfIvKvr8GBA== }) } } + +func TestGetSecret(t *testing.T) { + id := azsecrets.ID("https://test.vault.azure.net/secrets/secret1/v1") + testPFX := "MIIJ2gIBAzCCCZoGCSqGSIb3DQEHAaCCCYsEggmHMIIJgzCCBgwGCSqGSIb3DQEHAaCCBf0EggX5MIIF9TCCBfEGCyqGSIb3DQEMCgECoIIE/jCCBPowHAYKKoZIhvcNAQwBAzAOBAjyZKK5bEmydAICB9AEggTYc8Xz73uOqyAO2D/7AySispCqj1rqZa2le5o/aX1KXqajOhxoKB5NJftiBx3JvR0Bo9sjycHLWX2PZEs7wJm34ut2eblexkC2vP+Peyk6dMrVjxj56J8+QMgku5BLVX5D/XVOPrw7g77YPZ1U6YIHld9euMVkyXtnuMlLUqj2+XZjpe1tOdZwiZvqQFgaw44YOh1looS08895D77PMIKawcJliqA+5b0trIlbL7RjVJceb5g0s1QAGPtswfFykWtvVs2dvc+gsTJrtzDlVUbP6NCrbGZL89VXywdv1Ls4o63GrG4wUjvaEBzMvo3FYQLVA4XgknMNYglfxX5kTu177zLbrgVYmfFQ1uu5OR25HoQ9I9hlcQbZn7DNB8W9SxoeDhNN0a/DqKj/olj9e6hohzDIQyTAr2N3Om8DiXLUfyWDiUKSeOHp6KKWIFCynC8DsOZPPVS8dN2yjszLGItYV+g1x2L4b+EUO6gT5nweGY1Wt9+dSyRSaOkEms0hDwwvGyMk6FSZKk75MAYLskz+u3+cf9z46rpAsoarFrdAgxdb+0Azq/N0A4TiYEkCZNouJALWi0yOXSW27l5sKwlV4DyEqksUu5iHi+eGaCn+dc3zUiPISTZUSMbyiqnD5V5MEUgJQ1yUPpaJrIPuyfCW70WD4Hw9RWWKW76IwyfmbyzvUIR4rYr43COTcQ+wZ1pSOvij1Ny4iEYV/2DEesNgErDkPLJAk7TtSKLfLkkjvfL7DXtMVV8T/WLim24F15m1e0v35sehKrk9u+hwt8C1pE77q8Tu2423+7ELIYlO18Di4jRhNYooi1ySZIWojdXM6+BaFAieS10H9tmtYzMBGHKOdDmAPaehiB87MLBUlzeXe0InTOL5q9tv8lBFTbKbL7sPOd94yWpurUGjxOcF7uLgzrxf+ocdMr0EhMoCCh3GcS2iP2DqrWvAOx3dT0/iSTSnhEUlkY9OpP1hrjeidbkk9u64nEJd5Fo2y0wB6NDJThnds7wwD5vjyPUMvp2q5+zQ3Uf9dk0IHL+4sz+JJDbPwua9mbiseO5wqElDsF9culoyKKnJozBQ1+DjM7vZhTah2cgFy7U8THc7UDxrULFHSK4ue8KlN+WxzK4ebGRJ/RLSewXleTJEV9b+KfwKfRYWdITmnxn0t24lUN7skENG1qSCLujh+OdMyzXGTmo3AniK/wyS/lJaxloHd2w0aINzfr+9E/vVU+e++PUNLz7OgmI7BsqqlL1WqhvVV+wIBb5GhcvheJlxgM170t13aONf2itYDjsooOraRUN23BV2jx1Rb0LQpSFx550GtkUsHdxBpWe6YwbeDtJayjhmYtdTfDbbCrQzyTReqqzRbXoI5KnUHCLnO5uCkuOI3lLFX0Sj28eIgUucKpVQgtIqyy6mTM3tocgusEK9J53LmVbRLWTX5UrFaLopPn6S8i6UHwefz9XD3SJ1Qlj0rtTkZgPk6tw5nMskcXAiJ/jMm36IluJBp82AMaj79FnwgnxCxunYLmbTBXtKTmkMrr3nrDDoV38ynrnbu2otdZmrst0rjl1L9uuw0azQz5O4DQ1uAcXpgb21LUyOp3aS/TzWGJZtB6ne0b/37U/q3zvp1LXDwKG3yRP71J5TEhMnb4uazwgOjcvo6DGB3zATBgkqhkiG9w0BCRUxBgQEAQAAADBbBgkqhkiG9w0BCRQxTh5MAHsANgA3ADMAQQBDADkARABDAC0ANgAzAEMAQQAtADQAOQA1ADkALQA4ADkAOAAxAC0AQQA4ADgAOAA2AEQARgBGADEANgA5AEIAfTBrBgkrBgEEAYI3EQExXh5cAE0AaQBjAHIAbwBzAG8AZgB0ACAARQBuAGgAYQBuAGMAZQBkACAAQwByAHkAcAB0AG8AZwByAGEAcABoAGkAYwAgAFAAcgBvAHYAaQBkAGUAcgAgAHYAMQAuADAwggNvBgkqhkiG9w0BBwagggNgMIIDXAIBADCCA1UGCSqGSIb3DQEHATAcBgoqhkiG9w0BDAEGMA4ECEjwOIfbZPtRAgIH0ICCAyiaiiGa5xldOrZdkUKqa4kb1zLnqN5P+XRUO/bvl0Qr/JE57K9NxgcxEvkWSdI60CA7EoJ+voE3MCf0/UWOEV5di3JbRYZAsGI88bo46B/8L80pVCRQWI0ZQtdrk5gCJwCedEyy7te4eIRMf3bIjChlXuwBT6jUFw8dylLhlEDs5Br1k6h5yYrrB8KqVuSpqpR6SXxflcHxwhwZEKZp6peS+77sGRp2iF+YBk/946cUp/d/Amd9CZIO7SriZVW32sbflw7PGgB0Lwq5JbvPyUTqxWVsFLcbKMhaReWIxd5/WCMk4TObmtr9WrJ1/bWp+n/oyePQANNKdDhHSsCjRpHKuBQDKvDaL0NQkhH1lPHxHdMHVc12nbIFnz7zLzVmXSBfUnhdneQ0vZOb5oyWpM8uTLaDwykG2A6wr1/S58yNeY+C7WVr8EkvYdZdhgTIP9WEhws4X2HNG3g77yo1crmPXLW73nN7TobdwOxID5ipKHRJbqDlw69j7Z78lPHRdOjBCvvEXSSvdsAp2p56nkYsPq2yNsmUIBW3tT6kobdjEneseLYwYLlIe2jJ7vfaVjtHEk9JGKH2XrHVwPLZFx+S/w/a2dXwLzSFlR9+de11BEikA+JDeKIcRxvJmH3ZuyEIpGwN1OcnKZ+3HOKwmuj1SAmQQksxQNQcWc+5cSbPWJxC57nIUGPP4wWZjs03Nh7YOV9BpnnfdY/cVKr8wBCaOvA9raoWKyuVEUuA9lGQ9okID6Rnt/aKxVcOyan9SWJo/dH+JGsQqiFVmKBvDPK8pdPUhJe/05K06CYlyFMlyr56tTC+cua+EwsOGXbO8XBJzB84zIPczWa1btyqvw8StH15P9wFR0iKR+ZEFxLmtUaAIoJ7j9DeWNBzzpYuwaQQY6lzT3bPfF3ECTi617+p7xkULcDB0vWrApGrbOlBg4Z0GsJVwlDD+MYGf+4x9vpQu0bKa9qD/PlRS7eJF0Cjs9BNUkZUxNI8FwpSvMlD4fVSe7GMnRNQZrjhL0RcNrliOck/PLdO3mAH+HXDblgcgkRljpXkcvMoCRa1mHUGaYKKLEhKf/brMDcwHzAHBgUrDgMCGgQUO+i67chO15+HWhrm84Wq77Z3cEgEFBMn3lNZpt5o5o2neKnOZ5vNpIlB" + testCert := `-----BEGIN CERTIFICATE----- +MIIC2DCCAcACCQD9DZdcsr7kJDANBgkqhkiG9w0BAQsFADAuMRYwFAYDVQQDDA1k +ZW1vLnRlc3QuY29tMRQwEgYDVQQKDAtpbmdyZXNzLXRsczAeFw0yMDA1MjIxNjIz +MjZaFw0yMTA1MjIxNjIzMjZaMC4xFjAUBgNVBAMMDWRlbW8udGVzdC5jb20xFDAS +BgNVBAoMC2luZ3Jlc3MtdGxzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAte0os8X6ZKbEUWoFJdSfcYoSovbxPBhtisEJd/U+oOK1jKH/HMBliTv+9l6O +vIhldtt48v57mk4P0M72KT8ulXcBasNV95DNnPsEpAqs7wKrhftleeDMKPnk8VvU +6jidPy6SO6Ntbp8tJchrbfMwZW7e2y+PVweKN8QwNECQPfygBtX8jP93CG6oYvK9 +FDS45U1UcKUdxTLfXmSvORPo0HFEXLNvZxmdjSsrP0oSbasJfMr02DZb5/6MSxCb +J/FnPwdqXQH/cM6rQDLw2Is5iWn0QXPEYZMqYbtMAJoY0UEVHVHgIUb/HucQ+SEk +tt6kG3sIGKsKLiuymZGozRFNqwIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQCSZNbl +WFMjnZuGiFzIZoqfKOp/Dtw48poNJkrxMBJLkiciJD6drXj8vnTQrZUuR25TIiD/ +Sq+cO+XVRcJKNP13FjFpRdyHYRtAze4TaQZSJlW2nyfeUtUQkwj2iMhv5l1UMnPG +7+Jxg56aA+IBvyE/tAQVvS0NPdq6Ht2MX6j40ERTXmS8qNdY6qi3ZCEAPazlNsUF +C6nLdViZ/vbQ+l6uEcNsEsPJ6SDTNKLkO9tU7pWCa6QBTncuFLbpDqr3Q+lvx4mv +MVw9RO3NiLuDiPQA0VfKSMrEJJUp4F88pbEax5nq525Rbp85RWkmVoc97UuFS+oc +ldGQrUHVb2/iI1fd +-----END CERTIFICATE----- +` + testPrivateKey := `-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQC17SizxfpkpsRR +agUl1J9xihKi9vE8GG2KwQl39T6g4rWMof8cwGWJO/72Xo68iGV223jy/nuaTg/Q +zvYpPy6VdwFqw1X3kM2c+wSkCqzvAquF+2V54Mwo+eTxW9TqOJ0/LpI7o21uny0l +yGtt8zBlbt7bL49XB4o3xDA0QJA9/KAG1fyM/3cIbqhi8r0UNLjlTVRwpR3FMt9e +ZK85E+jQcURcs29nGZ2NKys/ShJtqwl8yvTYNlvn/oxLEJsn8Wc/B2pdAf9wzqtA +MvDYizmJafRBc8Rhkyphu0wAmhjRQRUdUeAhRv8e5xD5ISS23qQbewgYqwouK7KZ +kajNEU2rAgMBAAECggEBAK9MJxUapkxH+RDt1KoAN+aigZSv2ADtFNhHa0VAdal2 +6jLpgbWFmhDjU6i3slfuIb6meePC3PzxTQIJ+l4COHPi6OWj9PkIeWdS5MTgWIQx +kW8Xr08CEhdFu5npv7408SgJSvTWY8Lc9BbdCM84LqD+dRTEvhzA8ikMDNq8f4CJ +hLreFUUl/udHacpMdE8mpB6vgCUliZEjBlHHC9qD2mDKgWb0cm4jkO9PcHxz8CXL +szcRV2vqTwvsJcZWcJwTzjhFxq/lUZrgbwpn60iKlov3BCRoTJBppOXi01giom3v +Wz7Y7DoFbHfizh6jyBrf3ODhKJQ3CGvS65QCS0aJ/kECgYEA4JuGC9DpQYmlzWbV +0CqJYnTcZKqcPQx/A1QZDKot0VWqF61vZIku5XuoGKGfY3eLwVZJJZqxoTlVTbuT +nNzYJe+EHzftRoUxUqXZtIh9VdirJMwCu4RMdwk705FA8+8FcTKXarKWBbAzUmFi +iINR2rlRJHVyh2cOA9hWPbEXX0sCgYEAz1qAYUIMBGnccY9mALGArWTmtyNN3DcB +9tl3/5SzfL1NlcOWsBXdZL61oTl9LhOjqHGZbf164/uFuKNHTsT1E63180UKujmV +TbHL6N6MrMctaJfgru3+XprTMd5pwjzd8huX603OtS8Gvn5gKdBRkG1ZI8CrfTl6 +sJI9YRvl7yECgYEAjUIiptHHsVkhdrIDLL1j1BEM/x6xzk9KnkxIyMdKs4n9xJBm +K0N/xAHmMT+Mn6DyuzBKJqVIq84EETQ0XQYjxpABdyTUTHK+F22JItpogRIYaLcJ +zOcitAaRoriKsh+UO6IGyqrwYTl0vY3Ty2lTlIzSNGzND81HajGn43q56UsCgYEA +pGqArZ3vZXiDgdBQ82/MNrFxd/oYfOtpNVFPI2vHvrtkT8KdM9bCjGXkI4kwR17v +QFuDa4G49hm0+KkPm9f09LvV8CXo0a1jRA4dP/Nn3IC68tqrIEo6js15dWuEtK4K +1zUmC0DRDT3SvS38FmvGoRzzt7PIxyzSqjvrS5sRgcECgYAQ6b0YsM4p+89s4ALK +BPfGIKpoIEMKUcwiT3ovRrwIu1vbu70WRcYAi5do6rwOakp3FyUcQznkeZEOAQmc +xrBy8R64vg83WMuRITAqY6vartSa3oehqUHW0YbhGDVEtSrolXEs5elArUHbpYnX +SIVZww73PTGisLmXfIvKvr8GBA== +-----END PRIVATE KEY----- +` + + cases := []struct { + desc string + initKeyVaultSecret *azsecrets.SecretBundle + inputKeyVaultObject types.KeyVaultObject + writeCertAndKeyInSeparateFiles bool + expectedKeyVaultObject []keyvaultObject + }{ + { + desc: "secret", + initKeyVaultSecret: &azsecrets.SecretBundle{ + ID: &id, + Value: to.StringPtr("secret1value"), + }, + inputKeyVaultObject: types.KeyVaultObject{ + ObjectName: "secret1", + }, + expectedKeyVaultObject: []keyvaultObject{ + { + content: "secret1value", + version: "v1", + }, + }, + }, + { + desc: "secret with kid, pem cert and key", + initKeyVaultSecret: &azsecrets.SecretBundle{ + ID: &id, + Value: to.StringPtr(testCert + testPrivateKey), + Kid: to.StringPtr("https://testvault.vault.azure.net/keys/secrets/secret1/v1"), + ContentType: to.StringPtr("application/x-pem-file"), + }, + inputKeyVaultObject: types.KeyVaultObject{ + ObjectName: "secret1", + }, + expectedKeyVaultObject: []keyvaultObject{ + { + content: testCert + testPrivateKey, + version: "v1", + }, + }, + }, + { + desc: "secret with kid, pfx, objectFormat=pfx", + initKeyVaultSecret: &azsecrets.SecretBundle{ + ID: &id, + Value: to.StringPtr(testPFX), + Kid: to.StringPtr("https://testvault.vault.azure.net/keys/secrets/secret1/v1"), + ContentType: to.StringPtr("application/x-pkcs12"), + }, + inputKeyVaultObject: types.KeyVaultObject{ + ObjectName: "secret1", + ObjectFormat: "pfx", + }, + expectedKeyVaultObject: []keyvaultObject{ + { + content: testPFX, + version: "v1", + }, + }, + }, + { + desc: "secret with kid, pfx, objectFormat=pem", + initKeyVaultSecret: &azsecrets.SecretBundle{ + ID: &id, + Value: to.StringPtr(testPFX), + Kid: to.StringPtr("https://testvault.vault.azure.net/keys/secrets/secret1/v1"), + ContentType: to.StringPtr("application/x-pkcs12"), + }, + inputKeyVaultObject: types.KeyVaultObject{ + ObjectName: "secret1", + ObjectFormat: "pem", + }, + expectedKeyVaultObject: []keyvaultObject{ + { + content: testPrivateKey + testCert, + version: "v1", + }, + }, + }, + { + desc: "secret with kid, pfx, default objectFormat", + initKeyVaultSecret: &azsecrets.SecretBundle{ + ID: &id, + Value: to.StringPtr(testPFX), + Kid: to.StringPtr("https://testvault.vault.azure.net/keys/secrets/secret1/v1"), + ContentType: to.StringPtr("application/x-pkcs12"), + }, + inputKeyVaultObject: types.KeyVaultObject{ + ObjectName: "secret1", + }, + expectedKeyVaultObject: []keyvaultObject{ + { + content: testPrivateKey + testCert, + version: "v1", + }, + }, + }, + { + desc: "write cert and key in separate files", + initKeyVaultSecret: &azsecrets.SecretBundle{ + ID: &id, + Value: to.StringPtr(testPFX), + Kid: to.StringPtr("https://testvault.vault.azure.net/keys/secrets/secret1/v1"), + ContentType: to.StringPtr("application/x-pkcs12"), + }, + inputKeyVaultObject: types.KeyVaultObject{ + ObjectName: "secret1", + }, + writeCertAndKeyInSeparateFiles: true, + expectedKeyVaultObject: []keyvaultObject{ + { + content: testCert, + version: "v1", + fileNameSuffix: ".crt", + }, + { + content: testPrivateKey, + version: "v1", + fileNameSuffix: ".key", + }, + { + content: testPrivateKey + testCert, + version: "v1", + }, + }, + }, + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ctx := testContext(t) + + p := &provider{writeCertAndKeyInSeparateFiles: tc.writeCertAndKeyInSeparateFiles} + kvClient := mock_keyvault.NewMockKeyVault(ctrl) + kvClient.EXPECT().GetSecret(ctx, "secret1", "").Return( + tc.initKeyVaultSecret, nil, + ) + + objs, err := p.getSecret(ctx, kvClient, tc.inputKeyVaultObject) + if err != nil { + t.Fatalf("getSecret() = %v, want nil", err) + } + if !reflect.DeepEqual(objs, tc.expectedKeyVaultObject) { + t.Errorf("getSecret() = \n%v, want \n%v", objs, tc.expectedKeyVaultObject) + } + }) + } +} + +func TestGetSecretError(t *testing.T) { + id := azsecrets.ID("https://test.vault.azure.net/secrets/secret1/v1") + + cases := []struct { + desc string + initKeyVaultSecret *azsecrets.SecretBundle + kvError error + expectedKeyVaultObject []keyvaultObject + }{ + { + desc: "keyvault get secret error", + initKeyVaultSecret: &azsecrets.SecretBundle{}, + kvError: errors.New("keyvault error"), + }, + { + desc: "secret value is nil", + initKeyVaultSecret: &azsecrets.SecretBundle{ + ID: &id, + }, + }, + { + desc: "secret id is nil", + initKeyVaultSecret: &azsecrets.SecretBundle{ + Value: to.StringPtr("test"), + }, + }, + { + desc: "secret has kid, not valid pfx", + initKeyVaultSecret: &azsecrets.SecretBundle{ + ID: &id, + Value: to.StringPtr("invalid"), + Kid: to.StringPtr("https://testvault.vault.azure.net/keys/secrets/secret1/v1"), + ContentType: to.StringPtr("application/x-pkcs12"), + }, + }, + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ctx := testContext(t) + + p := &provider{} + kvClient := mock_keyvault.NewMockKeyVault(ctrl) + kvClient.EXPECT().GetSecret(ctx, "secret1", "").Return( + tc.initKeyVaultSecret, nil, + ) + + if _, err := p.getSecret(ctx, kvClient, types.KeyVaultObject{ObjectName: "secret1"}); err == nil { + t.Fatalf("getSecret() = nil, want error") + } + }) + } +} + +func TestGetCertificate(t *testing.T) { + id := azcertificates.ID("https://test.vault.azure.net/certificates/cert1/v1") + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := testContext(t) + + p := &provider{} + kvClient := mock_keyvault.NewMockKeyVault(ctrl) + kvClient.EXPECT().GetCertificate(ctx, "cert1", "").Return( + &azcertificates.CertificateBundle{ + CER: []byte("test"), + ID: &id, + }, nil, + ) + + objs, err := p.getCertificate(ctx, kvClient, types.KeyVaultObject{ObjectName: "cert1"}) + if err != nil { + t.Fatalf("getCertificate() = %v, want nil", err) + } + + expected := []keyvaultObject{ + { + content: `-----BEGIN CERTIFICATE----- +dGVzdA== +-----END CERTIFICATE----- +`, + version: "v1", + }, + } + + if !reflect.DeepEqual(objs, expected) { + t.Fatalf("getCertificate() = \n%v, want \n%v", objs, expected) + } +} + +func TestGetCertificateError(t *testing.T) { + id := azcertificates.ID("https://test.vault.azure.net/certificates/cert1/v1") + + cases := []struct { + desc string + initKeyVaultCert *azcertificates.CertificateBundle + kvError error + expectedKeyVaultObject []keyvaultObject + }{ + { + desc: "keyvault get certificate error", + initKeyVaultCert: &azcertificates.CertificateBundle{}, + kvError: errors.New("keyvault error"), + }, + { + desc: "certificate CER is nil", + initKeyVaultCert: &azcertificates.CertificateBundle{ + ID: &id, + }, + }, + { + desc: "certificate id is nil", + initKeyVaultCert: &azcertificates.CertificateBundle{ + CER: []byte("test"), + }, + }, + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ctx := testContext(t) + + p := &provider{} + kvClient := mock_keyvault.NewMockKeyVault(ctrl) + kvClient.EXPECT().GetCertificate(ctx, "cert1", "").Return( + tc.initKeyVaultCert, nil, + ) + + if _, err := p.getCertificate(ctx, kvClient, types.KeyVaultObject{ObjectName: "cert1"}); err == nil { + t.Fatalf("getCertificate() = nil, want error") + } + }) + } +} + +func testContext(t *testing.T) context.Context { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return ctx +} diff --git a/test/e2e/framework/keyvault/keyvault.go b/test/e2e/framework/keyvault/keyvault.go index e79ac8955..0b1dcc302 100644 --- a/test/e2e/framework/keyvault/keyvault.go +++ b/test/e2e/framework/keyvault/keyvault.go @@ -9,9 +9,10 @@ import ( "github.com/Azure/secrets-store-csi-driver-provider-azure/test/e2e/framework" - kv "github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/to" . "github.com/onsi/ginkgo/v2" @@ -26,55 +27,49 @@ type Client interface { } type client struct { - config *framework.Config - keyvaultClient kv.BaseClient + config *framework.Config + secretsClient *azsecrets.Client } func NewClient(config *framework.Config) Client { - kvClient := kv.New() - kvEndPoint := azure.PublicCloud.KeyVaultEndpoint - if '/' == kvEndPoint[len(kvEndPoint)-1] { - kvEndPoint = kvEndPoint[:len(kvEndPoint)-1] + opts := &azidentity.ClientSecretCredentialOptions{ + ClientOptions: azcore.ClientOptions{ + Cloud: cloud.Configuration{ + ActiveDirectoryAuthorityHost: azure.PublicCloud.ActiveDirectoryEndpoint, + }, + }, } - oauthConfig, err := getOAuthConfig(azure.PublicCloud, config.TenantID) + cred, err := azidentity.NewClientSecretCredential(config.TenantID, config.AzureClientID, config.AzureClientSecret, opts) Expect(err).To(BeNil()) - armSpt, err := adal.NewServicePrincipalToken(*oauthConfig, config.AzureClientID, config.AzureClientSecret, kvEndPoint) + c, err := azsecrets.NewClient(getVaultURL(config.KeyvaultName), cred, nil) Expect(err).To(BeNil()) - kvClient.Authorizer = autorest.NewBearerAuthorizer(armSpt) return &client{ - config: config, - keyvaultClient: kvClient, + config: config, + secretsClient: c, } } // SetSecret sets the secret in key vault func (c *client) SetSecret(name, value string) error { - By(fmt.Sprintf("Setting secret \"%s\" in keyvault \"%s\"", name, c.config.KeyvaultName)) - _, err := c.keyvaultClient.SetSecret(context.Background(), getVaultURL(c.config.KeyvaultName), name, kv.SecretSetParameters{ + params := azsecrets.SetSecretParameters{ Value: to.StringPtr(value), - }) + } + + By(fmt.Sprintf("Setting secret \"%s\" in keyvault \"%s\"", name, c.config.KeyvaultName)) + _, err := c.secretsClient.SetSecret(context.Background(), name, params, &azsecrets.SetSecretOptions{}) return err } // DeleteSecret deletes the secret in key vault func (c *client) DeleteSecret(name string) error { By(fmt.Sprintf("Deleting secret \"%s\" in keyvault \"%s\"", name, c.config.KeyvaultName)) - _, err := c.keyvaultClient.DeleteSecret(context.Background(), getVaultURL(c.config.KeyvaultName), name) + _, err := c.secretsClient.DeleteSecret(context.Background(), name, &azsecrets.DeleteSecretOptions{}) return err } -func getOAuthConfig(env azure.Environment, tenantID string) (*adal.OAuthConfig, error) { - oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, tenantID) - if err != nil { - return nil, err - } - - return oauthConfig, nil -} - func getVaultURL(vaultName string) string { return fmt.Sprintf("https://%s.%s/", vaultName, azure.PublicCloud.KeyVaultDNSSuffix) } diff --git a/test/e2e/go.mod b/test/e2e/go.mod index 1f5332eb2..32edae0d0 100644 --- a/test/e2e/go.mod +++ b/test/e2e/go.mod @@ -3,9 +3,10 @@ module github.com/Azure/secrets-store-csi-driver-provider-azure/test/e2e go 1.19 require ( - github.com/Azure/azure-sdk-for-go v68.0.0+incompatible + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 + github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.11.0 github.com/Azure/go-autorest/autorest v0.11.28 - github.com/Azure/go-autorest/autorest/adal v0.9.22 github.com/Azure/go-autorest/autorest/to v0.4.0 github.com/Azure/secrets-store-csi-driver-provider-azure v0.0.0-00010101000000-000000000000 github.com/ghodss/yaml v1.0.0 @@ -22,12 +23,11 @@ require ( ) require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.0 // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect + github.com/Azure/go-autorest/autorest/adal v0.9.22 // indirect github.com/Azure/go-autorest/autorest/date v0.3.0 // indirect - github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect github.com/Azure/go-autorest/logger v0.2.1 // indirect github.com/Azure/go-autorest/tracing v0.6.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1 // indirect @@ -57,7 +57,6 @@ require ( github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/imdario/mergo v0.3.12 // indirect github.com/inconshreveable/mousetrap v1.0.1 // indirect - github.com/jongio/azidext/go/azidext v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kylelemons/godebug v1.1.0 // indirect diff --git a/test/e2e/go.sum b/test/e2e/go.sum index e545807ab..b7db98e0b 100644 --- a/test/e2e/go.sum +++ b/test/e2e/go.sum @@ -31,14 +31,16 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= -github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1 h1:gVXuXcWd1i4C2Ruxe321aU+IKGaStvGB/S90PUPB/W8= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.1/go.mod h1:DffdKW9RFqa5VgmsjUOsS7UE7eiA5iAvYUs63bhKQ0M= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1 h1:T8quHYlUGyb/oqtSTwqlCr1ilJHrDv+ZtpSfo+hm1BU= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1/go.mod h1:gLa1CL2RNE4s7M3yopJ/p0iq5DdY6Yv5ZUt9MTRZOQM= github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8= github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.11.0 h1:82w8tzLcOwDP/Q35j/wEBPt0n0kVC3cjtPdD62G8UAk= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.11.0/go.mod h1:S78i9yTr4o/nXlH76bKjGUye9Z2wSxO5Tz7GoDr4vfI= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.0 h1:Lg6BW0VPmCwcMlvOviL3ruHFO+H9tZNqscK0AeuFjGM= +github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.0/go.mod h1:9V2j0jn9jDEkCkv8w/bKTNppX/d0FVA1ud77xCIP4KA= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= github.com/Azure/go-autorest/autorest v0.11.28 h1:ndAExarwr5Y+GaHE6VCaY1kyS/HwwGGyuimVhWsHOEM= @@ -53,8 +55,6 @@ github.com/Azure/go-autorest/autorest/mocks v0.4.2 h1:PGN4EDXnuQbojHbU0UWoNvmu9A github.com/Azure/go-autorest/autorest/mocks v0.4.2/go.mod h1:Vy7OitM9Kei0i1Oj+LvyAWMXJHeKH1MVlzFugfVrmyU= github.com/Azure/go-autorest/autorest/to v0.4.0 h1:oXVqrxakqqV1UZdSazDOPOLvOIz+XA683u8EctwboHk= github.com/Azure/go-autorest/autorest/to v0.4.0/go.mod h1:fE8iZBn7LQR7zH/9XU2NcPR4o9jEImooCeWJcYV/zLE= -github.com/Azure/go-autorest/autorest/validation v0.3.1 h1:AgyqjAd94fwNAoTjl/WQXg4VvFeRFpO+UhNyRXqF1ac= -github.com/Azure/go-autorest/autorest/validation v0.3.1/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E= github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg= github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= @@ -302,10 +302,7 @@ github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLf github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= -github.com/jongio/azidext/go/azidext v0.4.0 h1:TOYyVFMeWGgXNhURSgrEtUCu7JAAKgsy+5C4+AEfYlw= -github.com/jongio/azidext/go/azidext v0.4.0/go.mod h1:VrlpGde5B+pPbTUxnThE5UIQQkcebdr3jrC2MmlMVSI= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=