Skip to content

Commit

Permalink
handle traverseIndex when checking reference (#969)
Browse files Browse the repository at this point in the history
add supporting test
  • Loading branch information
Owen Rumney committed Jul 29, 2021
1 parent 8808315 commit 4b60c5e
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 5 deletions.
30 changes: 30 additions & 0 deletions example/counts/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,34 @@ variable "trust-sg-rules" {
type = "egress"
}
]
}

resource "aws_s3_bucket" "access-logs-bucket" {
count = var.enable_cloudtrail ? 1 : 0
bucket = "cloudtrail-access-logs"
acl = "private"
force_destroy = true

versioning {
enabled = true
}

server_side_encryption_configuration {
rule {
apply_server_side_encryption_by_default {
sse_algorithm = "AES256"
}
}
}
}

resource "aws_s3_bucket_public_access_block" "access-logs" {
count = var.enable_cloudtrail ? 1 : 0

bucket = aws_s3_bucket.access-logs-bucket[0].id

block_public_acls = true
block_public_policy = true
ignore_public_acls = true
restrict_public_buckets = true
}
18 changes: 18 additions & 0 deletions internal/app/tfsec/block/hclattribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,30 @@ func createDotReferenceFromTraversal(traversals ...hcl.Traversal) (*Reference, e
refParts = append(refParts, part.Name)
case hcl.TraverseAttr:
refParts = append(refParts, part.Name)
case hcl.TraverseIndex:
refParts[len(refParts)-1] = fmt.Sprintf("%s[%s]", refParts[len(refParts)-1], getIndexValue(part))
}
}
}
return newReference(refParts)
}

func getIndexValue(part hcl.TraverseIndex) string {
switch part.Key.Type() {
case cty.String:
return fmt.Sprintf("%q", part.Key.AsString())
case cty.Number:
var intVal int
if err := gocty.FromCtyValue(part.Key, &intVal); err != nil {
debug.Log("could not unpack the int, returning 0")
return "0"
}
return fmt.Sprintf("%d", intVal)
default:
return "0"
}
}

func (attr *HCLAttribute) Reference() (*Reference, error) {
if attr == nil {
return nil, fmt.Errorf("attribute is nil")
Expand Down
45 changes: 40 additions & 5 deletions internal/app/tfsec/test/count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,51 @@ func Test_ResourcesWithCount(t *testing.T) {
mustExcludeResultCode: "aws-vpc-no-default-vpc",
},
{
name: "count is 1 from conditional",
name: "count is 0 from conditional",
source: `
variable "enabled" {
default = true
default = false
}
resource "aws_default_vpc" "this" {
count = var.enabled ? 1 : 0
count = var.enabled ? 1 : 0
}
`,
mustIncludeResultCode: "aws-vpc-no-default-vpc",
mustExcludeResultCode: "aws-vpc-no-default-vpc",
},
{
name: "issue 962",
source: `
resource "aws_s3_bucket" "access-logs-bucket" {
count = var.enable_cloudtrail ? 1 : 0
bucket = "cloudtrail-access-logs"
acl = "private"
force_destroy = true
versioning {
enabled = true
}
server_side_encryption_configuration {
rule {
apply_server_side_encryption_by_default {
sse_algorithm = "AES256"
}
}
}
}
resource "aws_s3_bucket_public_access_block" "access-logs" {
count = var.enable_cloudtrail ? 1 : 0
bucket = aws_s3_bucket.access-logs-bucket[0].id
block_public_acls = true
block_public_policy = true
ignore_public_acls = true
restrict_public_buckets = true
}
`,
mustExcludeResultCode: "aws-s3-specify-public-access-block",
},
{
name: "Test use of count.index",
Expand Down Expand Up @@ -138,7 +173,7 @@ variable "trust-sg-rules" {
]
}
`,
mustExcludeResultCode: "AWS018",
mustExcludeResultCode: "aws-vpc-add-decription-to-security-group",
},
}

Expand Down
20 changes: 20 additions & 0 deletions internal/app/tfsec/testutil/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ func AssertCheckCode(t *testing.T, includeCode string, excludeCode string, resul

var excludeText string

if !validateCodes(includeCode, excludeCode) {
t.Logf("Either includeCode (%s) or excludeCode (%s) was invalid ", includeCode, excludeCode)
t.FailNow()
}

for _, res := range results {
if res.RuleID == excludeCode {
foundExclude = true
Expand Down Expand Up @@ -106,3 +111,18 @@ func CreateTestFileWithModule(contents string, moduleContents string) string {

return rootPath
}

func validateCodes(includeCode, excludeCode string) bool {
if includeCode != "" {
if _, err := scanner.GetRuleById(includeCode); err != nil {
return false
}
}

if excludeCode != "" {
if _, err := scanner.GetRuleById(excludeCode); err != nil {
return false
}
}
return true
}

0 comments on commit 4b60c5e

Please sign in to comment.