Skip to content

Commit

Permalink
feat: switch to driver writing files
Browse files Browse the repository at this point in the history
Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>
  • Loading branch information
aramase committed Apr 5, 2021
1 parent 8031a9c commit 3859e6e
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 353 deletions.
12 changes: 6 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ go 1.16

require (
github.com/Azure/azure-sdk-for-go v52.4.0+incompatible
github.com/Azure/go-autorest/autorest v0.9.6
github.com/Azure/go-autorest/autorest/adal v0.8.2
github.com/Azure/go-autorest/autorest v0.11.1
github.com/Azure/go-autorest/autorest/adal v0.9.5
github.com/Azure/go-autorest/autorest/to v0.4.0 // indirect
github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect
github.com/google/go-cmp v0.5.2
github.com/kubernetes-csi/csi-lib-utils v0.7.1
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.6.1
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/net v0.0.0-20200707034311-ab3426394381
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b
google.golang.org/grpc v1.31.0
gopkg.in/yaml.v2 v2.3.0
k8s.io/component-base v0.19.3
k8s.io/component-base v0.20.2
k8s.io/klog/v2 v2.5.0
sigs.k8s.io/secrets-store-csi-driver v0.0.20
sigs.k8s.io/secrets-store-csi-driver v0.0.21
)
322 changes: 189 additions & 133 deletions go.sum

Large diffs are not rendered by default.

74 changes: 19 additions & 55 deletions pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"fmt"
"math/big"
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
Expand Down Expand Up @@ -174,7 +173,7 @@ func (p *Provider) GetServicePrincipalToken(resource string) (*adal.ServicePrinc
}

