Skip to content

Commit

Permalink
refactor: update validation and make it fail-fast (#917)
Browse files Browse the repository at this point in the history
Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>
  • Loading branch information
aramase committed Jun 17, 2022
1 parent d90847c commit 92fa62e
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 298 deletions.
127 changes: 15 additions & 112 deletions pkg/provider/provider.go
Expand Up @@ -12,10 +12,8 @@ import (
"fmt"
"math/big"
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -198,6 +196,11 @@ func (p *Provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec
}
// remove whitespace from all fields in keyVaultObject
formatKeyVaultObject(&keyVaultObject)

if err = validate(keyVaultObject); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}

keyVaultObjects = append(keyVaultObjects, keyVaultObject)
}

Expand All @@ -222,25 +225,6 @@ func (p *Provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec
files := []types.SecretFile{}
for _, keyVaultObject := range keyVaultObjects {
klog.V(5).InfoS("fetching object from key vault", "objectName", keyVaultObject.ObjectName, "objectType", keyVaultObject.ObjectType, "keyvault", mc.keyvaultName, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
if err := validateObjectFormat(keyVaultObject.ObjectFormat, keyVaultObject.ObjectType); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}
if err := validateObjectEncoding(keyVaultObject.ObjectEncoding, keyVaultObject.ObjectType); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}
fileName := keyVaultObject.ObjectName
if keyVaultObject.ObjectAlias != "" {
fileName = keyVaultObject.ObjectAlias
}
if err := validateFileName(fileName); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}

filePermission, err := validateFilePermission(keyVaultObject.FilePermission, defaultFilePermission)
if err != nil {
return nil, err
}

