Skip to content

Commit

Permalink
Implement Aggregates Collection and Usage (#323)
Browse files Browse the repository at this point in the history
* Implement aggregates feature: #282

Co-authored-by: Sebastian Esponda <842946+sesponda@users.noreply.github.com>
  • Loading branch information
sesponda and sesponda committed Sep 23, 2023
1 parent bbdceb7 commit 0098ff7
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 13 deletions.
7 changes: 7 additions & 0 deletions bundle/regal/main.rego
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ report contains violation if {
not ignored(violation, ignore_directives)
}

aggregate contains aggregate if {
some category, title
config.for_rule(category, title).level != "ignore"
not config.excluded_file(category, title, input.regal.file.name)
aggregate := data.custom.regal.rules[category][title].aggregate[_]
}

ignored(violation, directives) if {
ignored_rules := directives[violation.location.row]
violation.title in ignored_rules
Expand Down
46 changes: 46 additions & 0 deletions e2e/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,52 @@ func TestLintRuleNamingConventionFromCustomCategory(t *testing.T) {
}
}

func TestAggregatesAreCollectedAndUsed(t *testing.T) {
t.Parallel()
cwd := must(os.Getwd)
basedir := cwd + "/testdata/aggregates"

t.Run("Zero violations expected", func(t *testing.T) {
stdout := bytes.Buffer{}
stderr := bytes.Buffer{}

err := regal(&stdout, &stderr)("lint", "--format", "json", basedir+"/rego", "--rules", basedir+"/rules/custom_rules_using_aggregates.rego")

if exp, act := 0, ExitStatus(err); exp != act {
t.Errorf("expected exit status %d, got %d", exp, act)
}

if exp, act := "", stderr.String(); exp != act {
t.Errorf("expected stderr %q, got %q", exp, act)
}
})

t.Run("One violation expected", func(t *testing.T) {
stdout := bytes.Buffer{}
stderr := bytes.Buffer{}
// By sending a single file to the command, we skip the aggregates computation, so we expect one violation
err := regal(&stdout, &stderr)("lint", "--format", "json", basedir+"/rego/policy_1.rego", "--rules", basedir+"/rules/custom_rules_using_aggregates.rego")

if exp, act := 3, ExitStatus(err); exp != act {
t.Errorf("expected exit status %d, got %d", exp, act)
}

if exp, act := "", stderr.String(); exp != act {
t.Errorf("expected stderr %q, got %q", exp, act)
}

var rep report.Report

if err = json.Unmarshal(stdout.Bytes(), &rep); err != nil {
t.Fatalf("expected JSON response, got %v", stdout.String())
}

if rep.Summary.NumViolations != 1 {
t.Errorf("expected 1 violation, got %d", rep.Summary.NumViolations)
}
})
}

func TestTestRegalBundledBundle(t *testing.T) {
t.Parallel()

Expand Down
3 changes: 3 additions & 0 deletions e2e/testdata/aggregates/rego/policy_1.rego
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package mypolicy1.public

my_policy_1 := true
3 changes: 3 additions & 0 deletions e2e/testdata/aggregates/rego/policy_2.rego
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package mypolicy2.public

export := []
20 changes: 20 additions & 0 deletions e2e/testdata/aggregates/rules/custom_rules_using_aggregates.rego
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# METADATA
# description: Collect data in aggregates and validate it
package custom.regal.rules.testcase["aggregates"]

import future.keywords
import data.regal.result

aggregate contains entry if {
entry := { "file" : input.regal.file.name }
}

report contains violation if {
not two_files_processed
violation := result.fail(rego.metadata.chain(), {})
}

two_files_processed {
files := [x | x = input.aggregate[_].file]
count(files) == 2
}
145 changes: 133 additions & 12 deletions pkg/linter/linter.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ type Linter struct {
metrics metrics.Metrics
}

type QueryInputBuilder func(name string, content string, module *ast.Module) (map[string]any, error)

type ReportCollector func(report report.Report)

const regalUserConfig = "regal_user_config"

// NewLinter creates a new Regal linter.
Expand Down Expand Up @@ -180,7 +184,7 @@ var query = ast.MustParseBody("violations = data.regal.main.report") //nolint:go
func (l Linter) Lint(ctx context.Context) (report.Report, error) {
l.startTimer(regalmetrics.RegalLint)

aggregate := report.Report{}
aggregateReport := report.Report{}

if len(l.inputPaths) == 0 && l.inputModules == nil {
return report.Report{}, errors.New("nothing provided to lint")
Expand Down Expand Up @@ -240,29 +244,39 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) {
return report.Report{}, fmt.Errorf("failed to lint using Go rules: %w", err)
}

aggregate.Violations = append(aggregate.Violations, goReport.Violations...)
aggregateReport.Violations = append(aggregateReport.Violations, goReport.Violations...)

var aggregates []report.Aggregate

if len(input.Modules) > 1 {
// No need to collect aggregates if there's only one file
aggregates, err = l.collectAggregates(ctx, input)
if err != nil {
return report.Report{}, fmt.Errorf("failed to collect aggregates using Rego rules: %w", err)
}
}

regoReport, err := l.lintWithRegoRules(ctx, input)
regoReport, err := l.lintWithRegoRules(ctx, input, aggregates)
if err != nil {
return report.Report{}, fmt.Errorf("failed to lint using Rego rules: %w", err)
}

aggregate.Violations = append(aggregate.Violations, regoReport.Violations...)
aggregateReport.Violations = append(aggregateReport.Violations, regoReport.Violations...)

aggregate.Summary = report.Summary{
aggregateReport.Summary = report.Summary{
FilesScanned: len(input.FileNames),
FilesFailed: len(aggregate.ViolationsFileCount()),
FilesFailed: len(aggregateReport.ViolationsFileCount()),
FilesSkipped: 0,
NumViolations: len(aggregate.Violations),
NumViolations: len(aggregateReport.Violations),
}

if l.metrics != nil {
l.metrics.Timer(regalmetrics.RegalLint).Stop()

aggregate.Metrics = l.metrics.All()
aggregateReport.Metrics = l.metrics.All()
}

return aggregate, nil
return aggregateReport, nil
}

func (l Linter) lintWithGoRules(ctx context.Context, input rules.Input) (report.Report, error) {
Expand Down Expand Up @@ -414,7 +428,7 @@ func (l Linter) paramsToRulesConfig() map[string]any {
}
}

func (l Linter) prepareRegoArgs() []func(*rego.Rego) {
func (l Linter) prepareRegoArgs(query ast.Body) []func(*rego.Rego) {
var regoArgs []func(*rego.Rego)

roots := []string{"eval"}
Expand Down Expand Up @@ -466,14 +480,16 @@ func (l Linter) prepareRegoArgs() []func(*rego.Rego) {
return regoArgs
}

func (l Linter) lintWithRegoRules(ctx context.Context, input rules.Input) (report.Report, error) {
func (l Linter) lintWithRegoRules(
ctx context.Context, input rules.Input, aggregates []report.Aggregate,
) (report.Report, error) {
l.startTimer(regalmetrics.RegalLintRego)
defer l.stopTimer(regalmetrics.RegalLintRego)

ctx, cancel := context.WithCancel(ctx)
defer cancel()

regoArgs := l.prepareRegoArgs()
regoArgs := l.prepareRegoArgs(query)

linterQuery, err := rego.New(regoArgs...).PrepareForEval(ctx)
if err != nil {
Expand Down Expand Up @@ -502,6 +518,10 @@ func (l Linter) lintWithRegoRules(ctx context.Context, input rules.Input) (repor
return
}

if len(aggregates) > 0 {
enhancedAST["aggregate"] = aggregates
}

evalArgs := []rego.EvalOption{
rego.EvalInput(enhancedAST),
}
Expand Down Expand Up @@ -738,6 +758,107 @@ func (l Linter) getBundleByName(name string) (*bundle.Bundle, error) {
return nil, fmt.Errorf("no regal bundle found")
}

func (l Linter) collectAggregates(ctx context.Context, input rules.Input) ([]report.Aggregate, error) {
var result []report.Aggregate

regoArgs := l.prepareRegoArgs(ast.MustParseBody("aggregates = data.regal.main.aggregate"))

var linterQuery rego.PreparedEvalQuery

var err error

if linterQuery, err = rego.New(regoArgs...).PrepareForEval(ctx); err != nil {
return []report.Aggregate{}, fmt.Errorf("failed preparing query for linting: %w", err)
}

if err = l.evalAndCollect(ctx, input, linterQuery,
// query input builder
func(name string, content string, module *ast.Module) (map[string]any, error) {
result, err := parse.EnhanceAST(name, input.FileContent[name], input.Modules[name])
if err != nil {
return nil,
fmt.Errorf("could not enhance AST when buiding input during lint with Rego rules: %w", err)
}

return result, nil
},
// result collector
func(report report.Report) {
result = append(result, report.Aggregates...)
},
); err != nil {
return nil, err
}

return result, nil
}

// Process each file in input.Filenames in a goroutine, with the given Rego query and building the eval input using the
// provided function. Collects the results via the provided collector. The collector is guaranteed to
// run sequentially via a mutex.
func (l Linter) evalAndCollect(ctx context.Context, input rules.Input, query rego.PreparedEvalQuery,
queryInputBuilder QueryInputBuilder,
reportCollector ReportCollector,
) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

var wg sync.WaitGroup

var mu sync.Mutex

errCh := make(chan error)

doneCh := make(chan bool)

for _, name := range input.FileNames {
wg.Add(1)

go func(name string) {
defer wg.Done()

queryInput, err := queryInputBuilder(name, input.FileContent[name], input.Modules[name])
if err != nil {
errCh <- fmt.Errorf("failed building query input: %w", err)

return
}

resultSet, err := query.Eval(ctx, rego.EvalInput(queryInput))
if err != nil {
errCh <- fmt.Errorf("error encountered in query evaluation %w", err)

return
}

result, err := resultSetToReport(resultSet)
if err != nil {
errCh <- fmt.Errorf("failed to convert result set to report: %w", err)

return
}

mu.Lock()
reportCollector(result)
mu.Unlock()
}(name)
}

go func() {
wg.Wait()
doneCh <- true
}()

select {
case <-ctx.Done():
return fmt.Errorf("context cancelled: %w", ctx.Err())
case err := <-errCh:
return fmt.Errorf("error encountered in rule evaluation %w", err)
case <-doneCh:
return nil
}
}

func (l Linter) startTimer(name string) {
if l.metrics != nil {
l.metrics.Timer(name).Start()
Expand Down
11 changes: 10 additions & 1 deletion pkg/report/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ type Violation struct {
Location Location `json:"location,omitempty"`
}

// An Aggregate is data collected by some rule while processing a file AST, to be used later by other rules needing a
// global context (i.e. broader than per-file)
// Rule authors are expected to collect the minimum needed data, to avoid performance problems
// while working with large Rego code repositories.
type Aggregate map[string]any

type Summary struct {
FilesScanned int `json:"files_scanned"`
FilesFailed int `json:"files_failed"`
Expand All @@ -38,7 +44,10 @@ type Summary struct {

// Report aggregate of Violation as returned by a linter run.
type Report struct {
Violations []Violation `json:"violations"`
Violations []Violation `json:"violations"`
// We don't have aggregates when publishing the final report (see JSONReporter), so omitempty is needed here
// to avoid surfacing a null/empty field.
Aggregates []Aggregate `json:"aggregates,omitempty"`
Summary Summary `json:"summary"`
Metrics map[string]any `json:"metrics,omitempty"`
}
Expand Down

0 comments on commit 0098ff7

Please sign in to comment.