Skip to content

Commit

Permalink
fix(terraform): fix root module search (#6160)
Browse files Browse the repository at this point in the history
Co-authored-by: simar7 <1254783+simar7@users.noreply.github.com>
  • Loading branch information
nikpivkin and simar7 committed Feb 28, 2024
1 parent e1ea02c commit 1dfece8
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 60 deletions.
10 changes: 1 addition & 9 deletions pkg/iac/scanners/terraform/parser/evaluator.go
Expand Up @@ -191,18 +191,10 @@ func (e *evaluator) EvaluateAll(ctx context.Context) (terraform.Modules, map[str

e.debug.Log("Module evaluation complete.")
parseDuration += time.Since(start)
rootModule := terraform.NewModule(e.projectRootPath, e.modulePath, e.blocks, e.ignores, e.isModuleLocal())
for _, m := range modules {
m.SetParent(rootModule)
}
rootModule := terraform.NewModule(e.projectRootPath, e.modulePath, e.blocks, e.ignores)
return append(terraform.Modules{rootModule}, modules...), fsMap, parseDuration
}

func (e *evaluator) isModuleLocal() bool {
// the module source is empty only for local modules
return e.parentParser.moduleSource == ""
}

func (e *evaluator) expandBlocks(blocks terraform.Blocks) terraform.Blocks {
return e.expandDynamicBlocks(e.expandBlockForEaches(e.expandBlockCounts(blocks), false)...)
}
Expand Down
78 changes: 78 additions & 0 deletions pkg/iac/scanners/terraform/parser/modules.go
@@ -0,0 +1,78 @@
package parser

import (
"context"
"path"
"sort"
"strings"

"github.com/samber/lo"
"github.com/zclconf/go-cty/cty"

"github.com/aquasecurity/trivy/pkg/iac/terraform"
)

// FindRootModules takes a list of module paths and identifies the root local modules.
// It builds a graph based on the module dependencies and determines the modules that have no incoming dependencies,
// considering them as root modules.
func (p *Parser) FindRootModules(ctx context.Context, dirs []string) ([]string, error) {
for _, dir := range dirs {
if err := p.ParseFS(ctx, dir); err != nil {
return nil, err
}
}

blocks, _, err := p.readBlocks(p.files)
if err != nil {
return nil, err
}

g := buildGraph(blocks, dirs)
rootModules := g.rootModules()
sort.Strings(rootModules)
return rootModules, nil
}

type modulesGraph map[string][]string

func buildGraph(blocks terraform.Blocks, paths []string) modulesGraph {
moduleBlocks := blocks.OfType("module")

graph := lo.SliceToMap(paths, func(p string) (string, []string) {
return p, nil
})

for _, block := range moduleBlocks {
sourceVal := block.GetAttribute("source").Value()
if sourceVal.Type() != cty.String {
continue
}

source := sourceVal.AsString()
if strings.HasPrefix(source, ".") {
filename := block.GetMetadata().Range().GetFilename()
dir := path.Dir(filename)
graph[dir] = append(graph[dir], path.Join(dir, source))
}
}

return graph
}

func (g modulesGraph) rootModules() []string {
incomingEdges := make(map[string]int)
for _, neighbors := range g {
for _, neighbor := range neighbors {
incomingEdges[neighbor]++
}
}

var roots []string
for module := range g {
if incomingEdges[module] == 0 {
roots = append(roots, module)
}
}

return roots
}
71 changes: 71 additions & 0 deletions pkg/iac/scanners/terraform/parser/modules_test.go
@@ -0,0 +1,71 @@
package parser

import (
"context"
"path"
"testing"

"github.com/aquasecurity/trivy/internal/testutil"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)

func TestFindRootModules(t *testing.T) {
tests := []struct {
name string
files map[string]string
expected []string
}{
{
name: "multiple root modules",
files: map[string]string{
"code/main.tf": `
module "this" {
count = 0
source = "./modules/s3"
}`,
"code/modules/s3/main.tf": `
module "this" {
source = "./modules/logging"
}
resource "aws_s3_bucket" "this" {
bucket = "test"
}`,
"code/modules/s3/modules/logging/main.tf": `
resource "aws_s3_bucket" "this" {
bucket = "test1"
}`,
"code/example/main.tf": `
module "this" {
source = "../modules/s3"
}`,
},
expected: []string{"code", "code/example"},
},
{
name: "without module block",
files: map[string]string{
"code/infra1/main.tf": `resource "test" "this" {}`,
"code/infra2/main.tf": `resource "test" "this" {}`,
},
expected: []string{"code/infra1", "code/infra2"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fsys := testutil.CreateFS(t, tt.files)
parser := New(fsys, "", OptionStopOnHCLError(true))

modules := lo.Map(maps.Keys(tt.files), func(p string, _ int) string {
return path.Dir(p)
})

got, err := parser.FindRootModules(context.TODO(), modules)
require.NoError(t, err)
assert.Equal(t, tt.expected, got)
})
}
}
90 changes: 90 additions & 0 deletions pkg/iac/scanners/terraform/parser/parser_test.go
Expand Up @@ -1271,6 +1271,96 @@ func TestForEachWithObjectsOfDifferentTypes(t *testing.T) {
assert.Len(t, modules, 1)
}

func TestCountMetaArgument(t *testing.T) {
tests := []struct {
name string
src string
expected int
}{
{
name: "zero resources",
src: `resource "test" "this" {
count = 0
}`,
expected: 0,
},
{
name: "several resources",
src: `resource "test" "this" {
count = 2
}`,
expected: 2,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fsys := testutil.CreateFS(t, map[string]string{
"main.tf": tt.src,
})
parser := New(fsys, "", OptionStopOnHCLError(true))
require.NoError(t, parser.ParseFS(context.TODO(), "."))

modules, _, err := parser.EvaluateAll(context.TODO())
require.NoError(t, err)
assert.Len(t, modules, 1)

resources := modules.GetResourcesByType("test")
assert.Len(t, resources, tt.expected)
})
}
}

func TestCountMetaArgumentInModule(t *testing.T) {
tests := []struct {
name string
files map[string]string
expectedCountModules int
expectedCountResources int
}{
{
name: "zero modules",
files: map[string]string{
"main.tf": `module "this" {
count = 0
source = "./modules/test"
}`,
"modules/test/main.tf": `resource "test" "this" {}`,
},
expectedCountModules: 1,
expectedCountResources: 0,
},
{
name: "several modules",
files: map[string]string{
"main.tf": `module "this" {
count = 2
source = "./modules/test"
}`,
"modules/test/main.tf": `resource "test" "this" {}`,
},
expectedCountModules: 3,
expectedCountResources: 2,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fsys := testutil.CreateFS(t, tt.files)
parser := New(fsys, "", OptionStopOnHCLError(true))
require.NoError(t, parser.ParseFS(context.TODO(), "."))

modules, _, err := parser.EvaluateAll(context.TODO())
require.NoError(t, err)

assert.Len(t, modules, tt.expectedCountModules)

resources := modules.GetResourcesByType("test")
assert.Len(t, resources, tt.expectedCountResources)
})
}
}

func TestDynamicBlocks(t *testing.T) {
t.Run("arg is list of int", func(t *testing.T) {
modules := parse(t, map[string]string{
Expand Down

0 comments on commit 1dfece8

Please sign in to comment.