// fetch the object from Key Vault
content, newObjectVersion, err := p.getKeyVaultObjectContent(ctx, kvClient, keyVaultObject, *vaultURL)
if err != nil {
Expand All @@ -255,16 +239,17 @@ func (p *Provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec
// objectUID is a unique identifier in the format <object type>/<object name>
// This is the object id the user sees in the SecretProviderClassPodStatus
objectUID := getObjectUID(keyVaultObject.ObjectName, keyVaultObject.ObjectType)
file := types.SecretFile{
Path: keyVaultObject.GetFileName(),
Content: objectContent,
UID: objectUID,
Version: newObjectVersion,
}
// the validity of file permission is already checked in the validate function above
file.FileMode, _ = keyVaultObject.GetFilePermission(defaultFilePermission)

// these files will be returned to the CSI driver as part of gRPC response
files = append(files, types.SecretFile{
Path: fileName,
Content: objectContent,
FileMode: filePermission,
UID: objectUID,
Version: newObjectVersion,
})
klog.V(5).InfoS("added file to the gRPC response", "file", fileName, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
files = append(files, file)
klog.V(5).InfoS("added file to the gRPC response", "file", file.Path, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
}

return files, nil
Expand Down Expand Up @@ -531,23 +516,6 @@ func setAzureEnvironmentFilePath(envFileName string) error {
return os.Setenv(azure.EnvironmentFilepathName, envFileName)
}

// validateObjectFormat checks if the object format is valid and is supported
// for the given object type
func validateObjectFormat(objectFormat, objectType string) error {
if len(objectFormat) == 0 {
return nil
}
if !strings.EqualFold(objectFormat, types.ObjectFormatPEM) && !strings.EqualFold(objectFormat, types.ObjectFormatPFX) {
return fmt.Errorf("invalid objectFormat: %v, should be PEM or PFX", objectFormat)
}
// Azure Key Vault returns the base64 encoded binary content only for type secret
// for types cert/key, the content is always in pem format
if objectFormat == types.ObjectFormatPFX && objectType != types.VaultObjectTypeSecret {
return fmt.Errorf("PFX format only supported for objectType: secret")
}
return nil
}

// getObjectVersion parses the id to retrieve the version
// of object fetched
// example id format - https://kindkv.vault.azure.net/secrets/actual/1f304204f3624873aab40231241243eb
Expand All @@ -564,25 +532,6 @@ func getObjectUID(objectName, objectType string) string {
return fmt.Sprintf("%s/%s", objectType, objectName)
}

// validateObjectEncoding checks if the object encoding is valid and is supported
// for the given object type
func validateObjectEncoding(objectEncoding, objectType string) error {
if len(objectEncoding) == 0 {
return nil
}

// ObjectEncoding is supported only for secret types
if objectType != types.VaultObjectTypeSecret {
return fmt.Errorf("objectEncoding only supported for objectType: secret")
}

if !strings.EqualFold(objectEncoding, types.ObjectEncodingHex) && !strings.EqualFold(objectEncoding, types.ObjectEncodingBase64) && !strings.EqualFold(objectEncoding, types.ObjectEncodingUtf8) {
return fmt.Errorf("invalid objectEncoding: %v, should be hex, base64 or utf-8", objectEncoding)
}

return nil
}

// getContentBytes takes the given content string and returns the bytes to write to disk
// If an encoding is specified it will decode the string first
func getContentBytes(content, objectType, objectEncoding string) ([]byte, error) {
Expand Down Expand Up @@ -620,35 +569,6 @@ func formatKeyVaultObject(object *types.KeyVaultObject) {
}
}

// This validate will make sure fileName:
// 1. is not abs path
// 2. does not contain any '..' elements
// 3. does not start with '..'
// These checks have been implemented based on -
// [validateLocalDescendingPath] https://github.com/kubernetes/kubernetes/blob/master/pkg/apis/core/validation/validation.go#L1158-L1170
// [validatePathNoBacksteps] https://github.com/kubernetes/kubernetes/blob/master/pkg/apis/core/validation/validation.go#L1172-L1186
func validateFileName(fileName string) error {
if len(fileName) == 0 {
return fmt.Errorf("file name must not be empty")
}
// is not abs path
if filepath.IsAbs(fileName) {
return fmt.Errorf("file name must be a relative path")
}
// does not have any element which is ".."
parts := strings.Split(filepath.ToSlash(fileName), "/")
for _, item := range parts {
if item == ".." {
return fmt.Errorf("file name must not contain '..'")
}
}
// fallback logic if .. is missed in the previous check
if strings.Contains(fileName, "..") {
return fmt.Errorf("file name must not contain '..'")
}
return nil
}

type node struct {
cert *x509.Certificate
parent *node
Expand Down Expand Up @@ -744,20 +664,3 @@ func fetchCertChains(data []byte) ([]byte, error) {
}
return pemData, nil
}

// validateFilePermission checks if the given file permission is correct octal number and returns
// a. decimal equivalent of the default file permission (0644) if file permission is not provided Or
// b. decimal equivalent Or
// c. error if it's not valid
func validateFilePermission(filePermission string, defaultFilePermission os.FileMode) (int32, error) {
if filePermission == "" {
return int32(defaultFilePermission), nil
}

permission, err := strconv.ParseInt(filePermission, 8, 32)
if err != nil {
return 0, fmt.Errorf("file permission must be a valid octal number: %w", err)
}

return int32(permission), nil
}
186 changes: 0 additions & 186 deletions pkg/provider/provider_test.go
Expand Up @@ -248,104 +248,6 @@ func TestParseAzureEnvironmentAzureStackCloud(t *testing.T) {
}
}

func TestValidateObjectFormat(t *testing.T) {
cases := []struct {
desc string
objectFormat string
objectType string
expectedErr error
}{
{
desc: "no object format specified",
objectFormat: "",
objectType: "cert",
expectedErr: nil,
},
{
desc: "object format not valid",
objectFormat: "pkcs",
objectType: "secret",
expectedErr: fmt.Errorf("invalid objectFormat: pkcs, should be PEM or PFX"),
},
{
desc: "object format PFX, but object type not secret",
objectFormat: "pfx",
objectType: "cert",
expectedErr: fmt.Errorf("PFX format only supported for objectType: secret"),
},
{
desc: "object format PFX case insensitive check",
objectFormat: "PFX",
objectType: "secret",
expectedErr: nil,
},
{
desc: "valid object format and type",
objectFormat: "pfx",
objectType: "secret",
expectedErr: nil,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := validateObjectFormat(tc.objectFormat, tc.objectType)
if tc.expectedErr != nil && err.Error() != tc.expectedErr.Error() || tc.expectedErr == nil && err != nil {
t.Fatalf("expected err: %+v, got: %+v", tc.expectedErr, err)
}
})
}
}

func TestValidateObjectEncoding(t *testing.T) {
cases := []struct {
desc string
objectEncoding string
objectType string
expectedErr error
}{
{
desc: "No encoding specified",
objectEncoding: "",
objectType: "cert",
expectedErr: nil,
},
{
desc: "Invalid encoding specified",
objectEncoding: "utf-16",
objectType: "secret",
expectedErr: fmt.Errorf("invalid objectEncoding: utf-16, should be hex, base64 or utf-8"),
},
{
desc: "Object Encoding Base64, but objectType is not secret",
objectEncoding: "base64",
objectType: "cert",
expectedErr: fmt.Errorf("objectEncoding only supported for objectType: secret"),
},
{
desc: "Object Encoding case-insensitive check",
objectEncoding: "BasE64",
objectType: "secret",
expectedErr: nil,
},
{
desc: "Valid ObjectEncoding and Type",
objectEncoding: "base64",
objectType: "secret",
expectedErr: nil,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := validateObjectEncoding(tc.objectEncoding, tc.objectType)
if tc.expectedErr != nil && err.Error() != tc.expectedErr.Error() || tc.expectedErr == nil && err != nil {
t.Fatalf("expected err: %+v, got: %+v", tc.expectedErr, err)
}
})
}
}