// MountSecretsStoreObjectContent mounts content of the secrets store object to target path
func (p *Provider) MountSecretsStoreObjectContent(ctx context.Context, attrib map[string]string, secrets map[string]string, targetPath string, permission os.FileMode) (map[string]string, error) {
func (p *Provider) MountSecretsStoreObjectContent(ctx context.Context, attrib map[string]string, secrets map[string]string, targetPath string, permission os.FileMode) (map[string][]byte, map[string]string, error) {
keyvaultName := strings.TrimSpace(attrib["keyvaultName"])
cloudName := strings.TrimSpace(attrib["cloudName"])
usePodIdentityStr := strings.TrimSpace(attrib["usePodIdentity"])
Expand All @@ -186,58 +185,58 @@ func (p *Provider) MountSecretsStoreObjectContent(ctx context.Context, attrib ma
p.PodNamespace = strings.TrimSpace(attrib["csi.storage.k8s.io/pod.namespace"])

if keyvaultName == "" {
return nil, fmt.Errorf("keyvaultName is not set")
return nil, nil, fmt.Errorf("keyvaultName is not set")
}
if tenantID == "" {
return nil, fmt.Errorf("tenantId is not set")
return nil, nil, fmt.Errorf("tenantId is not set")
}
if len(usePodIdentityStr) == 0 {
usePodIdentityStr = "false"
}
usePodIdentity, err := strconv.ParseBool(usePodIdentityStr)
if err != nil {
return nil, fmt.Errorf("failed to parse usePodIdentity flag, error: %+v", err)
return nil, nil, fmt.Errorf("failed to parse usePodIdentity flag, error: %+v", err)
}
if len(useVMManagedIdentityStr) == 0 {
useVMManagedIdentityStr = "false"
}
useVMManagedIdentity, err := strconv.ParseBool(useVMManagedIdentityStr)
if err != nil {
return nil, fmt.Errorf("failed to parse useVMManagedIdentity flag, error: %+v", err)
return nil, nil, fmt.Errorf("failed to parse useVMManagedIdentity flag, error: %+v", err)
}

err = setAzureEnvironmentFilePath(cloudEnvFileName)
if err != nil {
return nil, fmt.Errorf("failed to set AZURE_ENVIRONMENT_FILEPATH env to %s, error %+v", cloudEnvFileName, err)
return nil, nil, fmt.Errorf("failed to set AZURE_ENVIRONMENT_FILEPATH env to %s, error %+v", cloudEnvFileName, err)
}
azureCloudEnv, err := ParseAzureEnvironment(cloudName)
if err != nil {
return nil, fmt.Errorf("cloudName %s is not valid, error: %v", cloudName, err)
return nil, nil, fmt.Errorf("cloudName %s is not valid, error: %v", cloudName, err)
}

p.AuthConfig, err = auth.NewConfig(usePodIdentity, useVMManagedIdentity, userAssignedIdentityID, secrets)
if err != nil {
return nil, fmt.Errorf("failed to create auth config, error: %+v", err)
return nil, nil, fmt.Errorf("failed to create auth config, error: %+v", err)
}

objectsStrings := attrib["objects"]
if objectsStrings == "" {
return nil, fmt.Errorf("objects is not set")
return nil, nil, fmt.Errorf("objects is not set")
}
klog.V(2).InfoS("objects string defined in secret provider class", "objects", objectsStrings, "pod", klog.ObjectRef{Namespace: p.PodNamespace, Name: p.PodName})

var objects StringArray
err = yaml.Unmarshal([]byte(objectsStrings), &objects)
if err != nil {
return nil, fmt.Errorf("failed to yaml unmarshal objects, error: %+v", err)
return nil, nil, fmt.Errorf("failed to yaml unmarshal objects, error: %+v", err)
}
klog.V(2).InfoS("unmarshaled objects yaml array", "objectsArray", objects.Array, "pod", klog.ObjectRef{Namespace: p.PodNamespace, Name: p.PodName})
var keyVaultObjects []KeyVaultObject
for i, object := range objects.Array {
var keyVaultObject KeyVaultObject
err = yaml.Unmarshal([]byte(object), &keyVaultObject)
if err != nil {
return nil, fmt.Errorf("unmarshal failed for keyVaultObjects at index %d, error: %+v", i, err)
return nil, nil, fmt.Errorf("unmarshal failed for keyVaultObjects at index %d, error: %+v", i, err)
}
// remove whitespace from all fields in keyVaultObject
formatKeyVaultObject(&keyVaultObject)
Expand All @@ -247,33 +246,30 @@ func (p *Provider) MountSecretsStoreObjectContent(ctx context.Context, attrib ma
klog.InfoS("unmarshaled key vault objects", "keyVaultObjects", keyVaultObjects, "count", len(keyVaultObjects), "pod", klog.ObjectRef{Namespace: p.PodNamespace, Name: p.PodName})

if len(keyVaultObjects) == 0 {
return nil, fmt.Errorf("objects array is empty")
return nil, nil, fmt.Errorf("objects array is empty")
}
p.KeyvaultName = keyvaultName
p.AzureCloudEnvironment = azureCloudEnv
p.TenantID = tenantID

objectVersionMap := make(map[string]string)
files := make(map[string][]byte)
for _, keyVaultObject := range keyVaultObjects {
klog.InfoS("fetching object from key vault", "objectName", keyVaultObject.ObjectName, "objectType", keyVaultObject.ObjectType, "keyvault", p.KeyvaultName, "pod", klog.ObjectRef{Namespace: p.PodNamespace, Name: p.PodName})
if err := validateObjectFormat(keyVaultObject.ObjectFormat, keyVaultObject.ObjectType); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
return nil, nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}
if err := validateObjectEncoding(keyVaultObject.ObjectEncoding, keyVaultObject.ObjectType); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
return nil, nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}
fileName := keyVaultObject.ObjectName
if keyVaultObject.ObjectAlias != "" {
fileName = keyVaultObject.ObjectAlias
}
if err := validateFileName(fileName); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}

// fetch the object from Key Vault
content, newObjectVersion, err := p.GetKeyVaultObjectContent(ctx, keyVaultObject)
if err != nil {
return nil, err
return nil, nil, err
}

// objectUID is a unique identifier in the format <object type>/<object name>
Expand All @@ -283,15 +279,12 @@ func (p *Provider) MountSecretsStoreObjectContent(ctx context.Context, attrib ma

objectContent, err := getContentBytes(content, keyVaultObject.ObjectType, keyVaultObject.ObjectEncoding)
if err != nil {
return nil, err
}
if err := os.WriteFile(filepath.Join(targetPath, fileName), objectContent, permission); err != nil {
return nil, errors.Wrapf(err, "failed to write file %s at %s", fileName, targetPath)
return nil, nil, err
}
klog.InfoS("successfully wrote file", "file", fileName, "pod", klog.ObjectRef{Namespace: p.PodNamespace, Name: p.PodName})
files[fileName] = objectContent
}

return objectVersionMap, nil
return files, objectVersionMap, nil
}

// GetKeyVaultObjectContent get content of the keyvault object
Expand Down Expand Up @@ -611,35 +604,6 @@ func formatKeyVaultObject(object *KeyVaultObject) {
}
}

// This validate will make sure fileName:
// 1. is not abs path
// 2. does not contain any '..' elements
// 3. does not start with '..'
// These checks have been implemented based on -
// [validateLocalDescendingPath] https://github.com/kubernetes/kubernetes/blob/master/pkg/apis/core/validation/validation.go#L1158-L1170
// [validatePathNoBacksteps] https://github.com/kubernetes/kubernetes/blob/master/pkg/apis/core/validation/validation.go#L1172-L1186
func validateFileName(fileName string) error {
if len(fileName) == 0 {
return fmt.Errorf("file name must not be empty")
}
// is not abs path
if filepath.IsAbs(fileName) {
return fmt.Errorf("file name must be a relative path")
}
// does not have any element which is ".."
parts := strings.Split(filepath.ToSlash(fileName), "/")
for _, item := range parts {
if item == ".." {
return fmt.Errorf("file name must not contain '..'")
}
}
// fallback logic if .. is missed in the previous check
if strings.Contains(fileName, "..") {
return fmt.Errorf("file name must not contain '..'")
}
return nil
}

type node struct {
cert *x509.Certificate
parent *node
Expand Down
45 changes: 1 addition & 44 deletions pkg/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,49 +524,6 @@ func TestFormatKeyVaultObject(t *testing.T) {
}
}

func TestValidateFilePath(t *testing.T) {
cases := []struct {
desc string
fileName string
expectedErr error
}{
{
desc: "file name is absolute path",
fileName: "/secret1",
expectedErr: fmt.Errorf("file name must be a relative path"),
},
{
desc: "file name contains '..'",
fileName: "secret1/..",
expectedErr: fmt.Errorf("file name must not contain '..'"),
},
{
desc: "file name starts with '..'",
fileName: "../secret1",
expectedErr: fmt.Errorf("file name must not contain '..'"),
},
{
desc: "file name is empty",
fileName: "",
expectedErr: fmt.Errorf("file name must not be empty"),
},
{
desc: "valid file name",
fileName: "secret1",
expectedErr: nil,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := validateFileName(tc.fileName)
if tc.expectedErr != nil && err.Error() != tc.expectedErr.Error() || tc.expectedErr == nil && err != nil {
t.Fatalf("expected err: %+v, got: %+v", tc.expectedErr, err)
}
})
}
}

