From 917f388852b39a0d31da4a17a73c7302b3dc0d6f Mon Sep 17 00:00:00 2001 From: DmitriyLewen <91113035+DmitriyLewen@users.noreply.github.com> Date: Tue, 16 Aug 2022 12:57:46 +0600 Subject: [PATCH] fix(flag): add error when there are no supported security checks (#2713) --- pkg/flag/options.go | 5 ++- pkg/flag/scan_flags.go | 23 +++++++------ pkg/flag/scan_flags_test.go | 65 ++++++++++++++++--------------------- 3 files changed, 45 insertions(+), 48 deletions(-) diff --git a/pkg/flag/options.go b/pkg/flag/options.go index a05ca131cd15..35caeb2de46b 100644 --- a/pkg/flag/options.go +++ b/pkg/flag/options.go @@ -361,7 +361,10 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF } if f.ScanFlagGroup != nil { - opts.ScanOptions = f.ScanFlagGroup.ToOptions(args) + opts.ScanOptions, err = f.ScanFlagGroup.ToOptions(args) + if err != nil { + return Options{}, xerrors.Errorf("scan flag error: %w", err) + } } if f.SecretFlagGroup != nil { diff --git a/pkg/flag/scan_flags.go b/pkg/flag/scan_flags.go index 3d59f425c52d..d9fa5c152f45 100644 --- a/pkg/flag/scan_flags.go +++ b/pkg/flag/scan_flags.go @@ -5,8 +5,8 @@ import ( "strings" "golang.org/x/exp/slices" + "golang.org/x/xerrors" - "github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/types" ) @@ -69,25 +69,29 @@ func (f *ScanFlagGroup) Flags() []*Flag { return []*Flag{f.SkipDirs, f.SkipFiles, f.OfflineScan, f.SecurityChecks} } -func (f *ScanFlagGroup) ToOptions(args []string) ScanOptions { +func (f *ScanFlagGroup) ToOptions(args []string) (ScanOptions, error) { var target string if len(args) == 1 { target = args[0] } + securityChecks, err := parseSecurityCheck(getStringSlice(f.SecurityChecks)) + if err != nil { + return ScanOptions{}, xerrors.Errorf("unable to parse security checks: %w", err) + } return ScanOptions{ Target: target, SkipDirs: getStringSlice(f.SkipDirs), SkipFiles: getStringSlice(f.SkipFiles), OfflineScan: getBool(f.OfflineScan), - SecurityChecks: parseSecurityCheck(getStringSlice(f.SecurityChecks)), - } + SecurityChecks: securityChecks, + }, nil } -func parseSecurityCheck(securityCheck []string) []string { +func parseSecurityCheck(securityCheck []string) ([]string, error) { switch { - case len(securityCheck) == 0: // no checks - return nil + case len(securityCheck) == 0: // no checks. Can be empty when generating SBOM + return nil, nil case len(securityCheck) == 1 && strings.Contains(securityCheck[0], ","): // get checks from flag securityCheck = strings.Split(securityCheck[0], ",") } @@ -95,10 +99,9 @@ func parseSecurityCheck(securityCheck []string) []string { var securityChecks []string for _, v := range securityCheck { if !slices.Contains(types.SecurityChecks, v) { - log.Logger.Warnf("unknown security check: %s", v) - continue + return nil, xerrors.Errorf("unknown security check: %s", v) } securityChecks = append(securityChecks, v) } - return securityChecks + return securityChecks, nil } diff --git a/pkg/flag/scan_flags_test.go b/pkg/flag/scan_flags_test.go index 1e12a0a721b9..caae63e10f28 100644 --- a/pkg/flag/scan_flags_test.go +++ b/pkg/flag/scan_flags_test.go @@ -3,14 +3,11 @@ package flag_test import ( "testing" - "github.com/spf13/viper" - "github.com/stretchr/testify/assert" - "go.uber.org/zap" - "go.uber.org/zap/zaptest/observer" - "github.com/aquasecurity/trivy/pkg/flag" - "github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/types" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestScanFlagGroup_ToOptions(t *testing.T) { @@ -22,11 +19,11 @@ func TestScanFlagGroup_ToOptions(t *testing.T) { securityChecks string } tests := []struct { - name string - args []string - fields fields - want flag.ScanOptions - wantLogs []string + name string + args []string + fields fields + want flag.ScanOptions + assertion require.ErrorAssertionFunc }{ { name: "happy path", @@ -35,6 +32,7 @@ func TestScanFlagGroup_ToOptions(t *testing.T) { want: flag.ScanOptions{ Target: "alpine:latest", }, + assertion: require.NoError, }, { name: "happy path for configs", @@ -46,30 +44,31 @@ func TestScanFlagGroup_ToOptions(t *testing.T) { Target: "alpine:latest", SecurityChecks: []string{types.SecurityCheckConfig}, }, + assertion: require.NoError, }, { name: "with wrong security check", fields: fields{ securityChecks: "vuln,WRONG-CHECK", }, - want: flag.ScanOptions{ - SecurityChecks: []string{types.SecurityCheckVulnerability}, - }, - wantLogs: []string{ - `unknown security check: WRONG-CHECK`, + want: flag.ScanOptions{}, + assertion: func(t require.TestingT, err error, msgs ...interface{}) { + require.ErrorContains(t, err, "unknown security check") }, }, { - name: "without target (args)", - args: []string{}, - fields: fields{}, - want: flag.ScanOptions{}, + name: "without target (args)", + args: []string{}, + fields: fields{}, + want: flag.ScanOptions{}, + assertion: require.NoError, }, { - name: "with two or more targets (args)", - args: []string{"alpine:latest", "nginx:latest"}, - fields: fields{}, - want: flag.ScanOptions{}, + name: "with two or more targets (args)", + args: []string{"alpine:latest", "nginx:latest"}, + fields: fields{}, + want: flag.ScanOptions{}, + assertion: require.NoError, }, { name: "skip two files", @@ -79,6 +78,7 @@ func TestScanFlagGroup_ToOptions(t *testing.T) { want: flag.ScanOptions{ SkipFiles: []string{"file1", "file2"}, }, + assertion: require.NoError, }, { name: "skip two folders", @@ -88,6 +88,7 @@ func TestScanFlagGroup_ToOptions(t *testing.T) { want: flag.ScanOptions{ SkipDirs: []string{"dir1", "dir2"}, }, + assertion: require.NoError, }, { name: "offline scan", @@ -97,16 +98,12 @@ func TestScanFlagGroup_ToOptions(t *testing.T) { want: flag.ScanOptions{ OfflineScan: true, }, + assertion: require.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - level := zap.WarnLevel - - core, obs := observer.New(level) - log.Logger = zap.New(core).Sugar() - viper.Set(flag.SkipDirsFlag.ConfigName, tt.fields.skipDirs) viper.Set(flag.SkipFilesFlag.ConfigName, tt.fields.skipFiles) viper.Set(flag.OfflineScanFlag.ConfigName, tt.fields.offlineScan) @@ -121,15 +118,9 @@ func TestScanFlagGroup_ToOptions(t *testing.T) { SecurityChecks: &flag.SecurityChecksFlag, } - got := f.ToOptions(tt.args) + got, err := f.ToOptions(tt.args) + tt.assertion(t, err) assert.Equalf(t, tt.want, got, "ToOptions()") - - // Assert log messages - var gotMessages []string - for _, entry := range obs.AllUntimed() { - gotMessages = append(gotMessages, entry.Message) - } - assert.Equal(t, tt.wantLogs, gotMessages, tt.name) }) }