func TestGetContentBytes(t *testing.T) {
cases := []struct {
desc string
Expand Down Expand Up @@ -480,49 +382,6 @@ func TestFormatKeyVaultObject(t *testing.T) {
}
}

func TestValidateFilePath(t *testing.T) {
cases := []struct {
desc string
fileName string
expectedErr error
}{
{
desc: "file name is absolute path",
fileName: "/secret1",
expectedErr: fmt.Errorf("file name must be a relative path"),
},
{
desc: "file name contains '..'",
fileName: "secret1/..",
expectedErr: fmt.Errorf("file name must not contain '..'"),
},
{
desc: "file name starts with '..'",
fileName: "../secret1",
expectedErr: fmt.Errorf("file name must not contain '..'"),
},
{
desc: "file name is empty",
fileName: "",
expectedErr: fmt.Errorf("file name must not be empty"),
},
{
desc: "valid file name",
fileName: "secret1",
expectedErr: nil,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := validateFileName(tc.fileName)
if tc.expectedErr != nil && err.Error() != tc.expectedErr.Error() || tc.expectedErr == nil && err != nil {
t.Fatalf("expected err: %+v, got: %+v", tc.expectedErr, err)
}
})
}
}

func TestFetchCertChain(t *testing.T) {
rootCACert := `
-----BEGIN CERTIFICATE-----
Expand Down Expand Up @@ -1050,48 +909,3 @@ func TestGetObjectVersion(t *testing.T) {
actual := getObjectVersion(id)
assert.Equal(t, expectedVersion, actual)
}

func TestValidateFilePermisssion(t *testing.T) {
cases := []struct {
desc string
filePermission string
defaultFilePermission os.FileMode
isErrorExpected bool
}{
{
desc: "valid file permission",
filePermission: "0600",
defaultFilePermission: os.FileMode(0644),
isErrorExpected: false,
},
{
desc: "empty file permission",
filePermission: "",
defaultFilePermission: os.FileMode(0644),
isErrorExpected: false,
},
{
desc: "invalid file permission",
filePermission: "0900",
defaultFilePermission: os.FileMode(0644),
isErrorExpected: true,
},
{
desc: "invalid octal number",
filePermission: "900",
defaultFilePermission: os.FileMode(0644),
isErrorExpected: true,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
_, err := validateFilePermission(tc.filePermission, tc.defaultFilePermission)
if tc.isErrorExpected {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

0 comments on commit 92fa62e

Please sign in to comment.