Skip to content

Commit

Permalink
fix(cloudformation): support of all SSE algorithms for s3 (#6270)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikpivkin committed Mar 7, 2024
1 parent 9361cdb commit 337cb75
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 18 deletions.
46 changes: 28 additions & 18 deletions pkg/iac/adapters/cloudformation/aws/s3/bucket.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
package s3

import (
"cmp"
"regexp"
"slices"
"strings"

s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"

"github.com/aquasecurity/trivy/pkg/iac/providers/aws/s3"
parser2 "github.com/aquasecurity/trivy/pkg/iac/scanners/cloudformation/parser"
"github.com/aquasecurity/trivy/pkg/iac/scanners/cloudformation/parser"
iacTypes "github.com/aquasecurity/trivy/pkg/iac/types"
)

var aclConvertRegex = regexp.MustCompile(`[A-Z][^A-Z]*`)

func getBuckets(cfFile parser2.FileContext) []s3.Bucket {
func getBuckets(cfFile parser.FileContext) []s3.Bucket {
var buckets []s3.Bucket
bucketResources := cfFile.GetResourcesByType("AWS::S3::Bucket")

Expand All @@ -37,10 +41,15 @@ func getBuckets(cfFile parser2.FileContext) []s3.Bucket {

buckets = append(buckets, s3b)
}

slices.SortFunc(buckets, func(a, b s3.Bucket) int {
return cmp.Compare(a.Name.Value(), b.Name.Value())
})

return buckets
}

func getPublicAccessBlock(r *parser2.Resource) *s3.PublicAccessBlock {
func getPublicAccessBlock(r *parser.Resource) *s3.PublicAccessBlock {
if block := r.GetProperty("PublicAccessBlockConfiguration"); block.IsNil() {
return nil
}
Expand All @@ -60,8 +69,7 @@ func convertAclValue(aclValue iacTypes.StringValue) iacTypes.StringValue {
return iacTypes.String(strings.ToLower(strings.Join(matches, "-")), aclValue.GetMetadata())
}

func getLogging(r *parser2.Resource) s3.Logging {

func getLogging(r *parser.Resource) s3.Logging {
logging := s3.Logging{
Metadata: r.Metadata(),
Enabled: iacTypes.BoolDefault(false, r.Metadata()),
Expand All @@ -77,7 +85,7 @@ func getLogging(r *parser2.Resource) s3.Logging {
return logging
}

func hasVersioning(r *parser2.Resource) iacTypes.BoolValue {
func hasVersioning(r *parser.Resource) iacTypes.BoolValue {
versioningProp := r.GetProperty("VersioningConfiguration.Status")

if versioningProp.IsNil() {
Expand All @@ -92,8 +100,7 @@ func hasVersioning(r *parser2.Resource) iacTypes.BoolValue {
return iacTypes.Bool(versioningEnabled, versioningProp.Metadata())
}

func getEncryption(r *parser2.Resource, _ parser2.FileContext) s3.Encryption {

func getEncryption(r *parser.Resource, _ parser.FileContext) s3.Encryption {
encryption := s3.Encryption{
Metadata: r.Metadata(),
Enabled: iacTypes.BoolDefault(false, r.Metadata()),
Expand All @@ -103,23 +110,26 @@ func getEncryption(r *parser2.Resource, _ parser2.FileContext) s3.Encryption {

if encryptProps := r.GetProperty("BucketEncryption.ServerSideEncryptionConfiguration"); encryptProps.IsNotNil() {
for _, rule := range encryptProps.AsList() {
if algo := rule.GetProperty("ServerSideEncryptionByDefault.SSEAlgorithm"); algo.EqualTo("AES256") {
encryption.Enabled = iacTypes.Bool(true, algo.Metadata())
} else if kmsKeyProp := rule.GetProperty("ServerSideEncryptionByDefault.KMSMasterKeyID"); !kmsKeyProp.IsEmpty() && kmsKeyProp.IsString() {
encryption.KMSKeyId = kmsKeyProp.AsStringValue()
algo := rule.GetProperty("ServerSideEncryptionByDefault.SSEAlgorithm")
if algo.IsString() {
algoVal := algo.AsString()
isValidAlgo := slices.Contains(s3types.ServerSideEncryption("").Values(), s3types.ServerSideEncryption(algoVal))
encryption.Enabled = iacTypes.Bool(isValidAlgo, algo.Metadata())
encryption.Algorithm = algo.AsStringValue()
}
if encryption.Enabled.IsFalse() {
encryption.Enabled = rule.GetBoolProperty("BucketKeyEnabled", false)

kmsKeyProp := rule.GetProperty("ServerSideEncryptionByDefault.KMSMasterKeyID")
if !kmsKeyProp.IsEmpty() && kmsKeyProp.IsString() {
encryption.KMSKeyId = kmsKeyProp.AsStringValue()
}
}
}

return encryption
}

func getLifecycle(resource *parser2.Resource) []s3.Rules {
LifecycleProp := resource.GetProperty("LifecycleConfiguration")
RuleProp := LifecycleProp.GetProperty("Rules")
func getLifecycle(resource *parser.Resource) []s3.Rules {
RuleProp := resource.GetProperty("LifecycleConfiguration.Rules")

var rule []s3.Rules

Expand All @@ -136,7 +146,7 @@ func getLifecycle(resource *parser2.Resource) []s3.Rules {
return rule
}

func getWebsite(r *parser2.Resource) *s3.Website {
func getWebsite(r *parser.Resource) *s3.Website {
if block := r.GetProperty("WebsiteConfiguration"); block.IsNil() {
return nil
} else {
Expand Down
156 changes: 156 additions & 0 deletions pkg/iac/adapters/cloudformation/aws/s3/s3_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package s3

import (
"context"
"testing"

"github.com/aquasecurity/trivy/internal/testutil"
"github.com/aquasecurity/trivy/pkg/iac/providers/aws/s3"
"github.com/aquasecurity/trivy/pkg/iac/scanners/cloudformation/parser"
"github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/stretchr/testify/require"
)

func TestAdapt(t *testing.T) {
tests := []struct {
name string
source string
expected s3.S3
}{
{
name: "complete s3 bucket",
source: `AWSTemplateFormatVersion: 2010-09-09
Resources:
Key:
Type: "AWS::KMS::Key"
LoggingBucket:
Type: AWS::S3::Bucket
Properties:
BucketName: logging-bucket
Bucket:
Type: AWS::S3::Bucket
Properties:
BucketName: test-bucket
BucketEncryption:
ServerSideEncryptionConfiguration:
- ServerSideEncryptionByDefault:
KMSMasterKeyID:
Fn::GetAtt:
- Key
- Arn
SSEAlgorithm: aws:kms
AccessControl: AwsExecRead
PublicAccessBlockConfiguration:
BlockPublicAcls: true
BlockPublicPolicy: true
IgnorePublicAcls: true
RestrictPublicBuckets: true
LoggingConfiguration:
DestinationBucketName: !Ref LoggingBucket
LogFilePrefix: testing-logs
LifecycleConfiguration:
Rules:
- Id: GlacierRule
Prefix: glacier
Status: Enabled
ExpirationInDays: 365
AccelerateConfiguration:
AccelerationStatus: Enabled
`,
expected: s3.S3{
Buckets: []s3.Bucket{
{
Name: types.String("logging-bucket", types.NewTestMetadata()),
},
{
Name: types.String("test-bucket", types.NewTestMetadata()),
Encryption: s3.Encryption{
Enabled: types.Bool(true, types.NewTestMetadata()),
Algorithm: types.String("aws:kms", types.NewTestMetadata()),
KMSKeyId: types.String("Key", types.NewTestMetadata()),
},
ACL: types.String("aws-exec-read", types.NewTestMetadata()),
PublicAccessBlock: &s3.PublicAccessBlock{
BlockPublicACLs: types.Bool(true, types.NewTestMetadata()),
BlockPublicPolicy: types.Bool(true, types.NewTestMetadata()),
IgnorePublicACLs: types.Bool(true, types.NewTestMetadata()),
RestrictPublicBuckets: types.Bool(true, types.NewTestMetadata()),
},
Logging: s3.Logging{
TargetBucket: types.String("LoggingBucket", types.NewTestMetadata()),
Enabled: types.Bool(true, types.NewTestMetadata()),
},
LifecycleConfiguration: []s3.Rules{
{
Status: types.String("Enabled", types.NewTestMetadata()),
},
},
AccelerateConfigurationStatus: types.String("Enabled", types.NewTestMetadata()),
},
},
},
},
{
name: "empty s3 bucket",
source: `AWSTemplateFormatVersion: 2010-09-09
Resources:
Bucket:
Type: AWS::S3::Bucket
Properties:
BucketName: test-bucket`,
expected: s3.S3{
Buckets: []s3.Bucket{
{
Name: types.String("test-bucket", types.NewTestMetadata()),
Encryption: s3.Encryption{
Enabled: types.BoolDefault(false, types.NewTestMetadata()),
},
},
},
},
},
{
name: "incorrect SSE algorithm",
source: `AWSTemplateFormatVersion: 2010-09-09
Resources:
Bucket:
Type: AWS::S3::Bucket
Properties:
BucketName: test-bucket
BucketEncryption:
ServerSideEncryptionConfiguration:
- ServerSideEncryptionByDefault:
KMSMasterKeyID: alias/my-key
SSEAlgorithm: aes256
`,
expected: s3.S3{
Buckets: []s3.Bucket{
{
Name: types.String("test-bucket", types.NewTestMetadata()),
Encryption: s3.Encryption{
Enabled: types.BoolDefault(false, types.NewTestMetadata()),
KMSKeyId: types.String("alias/my-key", types.NewTestMetadata()),
Algorithm: types.String("aes256", types.NewTestMetadata()),
},
},
},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

fsys := testutil.CreateFS(t, map[string]string{
"main.yaml": tt.source,
})

fctx, err := parser.New().ParseFile(context.TODO(), fsys, "main.yaml")
require.NoError(t, err)

adapted := Adapt(*fctx)
testutil.AssertDefsecEqual(t, tt.expected, adapted)
})
}

}

0 comments on commit 337cb75

Please sign in to comment.