Skip to content

Commit

Permalink
Add --profile option to regal lint (#361)
Browse files Browse the repository at this point in the history
This sets a profiler on eval in each goroutine, and
then accumulates the result from each into a final
report. At this point we're only including this in
the JSON output format (similar to `--metrics`) but
we'll want to support this for each format some time
later.

Fixes #334

Signed-off-by: Anders Eknert <anders@styra.com>
  • Loading branch information
anderseknert committed Oct 3, 2023
1 parent 58bb7b8 commit 28a51f7
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 3 deletions.
7 changes: 7 additions & 0 deletions cmd/lint.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type lintCommandParams struct {
debug bool
enablePrint bool
metrics bool
profile bool
disable repeatedStringFlag
disableAll bool
disableCategory repeatedStringFlag
Expand Down Expand Up @@ -141,6 +142,8 @@ func init() {
"enable print output from policy")
lintCommand.Flags().BoolVar(&params.metrics, "metrics", false,
"enable metrics reporting (currently supported only for JSON output format)")
lintCommand.Flags().BoolVar(&params.profile, "profile", false,
"enable profiling metrics to be added to reporting (currently supported only for JSON output format)")

lintCommand.Flags().VarP(&params.disable, "disable", "d",
"disable specific rule(s). This flag can be repeated.")
Expand Down Expand Up @@ -253,6 +256,10 @@ func lint(args []string, params lintCommandParams) (report.Report, error) {
m.Timer(regalmetrics.RegalConfigParse).Start()
}

if params.profile {
regal = regal.WithProfiling(true)
}

var userConfig config.Config

userConfigFile, err := readUserConfig(params, regalDir)
Expand Down
16 changes: 16 additions & 0 deletions internal/metrics/metrics.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package metrics

import (
"github.com/open-policy-agent/opa/profiler"

"github.com/styrainc/regal/pkg/report"
)

const (
RegalConfigSearch = "regal_config_search"
RegalConfigParse = "regal_config_parse"
Expand All @@ -11,3 +17,13 @@ const (
RegalLintRego = "regal_lint_rego"
RegalLintRegoAggregate = "regal_lint_rego_aggregate"
)

func FromExprStats(stats profiler.ExprStats) report.ProfileEntry {
return report.ProfileEntry{
Location: stats.Location.String(),
TotalTimeNs: stats.ExprTimeNs,
NumEval: stats.NumEval,
NumRedo: stats.NumRedo,
NumGenExpr: stats.NumGenExpr,
}
}
38 changes: 38 additions & 0 deletions pkg/linter/linter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/bundle"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/profiler"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/topdown"
"github.com/open-policy-agent/opa/topdown/print"
Expand Down Expand Up @@ -49,6 +50,7 @@ type Linter struct {
enableCategory []string
ignoreFiles []string
metrics metrics.Metrics
profiling bool
}

const regalUserConfig = "regal_user_config"
Expand Down Expand Up @@ -181,6 +183,13 @@ func (l Linter) WithPrintHook(printHook print.Hook) Linter {
return l
}

// WithProfiling enables profiling metrics.
func (l Linter) WithProfiling(enabled bool) Linter {
l.profiling = enabled

return l
}

// Lint runs the linter on provided policies.
func (l Linter) Lint(ctx context.Context) (report.Report, error) {
l.startTimer(regalmetrics.RegalLint)
Expand Down Expand Up @@ -276,6 +285,12 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) {
finalReport.Metrics = l.metrics.All()
}

if l.profiling {
finalReport.AggregateProfile = regoReport.AggregateProfile
finalReport.AggregateProfileToSortedProfile(10)
finalReport.AggregateProfile = nil
}

return finalReport, nil
}

Expand Down Expand Up @@ -480,6 +495,7 @@ func (l Linter) prepareRegoArgs(query ast.Body) []func(*rego.Rego) {
return regoArgs
}

//nolint:gocognit
func (l Linter) lintWithRegoRules(ctx context.Context, input rules.Input) (report.Report, error) {
l.startTimer(regalmetrics.RegalLintRego)
defer l.stopTimer(regalmetrics.RegalLintRego)
Expand Down Expand Up @@ -532,6 +548,12 @@ func (l Linter) lintWithRegoRules(ctx context.Context, input rules.Input) (repor
evalArgs = append(evalArgs, rego.EvalMetrics(l.metrics))
}

var prof *profiler.Profiler
if l.profiling {
prof = profiler.New()
evalArgs = append(evalArgs, rego.EvalQueryTracer(prof))
}

resultSet, err := pq.Eval(ctx, evalArgs...)
if err != nil {
errCh <- fmt.Errorf("error encountered in query evaluation %w", err)
Expand All @@ -546,12 +568,28 @@ func (l Linter) lintWithRegoRules(ctx context.Context, input rules.Input) (repor
return
}

if l.profiling {
// Perhaps we'll want to make this number configurable later, but do note that
// this is only the top 10 locations for a *single* file, not the final report.
profRep := prof.ReportTopNResults(10, []string{"total_time_ns"})

result.AggregateProfile = make(map[string]report.ProfileEntry)

for _, rs := range profRep {
result.AggregateProfile[rs.Location.String()] = regalmetrics.FromExprStats(rs)
}
}

mu.Lock()
aggregate.Violations = append(aggregate.Violations, result.Violations...)

for k := range result.Aggregates {
aggregate.Aggregates[k] = append(aggregate.Aggregates[k], result.Aggregates[k]...)
}

if l.profiling {
aggregate.AddProfileEntries(result.AggregateProfile)
}
mu.Unlock()
}(name)
}
Expand Down
112 changes: 109 additions & 3 deletions pkg/report/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ package report

import (
"fmt"
"io"
"sort"
"strings"
"time"

"github.com/olekukonko/tablewriter"
)

// RelatedResource provides documentation on a violation.
Expand Down Expand Up @@ -47,9 +53,109 @@ type Report struct {
Violations []Violation `json:"violations"`
// We don't have aggregates when publishing the final report (see JSONReporter), so omitempty is needed here
// to avoid surfacing a null/empty field.
Aggregates map[string][]Aggregate `json:"aggregates,omitempty"`
Summary Summary `json:"summary"`
Metrics map[string]any `json:"metrics,omitempty"`
Aggregates map[string][]Aggregate `json:"aggregates,omitempty"`
Summary Summary `json:"summary"`
Metrics map[string]any `json:"metrics,omitempty"`
AggregateProfile map[string]ProfileEntry `json:"-"`
Profile []ProfileEntry `json:"profile,omitempty"`
}

// ProfileEntry is a single entry of profiling information, keyed by location.
// This data may have been aggregated across multiple runs.
type ProfileEntry struct {
Location string `json:"location"`
TotalTimeNs int64 `json:"total_time_ns"`
NumEval int `json:"num_eval"`
NumRedo int `json:"num_redo"`
NumGenExpr int `json:"num_gen_expr"`
}

func (r *Report) AddProfileEntries(prof map[string]ProfileEntry) {
if r.AggregateProfile == nil {
r.AggregateProfile = map[string]ProfileEntry{}
}

for loc, entry := range prof {
if _, ok := r.AggregateProfile[loc]; !ok {
r.AggregateProfile[loc] = entry
} else {
profCopy := prof[loc]
profCopy.NumEval += entry.NumEval
profCopy.NumRedo += entry.NumRedo
profCopy.NumGenExpr += entry.NumGenExpr
profCopy.TotalTimeNs += entry.TotalTimeNs
profCopy.Location = entry.Location
r.AggregateProfile[loc] = profCopy
}
}
}

func (r *Report) AggregateProfileToSortedProfile(numResults int) {
r.Profile = make([]ProfileEntry, 0, len(r.AggregateProfile))

for loc, rs := range r.AggregateProfile {
rs.Location = loc

r.Profile = append(r.Profile, r.AggregateProfile[loc])
}

sort.Slice(r.Profile, func(i, j int) bool {
return r.Profile[i].TotalTimeNs > r.Profile[j].TotalTimeNs
})

if numResults <= 0 || numResults > len(r.Profile) {
return
}

r.Profile = r.Profile[:numResults]
}

// TODO: This does not belong here and is only for internal testing purposes at this point in time. Profile reports are
// currently only publicly available for the JSON reporter. Some variation of this will eventually be moved to the table
// reporter. (this code borrowed from OPA).
func (r Report) printProfile(w io.Writer) { //nolint:unused
tableProfile := generateTableProfile(w)

for i, rs := range r.Profile {
timeNs := time.Duration(rs.TotalTimeNs) * time.Nanosecond
line := []string{
timeNs.String(),
fmt.Sprintf("%d", rs.NumEval),
fmt.Sprintf("%d", rs.NumRedo),
fmt.Sprintf("%d", rs.NumGenExpr),
rs.Location,
}
tableProfile.Append(line)

if i == 0 {
tableProfile.SetFooter([]string{"", "", "", "", ""})
}
}

if tableProfile.NumLines() > 0 {
tableProfile.Render()
}
}

func generateTableWithKeys(writer io.Writer, keys ...string) *tablewriter.Table { //nolint:unused
table := tablewriter.NewWriter(writer)
aligns := make([]int, 0, len(keys))
hdrs := make([]string, 0, len(keys))

for _, k := range keys {
hdrs = append(hdrs, strings.Title(k)) //nolint:staticcheck // SA1019, no unicode here
aligns = append(aligns, tablewriter.ALIGN_LEFT)
}

table.SetHeader(hdrs)
table.SetAlignment(tablewriter.ALIGN_CENTER)
table.SetColumnAlignment(aligns)

return table
}

func generateTableProfile(writer io.Writer) *tablewriter.Table { //nolint:unused
return generateTableWithKeys(writer, "Time", "Num Eval", "Num Redo", "Num Gen Expr", "Location")
}

// ViolationsFileCount returns the number of files containing violations.
Expand Down

0 comments on commit 28a51f7

Please sign in to comment.