diff --git a/integration/aws_cloud_test.go b/integration/aws_cloud_test.go index e992c758f87b..3b1ca9568fab 100644 --- a/integration/aws_cloud_test.go +++ b/integration/aws_cloud_test.go @@ -8,13 +8,14 @@ import ( "testing" "time" - awscommands "github.com/aquasecurity/trivy/pkg/cloud/aws/commands" - "github.com/aquasecurity/trivy/pkg/flag" dockercontainer "github.com/docker/docker/api/types/container" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - testcontainers "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/localstack" + + awscommands "github.com/aquasecurity/trivy/pkg/cloud/aws/commands" + "github.com/aquasecurity/trivy/pkg/flag" ) func TestAwsCommandRun(t *testing.T) { diff --git a/integration/client_server_test.go b/integration/client_server_test.go index 352fd8253444..2b21479e0f16 100644 --- a/integration/client_server_test.go +++ b/integration/client_server_test.go @@ -5,6 +5,7 @@ package integration import ( "context" "fmt" + "github.com/aquasecurity/trivy/pkg/types" "os" "path/filepath" "strings" @@ -15,16 +16,15 @@ import ( "github.com/docker/go-connections/nat" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - testcontainers "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go" "github.com/aquasecurity/trivy/pkg/report" - "github.com/aquasecurity/trivy/pkg/uuid" ) type csArgs struct { Command string RemoteAddrOption string - Format string + Format types.Format TemplatePath string IgnoreUnfixed bool Severity []string @@ -265,19 +265,15 @@ func TestClientServer(t *testing.T) { addr, cacheDir := setup(t, setupOptions{}) - for _, c := range tests { - t.Run(c.name, func(t *testing.T) { - osArgs, outputFile := setupClient(t, c.args, addr, cacheDir, c.golden) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + osArgs := setupClient(t, tt.args, addr, cacheDir, tt.golden) - if c.args.secretConfig != "" { - osArgs = append(osArgs, "--secret-config", c.args.secretConfig) + if tt.args.secretConfig != "" { + osArgs = append(osArgs, "--secret-config", tt.args.secretConfig) } - // - err := execute(osArgs) - require.NoError(t, err) - - compareReports(t, c.golden, outputFile, nil) + runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{}) }) } } @@ -389,19 +385,9 @@ func TestClientServerWithFormat(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Setenv("AWS_REGION", "test-region") t.Setenv("AWS_ACCOUNT_ID", "123456789012") - osArgs, outputFile := setupClient(t, tt.args, addr, cacheDir, tt.golden) - - // Run Trivy client - err := execute(osArgs) - require.NoError(t, err) - - want, err := os.ReadFile(tt.golden) - require.NoError(t, err) - - got, err := os.ReadFile(outputFile) - require.NoError(t, err) + osArgs := setupClient(t, tt.args, addr, cacheDir, tt.golden) - assert.EqualValues(t, string(want), string(got)) + runTest(t, osArgs, tt.golden, "", tt.args.Format, runOptions{}) }) } } @@ -425,21 +411,16 @@ func TestClientServerWithCycloneDX(t *testing.T) { addr, cacheDir := setup(t, setupOptions{}) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - uuid.SetFakeUUID(t, "3ff14136-e09f-4df9-80ea-%012d") - - osArgs, outputFile := setupClient(t, tt.args, addr, cacheDir, tt.golden) - - // Run Trivy client - err := execute(osArgs) - require.NoError(t, err) - - compareCycloneDX(t, tt.golden, outputFile) + osArgs := setupClient(t, tt.args, addr, cacheDir, tt.golden) + runTest(t, osArgs, tt.golden, "", types.FormatCycloneDX, runOptions{ + fakeUUID: "3ff14136-e09f-4df9-80ea-%012d", + }) }) } } func TestClientServerWithToken(t *testing.T) { - cases := []struct { + tests := []struct { name string args csArgs golden string @@ -481,20 +462,10 @@ func TestClientServerWithToken(t *testing.T) { tokenHeader: serverTokenHeader, }) - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - osArgs, outputFile := setupClient(t, c.args, addr, cacheDir, c.golden) - - // Run Trivy client - err := execute(osArgs) - if c.wantErr != "" { - require.Error(t, err, c.name) - assert.Contains(t, err.Error(), c.wantErr, c.name) - return - } - - require.NoError(t, err, c.name) - compareReports(t, c.golden, outputFile, nil) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + osArgs := setupClient(t, tt.args, addr, cacheDir, tt.golden) + runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{wantErr: tt.wantErr}) }) } } @@ -517,25 +488,22 @@ func TestClientServerWithRedis(t *testing.T) { golden := "testdata/alpine-39.json.golden" t.Run("alpine 3.9", func(t *testing.T) { - osArgs, outputFile := setupClient(t, testArgs, addr, cacheDir, golden) + osArgs := setupClient(t, testArgs, addr, cacheDir, golden) // Run Trivy client - err := execute(osArgs) - require.NoError(t, err) - - compareReports(t, golden, outputFile, nil) + runTest(t, osArgs, golden, "", types.FormatJSON, runOptions{}) }) // Terminate the Redis container require.NoError(t, redisC.Terminate(ctx)) t.Run("sad path", func(t *testing.T) { - osArgs, _ := setupClient(t, testArgs, addr, cacheDir, golden) + osArgs := setupClient(t, testArgs, addr, cacheDir, golden) // Run Trivy client - err := execute(osArgs) - require.Error(t, err) - assert.Contains(t, err.Error(), "unable to store cache") + runTest(t, osArgs, "", "", types.FormatJSON, runOptions{ + wantErr: "unable to store cache", + }) }) } @@ -595,7 +563,7 @@ func setupServer(addr, token, tokenHeader, cacheDir, cacheBackend string) []stri return osArgs } -func setupClient(t *testing.T, c csArgs, addr string, cacheDir string, golden string) ([]string, string) { +func setupClient(t *testing.T, c csArgs, addr string, cacheDir string, golden string) []string { if c.Command == "" { c.Command = "image" } @@ -612,7 +580,7 @@ func setupClient(t *testing.T, c csArgs, addr string, cacheDir string, golden st } if c.Format != "" { - osArgs = append(osArgs, "--format", c.Format) + osArgs = append(osArgs, "--format", string(c.Format)) if c.TemplatePath != "" { osArgs = append(osArgs, "--template", c.TemplatePath) } @@ -642,19 +610,11 @@ func setupClient(t *testing.T, c csArgs, addr string, cacheDir string, golden st osArgs = append(osArgs, "--input", c.Input) } - // Set up the output file - outputFile := filepath.Join(t.TempDir(), "output.json") - if *update { - outputFile = golden - } - - osArgs = append(osArgs, "--output", outputFile) - if c.Target != "" { osArgs = append(osArgs, c.Target) } - return osArgs, outputFile + return osArgs } func setupRedis(t *testing.T, ctx context.Context) (testcontainers.Container, string) { diff --git a/integration/config_test.go b/integration/config_test.go new file mode 100644 index 000000000000..b7c58e8238d6 --- /dev/null +++ b/integration/config_test.go @@ -0,0 +1,230 @@ +//go:build integration + +package integration + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/aquasecurity/trivy/pkg/types" +) + +// TestConfiguration tests the configuration of the CLI flags, environmental variables, and config file +func TestConfiguration(t *testing.T) { + type args struct { + input string + flags map[string]string + envs map[string]string + configFile string + } + type test struct { + name string + args args + golden string + wantErr string + } + + tests := []test{ + { + name: "skip files", + args: args{ + input: "testdata/fixtures/repo/gomod", + flags: map[string]string{ + "scanners": "vuln", + "skip-files": "path/to/dummy,testdata/fixtures/repo/gomod/submod2/go.mod", + }, + envs: map[string]string{ + "TRIVY_SCANNERS": "vuln", + "TRIVY_SKIP_FILES": "path/to/dummy,testdata/fixtures/repo/gomod/submod2/go.mod", + }, + configFile: `--- +scan: + scanners: + - vuln + skip-files: + - path/to/dummy + - testdata/fixtures/repo/gomod/submod2/go.mod +`, + }, + golden: "testdata/gomod-skip.json.golden", + }, + { + name: "dockerfile with custom file pattern", + args: args{ + input: "testdata/fixtures/repo/dockerfile_file_pattern", + flags: map[string]string{ + "scanners": "misconfig", + "file-patterns": "dockerfile:Customfile", + "namespaces": "testing", + }, + envs: map[string]string{ + "TRIVY_SCANNERS": "misconfig", + "TRIVY_FILE_PATTERNS": "dockerfile:Customfile", + "TRIVY_NAMESPACES": "testing", + }, + configFile: `--- +scan: + scanners: + - misconfig + file-patterns: + - dockerfile:Customfile +rego: + skip-policy-update: true + namespaces: + - testing +`, + }, + golden: "testdata/dockerfile_file_pattern.json.golden", + }, + { + name: "key alias", // "--scanners" vs "--security-checks" + args: args{ + input: "testdata/fixtures/repo/gomod", + flags: map[string]string{ + "security-checks": "vuln", + }, + envs: map[string]string{ + "TRIVY_SECURITY_CHECKS": "vuln", + }, + configFile: `--- +scan: + security-checks: + - vuln +`, + }, + golden: "testdata/gomod.json.golden", + }, + { + name: "value alias", // "--scanners vuln" vs "--scanners vulnerability" + args: args{ + input: "testdata/fixtures/repo/gomod", + flags: map[string]string{ + "scanners": "vulnerability", + }, + envs: map[string]string{ + "TRIVY_SCANNERS": "vulnerability", + }, + configFile: `--- +scan: + scanners: + - vulnerability +`, + }, + golden: "testdata/gomod.json.golden", + }, + { + name: "invalid value", + args: args{ + input: "testdata/fixtures/repo/gomod", + flags: map[string]string{ + "scanners": "vulnerability", + "severity": "CRITICAL,INVALID", + }, + envs: map[string]string{ + "TRIVY_SCANNERS": "vulnerability", + "TRIVY_SEVERITY": "CRITICAL,INVALID", + }, + configFile: `--- +scan: + scanners: + - vulnerability +severity: + - CRITICAL + - INVALID +`, + }, + wantErr: `invalid argument "[CRITICAL INVALID]" for "--severity" flag`, + }, + } + + // Set up testing DB + cacheDir := initDB(t) + + // Set a temp dir so that modules will not be loaded + t.Setenv("XDG_DATA_HOME", cacheDir) + + for _, tt := range tests { + command := "repo" + + t.Run(tt.name+" with CLI flags", func(t *testing.T) { + osArgs := []string{ + "--format", + "json", + "--cache-dir", + cacheDir, + "--skip-db-update", + "--skip-policy-update", + command, + tt.args.input, + } + for key, value := range tt.args.flags { + osArgs = append(osArgs, "--"+key, value) + } + + // Set up the output file + outputFile := filepath.Join(t.TempDir(), "output.json") + osArgs = append(osArgs, "--output", outputFile) + + runTest(t, osArgs, tt.golden, outputFile, types.FormatJSON, runOptions{ + wantErr: tt.wantErr, + }) + }) + + t.Run(tt.name+" with environmental variables", func(t *testing.T) { + // Set up the output file + outputFile := filepath.Join(t.TempDir(), "output.json") + + t.Setenv("TRIVY_OUTPUT", outputFile) + t.Setenv("TRIVY_FORMAT", "json") + t.Setenv("TRIVY_CACHE_DIR", cacheDir) + t.Setenv("TRIVY_SKIP_DB_UPDATE", "true") + t.Setenv("TRIVY_SKIP_POLICY_UPDATE", "true") + for key, value := range tt.args.envs { + t.Setenv(key, value) + } + + osArgs := []string{ + command, + tt.args.input, + } + + runTest(t, osArgs, tt.golden, outputFile, types.FormatJSON, runOptions{ + wantErr: tt.wantErr, + }) + }) + + t.Run(tt.name+" with config file", func(t *testing.T) { + // Set up the output file + outputFile := filepath.Join(t.TempDir(), "output.json") + + configFile := tt.args.configFile + configFile = configFile + fmt.Sprintf(` +format: json +output: %s +cache: + dir: %s +db: + skip-update: true +`, outputFile, cacheDir) + + configPath := filepath.Join(t.TempDir(), "trivy.yaml") + err := os.WriteFile(configPath, []byte(configFile), 0444) + require.NoError(t, err) + + osArgs := []string{ + command, + "--config", + configPath, + tt.args.input, + } + + runTest(t, osArgs, tt.golden, outputFile, types.FormatJSON, runOptions{ + wantErr: tt.wantErr, + }) + }) + } +} diff --git a/integration/docker_engine_test.go b/integration/docker_engine_test.go index c69f6d34a0a7..97222d08d328 100644 --- a/integration/docker_engine_test.go +++ b/integration/docker_engine_test.go @@ -5,9 +5,9 @@ package integration import ( "context" + "github.com/aquasecurity/trivy/pkg/types" "io" "os" - "path/filepath" "strings" "testing" @@ -40,18 +40,24 @@ func TestDockerEngine(t *testing.T) { golden: "testdata/alpine-39.json.golden", }, { - name: "alpine:3.9, with high and critical severity", - severity: []string{"HIGH", "CRITICAL"}, + name: "alpine:3.9, with high and critical severity", + severity: []string{ + "HIGH", + "CRITICAL", + }, imageTag: "ghcr.io/aquasecurity/trivy-test-images:alpine-39", input: "testdata/fixtures/images/alpine-39.tar.gz", golden: "testdata/alpine-39-high-critical.json.golden", }, { - name: "alpine:3.9, with .trivyignore", - imageTag: "ghcr.io/aquasecurity/trivy-test-images:alpine-39", - ignoreIDs: []string{"CVE-2019-1549", "CVE-2019-14697"}, - input: "testdata/fixtures/images/alpine-39.tar.gz", - golden: "testdata/alpine-39-ignore-cveids.json.golden", + name: "alpine:3.9, with .trivyignore", + imageTag: "ghcr.io/aquasecurity/trivy-test-images:alpine-39", + ignoreIDs: []string{ + "CVE-2019-1549", + "CVE-2019-14697", + }, + input: "testdata/fixtures/images/alpine-39.tar.gz", + golden: "testdata/alpine-39-ignore-cveids.json.golden", }, { name: "alpine:3.10", @@ -244,13 +250,28 @@ func TestDockerEngine(t *testing.T) { // tag our image to something unique err = cli.ImageTag(ctx, tt.imageTag, tt.input) require.NoError(t, err, tt.name) - } - tmpDir := t.TempDir() - output := filepath.Join(tmpDir, "result.json") + // cleanup + t.Cleanup(func() { + _, err = cli.ImageRemove(ctx, tt.input, api.ImageRemoveOptions{ + Force: true, + PruneChildren: true, + }) + _, err = cli.ImageRemove(ctx, tt.imageTag, api.ImageRemoveOptions{ + Force: true, + PruneChildren: true, + }) + assert.NoError(t, err, tt.name) + }) + } - osArgs := []string{"--cache-dir", cacheDir, "image", - "--skip-update", "--format=json", "--output", output} + osArgs := []string{ + "--cache-dir", + cacheDir, + "image", + "--skip-update", + "--format=json", + } if tt.ignoreUnfixed { osArgs = append(osArgs, "--ignore-unfixed") @@ -258,12 +279,18 @@ func TestDockerEngine(t *testing.T) { if len(tt.ignoreStatus) != 0 { osArgs = append(osArgs, - []string{"--ignore-status", strings.Join(tt.ignoreStatus, ",")}..., + []string{ + "--ignore-status", + strings.Join(tt.ignoreStatus, ","), + }..., ) } if len(tt.severity) != 0 { osArgs = append(osArgs, - []string{"--severity", strings.Join(tt.severity, ",")}..., + []string{ + "--severity", + strings.Join(tt.severity, ","), + }..., ) } if len(tt.ignoreIDs) != 0 { @@ -275,28 +302,7 @@ func TestDockerEngine(t *testing.T) { osArgs = append(osArgs, tt.input) // Run Trivy - err = execute(osArgs) - if tt.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr, tt.name) - return - } - - assert.NoError(t, err, tt.name) - - // check for vulnerability output info - compareReports(t, tt.golden, output, nil) - - // cleanup - _, err = cli.ImageRemove(ctx, tt.input, api.ImageRemoveOptions{ - Force: true, - PruneChildren: true, - }) - _, err = cli.ImageRemove(ctx, tt.imageTag, api.ImageRemoveOptions{ - Force: true, - PruneChildren: true, - }) - assert.NoError(t, err, tt.name) + runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{wantErr: tt.wantErr}) }) } } diff --git a/integration/integration_test.go b/integration/integration_test.go index 05f9c094cd03..43fe3ac8c820 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -7,7 +7,6 @@ import ( "encoding/json" "flag" "fmt" - "github.com/aquasecurity/trivy/pkg/clock" "io" "net" "os" @@ -22,15 +21,18 @@ import ( spdxjson "github.com/spdx/tools-golang/json" "github.com/spdx/tools-golang/spdx" "github.com/spdx/tools-golang/spdxlib" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xeipuuv/gojsonschema" "github.com/aquasecurity/trivy-db/pkg/db" "github.com/aquasecurity/trivy-db/pkg/metadata" + "github.com/aquasecurity/trivy/pkg/clock" "github.com/aquasecurity/trivy/pkg/commands" "github.com/aquasecurity/trivy/pkg/dbtest" "github.com/aquasecurity/trivy/pkg/types" + "github.com/aquasecurity/trivy/pkg/uuid" _ "modernc.org/sqlite" ) @@ -190,7 +192,56 @@ func readSpdxJson(t *testing.T, filePath string) *spdx.Document { return bom } +type runOptions struct { + wantErr string + override func(want, got *types.Report) + fakeUUID string +} + +// runTest runs Trivy with the given args and compares the output with the golden file. +// If outputFile is empty, the output file is created in a temporary directory. +// If update is true, the golden file is updated. +func runTest(t *testing.T, osArgs []string, wantFile, outputFile string, format types.Format, opts runOptions) { + if opts.fakeUUID != "" { + uuid.SetFakeUUID(t, opts.fakeUUID) + } + + if outputFile == "" { + // Set up the output file + outputFile = filepath.Join(t.TempDir(), "output.json") + if *update && opts.override == nil { + outputFile = wantFile + } + } + osArgs = append(osArgs, "--output", outputFile) + + // Run Trivy + err := execute(osArgs) + if opts.wantErr != "" { + require.ErrorContains(t, err, opts.wantErr) + return + } + require.NoError(t, err) + + // Compare want and got + switch format { + case types.FormatCycloneDX: + compareCycloneDX(t, wantFile, outputFile) + case types.FormatSPDXJSON: + compareSPDXJson(t, wantFile, outputFile) + case types.FormatJSON: + compareReports(t, wantFile, outputFile, opts.override) + case types.FormatTemplate, types.FormatSarif, types.FormatGitHub: + compareRawFiles(t, wantFile, outputFile) + default: + require.Fail(t, "invalid format", "format: %s", format) + } +} + func execute(osArgs []string) error { + // viper.XXX() (e.g. viper.ReadInConfig()) affects the global state, so we need to reset it after each test. + defer viper.Reset() + // Set a fake time ctx := clock.With(context.Background(), time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC)) @@ -203,11 +254,19 @@ func execute(osArgs []string) error { return app.ExecuteContext(ctx) } -func compareReports(t *testing.T, wantFile, gotFile string, override func(*types.Report)) { +func compareRawFiles(t *testing.T, wantFile, gotFile string) { + want, err := os.ReadFile(wantFile) + require.NoError(t, err) + got, err := os.ReadFile(gotFile) + require.NoError(t, err) + assert.EqualValues(t, string(want), string(got)) +} + +func compareReports(t *testing.T, wantFile, gotFile string, override func(want, got *types.Report)) { want := readReport(t, wantFile) got := readReport(t, gotFile) if override != nil { - override(&want) + override(&want, &got) } assert.Equal(t, want, got) } diff --git a/integration/module_test.go b/integration/module_test.go index c2fd3928922b..16745da1c8fd 100644 --- a/integration/module_test.go +++ b/integration/module_test.go @@ -3,11 +3,10 @@ package integration import ( + "github.com/aquasecurity/trivy/pkg/types" "path/filepath" "testing" - "github.com/stretchr/testify/require" - "github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/scanner/post" ) @@ -51,27 +50,13 @@ func TestModule(t *testing.T) { tt.input, } - // Set up the output file - outputFile := filepath.Join(t.TempDir(), "output.json") - if *update { - outputFile = tt.golden - } - - osArgs = append(osArgs, []string{ - "--output", - outputFile, - }...) - - // Run Trivy - err := execute(osArgs) - require.NoError(t, err) - defer func() { + t.Cleanup(func() { analyzer.DeregisterAnalyzer("spring4shell") post.DeregisterPostScanner("spring4shell") - }() + }) - // Compare want and got - compareReports(t, tt.golden, outputFile, nil) + // Run Trivy + runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{}) }) } } diff --git a/integration/registry_test.go b/integration/registry_test.go index 0d343a075d71..b62865667dc3 100644 --- a/integration/registry_test.go +++ b/integration/registry_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package integration @@ -11,6 +10,7 @@ import ( "crypto/x509" "encoding/json" "fmt" + "github.com/aquasecurity/trivy/pkg/types" "io" "net/http" "net/url" @@ -24,9 +24,8 @@ import ( "github.com/google/go-containerregistry/pkg/name" "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/tarball" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - testcontainers "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" ) @@ -62,7 +61,10 @@ func setupRegistry(ctx context.Context, baseDir string, authURL *url.URL) (testc HostConfigModifier: func(hostConfig *dockercontainer.HostConfig) { hostConfig.AutoRemove = true }, - WaitingFor: wait.ForLog("listening on [::]:5443"), + WaitingFor: wait.ForHTTP("v2").WithTLS(true).WithAllowInsecure(true). + WithStatusCodeMatcher(func(status int) bool { + return status == http.StatusUnauthorized + }), } registryC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ @@ -191,62 +193,50 @@ func TestRegistry(t *testing.T) { imageRef, err := name.ParseReference(s) require.NoError(t, err) - // 1. Load a test image from the tar file, tag it and push to the test registry. + // Load a test image from the tar file, tag it and push to the test registry. err = replicateImage(imageRef, tc.imageFile, auth) require.NoError(t, err) - // 2. Scan it - resultFile, err := scan(t, imageRef, baseDir, tc.golden, tc.option) - - if tc.wantErr != "" { - require.Error(t, err) - require.Contains(t, err.Error(), tc.wantErr, err) - return - } - require.NoError(t, err) - - // 3. Read want and got - want := readReport(t, tc.golden) - got := readReport(t, resultFile) - - // 4 Update some dynamic fields - want.ArtifactName = s - for i := range want.Results { - want.Results[i].Target = fmt.Sprintf("%s (alpine 3.10.2)", s) - } - - // 5. Compare want and got - assert.Equal(t, want, got) + osArgs, err := scan(t, imageRef, baseDir, tc.golden, tc.option) + + // Run Trivy + runTest(t, osArgs, tc.golden, "", types.FormatJSON, runOptions{ + wantErr: tc.wantErr, + override: func(_, got *types.Report) { + got.ArtifactName = tc.imageName + for i := range got.Results { + got.Results[i].Target = fmt.Sprintf("%s (alpine 3.10.2)", tc.imageName) + } + }, + }) }) } } -func scan(t *testing.T, imageRef name.Reference, baseDir, goldenFile string, opt registryOption) (string, error) { +func scan(t *testing.T, imageRef name.Reference, baseDir, goldenFile string, opt registryOption) ([]string, error) { // Set up testing DB cacheDir := initDB(t) // Set a temp dir so that modules will not be loaded t.Setenv("XDG_DATA_HOME", cacheDir) - // Setup the output file - outputFile := filepath.Join(t.TempDir(), "output.json") - if *update { - outputFile = goldenFile - } - // Setup env if err := setupEnv(t, imageRef, baseDir, opt); err != nil { - return "", err + return nil, err } - osArgs := []string{"-q", "--cache-dir", cacheDir, "image", "--format", "json", "--skip-update", - "--output", outputFile, imageRef.Name()} - - // Run Trivy - if err := execute(osArgs); err != nil { - return "", err + osArgs := []string{ + "-q", + "--cache-dir", + cacheDir, + "image", + "--format", + "json", + "--skip-update", + imageRef.Name(), } - return outputFile, nil + + return osArgs, nil } func setupEnv(t *testing.T, imageRef name.Reference, baseDir string, opt registryOption) error { diff --git a/integration/repo_test.go b/integration/repo_test.go index 68321a57effe..ba11aa9ccb0f 100644 --- a/integration/repo_test.go +++ b/integration/repo_test.go @@ -5,15 +5,12 @@ package integration import ( "fmt" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "os" - "path/filepath" "strings" "testing" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/types" - "github.com/aquasecurity/trivy/pkg/uuid" ) // TestRepository tests `trivy repo` with the local code repositories @@ -40,7 +37,7 @@ func TestRepository(t *testing.T) { name string args args golden string - override func(*types.Report) + override func(want, got *types.Report) }{ { name: "gomod", @@ -372,8 +369,8 @@ func TestRepository(t *testing.T) { skipFiles: []string{"testdata/fixtures/repo/gomod/submod2/go.mod"}, }, golden: "testdata/gomod-skip.json.golden", - override: func(report *types.Report) { - report.ArtifactType = ftypes.ArtifactFilesystem + override: func(want, _ *types.Report) { + want.ArtifactType = ftypes.ArtifactFilesystem }, }, { @@ -386,8 +383,8 @@ func TestRepository(t *testing.T) { input: "testdata/fixtures/repo/custom-policy", }, golden: "testdata/dockerfile-custom-policies.json.golden", - override: func(report *types.Report) { - report.ArtifactType = ftypes.ArtifactFilesystem + override: func(want, got *types.Report) { + want.ArtifactType = ftypes.ArtifactFilesystem }, }, } @@ -400,7 +397,6 @@ func TestRepository(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - command := "repo" if tt.args.command != "" { command = tt.args.command @@ -423,6 +419,7 @@ func TestRepository(t *testing.T) { "--parallel", fmt.Sprint(tt.args.parallel), "--offline-scan", + tt.args.input, } if tt.args.scanner != "" { @@ -478,12 +475,6 @@ func TestRepository(t *testing.T) { } } - // Setup the output file - outputFile := filepath.Join(t.TempDir(), "output.json") - if *update && tt.override == nil { - outputFile = tt.golden - } - if tt.args.listAllPkgs { osArgs = append(osArgs, "--list-all-pkgs") } @@ -496,26 +487,10 @@ func TestRepository(t *testing.T) { osArgs = append(osArgs, "--secret-config", tt.args.secretConfig) } - osArgs = append(osArgs, "--output", outputFile) - osArgs = append(osArgs, tt.args.input) - - uuid.SetFakeUUID(t, "3ff14136-e09f-4df9-80ea-%012d") - - // Run "trivy repo" - err := execute(osArgs) - require.NoError(t, err) - - // Compare want and got - switch format { - case types.FormatCycloneDX: - compareCycloneDX(t, tt.golden, outputFile) - case types.FormatSPDXJSON: - compareSPDXJson(t, tt.golden, outputFile) - case types.FormatJSON: - compareReports(t, tt.golden, outputFile, tt.override) - default: - require.Fail(t, "invalid format", "format: %s", format) - } + runTest(t, osArgs, tt.golden, "", format, runOptions{ + fakeUUID: "3ff14136-e09f-4df9-80ea-%012d", + override: tt.override, + }) }) } } diff --git a/integration/standalone_tar_test.go b/integration/standalone_tar_test.go index 570acef48bd8..67cd869ebf1a 100644 --- a/integration/standalone_tar_test.go +++ b/integration/standalone_tar_test.go @@ -3,6 +3,7 @@ package integration import ( + "github.com/aquasecurity/trivy/pkg/types" "os" "path/filepath" "strings" @@ -17,28 +18,28 @@ func TestTar(t *testing.T) { IgnoreUnfixed bool Severity []string IgnoreIDs []string - Format string + Format types.Format Input string SkipDirs []string SkipFiles []string } tests := []struct { - name string - testArgs args - golden string + name string + args args + golden string }{ { name: "alpine 3.9", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/alpine-39.tar.gz", }, golden: "testdata/alpine-39.json.golden", }, { name: "alpine 3.9 with skip dirs", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/alpine-39.tar.gz", SkipDirs: []string{ "/etc", @@ -48,8 +49,8 @@ func TestTar(t *testing.T) { }, { name: "alpine 3.9 with skip files", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/alpine-39.tar.gz", SkipFiles: []string{ "/etc", @@ -132,224 +133,224 @@ func TestTar(t *testing.T) { }, { name: "alpine 3.9 with high and critical severity", - testArgs: args{ + args: args{ IgnoreUnfixed: true, Severity: []string{ "HIGH", "CRITICAL", }, - Format: "json", + Format: types.FormatJSON, Input: "testdata/fixtures/images/alpine-39.tar.gz", }, golden: "testdata/alpine-39-high-critical.json.golden", }, { name: "alpine 3.9 with .trivyignore", - testArgs: args{ + args: args{ IgnoreUnfixed: false, IgnoreIDs: []string{ "CVE-2019-1549", "CVE-2019-14697", }, - Format: "json", + Format: types.FormatJSON, Input: "testdata/fixtures/images/alpine-39.tar.gz", }, golden: "testdata/alpine-39-ignore-cveids.json.golden", }, { name: "alpine 3.10", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/alpine-310.tar.gz", }, golden: "testdata/alpine-310.json.golden", }, { name: "alpine distroless", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/alpine-distroless.tar.gz", }, golden: "testdata/alpine-distroless.json.golden", }, { name: "amazon linux 1", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/amazon-1.tar.gz", }, golden: "testdata/amazon-1.json.golden", }, { name: "amazon linux 2", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/amazon-2.tar.gz", }, golden: "testdata/amazon-2.json.golden", }, { name: "debian buster/10", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/debian-buster.tar.gz", }, golden: "testdata/debian-buster.json.golden", }, { name: "debian buster/10 with --ignore-unfixed option", - testArgs: args{ + args: args{ IgnoreUnfixed: true, - Format: "json", + Format: types.FormatJSON, Input: "testdata/fixtures/images/debian-buster.tar.gz", }, golden: "testdata/debian-buster-ignore-unfixed.json.golden", }, { name: "debian stretch/9", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/debian-stretch.tar.gz", }, golden: "testdata/debian-stretch.json.golden", }, { name: "ubuntu 18.04", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/ubuntu-1804.tar.gz", }, golden: "testdata/ubuntu-1804.json.golden", }, { name: "ubuntu 18.04 with --ignore-unfixed option", - testArgs: args{ + args: args{ IgnoreUnfixed: true, - Format: "json", + Format: types.FormatJSON, Input: "testdata/fixtures/images/ubuntu-1804.tar.gz", }, golden: "testdata/ubuntu-1804-ignore-unfixed.json.golden", }, { name: "centos 7", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/centos-7.tar.gz", }, golden: "testdata/centos-7.json.golden", }, { name: "centos 7with --ignore-unfixed option", - testArgs: args{ + args: args{ IgnoreUnfixed: true, - Format: "json", + Format: types.FormatJSON, Input: "testdata/fixtures/images/centos-7.tar.gz", }, golden: "testdata/centos-7-ignore-unfixed.json.golden", }, { name: "centos 7 with medium severity", - testArgs: args{ + args: args{ IgnoreUnfixed: true, Severity: []string{"MEDIUM"}, - Format: "json", + Format: types.FormatJSON, Input: "testdata/fixtures/images/centos-7.tar.gz", }, golden: "testdata/centos-7-medium.json.golden", }, { name: "centos 6", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/centos-6.tar.gz", }, golden: "testdata/centos-6.json.golden", }, { name: "ubi 7", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/ubi-7.tar.gz", }, golden: "testdata/ubi-7.json.golden", }, { name: "almalinux 8", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/almalinux-8.tar.gz", }, golden: "testdata/almalinux-8.json.golden", }, { name: "rocky linux 8", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/rockylinux-8.tar.gz", }, golden: "testdata/rockylinux-8.json.golden", }, { name: "distroless base", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/distroless-base.tar.gz", }, golden: "testdata/distroless-base.json.golden", }, { name: "distroless python27", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/distroless-python27.tar.gz", }, golden: "testdata/distroless-python27.json.golden", }, { name: "oracle linux 8", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/oraclelinux-8.tar.gz", }, golden: "testdata/oraclelinux-8.json.golden", }, { name: "opensuse leap 15.1", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/opensuse-leap-151.tar.gz", }, golden: "testdata/opensuse-leap-151.json.golden", }, { name: "photon 3.0", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/photon-30.tar.gz", }, golden: "testdata/photon-30.json.golden", }, { name: "CBL-Mariner 1.0", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/mariner-1.0.tar.gz", }, golden: "testdata/mariner-1.0.json.golden", }, { name: "busybox with Cargo.lock integration", - testArgs: args{ - Format: "json", + args: args{ + Format: types.FormatJSON, Input: "testdata/fixtures/images/busybox-with-lockfile.tar.gz", }, golden: "testdata/busybox-with-lockfile.json.golden", }, { name: "fluentd with RubyGems", - testArgs: args{ + args: args{ IgnoreUnfixed: true, - Format: "json", + Format: types.FormatJSON, Input: "testdata/fixtures/images/fluentd-multiple-lockfiles.tar.gz", }, golden: "testdata/fluentd-gems.json.golden", @@ -370,55 +371,40 @@ func TestTar(t *testing.T) { "image", "-q", "--format", - tt.testArgs.Format, + string(tt.args.Format), "--skip-update", } - if tt.testArgs.IgnoreUnfixed { + if tt.args.IgnoreUnfixed { osArgs = append(osArgs, "--ignore-unfixed") } - if len(tt.testArgs.Severity) != 0 { - osArgs = append(osArgs, "--severity", strings.Join(tt.testArgs.Severity, ",")) + if len(tt.args.Severity) != 0 { + osArgs = append(osArgs, "--severity", strings.Join(tt.args.Severity, ",")) } - if len(tt.testArgs.IgnoreIDs) != 0 { + if len(tt.args.IgnoreIDs) != 0 { trivyIgnore := ".trivyignore" - err := os.WriteFile(trivyIgnore, []byte(strings.Join(tt.testArgs.IgnoreIDs, "\n")), 0444) + err := os.WriteFile(trivyIgnore, []byte(strings.Join(tt.args.IgnoreIDs, "\n")), 0444) assert.NoError(t, err, "failed to write .trivyignore") defer os.Remove(trivyIgnore) } - if tt.testArgs.Input != "" { - osArgs = append(osArgs, "--input", tt.testArgs.Input) + if tt.args.Input != "" { + osArgs = append(osArgs, "--input", tt.args.Input) } - if len(tt.testArgs.SkipFiles) != 0 { - for _, skipFile := range tt.testArgs.SkipFiles { + if len(tt.args.SkipFiles) != 0 { + for _, skipFile := range tt.args.SkipFiles { osArgs = append(osArgs, "--skip-files", skipFile) } } - if len(tt.testArgs.SkipDirs) != 0 { - for _, skipDir := range tt.testArgs.SkipDirs { + if len(tt.args.SkipDirs) != 0 { + for _, skipDir := range tt.args.SkipDirs { osArgs = append(osArgs, "--skip-dirs", skipDir) } } - // Set up the output file - outputFile := filepath.Join(t.TempDir(), "output.json") - if *update { - outputFile = tt.golden - } - - osArgs = append(osArgs, []string{ - "--output", - outputFile, - }...) - // Run Trivy - err := execute(osArgs) - require.NoError(t, err) - - // Compare want and got - compareReports(t, tt.golden, outputFile, nil) + runTest(t, osArgs, tt.golden, "", tt.args.Format, runOptions{}) }) } } @@ -479,8 +465,6 @@ func TestTarWithEnv(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - osArgs := []string{"image"} - t.Setenv("TRIVY_FORMAT", tt.testArgs.Format) t.Setenv("TRIVY_CACHE_DIR", cacheDir) t.Setenv("TRIVY_QUIET", "true") @@ -493,27 +477,15 @@ func TestTarWithEnv(t *testing.T) { t.Setenv("TRIVY_SEVERITY", strings.Join(tt.testArgs.Severity, ",")) } if tt.testArgs.Input != "" { - osArgs = append(osArgs, "--input", tt.testArgs.Input) + t.Setenv("TRIVY_INPUT", tt.testArgs.Input) } if len(tt.testArgs.SkipDirs) != 0 { t.Setenv("TRIVY_SKIP_DIRS", strings.Join(tt.testArgs.SkipDirs, ",")) } - // Set up the output file - outputFile := filepath.Join(t.TempDir(), "output.json") - - osArgs = append(osArgs, []string{ - "--output", - outputFile, - }...) - // Run Trivy - err := execute(osArgs) - require.NoError(t, err) - - // Compare want and got - compareReports(t, tt.golden, outputFile, nil) + runTest(t, []string{"image"}, tt.golden, "", types.FormatJSON, runOptions{}) }) } } @@ -531,13 +503,13 @@ func TestTarWithConfigFile(t *testing.T) { configFile: `quiet: true format: json severity: - - HIGH - - CRITICAL + - HIGH + - CRITICAL vulnerability: - type: - - os + type: + - os cache: - dir: /should/be/overwritten + dir: /should/be/overwritten `, golden: "testdata/alpine-39-high-critical.json.golden", }, @@ -547,9 +519,9 @@ cache: configFile: `quiet: true format: json vulnerability: - ignore-unfixed: true + ignore-unfixed: true cache: - dir: /should/be/overwritten + dir: /should/be/overwritten `, golden: "testdata/debian-buster-ignore-unfixed.json.golden", }, @@ -563,10 +535,7 @@ cache: for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tmpDir := t.TempDir() - outputFile := filepath.Join(tmpDir, "output.json") - configPath := filepath.Join(tmpDir, "trivy.yaml") - + configPath := filepath.Join(t.TempDir(), "trivy.yaml") err := os.WriteFile(configPath, []byte(tt.configFile), 0600) require.NoError(t, err) @@ -579,16 +548,10 @@ cache: configPath, "--input", tt.input, - "--output", - outputFile, } // Run Trivy - err = execute(osArgs) - require.NoError(t, err) - - // Compare want and got - compareReports(t, tt.golden, outputFile, nil) + runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{}) }) } } diff --git a/integration/testdata/alpine-310-registry.json.golden b/integration/testdata/alpine-310-registry.json.golden index cf451bbd7bfb..51100d633f53 100644 --- a/integration/testdata/alpine-310-registry.json.golden +++ b/integration/testdata/alpine-310-registry.json.golden @@ -1,7 +1,7 @@ { "SchemaVersion": 2, "CreatedAt": "2021-08-25T12:20:30.000000005Z", - "ArtifactName": "localhost:53869/alpine:3.10", + "ArtifactName": "alpine:3.10", "ArtifactType": "container_image", "Metadata": { "OS": { @@ -14,10 +14,10 @@ "sha256:03901b4a2ea88eeaad62dbe59b072b28b6efa00491962b8741081c5df50c65e0" ], "RepoTags": [ - "localhost:53869/alpine:3.10" + "alpine:3.10" ], "RepoDigests": [ - "localhost:53869/alpine@sha256:b1c5a500182b21d0bfa5a584a8526b56d8be316f89e87d951be04abed2446e60" + "alpine@sha256:b1c5a500182b21d0bfa5a584a8526b56d8be316f89e87d951be04abed2446e60" ], "ImageConfig": { "architecture": "amd64", @@ -56,7 +56,7 @@ }, "Results": [ { - "Target": "localhost:53869/alpine:3.10 (alpine 3.10.2)", + "Target": "alpine:3.10 (alpine 3.10.2)", "Class": "os-pkgs", "Type": "alpine", "Vulnerabilities": [ diff --git a/integration/vm_test.go b/integration/vm_test.go index fab87cc070a2..7ccc85f03994 100644 --- a/integration/vm_test.go +++ b/integration/vm_test.go @@ -3,12 +3,10 @@ package integration import ( - "os" "path/filepath" + "strings" "testing" - "github.com/stretchr/testify/require" - "github.com/aquasecurity/trivy/internal/testutil" "github.com/aquasecurity/trivy/pkg/types" ) @@ -66,10 +64,6 @@ func TestVM(t *testing.T) { // Set up testing DB cacheDir := initDB(t) - // Keep the current working directory - currentDir, err := os.Getwd() - require.NoError(t, err) - const imageFile = "disk.img" for _, tt := range tests { @@ -86,34 +80,22 @@ func TestVM(t *testing.T) { tt.args.format, } - tmpDir := t.TempDir() - - // Set up the output file - outputFile := filepath.Join(tmpDir, "output.json") - if *update { - outputFile = filepath.Join(currentDir, tt.golden) - } - - // Get the absolute path of the golden file - goldenFile, err := filepath.Abs(tt.golden) - require.NoError(t, err) - // Decompress the gzipped image file - imagePath := filepath.Join(tmpDir, imageFile) + imagePath := filepath.Join(t.TempDir(), imageFile) testutil.DecompressSparseGzip(t, tt.args.input, imagePath) - // Change the current working directory so that targets in the result could be the same as golden files. - err = os.Chdir(tmpDir) - require.NoError(t, err) - defer os.Chdir(currentDir) - - osArgs = append(osArgs, "--output", outputFile) - osArgs = append(osArgs, imageFile) + osArgs = append(osArgs, imagePath) // Run "trivy vm" - err = execute(osArgs) - require.NoError(t, err) - compareReports(t, goldenFile, outputFile, nil) + runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{ + override: func(_, got *types.Report) { + got.ArtifactName = "disk.img" + for i := range got.Results { + lastIndex := strings.LastIndex(got.Results[i].Target, "/") + got.Results[i].Target = got.Results[i].Target[lastIndex+1:] + } + }, + }) }) } } diff --git a/pkg/commands/app.go b/pkg/commands/app.go index dac78354075b..50ee64c617f4 100644 --- a/pkg/commands/app.go +++ b/pkg/commands/app.go @@ -190,7 +190,10 @@ func NewRootCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { return err } - globalOptions := globalFlags.ToOptions() + globalOptions, err := globalFlags.ToOptions() + if err != nil { + return err + } // Initialize logger if err := log.InitLogger(globalOptions.Debug, globalOptions.Quiet); err != nil { @@ -200,7 +203,11 @@ func NewRootCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { - globalOptions := globalFlags.ToOptions() + globalOptions, err := globalFlags.ToOptions() + if err != nil { + return err + } + if globalOptions.ShowVersion { // Customize version output return showVersion(globalOptions.CacheDir, versionFormat, cmd.OutOrStdout()) @@ -223,20 +230,21 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { scanFlagGroup.IncludeDevDeps = nil // disable '--include-dev-deps' reportFlagGroup := flag.NewReportFlagGroup() - report := flag.ReportFormatFlag + report := flag.ReportFormatFlag.Clone() report.Default = "summary" // override the default value as the summary is preferred for the compliance report report.Usage = "specify a format for the compliance report." // "--report" works only with "--compliance" - reportFlagGroup.ReportFormat = &report + reportFlagGroup.ReportFormat = report - compliance := flag.ComplianceFlag + compliance := flag.ComplianceFlag.Clone() compliance.Values = []string{types.ComplianceDockerCIS} - reportFlagGroup.Compliance = &compliance // override usage as the accepted values differ for each subcommand. + reportFlagGroup.Compliance = compliance // override usage as the accepted values differ for each subcommand. misconfFlagGroup := flag.NewMisconfFlagGroup() misconfFlagGroup.CloudformationParamVars = nil // disable '--cf-params' misconfFlagGroup.TerraformTFVars = nil // disable '--tf-vars' imageFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific @@ -292,7 +300,7 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { return validateArgs(cmd, args) }, RunE: func(cmd *cobra.Command, args []string) error { - options, err := imageFlags.ToOptions(args, globalFlags) + options, err := imageFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -311,12 +319,13 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { reportFlagGroup := flag.NewReportFlagGroup() - reportFormat := flag.ReportFormatFlag + reportFormat := flag.ReportFormatFlag.Clone() reportFormat.Usage = "specify a compliance report format for the output" // @TODO: support --report summary for non compliance reports - reportFlagGroup.ReportFormat = &reportFormat + reportFlagGroup.ReportFormat = reportFormat reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol' fsFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(), @@ -351,7 +360,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := fsFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := fsFlags.ToOptions(args, globalFlags) + options, err := fsFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -370,6 +379,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { rootfsFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(), @@ -410,7 +420,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := rootfsFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := rootfsFlags.ToOptions(args, globalFlags) + options, err := rootfsFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -428,6 +438,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { repoFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), LicenseFlagGroup: flag.NewLicenseFlagGroup(), @@ -465,7 +476,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := repoFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := repoFlags.ToOptions(args, globalFlags) + options, err := repoFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -483,6 +494,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { func NewConvertCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { convertFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, ScanFlagGroup: &flag.ScanFlagGroup{}, ReportFlagGroup: flag.NewReportFlagGroup(), } @@ -505,7 +517,7 @@ func NewConvertCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := convertFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - opts, err := convertFlags.ToOptions(args, globalFlags) + opts, err := convertFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -525,7 +537,7 @@ func NewConvertCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { // NewClientCommand returns the 'client' subcommand that is deprecated func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { remoteFlags := flag.NewClientFlags() - remoteAddr := flag.Flag{ + remoteAddr := flag.Flag[string]{ Name: "remote", ConfigName: "server.addr", Shorthand: "", @@ -535,6 +547,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { remoteFlags.ServerAddr = &remoteAddr // disable '--server' and enable '--remote' instead. clientFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), @@ -562,7 +575,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := clientFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := clientFlags.ToOptions(args, globalFlags) + options, err := clientFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -580,6 +593,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { serverFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), ModuleFlagGroup: flag.NewModuleFlagGroup(), @@ -608,7 +622,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := serverFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := serverFlags.ToOptions(args, globalFlags) + options, err := serverFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -629,18 +643,19 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { reportFlagGroup.DependencyTree = nil // disable '--dependency-tree' reportFlagGroup.ListAllPkgs = nil // disable '--list-all-pkgs' reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol' - reportFormat := flag.ReportFormatFlag + reportFormat := flag.ReportFormatFlag.Clone() reportFormat.Usage = "specify a compliance report format for the output" // @TODO: support --report summary for non compliance reports - reportFlagGroup.ReportFormat = &reportFormat + reportFlagGroup.ReportFormat = reportFormat scanFlags := &flag.ScanFlagGroup{ // Enable only '--skip-dirs' and '--skip-files' and disable other flags - SkipDirs: &flag.SkipDirsFlag, - SkipFiles: &flag.SkipFilesFlag, - FilePatterns: &flag.FilePatternsFlag, + SkipDirs: flag.SkipDirsFlag.Clone(), + SkipFiles: flag.SkipFilesFlag.Clone(), + FilePatterns: flag.FilePatternsFlag.Clone(), } configFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), ModuleFlagGroup: flag.NewModuleFlagGroup(), @@ -648,7 +663,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { RegoFlagGroup: flag.NewRegoFlagGroup(), K8sFlagGroup: &flag.K8sFlagGroup{ // disable unneeded flags - K8sVersion: &flag.K8sVersionFlag, + K8sVersion: flag.K8sVersionFlag.Clone(), }, ReportFlagGroup: reportFlagGroup, ScanFlagGroup: scanFlags, @@ -669,7 +684,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := configFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := configFlags.ToOptions(args, globalFlags) + options, err := configFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -796,6 +811,7 @@ func NewPluginCommand() *cobra.Command { func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { moduleFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, ModuleFlagGroup: flag.NewModuleFlagGroup(), } @@ -827,7 +843,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { } repo := args[0] - opts, err := moduleFlags.ToOptions(args, globalFlags) + opts, err := moduleFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -851,7 +867,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { } repo := args[0] - opts, err := moduleFlags.ToOptions(args, globalFlags) + opts, err := moduleFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -866,7 +882,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { scanFlags := flag.NewScanFlagGroup() - scanners := flag.ScannersFlag + scanners := flag.ScannersFlag.Clone() // overwrite the default scanners scanners.Values = xstrings.ToStringSlice(types.Scanners{ types.VulnerabilityScanner, @@ -875,36 +891,37 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { types.RBACScanner, }) scanners.Default = scanners.Values - scanFlags.Scanners = &scanners + scanFlags.Scanners = scanners scanFlags.IncludeDevDeps = nil // disable '--include-dev-deps' // required only SourceFlag - imageFlags := &flag.ImageFlagGroup{ImageSources: &flag.SourceFlag} + imageFlags := &flag.ImageFlagGroup{ImageSources: flag.SourceFlag.Clone()} reportFlagGroup := flag.NewReportFlagGroup() - compliance := flag.ComplianceFlag + compliance := flag.ComplianceFlag.Clone() compliance.Values = []string{ types.ComplianceK8sNsa, types.ComplianceK8sCIS, types.ComplianceK8sPSSBaseline, types.ComplianceK8sPSSRestricted, } - reportFlagGroup.Compliance = &compliance // override usage as the accepted values differ for each subcommand. - reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol' + reportFlagGroup.Compliance = compliance // override usage as the accepted values differ for each subcommand. + reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol' - formatFlag := flag.FormatFlag + formatFlag := flag.FormatFlag.Clone() formatFlag.Values = xstrings.ToStringSlice([]types.Format{ types.FormatTable, types.FormatJSON, types.FormatCycloneDX, }) - reportFlagGroup.Format = &formatFlag + reportFlagGroup.Format = formatFlag misconfFlagGroup := flag.NewMisconfFlagGroup() misconfFlagGroup.CloudformationParamVars = nil // disable '--cf-params' misconfFlagGroup.TerraformTFVars = nil // disable '--tf-vars' k8sFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), ImageFlagGroup: imageFlags, @@ -945,7 +962,7 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := k8sFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - opts, err := k8sFlags.ToOptions(args, globalFlags) + opts, err := k8sFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -973,6 +990,7 @@ func NewAWSCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol' awsFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, AWSFlagGroup: flag.NewAWSFlagGroup(), CloudFlagGroup: flag.NewCloudFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), @@ -1014,7 +1032,7 @@ The following services are supported: return nil }, RunE: func(cmd *cobra.Command, args []string) error { - opts, err := awsFlags.ToOptions(args, globalFlags) + opts, err := awsFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -1036,6 +1054,7 @@ The following services are supported: func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { vmFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), MisconfFlagGroup: flag.NewMisconfFlagGroup(), @@ -1046,10 +1065,9 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { SecretFlagGroup: flag.NewSecretFlagGroup(), VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(), AWSFlagGroup: &flag.AWSFlagGroup{ - Region: &flag.Flag{ + Region: &flag.Flag[string]{ Name: "aws-region", ConfigName: "aws.region", - Default: "", Usage: "AWS region to scan", }, }, @@ -1080,7 +1098,7 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := vmFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := vmFlags.ToOptions(args, globalFlags) + options, err := vmFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -1111,6 +1129,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { scanFlagGroup.Parallel = nil // disable '--parallel' sbomFlags := &flag.Flags{ + GlobalFlagGroup: globalFlags, CacheFlagGroup: flag.NewCacheFlagGroup(), DBFlagGroup: flag.NewDBFlagGroup(), RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode @@ -1140,7 +1159,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := sbomFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := sbomFlags.ToOptions(args, globalFlags) + options, err := sbomFlags.ToOptions(args) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -1168,7 +1187,10 @@ func NewVersionCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { GroupID: groupUtility, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - options := globalFlags.ToOptions() + options, err := globalFlags.ToOptions() + if err != nil { + return err + } return showVersion(options.CacheDir, versionFormat, cmd.OutOrStdout()) }, SilenceErrors: true, diff --git a/pkg/commands/app_test.go b/pkg/commands/app_test.go index a8e94882ff9e..0a4651d9e63d 100644 --- a/pkg/commands/app_test.go +++ b/pkg/commands/app_test.go @@ -250,7 +250,7 @@ func TestFlags(t *testing.T) { "--format", "foo", }, - wantErr: `invalid argument "foo" for "-f, --format" flag`, + wantErr: `invalid argument "foo" for "--format" flag`, }, } @@ -262,16 +262,21 @@ func TestFlags(t *testing.T) { rootCmd.SetOut(io.Discard) flags := &flag.Flags{ + GlobalFlagGroup: globalFlags, ReportFlagGroup: flag.NewReportFlagGroup(), } cmd := &cobra.Command{ Use: "test", RunE: func(cmd *cobra.Command, args []string) error { // Bind - require.NoError(t, flags.Bind(cmd)) + if err := flags.Bind(cmd); err != nil { + return err + } - options, err := flags.ToOptions(args, globalFlags) - require.NoError(t, err) + options, err := flags.ToOptions(args) + if err != nil { + return err + } assert.Equal(t, tt.want.format, options.Format) assert.Equal(t, tt.want.severities, options.Severities) diff --git a/pkg/flag/aws_flags.go b/pkg/flag/aws_flags.go index 95f2d756ea01..6801d8b1ae5a 100644 --- a/pkg/flag/aws_flags.go +++ b/pkg/flag/aws_flags.go @@ -1,51 +1,45 @@ package flag var ( - awsRegionFlag = Flag{ + awsRegionFlag = Flag[string]{ Name: "region", ConfigName: "cloud.aws.region", - Default: "", Usage: "AWS Region to scan", } - awsEndpointFlag = Flag{ + awsEndpointFlag = Flag[string]{ Name: "endpoint", ConfigName: "cloud.aws.endpoint", - Default: "", Usage: "AWS Endpoint override", } - awsServiceFlag = Flag{ + awsServiceFlag = Flag[[]string]{ Name: "service", ConfigName: "cloud.aws.service", - Default: []string{}, Usage: "Only scan AWS Service(s) specified with this flag. Can specify multiple services using --service A --service B etc.", } - awsSkipServicesFlag = Flag{ + awsSkipServicesFlag = Flag[[]string]{ Name: "skip-service", ConfigName: "cloud.aws.skip-service", - Default: []string{}, Usage: "Skip selected AWS Service(s) specified with this flag. Can specify multiple services using --skip-service A --skip-service B etc.", } - awsAccountFlag = Flag{ + awsAccountFlag = Flag[string]{ Name: "account", ConfigName: "cloud.aws.account", - Default: "", Usage: "The AWS account to scan. It's useful to specify this when reviewing cached results for multiple accounts.", } - awsARNFlag = Flag{ + awsARNFlag = Flag[string]{ Name: "arn", ConfigName: "cloud.aws.arn", - Default: "", Usage: "The AWS ARN to show results for. Useful to filter results once a scan is cached.", } ) type AWSFlagGroup struct { - Region *Flag - Endpoint *Flag - Services *Flag - SkipServices *Flag - Account *Flag - ARN *Flag + Region *Flag[string] + Endpoint *Flag[string] + Services *Flag[[]string] + SkipServices *Flag[[]string] + Account *Flag[string] + ARN *Flag[string] } type AWSOptions struct { @@ -59,12 +53,12 @@ type AWSOptions struct { func NewAWSFlagGroup() *AWSFlagGroup { return &AWSFlagGroup{ - Region: &awsRegionFlag, - Endpoint: &awsEndpointFlag, - Services: &awsServiceFlag, - SkipServices: &awsSkipServicesFlag, - Account: &awsAccountFlag, - ARN: &awsARNFlag, + Region: awsRegionFlag.Clone(), + Endpoint: awsEndpointFlag.Clone(), + Services: awsServiceFlag.Clone(), + SkipServices: awsSkipServicesFlag.Clone(), + Account: awsAccountFlag.Clone(), + ARN: awsARNFlag.Clone(), } } @@ -72,17 +66,27 @@ func (f *AWSFlagGroup) Name() string { return "AWS" } -func (f *AWSFlagGroup) Flags() []*Flag { - return []*Flag{f.Region, f.Endpoint, f.Services, f.SkipServices, f.Account, f.ARN} +func (f *AWSFlagGroup) Flags() []Flagger { + return []Flagger{ + f.Region, + f.Endpoint, + f.Services, + f.SkipServices, + f.Account, + f.ARN, + } } -func (f *AWSFlagGroup) ToOptions() AWSOptions { - return AWSOptions{ - Region: getString(f.Region), - Endpoint: getString(f.Endpoint), - Services: getStringSlice(f.Services), - SkipServices: getStringSlice(f.SkipServices), - Account: getString(f.Account), - ARN: getString(f.ARN), +func (f *AWSFlagGroup) ToOptions() (AWSOptions, error) { + if err := parseFlags(f); err != nil { + return AWSOptions{}, err } + return AWSOptions{ + Region: f.Region.Value(), + Endpoint: f.Endpoint.Value(), + Services: f.Services.Value(), + SkipServices: f.SkipServices.Value(), + Account: f.Account.Value(), + ARN: f.ARN.Value(), + }, nil } diff --git a/pkg/flag/cache_flags.go b/pkg/flag/cache_flags.go index a92cd97dcc68..c259d5b9e963 100644 --- a/pkg/flag/cache_flags.go +++ b/pkg/flag/cache_flags.go @@ -19,60 +19,54 @@ import ( // cert: cert.pem // key: key.pem var ( - ClearCacheFlag = Flag{ + ClearCacheFlag = Flag[bool]{ Name: "clear-cache", ConfigName: "cache.clear", - Default: false, Usage: "clear image caches without scanning", } - CacheBackendFlag = Flag{ + CacheBackendFlag = Flag[string]{ Name: "cache-backend", ConfigName: "cache.backend", Default: "fs", Usage: "cache backend (e.g. redis://localhost:6379)", } - CacheTTLFlag = Flag{ + CacheTTLFlag = Flag[time.Duration]{ Name: "cache-ttl", ConfigName: "cache.ttl", - Default: time.Duration(0), Usage: "cache TTL when using redis as cache backend", } - RedisTLSFlag = Flag{ + RedisTLSFlag = Flag[bool]{ Name: "redis-tls", ConfigName: "cache.redis.tls", - Default: false, Usage: "enable redis TLS with public certificates, if using redis as cache backend", } - RedisCACertFlag = Flag{ + RedisCACertFlag = Flag[string]{ Name: "redis-ca", ConfigName: "cache.redis.ca", - Default: "", Usage: "redis ca file location, if using redis as cache backend", } - RedisCertFlag = Flag{ + RedisCertFlag = Flag[string]{ Name: "redis-cert", ConfigName: "cache.redis.cert", - Default: "", Usage: "redis certificate file location, if using redis as cache backend", } - RedisKeyFlag = Flag{ + RedisKeyFlag = Flag[string]{ Name: "redis-key", ConfigName: "cache.redis.key", - Default: "", Usage: "redis key file location, if using redis as cache backend", } ) // CacheFlagGroup composes common printer flag structs used for commands requiring cache logic. type CacheFlagGroup struct { - ClearCache *Flag - CacheBackend *Flag - CacheTTL *Flag - - RedisTLS *Flag - RedisCACert *Flag - RedisCert *Flag - RedisKey *Flag + ClearCache *Flag[bool] + CacheBackend *Flag[string] + CacheTTL *Flag[time.Duration] + + RedisTLS *Flag[bool] + RedisCACert *Flag[string] + RedisCert *Flag[string] + RedisKey *Flag[string] } type CacheOptions struct { @@ -93,13 +87,13 @@ type RedisOptions struct { // NewCacheFlagGroup returns a default CacheFlagGroup func NewCacheFlagGroup() *CacheFlagGroup { return &CacheFlagGroup{ - ClearCache: &ClearCacheFlag, - CacheBackend: &CacheBackendFlag, - CacheTTL: &CacheTTLFlag, - RedisTLS: &RedisTLSFlag, - RedisCACert: &RedisCACertFlag, - RedisCert: &RedisCertFlag, - RedisKey: &RedisKeyFlag, + ClearCache: ClearCacheFlag.Clone(), + CacheBackend: CacheBackendFlag.Clone(), + CacheTTL: CacheTTLFlag.Clone(), + RedisTLS: RedisTLSFlag.Clone(), + RedisCACert: RedisCACertFlag.Clone(), + RedisCert: RedisCertFlag.Clone(), + RedisKey: RedisKeyFlag.Clone(), } } @@ -107,16 +101,28 @@ func (fg *CacheFlagGroup) Name() string { return "Cache" } -func (fg *CacheFlagGroup) Flags() []*Flag { - return []*Flag{fg.ClearCache, fg.CacheBackend, fg.CacheTTL, fg.RedisTLS, fg.RedisCACert, fg.RedisCert, fg.RedisKey} +func (fg *CacheFlagGroup) Flags() []Flagger { + return []Flagger{ + fg.ClearCache, + fg.CacheBackend, + fg.CacheTTL, + fg.RedisTLS, + fg.RedisCACert, + fg.RedisCert, + fg.RedisKey, + } } func (fg *CacheFlagGroup) ToOptions() (CacheOptions, error) { - cacheBackend := getString(fg.CacheBackend) + if err := parseFlags(fg); err != nil { + return CacheOptions{}, err + } + + cacheBackend := fg.CacheBackend.Value() redisOptions := RedisOptions{ - RedisCACert: getString(fg.RedisCACert), - RedisCert: getString(fg.RedisCert), - RedisKey: getString(fg.RedisKey), + RedisCACert: fg.RedisCACert.Value(), + RedisCert: fg.RedisCert.Value(), + RedisKey: fg.RedisKey.Value(), } // "redis://" or "fs" are allowed for now @@ -133,10 +139,10 @@ func (fg *CacheFlagGroup) ToOptions() (CacheOptions, error) { } return CacheOptions{ - ClearCache: getBool(fg.ClearCache), + ClearCache: fg.ClearCache.Value(), CacheBackend: cacheBackend, - CacheTTL: getDuration(fg.CacheTTL), - RedisTLS: getBool(fg.RedisTLS), + CacheTTL: fg.CacheTTL.Value(), + RedisTLS: fg.RedisTLS.Value(), RedisOptions: redisOptions, }, nil } diff --git a/pkg/flag/cache_flags_test.go b/pkg/flag/cache_flags_test.go index c2d01992272a..db54c75b13e8 100644 --- a/pkg/flag/cache_flags_test.go +++ b/pkg/flag/cache_flags_test.go @@ -108,13 +108,13 @@ func TestCacheFlagGroup_ToOptions(t *testing.T) { viper.Set(flag.RedisKeyFlag.ConfigName, tt.fields.RedisKey) f := &flag.CacheFlagGroup{ - ClearCache: &flag.ClearCacheFlag, - CacheBackend: &flag.CacheBackendFlag, - CacheTTL: &flag.CacheTTLFlag, - RedisTLS: &flag.RedisTLSFlag, - RedisCACert: &flag.RedisCACertFlag, - RedisCert: &flag.RedisCertFlag, - RedisKey: &flag.RedisKeyFlag, + ClearCache: flag.ClearCacheFlag.Clone(), + CacheBackend: flag.CacheBackendFlag.Clone(), + CacheTTL: flag.CacheTTLFlag.Clone(), + RedisTLS: flag.RedisTLSFlag.Clone(), + RedisCACert: flag.RedisCACertFlag.Clone(), + RedisCert: flag.RedisCertFlag.Clone(), + RedisKey: flag.RedisKeyFlag.Clone(), } got, err := f.ToOptions() diff --git a/pkg/flag/cloud_flags.go b/pkg/flag/cloud_flags.go index eed81230f6da..dfff5a997f7a 100644 --- a/pkg/flag/cloud_flags.go +++ b/pkg/flag/cloud_flags.go @@ -3,13 +3,12 @@ package flag import "time" var ( - cloudUpdateCacheFlag = Flag{ + cloudUpdateCacheFlag = Flag[bool]{ Name: "update-cache", ConfigName: "cloud.update-cache", - Default: false, Usage: "Update the cache for the applicable cloud provider instead of using cached results.", } - cloudMaxCacheAgeFlag = Flag{ + cloudMaxCacheAgeFlag = Flag[time.Duration]{ Name: "max-cache-age", ConfigName: "cloud.max-cache-age", Default: time.Hour * 24, @@ -18,8 +17,8 @@ var ( ) type CloudFlagGroup struct { - UpdateCache *Flag - MaxCacheAge *Flag + UpdateCache *Flag[bool] + MaxCacheAge *Flag[time.Duration] } type CloudOptions struct { @@ -29,8 +28,8 @@ type CloudOptions struct { func NewCloudFlagGroup() *CloudFlagGroup { return &CloudFlagGroup{ - UpdateCache: &cloudUpdateCacheFlag, - MaxCacheAge: &cloudMaxCacheAgeFlag, + UpdateCache: cloudUpdateCacheFlag.Clone(), + MaxCacheAge: cloudMaxCacheAgeFlag.Clone(), } } @@ -38,13 +37,19 @@ func (f *CloudFlagGroup) Name() string { return "Cloud" } -func (f *CloudFlagGroup) Flags() []*Flag { - return []*Flag{f.UpdateCache, f.MaxCacheAge} +func (f *CloudFlagGroup) Flags() []Flagger { + return []Flagger{ + f.UpdateCache, + f.MaxCacheAge, + } } -func (f *CloudFlagGroup) ToOptions() CloudOptions { - return CloudOptions{ - UpdateCache: getBool(f.UpdateCache), - MaxCacheAge: getDuration(f.MaxCacheAge), +func (f *CloudFlagGroup) ToOptions() (CloudOptions, error) { + if err := parseFlags(f); err != nil { + return CloudOptions{}, err } + return CloudOptions{ + UpdateCache: f.UpdateCache.Value(), + MaxCacheAge: f.MaxCacheAge.Value(), + }, nil } diff --git a/pkg/flag/db_flags.go b/pkg/flag/db_flags.go index 2e5d58f3a6a1..23fb42a72bec 100644 --- a/pkg/flag/db_flags.go +++ b/pkg/flag/db_flags.go @@ -10,22 +10,19 @@ const defaultDBRepository = "ghcr.io/aquasecurity/trivy-db" const defaultJavaDBRepository = "ghcr.io/aquasecurity/trivy-java-db" var ( - ResetFlag = Flag{ + ResetFlag = Flag[bool]{ Name: "reset", ConfigName: "reset", - Default: false, Usage: "remove all caches and database", } - DownloadDBOnlyFlag = Flag{ + DownloadDBOnlyFlag = Flag[bool]{ Name: "download-db-only", ConfigName: "db.download-only", - Default: false, Usage: "download/update vulnerability database but don't run a scan", } - SkipDBUpdateFlag = Flag{ + SkipDBUpdateFlag = Flag[bool]{ Name: "skip-db-update", ConfigName: "db.skip-update", - Default: false, Usage: "skip updating vulnerability database", Aliases: []Alias{ { @@ -34,40 +31,36 @@ var ( }, }, } - DownloadJavaDBOnlyFlag = Flag{ + DownloadJavaDBOnlyFlag = Flag[bool]{ Name: "download-java-db-only", ConfigName: "db.download-java-only", - Default: false, Usage: "download/update Java index database but don't run a scan", } - SkipJavaDBUpdateFlag = Flag{ + SkipJavaDBUpdateFlag = Flag[bool]{ Name: "skip-java-db-update", ConfigName: "db.java-skip-update", - Default: false, Usage: "skip updating Java index database", } - NoProgressFlag = Flag{ + NoProgressFlag = Flag[bool]{ Name: "no-progress", ConfigName: "db.no-progress", - Default: false, Usage: "suppress progress bar", } - DBRepositoryFlag = Flag{ + DBRepositoryFlag = Flag[string]{ Name: "db-repository", ConfigName: "db.repository", Default: defaultDBRepository, Usage: "OCI repository to retrieve trivy-db from", } - JavaDBRepositoryFlag = Flag{ + JavaDBRepositoryFlag = Flag[string]{ Name: "java-db-repository", ConfigName: "db.java-repository", Default: defaultJavaDBRepository, Usage: "OCI repository to retrieve trivy-java-db from", } - LightFlag = Flag{ + LightFlag = Flag[bool]{ Name: "light", ConfigName: "db.light", - Default: false, Usage: "deprecated", Deprecated: true, } @@ -75,15 +68,15 @@ var ( // DBFlagGroup composes common printer flag structs used for commands requiring DB logic. type DBFlagGroup struct { - Reset *Flag - DownloadDBOnly *Flag - SkipDBUpdate *Flag - DownloadJavaDBOnly *Flag - SkipJavaDBUpdate *Flag - NoProgress *Flag - DBRepository *Flag - JavaDBRepository *Flag - Light *Flag // deprecated + Reset *Flag[bool] + DownloadDBOnly *Flag[bool] + SkipDBUpdate *Flag[bool] + DownloadJavaDBOnly *Flag[bool] + SkipJavaDBUpdate *Flag[bool] + NoProgress *Flag[bool] + DBRepository *Flag[string] + JavaDBRepository *Flag[string] + Light *Flag[bool] // deprecated } type DBOptions struct { @@ -101,15 +94,15 @@ type DBOptions struct { // NewDBFlagGroup returns a default DBFlagGroup func NewDBFlagGroup() *DBFlagGroup { return &DBFlagGroup{ - Reset: &ResetFlag, - DownloadDBOnly: &DownloadDBOnlyFlag, - SkipDBUpdate: &SkipDBUpdateFlag, - DownloadJavaDBOnly: &DownloadJavaDBOnlyFlag, - SkipJavaDBUpdate: &SkipJavaDBUpdateFlag, - Light: &LightFlag, - NoProgress: &NoProgressFlag, - DBRepository: &DBRepositoryFlag, - JavaDBRepository: &JavaDBRepositoryFlag, + Reset: ResetFlag.Clone(), + DownloadDBOnly: DownloadDBOnlyFlag.Clone(), + SkipDBUpdate: SkipDBUpdateFlag.Clone(), + DownloadJavaDBOnly: DownloadJavaDBOnlyFlag.Clone(), + SkipJavaDBUpdate: SkipJavaDBUpdateFlag.Clone(), + Light: LightFlag.Clone(), + NoProgress: NoProgressFlag.Clone(), + DBRepository: DBRepositoryFlag.Clone(), + JavaDBRepository: JavaDBRepositoryFlag.Clone(), } } @@ -117,8 +110,8 @@ func (f *DBFlagGroup) Name() string { return "DB" } -func (f *DBFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *DBFlagGroup) Flags() []Flagger { + return []Flagger{ f.Reset, f.DownloadDBOnly, f.SkipDBUpdate, @@ -132,11 +125,15 @@ func (f *DBFlagGroup) Flags() []*Flag { } func (f *DBFlagGroup) ToOptions() (DBOptions, error) { - skipDBUpdate := getBool(f.SkipDBUpdate) - skipJavaDBUpdate := getBool(f.SkipJavaDBUpdate) - downloadDBOnly := getBool(f.DownloadDBOnly) - downloadJavaDBOnly := getBool(f.DownloadJavaDBOnly) - light := getBool(f.Light) + if err := parseFlags(f); err != nil { + return DBOptions{}, err + } + + skipDBUpdate := f.SkipDBUpdate.Value() + skipJavaDBUpdate := f.SkipJavaDBUpdate.Value() + downloadDBOnly := f.DownloadDBOnly.Value() + downloadJavaDBOnly := f.DownloadJavaDBOnly.Value() + light := f.Light.Value() if downloadDBOnly && skipDBUpdate { return DBOptions{}, xerrors.New("--skip-db-update and --download-db-only options can not be specified both") @@ -149,14 +146,14 @@ func (f *DBFlagGroup) ToOptions() (DBOptions, error) { } return DBOptions{ - Reset: getBool(f.Reset), + Reset: f.Reset.Value(), DownloadDBOnly: downloadDBOnly, SkipDBUpdate: skipDBUpdate, DownloadJavaDBOnly: downloadJavaDBOnly, SkipJavaDBUpdate: skipJavaDBUpdate, Light: light, - NoProgress: getBool(f.NoProgress), - DBRepository: getString(f.DBRepository), - JavaDBRepository: getString(f.JavaDBRepository), + NoProgress: f.NoProgress.Value(), + DBRepository: f.DBRepository.Value(), + JavaDBRepository: f.JavaDBRepository.Value(), }, nil } diff --git a/pkg/flag/db_flags_test.go b/pkg/flag/db_flags_test.go index d1ce7c65cbc4..c590ed49f7a3 100644 --- a/pkg/flag/db_flags_test.go +++ b/pkg/flag/db_flags_test.go @@ -74,9 +74,9 @@ func TestDBFlagGroup_ToOptions(t *testing.T) { // Assert options f := &flag.DBFlagGroup{ - DownloadDBOnly: &flag.DownloadDBOnlyFlag, - SkipDBUpdate: &flag.SkipDBUpdateFlag, - Light: &flag.LightFlag, + DownloadDBOnly: flag.DownloadDBOnlyFlag.Clone(), + SkipDBUpdate: flag.SkipDBUpdateFlag.Clone(), + Light: flag.LightFlag.Clone(), } got, err := f.ToOptions() tt.assertion(t, err) diff --git a/pkg/flag/global_flags.go b/pkg/flag/global_flags.go index a9b5440d835b..aa4851c657f0 100644 --- a/pkg/flag/global_flags.go +++ b/pkg/flag/global_flags.go @@ -10,7 +10,7 @@ import ( ) var ( - ConfigFileFlag = Flag{ + ConfigFileFlag = Flag[string]{ Name: "config", ConfigName: "config", Shorthand: "c", @@ -18,55 +18,50 @@ var ( Usage: "config path", Persistent: true, } - ShowVersionFlag = Flag{ + ShowVersionFlag = Flag[bool]{ Name: "version", ConfigName: "version", Shorthand: "v", - Default: false, Usage: "show version", Persistent: true, } - QuietFlag = Flag{ + QuietFlag = Flag[bool]{ Name: "quiet", ConfigName: "quiet", Shorthand: "q", - Default: false, Usage: "suppress progress bar and log output", Persistent: true, } - DebugFlag = Flag{ + DebugFlag = Flag[bool]{ Name: "debug", ConfigName: "debug", Shorthand: "d", - Default: false, Usage: "debug mode", Persistent: true, } - InsecureFlag = Flag{ + InsecureFlag = Flag[bool]{ Name: "insecure", ConfigName: "insecure", - Default: false, Usage: "allow insecure server connections", Persistent: true, } - TimeoutFlag = Flag{ + TimeoutFlag = Flag[time.Duration]{ Name: "timeout", ConfigName: "timeout", Default: time.Second * 300, // 5 mins Usage: "timeout", Persistent: true, } - CacheDirFlag = Flag{ + CacheDirFlag = Flag[string]{ Name: "cache-dir", ConfigName: "cache.dir", Default: fsutils.CacheDir(), Usage: "cache directory", Persistent: true, } - GenerateDefaultConfigFlag = Flag{ + GenerateDefaultConfigFlag = Flag[bool]{ Name: "generate-default-config", ConfigName: "generate-default-config", - Default: false, Usage: "write the default config to trivy-default.yaml", Persistent: true, } @@ -74,14 +69,14 @@ var ( // GlobalFlagGroup composes global flags type GlobalFlagGroup struct { - ConfigFile *Flag - ShowVersion *Flag // spf13/cobra can't override the logic of version printing like VersionPrinter in urfave/cli. -v needs to be defined ourselves. - Quiet *Flag - Debug *Flag - Insecure *Flag - Timeout *Flag - CacheDir *Flag - GenerateDefaultConfig *Flag + ConfigFile *Flag[string] + ShowVersion *Flag[bool] // spf13/cobra can't override the logic of version printing like VersionPrinter in urfave/cli. -v needs to be defined ourselves. + Quiet *Flag[bool] + Debug *Flag[bool] + Insecure *Flag[bool] + Timeout *Flag[time.Duration] + CacheDir *Flag[string] + GenerateDefaultConfig *Flag[bool] } // GlobalOptions defines flags and other configuration parameters for all the subcommands @@ -98,19 +93,23 @@ type GlobalOptions struct { func NewGlobalFlagGroup() *GlobalFlagGroup { return &GlobalFlagGroup{ - ConfigFile: &ConfigFileFlag, - ShowVersion: &ShowVersionFlag, - Quiet: &QuietFlag, - Debug: &DebugFlag, - Insecure: &InsecureFlag, - Timeout: &TimeoutFlag, - CacheDir: &CacheDirFlag, - GenerateDefaultConfig: &GenerateDefaultConfigFlag, + ConfigFile: ConfigFileFlag.Clone(), + ShowVersion: ShowVersionFlag.Clone(), + Quiet: QuietFlag.Clone(), + Debug: DebugFlag.Clone(), + Insecure: InsecureFlag.Clone(), + Timeout: TimeoutFlag.Clone(), + CacheDir: CacheDirFlag.Clone(), + GenerateDefaultConfig: GenerateDefaultConfigFlag.Clone(), } } -func (f *GlobalFlagGroup) flags() []*Flag { - return []*Flag{ +func (f *GlobalFlagGroup) Name() string { + return "global" +} + +func (f *GlobalFlagGroup) Flags() []Flagger { + return []Flagger{ f.ConfigFile, f.ShowVersion, f.Quiet, @@ -123,32 +122,36 @@ func (f *GlobalFlagGroup) flags() []*Flag { } func (f *GlobalFlagGroup) AddFlags(cmd *cobra.Command) { - for _, flag := range f.flags() { - addFlag(cmd, flag) + for _, flag := range f.Flags() { + flag.Add(cmd) } } func (f *GlobalFlagGroup) Bind(cmd *cobra.Command) error { - for _, flag := range f.flags() { - if err := bind(cmd, flag); err != nil { + for _, flag := range f.Flags() { + if err := flag.Bind(cmd); err != nil { return err } } return nil } -func (f *GlobalFlagGroup) ToOptions() GlobalOptions { +func (f *GlobalFlagGroup) ToOptions() (GlobalOptions, error) { + if err := parseFlags(f); err != nil { + return GlobalOptions{}, err + } + // Keep TRIVY_NON_SSL for backward compatibility - insecure := getBool(f.Insecure) || os.Getenv("TRIVY_NON_SSL") != "" + insecure := f.Insecure.Value() || os.Getenv("TRIVY_NON_SSL") != "" return GlobalOptions{ - ConfigFile: getString(f.ConfigFile), - ShowVersion: getBool(f.ShowVersion), - Quiet: getBool(f.Quiet), - Debug: getBool(f.Debug), + ConfigFile: f.ConfigFile.Value(), + ShowVersion: f.ShowVersion.Value(), + Quiet: f.Quiet.Value(), + Debug: f.Debug.Value(), Insecure: insecure, - Timeout: getDuration(f.Timeout), - CacheDir: getString(f.CacheDir), - GenerateDefaultConfig: getBool(f.GenerateDefaultConfig), - } + Timeout: f.Timeout.Value(), + CacheDir: f.CacheDir.Value(), + GenerateDefaultConfig: f.GenerateDefaultConfig.Value(), + }, nil } diff --git a/pkg/flag/image_flags.go b/pkg/flag/image_flags.go index b94162aeb9a7..93caccd48834 100644 --- a/pkg/flag/image_flags.go +++ b/pkg/flag/image_flags.go @@ -15,41 +15,37 @@ import ( // input: "/path/to/alpine" var ( - ImageConfigScannersFlag = Flag{ + ImageConfigScannersFlag = Flag[[]string]{ Name: "image-config-scanners", ConfigName: "image.image-config-scanners", - Default: []string{}, Values: xstrings.ToStringSlice(types.Scanners{ types.MisconfigScanner, types.SecretScanner, }), Usage: "comma-separated list of what security issues to detect on container image configurations", } - ScanRemovedPkgsFlag = Flag{ + ScanRemovedPkgsFlag = Flag[bool]{ Name: "removed-pkgs", ConfigName: "image.removed-pkgs", - Default: false, Usage: "detect vulnerabilities of removed packages (only for Alpine)", } - InputFlag = Flag{ + InputFlag = Flag[string]{ Name: "input", ConfigName: "image.input", - Default: "", Usage: "input file path instead of image name", } - PlatformFlag = Flag{ + PlatformFlag = Flag[string]{ Name: "platform", ConfigName: "image.platform", - Default: "", Usage: "set platform in the form os/arch if image is multi-platform capable", } - DockerHostFlag = Flag{ + DockerHostFlag = Flag[string]{ Name: "docker-host", ConfigName: "image.docker.host", Default: "", Usage: "unix domain socket path to use for docker scanning", } - SourceFlag = Flag{ + SourceFlag = Flag[[]string]{ Name: "image-src", ConfigName: "image.source", Default: xstrings.ToStringSlice(ftypes.AllImageSources), @@ -59,12 +55,12 @@ var ( ) type ImageFlagGroup struct { - Input *Flag // local image archive - ImageConfigScanners *Flag - ScanRemovedPkgs *Flag - Platform *Flag - DockerHost *Flag - ImageSources *Flag + Input *Flag[string] // local image archive + ImageConfigScanners *Flag[[]string] + ScanRemovedPkgs *Flag[bool] + Platform *Flag[string] + DockerHost *Flag[string] + ImageSources *Flag[[]string] } type ImageOptions struct { @@ -78,12 +74,12 @@ type ImageOptions struct { func NewImageFlagGroup() *ImageFlagGroup { return &ImageFlagGroup{ - Input: &InputFlag, - ImageConfigScanners: &ImageConfigScannersFlag, - ScanRemovedPkgs: &ScanRemovedPkgsFlag, - Platform: &PlatformFlag, - DockerHost: &DockerHostFlag, - ImageSources: &SourceFlag, + Input: InputFlag.Clone(), + ImageConfigScanners: ImageConfigScannersFlag.Clone(), + ScanRemovedPkgs: ScanRemovedPkgsFlag.Clone(), + Platform: PlatformFlag.Clone(), + DockerHost: DockerHostFlag.Clone(), + ImageSources: SourceFlag.Clone(), } } @@ -91,8 +87,8 @@ func (f *ImageFlagGroup) Name() string { return "Image" } -func (f *ImageFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *ImageFlagGroup) Flags() []Flagger { + return []Flagger{ f.Input, f.ImageConfigScanners, f.ScanRemovedPkgs, @@ -103,8 +99,12 @@ func (f *ImageFlagGroup) Flags() []*Flag { } func (f *ImageFlagGroup) ToOptions() (ImageOptions, error) { + if err := parseFlags(f); err != nil { + return ImageOptions{}, err + } + var platform ftypes.Platform - if p := getString(f.Platform); p != "" { + if p := f.Platform.Value(); p != "" { pl, err := v1.ParsePlatform(p) if err != nil { return ImageOptions{}, xerrors.Errorf("unable to parse platform: %w", err) @@ -116,11 +116,11 @@ func (f *ImageFlagGroup) ToOptions() (ImageOptions, error) { } return ImageOptions{ - Input: getString(f.Input), - ImageConfigScanners: getUnderlyingStringSlice[types.Scanner](f.ImageConfigScanners), - ScanRemovedPkgs: getBool(f.ScanRemovedPkgs), + Input: f.Input.Value(), + ImageConfigScanners: xstrings.ToTSlice[types.Scanner](f.ImageConfigScanners.Value()), + ScanRemovedPkgs: f.ScanRemovedPkgs.Value(), Platform: platform, - DockerHost: getString(f.DockerHost), - ImageSources: getUnderlyingStringSlice[ftypes.ImageSource](f.ImageSources), + DockerHost: f.DockerHost.Value(), + ImageSources: xstrings.ToTSlice[ftypes.ImageSource](f.ImageSources.Value()), }, nil } diff --git a/pkg/flag/kubernetes_flags.go b/pkg/flag/kubernetes_flags.go index e1405b501d37..7a87040ba698 100644 --- a/pkg/flag/kubernetes_flags.go +++ b/pkg/flag/kubernetes_flags.go @@ -10,29 +10,26 @@ import ( ) var ( - ClusterContextFlag = Flag{ + ClusterContextFlag = Flag[string]{ Name: "context", ConfigName: "kubernetes.context", - Default: "", Usage: "specify a context to scan", Aliases: []Alias{ {Name: "ctx"}, }, } - K8sNamespaceFlag = Flag{ + K8sNamespaceFlag = Flag[string]{ Name: "namespace", ConfigName: "kubernetes.namespace", Shorthand: "n", - Default: "", Usage: "specify a namespace to scan", } - KubeConfigFlag = Flag{ + KubeConfigFlag = Flag[string]{ Name: "kubeconfig", ConfigName: "kubernetes.kubeconfig", - Default: "", Usage: "specify the kubeconfig file path to use", } - ComponentsFlag = Flag{ + ComponentsFlag = Flag[[]string]{ Name: "components", ConfigName: "kubernetes.components", Default: []string{ @@ -45,56 +42,51 @@ var ( }, Usage: "specify which components to scan", } - K8sVersionFlag = Flag{ + K8sVersionFlag = Flag[string]{ Name: "k8s-version", ConfigName: "kubernetes.k8s.version", - Default: "", Usage: "specify k8s version to validate outdated api by it (example: 1.21.0)", } - TolerationsFlag = Flag{ + TolerationsFlag = Flag[[]string]{ Name: "tolerations", ConfigName: "kubernetes.tolerations", - Default: []string{}, Usage: "specify node-collector job tolerations (example: key1=value1:NoExecute,key2=value2:NoSchedule)", } - AllNamespaces = Flag{ + AllNamespaces = Flag[bool]{ Name: "all-namespaces", ConfigName: "kubernetes.all.namespaces", Shorthand: "A", - Default: false, Usage: "fetch resources from all cluster namespaces", } - NodeCollectorNamespace = Flag{ + NodeCollectorNamespace = Flag[string]{ Name: "node-collector-namespace", ConfigName: "node.collector.namespace", Default: "trivy-temp", Usage: "specify the namespace in which the node-collector job should be deployed", } - ExcludeOwned = Flag{ + ExcludeOwned = Flag[bool]{ Name: "exclude-owned", ConfigName: "kubernetes.exclude.owned", - Default: false, Usage: "exclude resources that have an owner reference", } - ExcludeNodes = Flag{ + ExcludeNodes = Flag[[]string]{ Name: "exclude-nodes", - ConfigName: "exclude.nodes", - Default: []string{}, + ConfigName: "kubernetes.exclude.nodes", Usage: "indicate the node labels that the node-collector job should exclude from scanning (example: kubernetes.io/arch:arm64,team:dev)", } - NodeCollectorImageRef = Flag{ + NodeCollectorImageRef = Flag[string]{ Name: "node-collector-imageref", - ConfigName: "node.collector.imageref", + ConfigName: "kubernetes.node.collector.imageref", Default: "ghcr.io/aquasecurity/node-collector:0.0.9", Usage: "indicate the image reference for the node-collector scan job", } - QPS = Flag{ + QPS = Flag[float64]{ Name: "qps", ConfigName: "kubernetes.qps", Default: 5.0, Usage: "specify the maximum QPS to the master from this client", } - Burst = Flag{ + Burst = Flag[int]{ Name: "burst", ConfigName: "kubernetes.burst", Default: 10, @@ -103,19 +95,19 @@ var ( ) type K8sFlagGroup struct { - ClusterContext *Flag - Namespace *Flag - KubeConfig *Flag - Components *Flag - K8sVersion *Flag - Tolerations *Flag - NodeCollectorImageRef *Flag - AllNamespaces *Flag - NodeCollectorNamespace *Flag - ExcludeOwned *Flag - ExcludeNodes *Flag - QPS *Flag - Burst *Flag + ClusterContext *Flag[string] + Namespace *Flag[string] + KubeConfig *Flag[string] + Components *Flag[[]string] + K8sVersion *Flag[string] + Tolerations *Flag[[]string] + NodeCollectorImageRef *Flag[string] + AllNamespaces *Flag[bool] + NodeCollectorNamespace *Flag[string] + ExcludeOwned *Flag[bool] + ExcludeNodes *Flag[[]string] + QPS *Flag[float64] + Burst *Flag[int] } type K8sOptions struct { @@ -136,19 +128,19 @@ type K8sOptions struct { func NewK8sFlagGroup() *K8sFlagGroup { return &K8sFlagGroup{ - ClusterContext: &ClusterContextFlag, - Namespace: &K8sNamespaceFlag, - KubeConfig: &KubeConfigFlag, - Components: &ComponentsFlag, - K8sVersion: &K8sVersionFlag, - Tolerations: &TolerationsFlag, - AllNamespaces: &AllNamespaces, - NodeCollectorNamespace: &NodeCollectorNamespace, - ExcludeOwned: &ExcludeOwned, - ExcludeNodes: &ExcludeNodes, - NodeCollectorImageRef: &NodeCollectorImageRef, - QPS: &QPS, - Burst: &Burst, + ClusterContext: ClusterContextFlag.Clone(), + Namespace: K8sNamespaceFlag.Clone(), + KubeConfig: KubeConfigFlag.Clone(), + Components: ComponentsFlag.Clone(), + K8sVersion: K8sVersionFlag.Clone(), + Tolerations: TolerationsFlag.Clone(), + AllNamespaces: AllNamespaces.Clone(), + NodeCollectorNamespace: NodeCollectorNamespace.Clone(), + ExcludeOwned: ExcludeOwned.Clone(), + ExcludeNodes: ExcludeNodes.Clone(), + NodeCollectorImageRef: NodeCollectorImageRef.Clone(), + QPS: QPS.Clone(), + Burst: Burst.Clone(), } } @@ -156,8 +148,8 @@ func (f *K8sFlagGroup) Name() string { return "Kubernetes" } -func (f *K8sFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *K8sFlagGroup) Flags() []Flagger { + return []Flagger{ f.ClusterContext, f.Namespace, f.KubeConfig, @@ -175,13 +167,17 @@ func (f *K8sFlagGroup) Flags() []*Flag { } func (f *K8sFlagGroup) ToOptions() (K8sOptions, error) { - tolerations, err := optionToTolerations(getStringSlice(f.Tolerations)) + if err := parseFlags(f); err != nil { + return K8sOptions{}, err + } + + tolerations, err := optionToTolerations(f.Tolerations.Value()) if err != nil { return K8sOptions{}, err } exludeNodeLabels := make(map[string]string) - exludeNodes := getStringSlice(f.ExcludeNodes) + exludeNodes := f.ExcludeNodes.Value() for _, exludeNodeValue := range exludeNodes { excludeNodeParts := strings.Split(exludeNodeValue, ":") if len(excludeNodeParts) != 2 { @@ -191,17 +187,19 @@ func (f *K8sFlagGroup) ToOptions() (K8sOptions, error) { } return K8sOptions{ - ClusterContext: getString(f.ClusterContext), - Namespace: getString(f.Namespace), - KubeConfig: getString(f.KubeConfig), - Components: getStringSlice(f.Components), - K8sVersion: getString(f.K8sVersion), + ClusterContext: f.ClusterContext.Value(), + Namespace: f.Namespace.Value(), + KubeConfig: f.KubeConfig.Value(), + Components: f.Components.Value(), + K8sVersion: f.K8sVersion.Value(), Tolerations: tolerations, - AllNamespaces: getBool(f.AllNamespaces), - NodeCollectorNamespace: getString(f.NodeCollectorNamespace), - ExcludeOwned: getBool(f.ExcludeOwned), + AllNamespaces: f.AllNamespaces.Value(), + NodeCollectorNamespace: f.NodeCollectorNamespace.Value(), + ExcludeOwned: f.ExcludeOwned.Value(), ExcludeNodes: exludeNodeLabels, - NodeCollectorImageRef: getString(f.NodeCollectorImageRef), + NodeCollectorImageRef: f.NodeCollectorImageRef.Value(), + QPS: float32(f.QPS.Value()), + Burst: f.Burst.Value(), }, nil } diff --git a/pkg/flag/license_flags.go b/pkg/flag/license_flags.go index 907aaca287c5..5f4e148af5b2 100644 --- a/pkg/flag/license_flags.go +++ b/pkg/flag/license_flags.go @@ -6,19 +6,17 @@ import ( ) var ( - LicenseFull = Flag{ + LicenseFull = Flag[bool]{ Name: "license-full", ConfigName: "license.full", - Default: false, Usage: "eagerly look for licenses in source code headers and license files", } - IgnoredLicenses = Flag{ + IgnoredLicenses = Flag[[]string]{ Name: "ignored-licenses", ConfigName: "license.ignored", - Default: []string{}, Usage: "specify a list of license to ignore", } - LicenseConfidenceLevel = Flag{ + LicenseConfidenceLevel = Flag[float64]{ Name: "license-confidence-level", ConfigName: "license.confidenceLevel", Default: 0.9, @@ -26,37 +24,37 @@ var ( } // LicenseForbidden is an option only in a config file - LicenseForbidden = Flag{ + LicenseForbidden = Flag[[]string]{ ConfigName: "license.forbidden", Default: licensing.ForbiddenLicenses, Usage: "forbidden licenses", } // LicenseRestricted is an option only in a config file - LicenseRestricted = Flag{ + LicenseRestricted = Flag[[]string]{ ConfigName: "license.restricted", Default: licensing.RestrictedLicenses, Usage: "restricted licenses", } // LicenseReciprocal is an option only in a config file - LicenseReciprocal = Flag{ + LicenseReciprocal = Flag[[]string]{ ConfigName: "license.reciprocal", Default: licensing.ReciprocalLicenses, Usage: "reciprocal licenses", } // LicenseNotice is an option only in a config file - LicenseNotice = Flag{ + LicenseNotice = Flag[[]string]{ ConfigName: "license.notice", Default: licensing.NoticeLicenses, Usage: "notice licenses", } // LicensePermissive is an option only in a config file - LicensePermissive = Flag{ + LicensePermissive = Flag[[]string]{ ConfigName: "license.permissive", Default: licensing.PermissiveLicenses, Usage: "permissive licenses", } // LicenseUnencumbered is an option only in a config file - LicenseUnencumbered = Flag{ + LicenseUnencumbered = Flag[[]string]{ ConfigName: "license.unencumbered", Default: licensing.UnencumberedLicenses, Usage: "unencumbered licenses", @@ -64,17 +62,17 @@ var ( ) type LicenseFlagGroup struct { - LicenseFull *Flag - IgnoredLicenses *Flag - LicenseConfidenceLevel *Flag + LicenseFull *Flag[bool] + IgnoredLicenses *Flag[[]string] + LicenseConfidenceLevel *Flag[float64] // License Categories - LicenseForbidden *Flag // mapped to CRITICAL - LicenseRestricted *Flag // mapped to HIGH - LicenseReciprocal *Flag // mapped to MEDIUM - LicenseNotice *Flag // mapped to LOW - LicensePermissive *Flag // mapped to LOW - LicenseUnencumbered *Flag // mapped to LOW + LicenseForbidden *Flag[[]string] // mapped to CRITICAL + LicenseRestricted *Flag[[]string] // mapped to HIGH + LicenseReciprocal *Flag[[]string] // mapped to MEDIUM + LicenseNotice *Flag[[]string] // mapped to LOW + LicensePermissive *Flag[[]string] // mapped to LOW + LicenseUnencumbered *Flag[[]string] // mapped to LOW } type LicenseOptions struct { @@ -87,15 +85,15 @@ type LicenseOptions struct { func NewLicenseFlagGroup() *LicenseFlagGroup { return &LicenseFlagGroup{ - LicenseFull: &LicenseFull, - IgnoredLicenses: &IgnoredLicenses, - LicenseConfidenceLevel: &LicenseConfidenceLevel, - LicenseForbidden: &LicenseForbidden, - LicenseRestricted: &LicenseRestricted, - LicenseReciprocal: &LicenseReciprocal, - LicenseNotice: &LicenseNotice, - LicensePermissive: &LicensePermissive, - LicenseUnencumbered: &LicenseUnencumbered, + LicenseFull: LicenseFull.Clone(), + IgnoredLicenses: IgnoredLicenses.Clone(), + LicenseConfidenceLevel: LicenseConfidenceLevel.Clone(), + LicenseForbidden: LicenseForbidden.Clone(), + LicenseRestricted: LicenseRestricted.Clone(), + LicenseReciprocal: LicenseReciprocal.Clone(), + LicenseNotice: LicenseNotice.Clone(), + LicensePermissive: LicensePermissive.Clone(), + LicenseUnencumbered: LicenseUnencumbered.Clone(), } } @@ -103,24 +101,37 @@ func (f *LicenseFlagGroup) Name() string { return "License" } -func (f *LicenseFlagGroup) Flags() []*Flag { - return []*Flag{f.LicenseFull, f.IgnoredLicenses, f.LicenseForbidden, f.LicenseRestricted, f.LicenseReciprocal, - f.LicenseNotice, f.LicensePermissive, f.LicenseUnencumbered, f.LicenseConfidenceLevel} +func (f *LicenseFlagGroup) Flags() []Flagger { + return []Flagger{ + f.LicenseFull, + f.IgnoredLicenses, + f.LicenseForbidden, + f.LicenseRestricted, + f.LicenseReciprocal, + f.LicenseNotice, + f.LicensePermissive, + f.LicenseUnencumbered, + f.LicenseConfidenceLevel, + } } -func (f *LicenseFlagGroup) ToOptions() LicenseOptions { +func (f *LicenseFlagGroup) ToOptions() (LicenseOptions, error) { + if err := parseFlags(f); err != nil { + return LicenseOptions{}, err + } + licenseCategories := make(map[types.LicenseCategory][]string) - licenseCategories[types.CategoryForbidden] = getStringSlice(f.LicenseForbidden) - licenseCategories[types.CategoryRestricted] = getStringSlice(f.LicenseRestricted) - licenseCategories[types.CategoryReciprocal] = getStringSlice(f.LicenseReciprocal) - licenseCategories[types.CategoryNotice] = getStringSlice(f.LicenseNotice) - licenseCategories[types.CategoryPermissive] = getStringSlice(f.LicensePermissive) - licenseCategories[types.CategoryUnencumbered] = getStringSlice(f.LicenseUnencumbered) + licenseCategories[types.CategoryForbidden] = f.LicenseForbidden.Value() + licenseCategories[types.CategoryRestricted] = f.LicenseRestricted.Value() + licenseCategories[types.CategoryReciprocal] = f.LicenseReciprocal.Value() + licenseCategories[types.CategoryNotice] = f.LicenseNotice.Value() + licenseCategories[types.CategoryPermissive] = f.LicensePermissive.Value() + licenseCategories[types.CategoryUnencumbered] = f.LicenseUnencumbered.Value() return LicenseOptions{ - LicenseFull: getBool(f.LicenseFull), - IgnoredLicenses: getStringSlice(f.IgnoredLicenses), - LicenseConfidenceLevel: getFloat(f.LicenseConfidenceLevel), + LicenseFull: f.LicenseFull.Value(), + IgnoredLicenses: f.IgnoredLicenses.Value(), + LicenseConfidenceLevel: f.LicenseConfidenceLevel.Value(), LicenseCategories: licenseCategories, - } + }, nil } diff --git a/pkg/flag/misconf_flags.go b/pkg/flag/misconf_flags.go index 10db4bb81421..492960b60ff9 100644 --- a/pkg/flag/misconf_flags.go +++ b/pkg/flag/misconf_flags.go @@ -15,67 +15,59 @@ import ( // config-policy: "custom-policy/policy" // policy-namespaces: "user" var ( - ResetPolicyBundleFlag = Flag{ + ResetPolicyBundleFlag = Flag[bool]{ Name: "reset-policy-bundle", ConfigName: "misconfiguration.reset-policy-bundle", - Default: false, Usage: "remove policy bundle", } - IncludeNonFailuresFlag = Flag{ + IncludeNonFailuresFlag = Flag[bool]{ Name: "include-non-failures", ConfigName: "misconfiguration.include-non-failures", - Default: false, Usage: "include successes and exceptions, available with '--scanners misconfig'", } - HelmValuesFileFlag = Flag{ + HelmValuesFileFlag = Flag[[]string]{ Name: "helm-values", ConfigName: "misconfiguration.helm.values", - Default: []string{}, Usage: "specify paths to override the Helm values.yaml files", } - HelmSetFlag = Flag{ + HelmSetFlag = Flag[[]string]{ Name: "helm-set", ConfigName: "misconfiguration.helm.set", - Default: []string{}, Usage: "specify Helm values on the command line (can specify multiple or separate values with commas: key1=val1,key2=val2)", } - HelmSetFileFlag = Flag{ + HelmSetFileFlag = Flag[[]string]{ Name: "helm-set-file", ConfigName: "misconfiguration.helm.set-file", - Default: []string{}, Usage: "specify Helm values from respective files specified via the command line (can specify multiple or separate values with commas: key1=path1,key2=path2)", } - HelmSetStringFlag = Flag{ + HelmSetStringFlag = Flag[[]string]{ Name: "helm-set-string", ConfigName: "misconfiguration.helm.set-string", - Default: []string{}, Usage: "specify Helm string values on the command line (can specify multiple or separate values with commas: key1=val1,key2=val2)", } - TfVarsFlag = Flag{ + TfVarsFlag = Flag[[]string]{ Name: "tf-vars", ConfigName: "misconfiguration.terraform.vars", - Default: []string{}, Usage: "specify paths to override the Terraform tfvars files", } - CfParamsFlag = Flag{ + CfParamsFlag = Flag[[]string]{ Name: "cf-params", ConfigName: "misconfiguration.cloudformation.params", Default: []string{}, Usage: "specify paths to override the CloudFormation parameters files", } - TerraformExcludeDownloaded = Flag{ + TerraformExcludeDownloaded = Flag[bool]{ Name: "tf-exclude-downloaded-modules", ConfigName: "misconfiguration.terraform.exclude-downloaded-modules", - Default: false, Usage: "exclude misconfigurations for downloaded terraform modules", } - PolicyBundleRepositoryFlag = Flag{ + PolicyBundleRepositoryFlag = Flag[string]{ Name: "policy-bundle-repository", ConfigName: "misconfiguration.policy-bundle-repository", Default: fmt.Sprintf("%s:%d", policy.BundleRepository, policy.BundleVersion), Usage: "OCI registry URL to retrieve policy bundle from", } - MisconfigScannersFlag = Flag{ + MisconfigScannersFlag = Flag[[]string]{ Name: "misconfig-scanners", ConfigName: "misconfiguration.scanners", Default: xstrings.ToStringSlice(analyzer.TypeConfigFiles), @@ -85,19 +77,19 @@ var ( // MisconfFlagGroup composes common printer flag structs used for commands providing misconfiguration scanning. type MisconfFlagGroup struct { - IncludeNonFailures *Flag - ResetPolicyBundle *Flag - PolicyBundleRepository *Flag + IncludeNonFailures *Flag[bool] + ResetPolicyBundle *Flag[bool] + PolicyBundleRepository *Flag[string] // Values Files - HelmValues *Flag - HelmValueFiles *Flag - HelmFileValues *Flag - HelmStringValues *Flag - TerraformTFVars *Flag - CloudformationParamVars *Flag - TerraformExcludeDownloaded *Flag - MisconfigScanners *Flag + HelmValues *Flag[[]string] + HelmValueFiles *Flag[[]string] + HelmFileValues *Flag[[]string] + HelmStringValues *Flag[[]string] + TerraformTFVars *Flag[[]string] + CloudformationParamVars *Flag[[]string] + TerraformExcludeDownloaded *Flag[bool] + MisconfigScanners *Flag[[]string] } type MisconfOptions struct { @@ -118,18 +110,18 @@ type MisconfOptions struct { func NewMisconfFlagGroup() *MisconfFlagGroup { return &MisconfFlagGroup{ - IncludeNonFailures: &IncludeNonFailuresFlag, - ResetPolicyBundle: &ResetPolicyBundleFlag, - PolicyBundleRepository: &PolicyBundleRepositoryFlag, + IncludeNonFailures: IncludeNonFailuresFlag.Clone(), + ResetPolicyBundle: ResetPolicyBundleFlag.Clone(), + PolicyBundleRepository: PolicyBundleRepositoryFlag.Clone(), - HelmValues: &HelmSetFlag, - HelmFileValues: &HelmSetFileFlag, - HelmStringValues: &HelmSetStringFlag, - HelmValueFiles: &HelmValuesFileFlag, - TerraformTFVars: &TfVarsFlag, - CloudformationParamVars: &CfParamsFlag, - TerraformExcludeDownloaded: &TerraformExcludeDownloaded, - MisconfigScanners: &MisconfigScannersFlag, + HelmValues: HelmSetFlag.Clone(), + HelmFileValues: HelmSetFileFlag.Clone(), + HelmStringValues: HelmSetStringFlag.Clone(), + HelmValueFiles: HelmValuesFileFlag.Clone(), + TerraformTFVars: TfVarsFlag.Clone(), + CloudformationParamVars: CfParamsFlag.Clone(), + TerraformExcludeDownloaded: TerraformExcludeDownloaded.Clone(), + MisconfigScanners: MisconfigScannersFlag.Clone(), } } @@ -137,8 +129,8 @@ func (f *MisconfFlagGroup) Name() string { return "Misconfiguration" } -func (f *MisconfFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *MisconfFlagGroup) Flags() []Flagger { + return []Flagger{ f.IncludeNonFailures, f.ResetPolicyBundle, f.PolicyBundleRepository, @@ -154,17 +146,21 @@ func (f *MisconfFlagGroup) Flags() []*Flag { } func (f *MisconfFlagGroup) ToOptions() (MisconfOptions, error) { + if err := parseFlags(f); err != nil { + return MisconfOptions{}, err + } + return MisconfOptions{ - IncludeNonFailures: getBool(f.IncludeNonFailures), - ResetPolicyBundle: getBool(f.ResetPolicyBundle), - PolicyBundleRepository: getString(f.PolicyBundleRepository), - HelmValues: getStringSlice(f.HelmValues), - HelmValueFiles: getStringSlice(f.HelmValueFiles), - HelmFileValues: getStringSlice(f.HelmFileValues), - HelmStringValues: getStringSlice(f.HelmStringValues), - TerraformTFVars: getStringSlice(f.TerraformTFVars), - CloudFormationParamVars: getStringSlice(f.CloudformationParamVars), - TfExcludeDownloaded: getBool(f.TerraformExcludeDownloaded), - MisconfigScanners: getUnderlyingStringSlice[analyzer.Type](f.MisconfigScanners), + IncludeNonFailures: f.IncludeNonFailures.Value(), + ResetPolicyBundle: f.ResetPolicyBundle.Value(), + PolicyBundleRepository: f.PolicyBundleRepository.Value(), + HelmValues: f.HelmValues.Value(), + HelmValueFiles: f.HelmValueFiles.Value(), + HelmFileValues: f.HelmFileValues.Value(), + HelmStringValues: f.HelmStringValues.Value(), + TerraformTFVars: f.TerraformTFVars.Value(), + CloudFormationParamVars: f.CloudformationParamVars.Value(), + TfExcludeDownloaded: f.TerraformExcludeDownloaded.Value(), + MisconfigScanners: xstrings.ToTSlice[analyzer.Type](f.MisconfigScanners.Value()), }, nil } diff --git a/pkg/flag/module_flags.go b/pkg/flag/module_flags.go index 1a84e2acbb68..a3fdca308208 100644 --- a/pkg/flag/module_flags.go +++ b/pkg/flag/module_flags.go @@ -11,14 +11,14 @@ import ( // - spring4shell var ( - ModuleDirFlag = Flag{ + ModuleDirFlag = Flag[string]{ Name: "module-dir", ConfigName: "module.dir", Default: module.DefaultDir, Usage: "specify directory to the wasm modules that will be loaded", Persistent: true, } - EnableModulesFlag = Flag{ + EnableModulesFlag = Flag[[]string]{ Name: "enable-modules", ConfigName: "module.enable-modules", Default: []string{}, @@ -29,8 +29,8 @@ var ( // ModuleFlagGroup defines flags for modules type ModuleFlagGroup struct { - Dir *Flag - EnabledModules *Flag + Dir *Flag[string] + EnabledModules *Flag[[]string] } type ModuleOptions struct { @@ -40,8 +40,8 @@ type ModuleOptions struct { func NewModuleFlagGroup() *ModuleFlagGroup { return &ModuleFlagGroup{ - Dir: &ModuleDirFlag, - EnabledModules: &EnableModulesFlag, + Dir: ModuleDirFlag.Clone(), + EnabledModules: EnableModulesFlag.Clone(), } } @@ -49,16 +49,20 @@ func (f *ModuleFlagGroup) Name() string { return "Module" } -func (f *ModuleFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *ModuleFlagGroup) Flags() []Flagger { + return []Flagger{ f.Dir, f.EnabledModules, } } -func (f *ModuleFlagGroup) ToOptions() ModuleOptions { - return ModuleOptions{ - ModuleDir: getString(f.Dir), - EnabledModules: getStringSlice(f.EnabledModules), +func (f *ModuleFlagGroup) ToOptions() (ModuleOptions, error) { + if err := parseFlags(f); err != nil { + return ModuleOptions{}, err } + + return ModuleOptions{ + ModuleDir: f.Dir.Value(), + EnabledModules: f.EnabledModules.Value(), + }, nil } diff --git a/pkg/flag/options.go b/pkg/flag/options.go index fb3d69eaa396..cfdce46e2240 100644 --- a/pkg/flag/options.go +++ b/pkg/flag/options.go @@ -9,10 +9,12 @@ import ( "sync" "time" + "github.com/samber/lo" "github.com/spf13/cast" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" + "golang.org/x/exp/slices" "golang.org/x/xerrors" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" @@ -22,10 +24,13 @@ import ( "github.com/aquasecurity/trivy/pkg/result" "github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/version" - xstrings "github.com/aquasecurity/trivy/pkg/x/strings" ) -type Flag struct { +type FlagType interface { + int | string | []string | bool | time.Duration | float64 +} + +type Flag[T FlagType] struct { // Name is for CLI flag and environment variable. // If this field is empty, it will be available only in config file. Name string @@ -36,8 +41,8 @@ type Flag struct { // Shorthand is a shorthand letter. Shorthand string - // Default is the default value. It must be filled to determine the flag type. - Default any + // Default is the default value. It should be defined when the value is different from the zero value. + Default T // Values is a list of allowed values. // It currently supports string flags and string slice flags only. @@ -45,7 +50,7 @@ type Flag struct { // ValueNormalize is a function to normalize the value. // It can be used for aliases, etc. - ValueNormalize func(string) string + ValueNormalize func(T) T // Usage explains how to use the flag. Usage string @@ -58,6 +63,10 @@ type Flag struct { // Aliases represents aliases Aliases []Alias + + // value is the value passed through CLI flag, env, or config file. + // It is populated after flag.Parse() is called. + value T } type Alias struct { @@ -66,12 +75,230 @@ type Alias struct { Deprecated bool } +func (f *Flag[T]) Clone() *Flag[T] { + var t T + ff := *f + ff.value = t + fff := &ff + return fff +} + +func (f *Flag[T]) Parse() error { + if f == nil { + return nil + } + + v := f.parse() + if v == nil { + f.value = lo.Empty[T]() + return nil + } + + value, ok := f.cast(v).(T) + if !ok { + return xerrors.Errorf("failed to parse flag %s", f.Name) + } + + if f.ValueNormalize != nil { + value = f.ValueNormalize(value) + } + + if f.isSet() && !f.allowedValue(value) { + return xerrors.Errorf(`invalid argument "%s" for "--%s" flag: must be one of %q`, value, f.Name, f.Values) + } + + f.value = value + return nil +} + +func (f *Flag[T]) parse() any { + // First, looks for aliases in config file (trivy.yaml). + // Note that viper.RegisterAlias cannot be used for this purpose. + var v any + for _, alias := range f.Aliases { + if alias.ConfigName == "" { + continue + } + v = viper.Get(alias.ConfigName) + if v != nil { + log.Logger.Warnf("'%s' in config file is deprecated. Use '%s' instead.", alias.ConfigName, f.ConfigName) + return v + } + } + return viper.Get(f.ConfigName) +} + +// cast converts the value to the type of the flag. +func (f *Flag[T]) cast(val any) any { + switch any(f.Default).(type) { + case bool: + return cast.ToBool(val) + case string: + return cast.ToString(val) + case int: + return cast.ToInt(val) + case float64, float32: + return cast.ToFloat64(val) + case time.Duration: + return cast.ToDuration(val) + case []string: + if s, ok := val.(string); ok && strings.Contains(s, ",") { + // Split environmental variables by comma as it is not done by viper. + // cf. https://github.com/spf13/viper/issues/380 + // It is split by spaces only. + // https://github.com/spf13/cast/blob/48ddde5701366ade1d3aba346e09bb58430d37c6/caste.go#L1296-L1297 + val = strings.Split(s, ",") + } + return cast.ToStringSlice(val) + } + return val +} + +func (f *Flag[T]) isSet() bool { + configNames := lo.FilterMap(f.Aliases, func(alias Alias, _ int) (string, bool) { + return alias.ConfigName, alias.ConfigName != "" + }) + configNames = append(configNames, f.ConfigName) + + return lo.SomeBy(configNames, viper.IsSet) +} + +func (f *Flag[T]) allowedValue(v any) bool { + if len(f.Values) == 0 { + return true + } + switch value := v.(type) { + case string: + return slices.Contains(f.Values, value) + case []string: + for _, v := range value { + if !slices.Contains(f.Values, v) { + return false + } + } + } + return true +} + +func (f *Flag[T]) GetName() string { + return f.Name +} + +func (f *Flag[T]) GetAliases() []Alias { + return f.Aliases +} + +func (f *Flag[T]) Value() (t T) { + if f == nil { + return t + } + return f.value +} + +func (f *Flag[T]) Add(cmd *cobra.Command) { + if f == nil || f.Name == "" { + return + } + var flags *pflag.FlagSet + if f.Persistent { + flags = cmd.PersistentFlags() + } else { + flags = cmd.Flags() + } + + switch v := any(f.Default).(type) { + case int: + flags.IntP(f.Name, f.Shorthand, v, f.Usage) + case string: + usage := f.Usage + if len(f.Values) > 0 { + usage += fmt.Sprintf(" (%s)", strings.Join(f.Values, ",")) + } + flags.StringP(f.Name, f.Shorthand, v, usage) + case []string: + usage := f.Usage + if len(f.Values) > 0 { + usage += fmt.Sprintf(" (%s)", strings.Join(f.Values, ",")) + } + flags.StringSliceP(f.Name, f.Shorthand, v, usage) + case bool: + flags.BoolP(f.Name, f.Shorthand, v, f.Usage) + case time.Duration: + flags.DurationP(f.Name, f.Shorthand, v, f.Usage) + case float64: + flags.Float64P(f.Name, f.Shorthand, v, f.Usage) + } + + if f.Deprecated { + flags.MarkHidden(f.Name) // nolint: gosec + } +} + +func (f *Flag[T]) Bind(cmd *cobra.Command) error { + if f == nil { + return nil + } else if f.Name == "" { + // This flag is available only in trivy.yaml + viper.SetDefault(f.ConfigName, f.Default) + return nil + } + + // Bind CLI flags + flag := cmd.Flags().Lookup(f.Name) + if f == nil { + // Lookup local persistent flags + flag = cmd.PersistentFlags().Lookup(f.Name) + } + if err := viper.BindPFlag(f.ConfigName, flag); err != nil { + return xerrors.Errorf("bind flag error: %w", err) + } + + // Bind environmental variable + if err := f.BindEnv(); err != nil { + return err + } + + return nil +} + +func (f *Flag[T]) BindEnv() error { + // We don't use viper.AutomaticEnv, so we need to add a prefix manually here. + envName := strings.ToUpper("trivy_" + strings.ReplaceAll(f.Name, "-", "_")) + if err := viper.BindEnv(f.ConfigName, envName); err != nil { + return xerrors.Errorf("bind env error: %w", err) + } + + // Bind env aliases + for _, alias := range f.Aliases { + envAlias := strings.ToUpper("trivy_" + strings.ReplaceAll(alias.Name, "-", "_")) + if err := viper.BindEnv(f.ConfigName, envAlias); err != nil { + return xerrors.Errorf("bind env error: %w", err) + } + if alias.Deprecated { + if _, ok := os.LookupEnv(envAlias); ok { + log.Logger.Warnf("'%s' is deprecated. Use '%s' instead.", envAlias, envName) + } + } + } + return nil +} + type FlagGroup interface { Name() string - Flags() []*Flag + Flags() []Flagger +} + +type Flagger interface { + GetName() string + GetAliases() []Alias + + Parse() error + Add(cmd *cobra.Command) + Bind(cmd *cobra.Command) error } type Flags struct { + GlobalFlagGroup *GlobalFlagGroup AWSFlagGroup *AWSFlagGroup CacheFlagGroup *CacheFlagGroup CloudFlagGroup *CloudFlagGroup @@ -217,163 +444,7 @@ func (o *Options) outputPluginWriter(ctx context.Context) (io.Writer, func() err return pw, cleanup, nil } -func addFlag(cmd *cobra.Command, flag *Flag) { - if flag == nil || flag.Name == "" { - return - } - var flags *pflag.FlagSet - if flag.Persistent { - flags = cmd.PersistentFlags() - } else { - flags = cmd.Flags() - } - - switch v := flag.Default.(type) { - case int: - flags.IntP(flag.Name, flag.Shorthand, v, flag.Usage) - case string: - usage := flag.Usage - if len(flag.Values) > 0 { - usage += fmt.Sprintf(" (%s)", strings.Join(flag.Values, ",")) - } - flags.VarP(newCustomStringValue(v, flag.Values, flag.ValueNormalize), flag.Name, flag.Shorthand, usage) - case []string: - usage := flag.Usage - if len(flag.Values) > 0 { - usage += fmt.Sprintf(" (%s)", strings.Join(flag.Values, ",")) - } - flags.VarP(newCustomStringSliceValue(v, flag.Values, flag.ValueNormalize), flag.Name, flag.Shorthand, usage) - case bool: - flags.BoolP(flag.Name, flag.Shorthand, v, flag.Usage) - case time.Duration: - flags.DurationP(flag.Name, flag.Shorthand, v, flag.Usage) - case float64: - flags.Float64P(flag.Name, flag.Shorthand, v, flag.Usage) - } - - if flag.Deprecated { - flags.MarkHidden(flag.Name) // nolint: gosec - } -} - -func bind(cmd *cobra.Command, flag *Flag) error { - if flag == nil { - return nil - } else if flag.Name == "" { - // This flag is available only in trivy.yaml - viper.SetDefault(flag.ConfigName, flag.Default) - return nil - } - - // Bind CLI flags - f := cmd.Flags().Lookup(flag.Name) - if f == nil { - // Lookup local persistent flags - f = cmd.PersistentFlags().Lookup(flag.Name) - } - if err := viper.BindPFlag(flag.ConfigName, f); err != nil { - return xerrors.Errorf("bind flag error: %w", err) - } - - // Bind environmental variable - if err := bindEnv(flag); err != nil { - return err - } - - return nil -} - -func bindEnv(flag *Flag) error { - // We don't use viper.AutomaticEnv, so we need to add a prefix manually here. - envName := strings.ToUpper("trivy_" + strings.ReplaceAll(flag.Name, "-", "_")) - if err := viper.BindEnv(flag.ConfigName, envName); err != nil { - return xerrors.Errorf("bind env error: %w", err) - } - - // Bind env aliases - for _, alias := range flag.Aliases { - envAlias := strings.ToUpper("trivy_" + strings.ReplaceAll(alias.Name, "-", "_")) - if err := viper.BindEnv(flag.ConfigName, envAlias); err != nil { - return xerrors.Errorf("bind env error: %w", err) - } - if alias.Deprecated { - if _, ok := os.LookupEnv(envAlias); ok { - log.Logger.Warnf("'%s' is deprecated. Use '%s' instead.", envAlias, envName) - } - } - } - return nil -} - -func getString(flag *Flag) string { - return cast.ToString(getValue(flag)) -} - -func getUnderlyingString[T xstrings.String](flag *Flag) T { - s := getString(flag) - return T(s) -} - -func getStringSlice(flag *Flag) []string { - // viper always returns a string for ENV - // https://github.com/spf13/viper/blob/419fd86e49ef061d0d33f4d1d56d5e2a480df5bb/viper.go#L545-L553 - // and uses strings.Field to separate values (whitespace only) - // we need to separate env values with ',' - v := cast.ToStringSlice(getValue(flag)) - switch { - case len(v) == 0: // no strings - return nil - case len(v) == 1 && strings.Contains(v[0], ","): // unseparated string - v = strings.Split(v[0], ",") - } - return v -} - -func getUnderlyingStringSlice[T xstrings.String](flag *Flag) []T { - ss := getStringSlice(flag) - if len(ss) == 0 { - return nil - } - return xstrings.ToTSlice[T](ss) -} - -func getInt(flag *Flag) int { - return cast.ToInt(getValue(flag)) -} - -func getFloat(flag *Flag) float64 { - return cast.ToFloat64(getValue(flag)) -} - -func getBool(flag *Flag) bool { - return cast.ToBool(getValue(flag)) -} - -func getDuration(flag *Flag) time.Duration { - return cast.ToDuration(getValue(flag)) -} - -func getValue(flag *Flag) any { - if flag == nil { - return nil - } - - // First, looks for aliases in config file (trivy.yaml). - // Note that viper.RegisterAlias cannot be used for this purpose. - var v any - for _, alias := range flag.Aliases { - if alias.ConfigName == "" { - continue - } - v = viper.Get(alias.ConfigName) - if v != nil { - log.Logger.Warnf("'%s' in config file is deprecated. Use '%s' instead.", alias.ConfigName, flag.ConfigName) - return v - } - } - return viper.Get(flag.ConfigName) -} - +// groups returns all the flag groups other than global flags func (f *Flags) groups() []FlagGroup { var groups []FlagGroup // This order affects the usage message, so they are sorted by frequency of use. @@ -438,7 +509,11 @@ func (f *Flags) AddFlags(cmd *cobra.Command) { aliases := make(flagAliases) for _, group := range f.groups() { for _, flag := range group.Flags() { - addFlag(cmd, flag) + if lo.IsNil(flag) || flag.GetName() == "" { + continue + } + // Register the CLI flag + flag.Add(cmd) // Register flag aliases aliases.Add(flag) @@ -451,14 +526,13 @@ func (f *Flags) AddFlags(cmd *cobra.Command) { func (f *Flags) Usages(cmd *cobra.Command) string { var usages string for _, group := range f.groups() { - flags := pflag.NewFlagSet(cmd.Name(), pflag.ContinueOnError) lflags := cmd.LocalFlags() for _, flag := range group.Flags() { - if flag == nil || flag.Name == "" { + if lo.IsNil(flag) || flag.GetName() == "" { continue } - flags.AddFlag(lflags.Lookup(flag.Name)) + flags.AddFlag(lflags.Lookup(flag.GetName())) } if !flags.HasAvailableFlags() { continue @@ -476,7 +550,7 @@ func (f *Flags) Bind(cmd *cobra.Command) error { continue } for _, flag := range group.Flags() { - if err := bind(cmd, flag); err != nil { + if err := flag.Bind(cmd); err != nil { return xerrors.Errorf("flag groups: %w", err) } } @@ -485,19 +559,31 @@ func (f *Flags) Bind(cmd *cobra.Command) error { } // nolint: gocyclo -func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, error) { +func (f *Flags) ToOptions(args []string) (Options, error) { var err error opts := Options{ - AppVersion: version.AppVersion(), - GlobalOptions: globalFlags.ToOptions(), + AppVersion: version.AppVersion(), + } + + if f.GlobalFlagGroup != nil { + opts.GlobalOptions, err = f.GlobalFlagGroup.ToOptions() + if err != nil { + return Options{}, xerrors.Errorf("global flag error: %w", err) + } } if f.AWSFlagGroup != nil { - opts.AWSOptions = f.AWSFlagGroup.ToOptions() + opts.AWSOptions, err = f.AWSFlagGroup.ToOptions() + if err != nil { + return Options{}, xerrors.Errorf("aws flag error: %w", err) + } } if f.CloudFlagGroup != nil { - opts.CloudOptions = f.CloudFlagGroup.ToOptions() + opts.CloudOptions, err = f.CloudFlagGroup.ToOptions() + if err != nil { + return Options{}, xerrors.Errorf("cloud flag error: %w", err) + } } if f.CacheFlagGroup != nil { @@ -510,7 +596,7 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, if f.DBFlagGroup != nil { opts.DBOptions, err = f.DBFlagGroup.ToOptions() if err != nil { - return Options{}, xerrors.Errorf("flag error: %w", err) + return Options{}, xerrors.Errorf("db flag error: %w", err) } } @@ -529,7 +615,10 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, } if f.LicenseFlagGroup != nil { - opts.LicenseOptions = f.LicenseFlagGroup.ToOptions() + opts.LicenseOptions, err = f.LicenseFlagGroup.ToOptions() + if err != nil { + return Options{}, xerrors.Errorf("license flag error: %w", err) + } } if f.MisconfFlagGroup != nil { @@ -540,7 +629,10 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, } if f.ModuleFlagGroup != nil { - opts.ModuleOptions = f.ModuleFlagGroup.ToOptions() + opts.ModuleOptions, err = f.ModuleFlagGroup.ToOptions() + if err != nil { + return Options{}, xerrors.Errorf("module flag error: %w", err) + } } if f.RegoFlagGroup != nil { @@ -551,7 +643,10 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, } if f.RemoteFlagGroup != nil { - opts.RemoteOptions = f.RemoteFlagGroup.ToOptions() + opts.RemoteOptions, err = f.RemoteFlagGroup.ToOptions() + if err != nil { + return Options{}, xerrors.Errorf("remote flag error: %w", err) + } } if f.RegistryFlagGroup != nil { @@ -562,7 +657,10 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, } if f.RepoFlagGroup != nil { - opts.RepoOptions = f.RepoFlagGroup.ToOptions() + opts.RepoOptions, err = f.RepoFlagGroup.ToOptions() + if err != nil { + return Options{}, xerrors.Errorf("rego flag error: %w", err) + } } if f.ReportFlagGroup != nil { @@ -587,11 +685,17 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, } if f.SecretFlagGroup != nil { - opts.SecretOptions = f.SecretFlagGroup.ToOptions() + opts.SecretOptions, err = f.SecretFlagGroup.ToOptions() + if err != nil { + return Options{}, xerrors.Errorf("secret flag error: %w", err) + } } if f.VulnerabilityFlagGroup != nil { - opts.VulnerabilityOptions = f.VulnerabilityFlagGroup.ToOptions() + opts.VulnerabilityOptions, err = f.VulnerabilityFlagGroup.ToOptions() + if err != nil { + return Options{}, xerrors.Errorf("vulnerability flag error: %w", err) + } } opts.Align() @@ -599,6 +703,15 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, return opts, nil } +func parseFlags(fg FlagGroup) error { + for _, flag := range fg.Flags() { + if err := flag.Parse(); err != nil { + return xerrors.Errorf("unable to parse flag: %w", err) + } + } + return nil +} + type flagAlias struct { formalName string deprecated bool @@ -608,13 +721,10 @@ type flagAlias struct { // flagAliases have aliases for CLI flags type flagAliases map[string]*flagAlias -func (a flagAliases) Add(flag *Flag) { - if flag == nil { - return - } - for _, alias := range flag.Aliases { +func (a flagAliases) Add(flag Flagger) { + for _, alias := range flag.GetAliases() { a[alias.Name] = &flagAlias{ - formalName: flag.Name, + formalName: flag.GetName(), deprecated: alias.Deprecated, } } diff --git a/pkg/flag/options_test.go b/pkg/flag/options_test.go index f3a76d177730..092e09d7b411 100644 --- a/pkg/flag/options_test.go +++ b/pkg/flag/options_test.go @@ -1,82 +1,127 @@ -package flag +package flag_test import ( + "github.com/aquasecurity/trivy/pkg/flag" + "github.com/aquasecurity/trivy/pkg/types" + "github.com/samber/lo" + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" "testing" "github.com/spf13/viper" - "github.com/stretchr/testify/assert" - - "github.com/aquasecurity/trivy/pkg/types" ) -func Test_getStringSlice(t *testing.T) { - type env struct { +func TestFlag_Parse(t *testing.T) { + type kv struct { key string - value string + value any } tests := []struct { - name string - flag *Flag - flagValue interface{} - env env - want []string + name string + flag *kv + env *kv + want []string + wantErr string }{ { - name: "happy path. Empty value", - flag: &ScannersFlag, - flagValue: "", - want: nil, + name: "flag, string slice", + flag: &kv{ + key: "scan.scanners", + value: []string{ + "vuln", + "misconfig", + }, + }, + want: []string{ + string(types.VulnerabilityScanner), + string(types.MisconfigScanner), + }, }, { - name: "happy path. String value", - flag: &ScannersFlag, - flagValue: "license,vuln", + name: "env, string", + env: &kv{ + key: "TRIVY_SCANNERS", + value: "vuln,misconfig", + }, want: []string{ - string(types.LicenseScanner), string(types.VulnerabilityScanner), + string(types.MisconfigScanner), }, }, { - name: "happy path. Slice value", - flag: &ScannersFlag, - flagValue: []string{ - "license", - "secret", + name: "flag, alias", + flag: &kv{ + key: "scan.security-checks", + value: "vulnerability,config", }, want: []string{ - string(types.LicenseScanner), - string(types.SecretScanner), + string(types.VulnerabilityScanner), + string(types.MisconfigScanner), }, }, { - name: "happy path. Env value", - flag: &ScannersFlag, - env: env{ + name: "env, alias", + env: &kv{ key: "TRIVY_SECURITY_CHECKS", - value: "rbac,misconfig", + value: "vulnerability,config", }, want: []string{ - string(types.RBACScanner), + string(types.VulnerabilityScanner), string(types.MisconfigScanner), }, }, + { + name: "flag, invalid value", + flag: &kv{ + key: "scan.scanners", + value: "vuln,invalid", + }, + wantErr: `invalid argument "[vuln invalid]" for "--scanners" flag`, + }, + { + name: "env, invalid value", + env: &kv{ + key: "TRIVY_SCANNERS", + value: "vuln,invalid", + }, + wantErr: `invalid argument "[vuln invalid]" for "--scanners" flag`, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.env.key == "" { - viper.Set(tt.flag.ConfigName, tt.flagValue) - } else { - err := viper.BindEnv(tt.flag.ConfigName, tt.env.key) - assert.NoError(t, err) + t.Cleanup(viper.Reset) - t.Setenv(tt.env.key, tt.env.value) + if tt.flag != nil { + viper.Set(tt.flag.key, tt.flag.value) + } else { + t.Setenv(tt.env.key, tt.env.value.(string)) } - sl := getStringSlice(tt.flag) - assert.Equal(t, tt.want, sl) + app := &cobra.Command{} + f := flag.ScannersFlag.Clone() + f.Add(app) + require.NoError(t, f.Bind(app)) - viper.Reset() + err := f.Parse() + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, f.Value()) }) } } + +func setValue[T comparable](key string, value T) { + if !lo.IsEmpty(value) { + viper.Set(key, value) + } +} + +func setSliceValue[T any](key string, value []T) { + if len(value) > 0 { + viper.Set(key, value) + } +} diff --git a/pkg/flag/registry_flags.go b/pkg/flag/registry_flags.go index 7aed50d5e2f2..552eaf4276d6 100644 --- a/pkg/flag/registry_flags.go +++ b/pkg/flag/registry_flags.go @@ -9,30 +9,27 @@ import ( ) var ( - UsernameFlag = Flag{ + UsernameFlag = Flag[[]string]{ Name: "username", ConfigName: "registry.username", - Default: []string{}, Usage: "username. Comma-separated usernames allowed.", } - PasswordFlag = Flag{ + PasswordFlag = Flag[[]string]{ Name: "password", ConfigName: "registry.password", - Default: []string{}, Usage: "password. Comma-separated passwords allowed. TRIVY_PASSWORD should be used for security reasons.", } - RegistryTokenFlag = Flag{ + RegistryTokenFlag = Flag[string]{ Name: "registry-token", ConfigName: "registry.token", - Default: "", Usage: "registry token", } ) type RegistryFlagGroup struct { - Username *Flag - Password *Flag - RegistryToken *Flag + Username *Flag[[]string] + Password *Flag[[]string] + RegistryToken *Flag[string] } type RegistryOptions struct { @@ -42,9 +39,9 @@ type RegistryOptions struct { func NewRegistryFlagGroup() *RegistryFlagGroup { return &RegistryFlagGroup{ - Username: &UsernameFlag, - Password: &PasswordFlag, - RegistryToken: &RegistryTokenFlag, + Username: UsernameFlag.Clone(), + Password: PasswordFlag.Clone(), + RegistryToken: RegistryTokenFlag.Clone(), } } @@ -52,8 +49,8 @@ func (f *RegistryFlagGroup) Name() string { return "Registry" } -func (f *RegistryFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *RegistryFlagGroup) Flags() []Flagger { + return []Flagger{ f.Username, f.Password, f.RegistryToken, @@ -61,9 +58,13 @@ func (f *RegistryFlagGroup) Flags() []*Flag { } func (f *RegistryFlagGroup) ToOptions() (RegistryOptions, error) { + if err := parseFlags(f); err != nil { + return RegistryOptions{}, err + } + var credentials []types.Credential - users := getStringSlice(f.Username) - passwords := getStringSlice(f.Password) + users := f.Username.Value() + passwords := f.Password.Value() if len(users) != len(passwords) { return RegistryOptions{}, xerrors.New("the length of usernames and passwords must match") } @@ -76,6 +77,6 @@ func (f *RegistryFlagGroup) ToOptions() (RegistryOptions, error) { return RegistryOptions{ Credentials: credentials, - RegistryToken: getString(f.RegistryToken), + RegistryToken: f.RegistryToken.Value(), }, nil } diff --git a/pkg/flag/rego_flags.go b/pkg/flag/rego_flags.go index af615f7ed7bc..e0b21f73030b 100644 --- a/pkg/flag/rego_flags.go +++ b/pkg/flag/rego_flags.go @@ -7,40 +7,35 @@ package flag // config-policy: "custom-policy/policy" // policy-namespaces: "user" var ( - SkipPolicyUpdateFlag = Flag{ + SkipPolicyUpdateFlag = Flag[bool]{ Name: "skip-policy-update", ConfigName: "rego.skip-policy-update", - Default: false, Usage: "skip fetching rego policy updates", } - TraceFlag = Flag{ + TraceFlag = Flag[bool]{ Name: "trace", ConfigName: "rego.trace", - Default: false, Usage: "enable more verbose trace output for custom queries", } - ConfigPolicyFlag = Flag{ + ConfigPolicyFlag = Flag[[]string]{ Name: "config-policy", ConfigName: "rego.policy", - Default: []string{}, Usage: "specify the paths to the Rego policy files or to the directories containing them, applying config files", Aliases: []Alias{ {Name: "policy"}, }, } - ConfigDataFlag = Flag{ + ConfigDataFlag = Flag[[]string]{ Name: "config-data", ConfigName: "rego.data", - Default: []string{}, Usage: "specify paths from which data for the Rego policies will be recursively loaded", Aliases: []Alias{ {Name: "data"}, }, } - PolicyNamespaceFlag = Flag{ + PolicyNamespaceFlag = Flag[[]string]{ Name: "policy-namespaces", ConfigName: "rego.namespaces", - Default: []string{}, Usage: "Rego namespaces", Aliases: []Alias{ {Name: "namespaces"}, @@ -50,11 +45,11 @@ var ( // RegoFlagGroup composes common printer flag structs used for commands providing misconfinguration scanning. type RegoFlagGroup struct { - SkipPolicyUpdate *Flag - Trace *Flag - PolicyPaths *Flag - DataPaths *Flag - PolicyNamespaces *Flag + SkipPolicyUpdate *Flag[bool] + Trace *Flag[bool] + PolicyPaths *Flag[[]string] + DataPaths *Flag[[]string] + PolicyNamespaces *Flag[[]string] } type RegoOptions struct { @@ -67,11 +62,11 @@ type RegoOptions struct { func NewRegoFlagGroup() *RegoFlagGroup { return &RegoFlagGroup{ - SkipPolicyUpdate: &SkipPolicyUpdateFlag, - Trace: &TraceFlag, - PolicyPaths: &ConfigPolicyFlag, - DataPaths: &ConfigDataFlag, - PolicyNamespaces: &PolicyNamespaceFlag, + SkipPolicyUpdate: SkipPolicyUpdateFlag.Clone(), + Trace: TraceFlag.Clone(), + PolicyPaths: ConfigPolicyFlag.Clone(), + DataPaths: ConfigDataFlag.Clone(), + PolicyNamespaces: PolicyNamespaceFlag.Clone(), } } @@ -79,8 +74,8 @@ func (f *RegoFlagGroup) Name() string { return "Rego" } -func (f *RegoFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *RegoFlagGroup) Flags() []Flagger { + return []Flagger{ f.SkipPolicyUpdate, f.Trace, f.PolicyPaths, @@ -90,11 +85,15 @@ func (f *RegoFlagGroup) Flags() []*Flag { } func (f *RegoFlagGroup) ToOptions() (RegoOptions, error) { + if err := parseFlags(f); err != nil { + return RegoOptions{}, err + } + return RegoOptions{ - SkipPolicyUpdate: getBool(f.SkipPolicyUpdate), - Trace: getBool(f.Trace), - PolicyPaths: getStringSlice(f.PolicyPaths), - DataPaths: getStringSlice(f.DataPaths), - PolicyNamespaces: getStringSlice(f.PolicyNamespaces), + SkipPolicyUpdate: f.SkipPolicyUpdate.Value(), + Trace: f.Trace.Value(), + PolicyPaths: f.PolicyPaths.Value(), + DataPaths: f.DataPaths.Value(), + PolicyNamespaces: f.PolicyNamespaces.Value(), }, nil } diff --git a/pkg/flag/remote_flags.go b/pkg/flag/remote_flags.go index c06e9605087c..9277f2db908f 100644 --- a/pkg/flag/remote_flags.go +++ b/pkg/flag/remote_flags.go @@ -12,31 +12,28 @@ const ( ) var ( - ServerTokenFlag = Flag{ + ServerTokenFlag = Flag[string]{ Name: "token", ConfigName: "server.token", - Default: "", Usage: "for authentication in client/server mode", } - ServerTokenHeaderFlag = Flag{ + ServerTokenHeaderFlag = Flag[string]{ Name: "token-header", ConfigName: "server.token-header", Default: DefaultTokenHeader, Usage: "specify a header name for token in client/server mode", } - ServerAddrFlag = Flag{ + ServerAddrFlag = Flag[string]{ Name: "server", ConfigName: "server.addr", - Default: "", Usage: "server address in client mode", } - ServerCustomHeadersFlag = Flag{ + ServerCustomHeadersFlag = Flag[[]string]{ Name: "custom-headers", ConfigName: "server.custom-headers", - Default: []string{}, Usage: "custom headers in client mode", } - ServerListenFlag = Flag{ + ServerListenFlag = Flag[string]{ Name: "listen", ConfigName: "server.listen", Default: "localhost:4954", @@ -48,15 +45,15 @@ var ( // used for commands requiring reporting logic. type RemoteFlagGroup struct { // for client/server - Token *Flag - TokenHeader *Flag + Token *Flag[string] + TokenHeader *Flag[string] // for client - ServerAddr *Flag - CustomHeaders *Flag + ServerAddr *Flag[string] + CustomHeaders *Flag[[]string] // for server - Listen *Flag + Listen *Flag[string] } type RemoteOptions struct { @@ -70,10 +67,10 @@ type RemoteOptions struct { func NewClientFlags() *RemoteFlagGroup { return &RemoteFlagGroup{ - Token: &ServerTokenFlag, - TokenHeader: &ServerTokenHeaderFlag, - ServerAddr: &ServerAddrFlag, - CustomHeaders: &ServerCustomHeadersFlag, + Token: ServerTokenFlag.Clone(), + TokenHeader: ServerTokenHeaderFlag.Clone(), + ServerAddr: ServerAddrFlag.Clone(), + CustomHeaders: ServerCustomHeadersFlag.Clone(), } } @@ -89,16 +86,26 @@ func (f *RemoteFlagGroup) Name() string { return "Client/Server" } -func (f *RemoteFlagGroup) Flags() []*Flag { - return []*Flag{f.Token, f.TokenHeader, f.ServerAddr, f.CustomHeaders, f.Listen} +func (f *RemoteFlagGroup) Flags() []Flagger { + return []Flagger{ + f.Token, + f.TokenHeader, + f.ServerAddr, + f.CustomHeaders, + f.Listen, + } } -func (f *RemoteFlagGroup) ToOptions() RemoteOptions { - serverAddr := getString(f.ServerAddr) - customHeaders := splitCustomHeaders(getStringSlice(f.CustomHeaders)) - listen := getString(f.Listen) - token := getString(f.Token) - tokenHeader := getString(f.TokenHeader) +func (f *RemoteFlagGroup) ToOptions() (RemoteOptions, error) { + if err := parseFlags(f); err != nil { + return RemoteOptions{}, err + } + + serverAddr := f.ServerAddr.Value() + customHeaders := splitCustomHeaders(f.CustomHeaders.Value()) + listen := f.Listen.Value() + token := f.Token.Value() + tokenHeader := f.TokenHeader.Value() if serverAddr == "" && listen == "" { switch { @@ -125,7 +132,7 @@ func (f *RemoteFlagGroup) ToOptions() RemoteOptions { ServerAddr: serverAddr, CustomHeaders: customHeaders, Listen: listen, - } + }, nil } func splitCustomHeaders(headers []string) http.Header { diff --git a/pkg/flag/remote_flags_test.go b/pkg/flag/remote_flags_test.go index 5c2eea6b762d..4500b0bb5ca9 100644 --- a/pkg/flag/remote_flags_test.go +++ b/pkg/flag/remote_flags_test.go @@ -1,6 +1,7 @@ package flag_test import ( + "github.com/stretchr/testify/require" "net/http" "testing" @@ -108,12 +109,13 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) { // Assert options f := &flag.RemoteFlagGroup{ - ServerAddr: &flag.ServerAddrFlag, - CustomHeaders: &flag.ServerCustomHeadersFlag, - Token: &flag.ServerTokenFlag, - TokenHeader: &flag.ServerTokenHeaderFlag, + ServerAddr: flag.ServerAddrFlag.Clone(), + CustomHeaders: flag.ServerCustomHeadersFlag.Clone(), + Token: flag.ServerTokenFlag.Clone(), + TokenHeader: flag.ServerTokenHeaderFlag.Clone(), } - got := f.ToOptions() + got, err := f.ToOptions() + require.NoError(t, err) assert.Equalf(t, tt.want, got, "ToOptions()") // Assert log messages diff --git a/pkg/flag/repo.go b/pkg/flag/repo_flags.go similarity index 51% rename from pkg/flag/repo.go rename to pkg/flag/repo_flags.go index 6d59281ba4fb..31ac0b634a3e 100644 --- a/pkg/flag/repo.go +++ b/pkg/flag/repo_flags.go @@ -1,30 +1,27 @@ package flag var ( - FetchBranchFlag = Flag{ + FetchBranchFlag = Flag[string]{ Name: "branch", ConfigName: "repository.branch", - Default: "", Usage: "pass the branch name to be scanned", } - FetchCommitFlag = Flag{ + FetchCommitFlag = Flag[string]{ Name: "commit", ConfigName: "repository.commit", - Default: "", Usage: "pass the commit hash to be scanned", } - FetchTagFlag = Flag{ + FetchTagFlag = Flag[string]{ Name: "tag", ConfigName: "repository.tag", - Default: "", Usage: "pass the tag name to be scanned", } ) type RepoFlagGroup struct { - Branch *Flag - Commit *Flag - Tag *Flag + Branch *Flag[string] + Commit *Flag[string] + Tag *Flag[string] } type RepoOptions struct { @@ -35,9 +32,9 @@ type RepoOptions struct { func NewRepoFlagGroup() *RepoFlagGroup { return &RepoFlagGroup{ - Branch: &FetchBranchFlag, - Commit: &FetchCommitFlag, - Tag: &FetchTagFlag, + Branch: FetchBranchFlag.Clone(), + Commit: FetchCommitFlag.Clone(), + Tag: FetchTagFlag.Clone(), } } @@ -45,14 +42,22 @@ func (f *RepoFlagGroup) Name() string { return "Repository" } -func (f *RepoFlagGroup) Flags() []*Flag { - return []*Flag{f.Branch, f.Commit, f.Tag} +func (f *RepoFlagGroup) Flags() []Flagger { + return []Flagger{ + f.Branch, + f.Commit, + f.Tag, + } } -func (f *RepoFlagGroup) ToOptions() RepoOptions { - return RepoOptions{ - RepoBranch: getString(f.Branch), - RepoCommit: getString(f.Commit), - RepoTag: getString(f.Tag), +func (f *RepoFlagGroup) ToOptions() (RepoOptions, error) { + if err := parseFlags(f); err != nil { + return RepoOptions{}, err } + + return RepoOptions{ + RepoBranch: f.Branch.Value(), + RepoCommit: f.Commit.Value(), + RepoTag: f.Tag.Value(), + }, nil } diff --git a/pkg/flag/report_flags.go b/pkg/flag/report_flags.go index 54554be3d126..5ab4788e71cc 100644 --- a/pkg/flag/report_flags.go +++ b/pkg/flag/report_flags.go @@ -22,7 +22,7 @@ import ( // dependency-tree: true // severity: HIGH,CRITICAL var ( - FormatFlag = Flag{ + FormatFlag = Flag[string]{ Name: "format", ConfigName: "format", Shorthand: "f", @@ -30,70 +30,65 @@ var ( Values: xstrings.ToStringSlice(types.SupportedFormats), Usage: "format", } - ReportFormatFlag = Flag{ + ReportFormatFlag = Flag[string]{ Name: "report", ConfigName: "report", Default: "all", - Values: []string{"all", "summary"}, - Usage: "specify a report format for the output", + Values: []string{ + "all", + "summary", + }, + Usage: "specify a report format for the output", } - TemplateFlag = Flag{ + TemplateFlag = Flag[string]{ Name: "template", ConfigName: "template", Shorthand: "t", - Default: "", Usage: "output template", } - DependencyTreeFlag = Flag{ + DependencyTreeFlag = Flag[bool]{ Name: "dependency-tree", ConfigName: "dependency-tree", - Default: false, Usage: "[EXPERIMENTAL] show dependency origin tree of vulnerable packages", } - ListAllPkgsFlag = Flag{ + ListAllPkgsFlag = Flag[bool]{ Name: "list-all-pkgs", ConfigName: "list-all-pkgs", - Default: false, Usage: "enabling the option will output all packages regardless of vulnerability", } - IgnoreFileFlag = Flag{ + IgnoreFileFlag = Flag[string]{ Name: "ignorefile", ConfigName: "ignorefile", Default: result.DefaultIgnoreFile, Usage: "specify .trivyignore file", } - IgnorePolicyFlag = Flag{ + IgnorePolicyFlag = Flag[string]{ Name: "ignore-policy", ConfigName: "ignore-policy", - Default: "", Usage: "specify the Rego file path to evaluate each vulnerability", } - ExitCodeFlag = Flag{ + ExitCodeFlag = Flag[int]{ Name: "exit-code", ConfigName: "exit-code", - Default: 0, Usage: "specify exit code when any security issues are found", } - ExitOnEOLFlag = Flag{ + ExitOnEOLFlag = Flag[int]{ Name: "exit-on-eol", ConfigName: "exit-on-eol", - Default: 0, Usage: "exit with the specified code when the OS reaches end of service/life", } - OutputFlag = Flag{ + OutputFlag = Flag[string]{ Name: "output", ConfigName: "output", Shorthand: "o", - Default: "", Usage: "output file name", } - OutputPluginArgFlag = Flag{ + OutputPluginArgFlag = Flag[string]{ Name: "output-plugin-arg", ConfigName: "output-plugin-arg", - Default: "", Usage: "[EXPERIMENTAL] output plugin arguments", } - SeverityFlag = Flag{ + SeverityFlag = Flag[[]string]{ Name: "severity", ConfigName: "severity", Shorthand: "s", @@ -101,10 +96,9 @@ var ( Values: dbTypes.SeverityNames, Usage: "severities of security issues to be displayed", } - ComplianceFlag = Flag{ + ComplianceFlag = Flag[string]{ Name: "compliance", ConfigName: "scan.compliance", - Default: "", Usage: "compliance report to generate", } ) @@ -112,19 +106,19 @@ var ( // ReportFlagGroup composes common printer flag structs // used for commands requiring reporting logic. type ReportFlagGroup struct { - Format *Flag - ReportFormat *Flag - Template *Flag - DependencyTree *Flag - ListAllPkgs *Flag - IgnoreFile *Flag - IgnorePolicy *Flag - ExitCode *Flag - ExitOnEOL *Flag - Output *Flag - OutputPluginArg *Flag - Severity *Flag - Compliance *Flag + Format *Flag[string] + ReportFormat *Flag[string] + Template *Flag[string] + DependencyTree *Flag[bool] + ListAllPkgs *Flag[bool] + IgnoreFile *Flag[string] + IgnorePolicy *Flag[string] + ExitCode *Flag[int] + ExitOnEOL *Flag[int] + Output *Flag[string] + OutputPluginArg *Flag[string] + Severity *Flag[[]string] + Compliance *Flag[string] } type ReportOptions struct { @@ -145,19 +139,19 @@ type ReportOptions struct { func NewReportFlagGroup() *ReportFlagGroup { return &ReportFlagGroup{ - Format: &FormatFlag, - ReportFormat: &ReportFormatFlag, - Template: &TemplateFlag, - DependencyTree: &DependencyTreeFlag, - ListAllPkgs: &ListAllPkgsFlag, - IgnoreFile: &IgnoreFileFlag, - IgnorePolicy: &IgnorePolicyFlag, - ExitCode: &ExitCodeFlag, - ExitOnEOL: &ExitOnEOLFlag, - Output: &OutputFlag, - OutputPluginArg: &OutputPluginArgFlag, - Severity: &SeverityFlag, - Compliance: &ComplianceFlag, + Format: FormatFlag.Clone(), + ReportFormat: ReportFormatFlag.Clone(), + Template: TemplateFlag.Clone(), + DependencyTree: DependencyTreeFlag.Clone(), + ListAllPkgs: ListAllPkgsFlag.Clone(), + IgnoreFile: IgnoreFileFlag.Clone(), + IgnorePolicy: IgnorePolicyFlag.Clone(), + ExitCode: ExitCodeFlag.Clone(), + ExitOnEOL: ExitOnEOLFlag.Clone(), + Output: OutputFlag.Clone(), + OutputPluginArg: OutputPluginArgFlag.Clone(), + Severity: SeverityFlag.Clone(), + Compliance: ComplianceFlag.Clone(), } } @@ -165,8 +159,8 @@ func (f *ReportFlagGroup) Name() string { return "Report" } -func (f *ReportFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *ReportFlagGroup) Flags() []Flagger { + return []Flagger{ f.Format, f.ReportFormat, f.Template, @@ -184,10 +178,14 @@ func (f *ReportFlagGroup) Flags() []*Flag { } func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) { - format := getUnderlyingString[types.Format](f.Format) - template := getString(f.Template) - dependencyTree := getBool(f.DependencyTree) - listAllPkgs := getBool(f.ListAllPkgs) + if err := parseFlags(f); err != nil { + return ReportOptions{}, err + } + + format := types.Format(f.Format.Value()) + template := f.Template.Value() + dependencyTree := f.DependencyTree.Value() + listAllPkgs := f.ListAllPkgs.Value() if template != "" { if format == "" { @@ -222,13 +220,13 @@ func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) { listAllPkgs = true } - cs, err := loadComplianceTypes(getString(f.Compliance)) + cs, err := loadComplianceTypes(f.Compliance.Value()) if err != nil { return ReportOptions{}, xerrors.Errorf("unable to load compliance spec: %w", err) } var outputPluginArgs []string - if arg := getString(f.OutputPluginArg); arg != "" { + if arg := f.OutputPluginArg.Value(); arg != "" { outputPluginArgs, err = shellwords.Parse(arg) if err != nil { return ReportOptions{}, xerrors.Errorf("unable to parse output plugin argument: %w", err) @@ -237,17 +235,17 @@ func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) { return ReportOptions{ Format: format, - ReportFormat: getString(f.ReportFormat), + ReportFormat: f.ReportFormat.Value(), Template: template, DependencyTree: dependencyTree, ListAllPkgs: listAllPkgs, - IgnoreFile: getString(f.IgnoreFile), - ExitCode: getInt(f.ExitCode), - ExitOnEOL: getInt(f.ExitOnEOL), - IgnorePolicy: getString(f.IgnorePolicy), - Output: getString(f.Output), + IgnoreFile: f.IgnoreFile.Value(), + ExitCode: f.ExitCode.Value(), + ExitOnEOL: f.ExitOnEOL.Value(), + IgnorePolicy: f.IgnorePolicy.Value(), + Output: f.Output.Value(), OutputPluginArgs: outputPluginArgs, - Severities: toSeverity(getStringSlice(f.Severity)), + Severities: toSeverity(f.Severity.Value()), Compliance: cs, }, nil } diff --git a/pkg/flag/report_flags_test.go b/pkg/flag/report_flags_test.go index c2e92fc5fcbd..4207aa2747d1 100644 --- a/pkg/flag/report_flags_test.go +++ b/pkg/flag/report_flags_test.go @@ -185,6 +185,8 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Cleanup(viper.Reset) + level := zap.WarnLevel if tt.fields.debug { level = zap.DebugLevel @@ -192,34 +194,34 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { core, obs := observer.New(level) log.Logger = zap.New(core).Sugar() - viper.Set(flag.FormatFlag.ConfigName, string(tt.fields.format)) - viper.Set(flag.TemplateFlag.ConfigName, tt.fields.template) - viper.Set(flag.DependencyTreeFlag.ConfigName, tt.fields.dependencyTree) - viper.Set(flag.ListAllPkgsFlag.ConfigName, tt.fields.listAllPkgs) - viper.Set(flag.IgnoreFileFlag.ConfigName, tt.fields.ignoreFile) - viper.Set(flag.IgnoreUnfixedFlag.ConfigName, tt.fields.ignoreUnfixed) - viper.Set(flag.IgnorePolicyFlag.ConfigName, tt.fields.ignorePolicy) - viper.Set(flag.ExitCodeFlag.ConfigName, tt.fields.exitCode) - viper.Set(flag.ExitOnEOLFlag.ConfigName, tt.fields.exitOnEOSL) - viper.Set(flag.OutputFlag.ConfigName, tt.fields.output) - viper.Set(flag.OutputPluginArgFlag.ConfigName, tt.fields.outputPluginArgs) - viper.Set(flag.SeverityFlag.ConfigName, tt.fields.severities) - viper.Set(flag.ComplianceFlag.ConfigName, tt.fields.compliance) + setValue(flag.FormatFlag.ConfigName, string(tt.fields.format)) + setValue(flag.TemplateFlag.ConfigName, tt.fields.template) + setValue(flag.DependencyTreeFlag.ConfigName, tt.fields.dependencyTree) + setValue(flag.ListAllPkgsFlag.ConfigName, tt.fields.listAllPkgs) + setValue(flag.IgnoreFileFlag.ConfigName, tt.fields.ignoreFile) + setValue(flag.IgnoreUnfixedFlag.ConfigName, tt.fields.ignoreUnfixed) + setValue(flag.IgnorePolicyFlag.ConfigName, tt.fields.ignorePolicy) + setValue(flag.ExitCodeFlag.ConfigName, tt.fields.exitCode) + setValue(flag.ExitOnEOLFlag.ConfigName, tt.fields.exitOnEOSL) + setValue(flag.OutputFlag.ConfigName, tt.fields.output) + setValue(flag.OutputPluginArgFlag.ConfigName, tt.fields.outputPluginArgs) + setValue(flag.SeverityFlag.ConfigName, tt.fields.severities) + setValue(flag.ComplianceFlag.ConfigName, tt.fields.compliance) // Assert options f := &flag.ReportFlagGroup{ - Format: &flag.FormatFlag, - Template: &flag.TemplateFlag, - DependencyTree: &flag.DependencyTreeFlag, - ListAllPkgs: &flag.ListAllPkgsFlag, - IgnoreFile: &flag.IgnoreFileFlag, - IgnorePolicy: &flag.IgnorePolicyFlag, - ExitCode: &flag.ExitCodeFlag, - ExitOnEOL: &flag.ExitOnEOLFlag, - Output: &flag.OutputFlag, - OutputPluginArg: &flag.OutputPluginArgFlag, - Severity: &flag.SeverityFlag, - Compliance: &flag.ComplianceFlag, + Format: flag.FormatFlag.Clone(), + Template: flag.TemplateFlag.Clone(), + DependencyTree: flag.DependencyTreeFlag.Clone(), + ListAllPkgs: flag.ListAllPkgsFlag.Clone(), + IgnoreFile: flag.IgnoreFileFlag.Clone(), + IgnorePolicy: flag.IgnorePolicyFlag.Clone(), + ExitCode: flag.ExitCodeFlag.Clone(), + ExitOnEOL: flag.ExitOnEOLFlag.Clone(), + Output: flag.OutputFlag.Clone(), + OutputPluginArg: flag.OutputPluginArgFlag.Clone(), + Severity: flag.SeverityFlag.Clone(), + Compliance: flag.ComplianceFlag.Clone(), } got, err := f.ToOptions() diff --git a/pkg/flag/sbom_flags.go b/pkg/flag/sbom_flags.go index 5d8c515c1d74..f5ab1aff3189 100644 --- a/pkg/flag/sbom_flags.go +++ b/pkg/flag/sbom_flags.go @@ -7,25 +7,23 @@ import ( ) var ( - ArtifactTypeFlag = Flag{ + ArtifactTypeFlag = Flag[string]{ Name: "artifact-type", ConfigName: "sbom.artifact-type", - Default: "", Usage: "deprecated", Deprecated: true, } - SBOMFormatFlag = Flag{ + SBOMFormatFlag = Flag[string]{ Name: "sbom-format", ConfigName: "sbom.format", - Default: "", Usage: "deprecated", Deprecated: true, } ) type SBOMFlagGroup struct { - ArtifactType *Flag // deprecated - SBOMFormat *Flag // deprecated + ArtifactType *Flag[string] // deprecated + SBOMFormat *Flag[string] // deprecated } type SBOMOptions struct { @@ -33,8 +31,8 @@ type SBOMOptions struct { func NewSBOMFlagGroup() *SBOMFlagGroup { return &SBOMFlagGroup{ - ArtifactType: &ArtifactTypeFlag, - SBOMFormat: &SBOMFormatFlag, + ArtifactType: ArtifactTypeFlag.Clone(), + SBOMFormat: SBOMFormatFlag.Clone(), } } @@ -42,16 +40,20 @@ func (f *SBOMFlagGroup) Name() string { return "SBOM" } -func (f *SBOMFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *SBOMFlagGroup) Flags() []Flagger { + return []Flagger{ f.ArtifactType, f.SBOMFormat, } } func (f *SBOMFlagGroup) ToOptions() (SBOMOptions, error) { - artifactType := getString(f.ArtifactType) - sbomFormat := getString(f.SBOMFormat) + if err := parseFlags(f); err != nil { + return SBOMOptions{}, err + } + + artifactType := f.ArtifactType.Value() + sbomFormat := f.SBOMFormat.Value() if artifactType != "" || sbomFormat != "" { log.Logger.Error("'trivy sbom' is now for scanning SBOM. " + diff --git a/pkg/flag/scan_flags.go b/pkg/flag/scan_flags.go index 0464436a4f11..aba3961b0243 100644 --- a/pkg/flag/scan_flags.go +++ b/pkg/flag/scan_flags.go @@ -3,31 +3,31 @@ package flag import ( "runtime" + "github.com/samber/lo" + "github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/types" xstrings "github.com/aquasecurity/trivy/pkg/x/strings" ) var ( - SkipDirsFlag = Flag{ + SkipDirsFlag = Flag[[]string]{ Name: "skip-dirs", ConfigName: "scan.skip-dirs", - Default: []string{}, Usage: "specify the directories or glob patterns to skip", } - SkipFilesFlag = Flag{ + SkipFilesFlag = Flag[[]string]{ Name: "skip-files", ConfigName: "scan.skip-files", Default: []string{}, Usage: "specify the files or glob patterns to skip", } - OfflineScanFlag = Flag{ + OfflineScanFlag = Flag[bool]{ Name: "offline-scan", ConfigName: "scan.offline", - Default: false, Usage: "do not issue API requests to identify dependencies", } - ScannersFlag = Flag{ + ScannersFlag = Flag[[]string]{ Name: "scanners", ConfigName: "scan.scanners", Default: xstrings.ToStringSlice(types.Scanners{ @@ -40,17 +40,19 @@ var ( types.SecretScanner, types.LicenseScanner, }), - ValueNormalize: func(s string) string { - switch s { - case "vulnerability": - return string(types.VulnerabilityScanner) - case "misconf", "misconfiguration": - return string(types.MisconfigScanner) - case "config": - log.Logger.Warn("'--scanner config' is deprecated. Use '--scanner misconfig' instead. See https://github.com/aquasecurity/trivy/discussions/5586 for the detail.") - return string(types.MisconfigScanner) - } - return s + ValueNormalize: func(ss []string) []string { + return lo.Map(ss, func(s string, _ int) string { + switch s { + case "vulnerability": + return string(types.VulnerabilityScanner) + case "misconf", "misconfiguration": + return string(types.MisconfigScanner) + case "config": + log.Logger.Warn("'--scanners config' is deprecated. Use '--scanners misconfig' instead. See https://github.com/aquasecurity/trivy/discussions/5586 for the detail.") + return string(types.MisconfigScanner) + } + return s + }) }, Aliases: []Alias{ { @@ -61,57 +63,57 @@ var ( }, Usage: "comma-separated list of what security issues to detect", } - FilePatternsFlag = Flag{ + FilePatternsFlag = Flag[[]string]{ Name: "file-patterns", ConfigName: "scan.file-patterns", - Default: []string{}, Usage: "specify config file patterns", } - SlowFlag = Flag{ + SlowFlag = Flag[bool]{ Name: "slow", ConfigName: "scan.slow", Default: false, Usage: "scan over time with lower CPU and memory utilization", Deprecated: true, } - ParallelFlag = Flag{ + ParallelFlag = Flag[int]{ Name: "parallel", ConfigName: "scan.parallel", Default: 5, Usage: "number of goroutines enabled for parallel scanning, set 0 to auto-detect parallelism", } - SBOMSourcesFlag = Flag{ + SBOMSourcesFlag = Flag[[]string]{ Name: "sbom-sources", ConfigName: "scan.sbom-sources", - Default: []string{}, - Values: []string{"oci", "rekor"}, - Usage: "[EXPERIMENTAL] try to retrieve SBOM from the specified sources", + Values: []string{ + "oci", + "rekor", + }, + Usage: "[EXPERIMENTAL] try to retrieve SBOM from the specified sources", } - RekorURLFlag = Flag{ + RekorURLFlag = Flag[string]{ Name: "rekor-url", ConfigName: "scan.rekor-url", Default: "https://rekor.sigstore.dev", Usage: "[EXPERIMENTAL] address of rekor STL server", } - IncludeDevDepsFlag = Flag{ + IncludeDevDepsFlag = Flag[bool]{ Name: "include-dev-deps", ConfigName: "include-dev-deps", - Default: false, Usage: "include development dependencies in the report (supported: npm, yarn)", } ) type ScanFlagGroup struct { - SkipDirs *Flag - SkipFiles *Flag - OfflineScan *Flag - Scanners *Flag - FilePatterns *Flag - Slow *Flag // deprecated - Parallel *Flag - SBOMSources *Flag - RekorURL *Flag - IncludeDevDeps *Flag + SkipDirs *Flag[[]string] + SkipFiles *Flag[[]string] + OfflineScan *Flag[bool] + Scanners *Flag[[]string] + FilePatterns *Flag[[]string] + Slow *Flag[bool] // deprecated + Parallel *Flag[int] + SBOMSources *Flag[[]string] + RekorURL *Flag[string] + IncludeDevDeps *Flag[bool] } type ScanOptions struct { @@ -129,16 +131,16 @@ type ScanOptions struct { func NewScanFlagGroup() *ScanFlagGroup { return &ScanFlagGroup{ - SkipDirs: &SkipDirsFlag, - SkipFiles: &SkipFilesFlag, - OfflineScan: &OfflineScanFlag, - Scanners: &ScannersFlag, - FilePatterns: &FilePatternsFlag, - Parallel: &ParallelFlag, - SBOMSources: &SBOMSourcesFlag, - RekorURL: &RekorURLFlag, - IncludeDevDeps: &IncludeDevDepsFlag, - Slow: &SlowFlag, + SkipDirs: SkipDirsFlag.Clone(), + SkipFiles: SkipFilesFlag.Clone(), + OfflineScan: OfflineScanFlag.Clone(), + Scanners: ScannersFlag.Clone(), + FilePatterns: FilePatternsFlag.Clone(), + Parallel: ParallelFlag.Clone(), + SBOMSources: SBOMSourcesFlag.Clone(), + RekorURL: RekorURLFlag.Clone(), + IncludeDevDeps: IncludeDevDepsFlag.Clone(), + Slow: SlowFlag.Clone(), } } @@ -146,8 +148,8 @@ func (f *ScanFlagGroup) Name() string { return "Scan" } -func (f *ScanFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *ScanFlagGroup) Flags() []Flagger { + return []Flagger{ f.SkipDirs, f.SkipFiles, f.OfflineScan, @@ -162,12 +164,16 @@ func (f *ScanFlagGroup) Flags() []*Flag { } func (f *ScanFlagGroup) ToOptions(args []string) (ScanOptions, error) { + if err := parseFlags(f); err != nil { + return ScanOptions{}, err + } + var target string if len(args) == 1 { target = args[0] } - parallel := getInt(f.Parallel) + parallel := f.Parallel.Value() if f.Parallel != nil && parallel == 0 { log.Logger.Infof("Set '--parallel' to the number of CPUs (%d)", runtime.NumCPU()) parallel = runtime.NumCPU() @@ -175,14 +181,14 @@ func (f *ScanFlagGroup) ToOptions(args []string) (ScanOptions, error) { return ScanOptions{ Target: target, - SkipDirs: getStringSlice(f.SkipDirs), - SkipFiles: getStringSlice(f.SkipFiles), - OfflineScan: getBool(f.OfflineScan), - Scanners: getUnderlyingStringSlice[types.Scanner](f.Scanners), - FilePatterns: getStringSlice(f.FilePatterns), + SkipDirs: f.SkipDirs.Value(), + SkipFiles: f.SkipFiles.Value(), + OfflineScan: f.OfflineScan.Value(), + Scanners: xstrings.ToTSlice[types.Scanner](f.Scanners.Value()), + FilePatterns: f.FilePatterns.Value(), Parallel: parallel, - SBOMSources: getStringSlice(f.SBOMSources), - RekorURL: getString(f.RekorURL), - IncludeDevDeps: getBool(f.IncludeDevDeps), + SBOMSources: f.SBOMSources.Value(), + RekorURL: f.RekorURL.Value(), + IncludeDevDeps: f.IncludeDevDeps.Value(), }, nil } diff --git a/pkg/flag/scan_flags_test.go b/pkg/flag/scan_flags_test.go index 7c9d2ba42457..2d5cb718b0d1 100644 --- a/pkg/flag/scan_flags_test.go +++ b/pkg/flag/scan_flags_test.go @@ -1,9 +1,9 @@ package flag_test import ( + "github.com/spf13/viper" "testing" - "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -109,23 +109,23 @@ func TestScanFlagGroup_ToOptions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - viper.Set(flag.SkipDirsFlag.ConfigName, tt.fields.skipDirs) - viper.Set(flag.SkipFilesFlag.ConfigName, tt.fields.skipFiles) - viper.Set(flag.OfflineScanFlag.ConfigName, tt.fields.offlineScan) - viper.Set(flag.ScannersFlag.ConfigName, tt.fields.scanners) + t.Cleanup(viper.Reset) + setSliceValue(flag.SkipDirsFlag.ConfigName, tt.fields.skipDirs) + setSliceValue(flag.SkipFilesFlag.ConfigName, tt.fields.skipFiles) + setValue(flag.OfflineScanFlag.ConfigName, tt.fields.offlineScan) + setValue(flag.ScannersFlag.ConfigName, tt.fields.scanners) // Assert options f := &flag.ScanFlagGroup{ - SkipDirs: &flag.SkipDirsFlag, - SkipFiles: &flag.SkipFilesFlag, - OfflineScan: &flag.OfflineScanFlag, - Scanners: &flag.ScannersFlag, + SkipDirs: flag.SkipDirsFlag.Clone(), + SkipFiles: flag.SkipFilesFlag.Clone(), + OfflineScan: flag.OfflineScanFlag.Clone(), + Scanners: flag.ScannersFlag.Clone(), } got, err := f.ToOptions(tt.args) tt.assertion(t, err) assert.Equalf(t, tt.want, got, "ToOptions()") }) - } } diff --git a/pkg/flag/secret_flags.go b/pkg/flag/secret_flags.go index 0190480f813d..295753c48cc5 100644 --- a/pkg/flag/secret_flags.go +++ b/pkg/flag/secret_flags.go @@ -1,7 +1,7 @@ package flag var ( - SecretConfigFlag = Flag{ + SecretConfigFlag = Flag[string]{ Name: "secret-config", ConfigName: "secret.config", Default: "trivy-secret.yaml", @@ -10,7 +10,7 @@ var ( ) type SecretFlagGroup struct { - SecretConfig *Flag + SecretConfig *Flag[string] } type SecretOptions struct { @@ -19,7 +19,7 @@ type SecretOptions struct { func NewSecretFlagGroup() *SecretFlagGroup { return &SecretFlagGroup{ - SecretConfig: &SecretConfigFlag, + SecretConfig: SecretConfigFlag.Clone(), } } @@ -27,12 +27,16 @@ func (f *SecretFlagGroup) Name() string { return "Secret" } -func (f *SecretFlagGroup) Flags() []*Flag { - return []*Flag{f.SecretConfig} +func (f *SecretFlagGroup) Flags() []Flagger { + return []Flagger{f.SecretConfig} } -func (f *SecretFlagGroup) ToOptions() SecretOptions { - return SecretOptions{ - SecretConfigPath: getString(f.SecretConfig), +func (f *SecretFlagGroup) ToOptions() (SecretOptions, error) { + if err := parseFlags(f); err != nil { + return SecretOptions{}, err } + + return SecretOptions{ + SecretConfigPath: f.SecretConfig.Value(), + }, nil } diff --git a/pkg/flag/value.go b/pkg/flag/value.go deleted file mode 100644 index 45108fe7f556..000000000000 --- a/pkg/flag/value.go +++ /dev/null @@ -1,104 +0,0 @@ -package flag - -import ( - "strings" - - "github.com/samber/lo" - "golang.org/x/exp/slices" - "golang.org/x/xerrors" -) - -type ValueNormalizeFunc func(string) string - -// -- string Value -type customStringValue struct { - value *string - allowed []string - normalize ValueNormalizeFunc -} - -func newCustomStringValue(val string, allowed []string, fn ValueNormalizeFunc) *customStringValue { - return &customStringValue{ - value: &val, - allowed: allowed, - normalize: fn, - } -} - -func (s *customStringValue) Set(val string) error { - if s.normalize != nil { - val = s.normalize(val) - } - if len(s.allowed) > 0 && !slices.Contains(s.allowed, val) { - return xerrors.Errorf("must be one of %q", s.allowed) - } - s.value = &val - return nil -} -func (s *customStringValue) Type() string { - return "string" -} - -func (s *customStringValue) String() string { return *s.value } - -// -- stringSlice Value -type customStringSliceValue struct { - value *[]string - allowed []string - normalize ValueNormalizeFunc - changed bool -} - -func newCustomStringSliceValue(val, allowed []string, fn ValueNormalizeFunc) *customStringSliceValue { - return &customStringSliceValue{ - value: &val, - allowed: allowed, - normalize: fn, - } -} - -func (s *customStringSliceValue) Set(val string) error { - values := strings.Split(val, ",") - if s.normalize != nil { - values = lo.Map(values, func(item string, _ int) string { return s.normalize(item) }) - } - for _, v := range values { - if len(s.allowed) > 0 && !slices.Contains(s.allowed, v) { - return xerrors.Errorf("must be one of %q", s.allowed) - } - } - if !s.changed { - *s.value = values - } else { - *s.value = append(*s.value, values...) - } - s.changed = true - return nil -} - -func (s *customStringSliceValue) Type() string { - return "stringSlice" -} - -func (s *customStringSliceValue) String() string { - if len(*s.value) == 0 { - // "[]" is not recognized as a zero value - // cf. https://github.com/spf13/pflag/blob/d5e0c0615acee7028e1e2740a11102313be88de1/flag.go#L553-L565 - return "" - } - return "[" + strings.Join(*s.value, ",") + "]" -} - -func (s *customStringSliceValue) Append(val string) error { - s.changed = true - return s.Set(val) -} - -func (s *customStringSliceValue) Replace(val []string) error { - *s.value = val - return nil -} - -func (s *customStringSliceValue) GetSlice() []string { - return *s.value -} diff --git a/pkg/flag/vulnerability_flags.go b/pkg/flag/vulnerability_flags.go index bf476c38fff9..3989fbfa1c51 100644 --- a/pkg/flag/vulnerability_flags.go +++ b/pkg/flag/vulnerability_flags.go @@ -9,7 +9,7 @@ import ( ) var ( - VulnTypeFlag = Flag{ + VulnTypeFlag = Flag[[]string]{ Name: "vuln-type", ConfigName: "vulnerability.type", Default: []string{ @@ -22,20 +22,18 @@ var ( }, Usage: "comma-separated list of vulnerability types", } - IgnoreUnfixedFlag = Flag{ + IgnoreUnfixedFlag = Flag[bool]{ Name: "ignore-unfixed", ConfigName: "vulnerability.ignore-unfixed", - Default: false, Usage: "display only fixed vulnerabilities", } - IgnoreStatusFlag = Flag{ + IgnoreStatusFlag = Flag[[]string]{ Name: "ignore-status", ConfigName: "vulnerability.ignore-status", - Default: []string{}, Values: dbTypes.Statuses, Usage: "comma-separated list of vulnerability status to ignore", } - VEXFlag = Flag{ + VEXFlag = Flag[string]{ Name: "vex", ConfigName: "vulnerability.vex", Default: "", @@ -44,10 +42,10 @@ var ( ) type VulnerabilityFlagGroup struct { - VulnType *Flag - IgnoreUnfixed *Flag - IgnoreStatus *Flag - VEXPath *Flag + VulnType *Flag[[]string] + IgnoreUnfixed *Flag[bool] + IgnoreStatus *Flag[[]string] + VEXPath *Flag[string] } type VulnerabilityOptions struct { @@ -58,10 +56,10 @@ type VulnerabilityOptions struct { func NewVulnerabilityFlagGroup() *VulnerabilityFlagGroup { return &VulnerabilityFlagGroup{ - VulnType: &VulnTypeFlag, - IgnoreUnfixed: &IgnoreUnfixedFlag, - IgnoreStatus: &IgnoreStatusFlag, - VEXPath: &VEXFlag, + VulnType: VulnTypeFlag.Clone(), + IgnoreUnfixed: IgnoreUnfixedFlag.Clone(), + IgnoreStatus: IgnoreStatusFlag.Clone(), + VEXPath: VEXFlag.Clone(), } } @@ -69,8 +67,8 @@ func (f *VulnerabilityFlagGroup) Name() string { return "Vulnerability" } -func (f *VulnerabilityFlagGroup) Flags() []*Flag { - return []*Flag{ +func (f *VulnerabilityFlagGroup) Flags() []Flagger { + return []Flagger{ f.VulnType, f.IgnoreUnfixed, f.IgnoreStatus, @@ -78,12 +76,16 @@ func (f *VulnerabilityFlagGroup) Flags() []*Flag { } } -func (f *VulnerabilityFlagGroup) ToOptions() VulnerabilityOptions { +func (f *VulnerabilityFlagGroup) ToOptions() (VulnerabilityOptions, error) { + if err := parseFlags(f); err != nil { + return VulnerabilityOptions{}, err + } + // Just convert string to dbTypes.Status as the validated values are passed here. - ignoreStatuses := lo.Map(getStringSlice(f.IgnoreStatus), func(s string, _ int) dbTypes.Status { + ignoreStatuses := lo.Map(f.IgnoreStatus.Value(), func(s string, _ int) dbTypes.Status { return dbTypes.NewStatus(s) }) - ignoreUnfixed := getBool(f.IgnoreUnfixed) + ignoreUnfixed := f.IgnoreUnfixed.Value() switch { case ignoreUnfixed && len(ignoreStatuses) > 0: @@ -103,8 +105,8 @@ func (f *VulnerabilityFlagGroup) ToOptions() VulnerabilityOptions { log.Logger.Debugw("Ignore statuses", "statuses", ignoreStatuses) return VulnerabilityOptions{ - VulnType: getStringSlice(f.VulnType), + VulnType: f.VulnType.Value(), IgnoreStatuses: ignoreStatuses, - VEXPath: getString(f.VEXPath), - } + VEXPath: f.VEXPath.Value(), + }, nil } diff --git a/pkg/flag/vulnerability_flags_test.go b/pkg/flag/vulnerability_flags_test.go index a6055ab2a1ea..02ee3c8d9605 100644 --- a/pkg/flag/vulnerability_flags_test.go +++ b/pkg/flag/vulnerability_flags_test.go @@ -1,6 +1,7 @@ package flag_test import ( + "github.com/stretchr/testify/require" "testing" "github.com/spf13/viper" @@ -57,10 +58,11 @@ func TestVulnerabilityFlagGroup_ToOptions(t *testing.T) { // Assert options f := &flag.VulnerabilityFlagGroup{ - VulnType: &flag.VulnTypeFlag, + VulnType: flag.VulnTypeFlag.Clone(), } - got := f.ToOptions() + got, err := f.ToOptions() + require.NoError(t, err) assert.Equalf(t, tt.want, got, "ToOptions()") // Assert log messages diff --git a/pkg/sbom/cyclonedx/core/cyclonedx.go b/pkg/sbom/cyclonedx/core/cyclonedx.go index 023e1c18d7b1..fc326145d5c6 100644 --- a/pkg/sbom/cyclonedx/core/cyclonedx.go +++ b/pkg/sbom/cyclonedx/core/cyclonedx.go @@ -230,7 +230,7 @@ func (c *CycloneDX) Vulnerabilities(uniq map[string]*cdx.Vulnerability) *[]cdx.V return *value }) sort.Slice(vulns, func(i, j int) bool { - return vulns[i].BOMRef < vulns[j].BOMRef + return vulns[i].ID < vulns[j].ID }) return &vulns } diff --git a/pkg/x/strings/strings.go b/pkg/x/strings/strings.go index ce534b18d68a..78c87231605a 100644 --- a/pkg/x/strings/strings.go +++ b/pkg/x/strings/strings.go @@ -7,12 +7,18 @@ type String interface { } func ToStringSlice[T String](ss []T) []string { + if ss == nil { + return nil + } return lo.Map(ss, func(s T, _ int) string { return string(s) }) } func ToTSlice[T String](ss []string) []T { + if ss == nil { + return nil + } return lo.Map(ss, func(s string, _ int) T { return T(s) })