func TestFetchCertChain(t *testing.T) {
rootCACert := `
-----BEGIN CERTIFICATE-----
Expand Down Expand Up @@ -877,7 +834,7 @@ func TestMountSecretsStoreObjectContent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "ut")
assert.NoError(t, err)

_, err = p.MountSecretsStoreObjectContent(context.TODO(), tc.parameters, tc.secrets, tmpDir, 0420)
_, _, err = p.MountSecretsStoreObjectContent(context.TODO(), tc.parameters, tc.secrets, tmpDir, 0420)
if tc.expectedErr {
assert.NotNil(t, err)
} else {
Expand Down
15 changes: 13 additions & 2 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
return &v1alpha1.MountResponse{}, fmt.Errorf("failed to initialize new provider, error: %v", err)
}

objectVersions, err := provider.MountSecretsStoreObjectContent(ctx, attrib, secret, req.GetTargetPath(), filePermission)
files, objectVersions, err := provider.MountSecretsStoreObjectContent(ctx, attrib, secret, req.GetTargetPath(), filePermission)
if err != nil {
klog.ErrorS(err, "failed to process mount request")
return &v1alpha1.MountResponse{}, fmt.Errorf("failed to mount objects, error: %v", err)
Expand All @@ -58,8 +58,19 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
for k, v := range objectVersions {
ov = append(ov, &v1alpha1.ObjectVersion{Id: k, Version: v})
}
var f []*v1alpha1.File
for k, v := range files {
f = append(f, &v1alpha1.File{
Path: k,
Contents: v,
Mode: int32(filePermission),
})
}

return &v1alpha1.MountResponse{ObjectVersion: ov}, nil
return &v1alpha1.MountResponse{
ObjectVersion: ov,
Files: f,
}, nil
}

func (s *CSIDriverProviderServer) Version(ctx context.Context, req *v1alpha1.VersionRequest) (*v1alpha1.VersionResponse, error) {
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ require (
k8s.io/apimachinery v0.20.2
k8s.io/client-go v0.20.2
sigs.k8s.io/controller-runtime v0.8.2
sigs.k8s.io/secrets-store-csi-driver v0.0.20
sigs.k8s.io/secrets-store-csi-driver v0.0.21
)

replace github.com/Azure/secrets-store-csi-driver-provider-azure => ../..

0 comments on commit 3859e6e

Please sign in to comment.