Skip to content

Commit

Permalink
fix(flag): add error when there are no supported security checks (#2713)
Browse files Browse the repository at this point in the history
  • Loading branch information
DmitriyLewen committed Aug 16, 2022
1 parent aef02aa commit 917f388
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 48 deletions.
5 changes: 4 additions & 1 deletion pkg/flag/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,10 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF
}

if f.ScanFlagGroup != nil {
opts.ScanOptions = f.ScanFlagGroup.ToOptions(args)
opts.ScanOptions, err = f.ScanFlagGroup.ToOptions(args)
if err != nil {
return Options{}, xerrors.Errorf("scan flag error: %w", err)
}
}

if f.SecretFlagGroup != nil {
Expand Down
23 changes: 13 additions & 10 deletions pkg/flag/scan_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"strings"

"golang.org/x/exp/slices"
"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/types"
)

Expand Down Expand Up @@ -69,36 +69,39 @@ func (f *ScanFlagGroup) Flags() []*Flag {
return []*Flag{f.SkipDirs, f.SkipFiles, f.OfflineScan, f.SecurityChecks}
}

func (f *ScanFlagGroup) ToOptions(args []string) ScanOptions {
func (f *ScanFlagGroup) ToOptions(args []string) (ScanOptions, error) {
var target string
if len(args) == 1 {
target = args[0]
}
securityChecks, err := parseSecurityCheck(getStringSlice(f.SecurityChecks))
if err != nil {
return ScanOptions{}, xerrors.Errorf("unable to parse security checks: %w", err)
}

return ScanOptions{
Target: target,
SkipDirs: getStringSlice(f.SkipDirs),
SkipFiles: getStringSlice(f.SkipFiles),
OfflineScan: getBool(f.OfflineScan),
SecurityChecks: parseSecurityCheck(getStringSlice(f.SecurityChecks)),
}
SecurityChecks: securityChecks,
}, nil
}

func parseSecurityCheck(securityCheck []string) []string {
func parseSecurityCheck(securityCheck []string) ([]string, error) {
switch {
case len(securityCheck) == 0: // no checks
return nil
case len(securityCheck) == 0: // no checks. Can be empty when generating SBOM
return nil, nil
case len(securityCheck) == 1 && strings.Contains(securityCheck[0], ","): // get checks from flag
securityCheck = strings.Split(securityCheck[0], ",")
}

var securityChecks []string
for _, v := range securityCheck {
if !slices.Contains(types.SecurityChecks, v) {
log.Logger.Warnf("unknown security check: %s", v)
continue
return nil, xerrors.Errorf("unknown security check: %s", v)
}
securityChecks = append(securityChecks, v)
}
return securityChecks
return securityChecks, nil
}
65 changes: 28 additions & 37 deletions pkg/flag/scan_flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@ package flag_test
import (
"testing"

"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"

"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestScanFlagGroup_ToOptions(t *testing.T) {
Expand All @@ -22,11 +19,11 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
securityChecks string
}
tests := []struct {
name string
args []string
fields fields
want flag.ScanOptions
wantLogs []string
name string
args []string
fields fields
want flag.ScanOptions
assertion require.ErrorAssertionFunc
}{
{
name: "happy path",
Expand All @@ -35,6 +32,7 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
want: flag.ScanOptions{
Target: "alpine:latest",
},
assertion: require.NoError,
},
{
name: "happy path for configs",
Expand All @@ -46,30 +44,31 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
Target: "alpine:latest",
SecurityChecks: []string{types.SecurityCheckConfig},
},
assertion: require.NoError,
},
{
name: "with wrong security check",
fields: fields{
securityChecks: "vuln,WRONG-CHECK",
},
want: flag.ScanOptions{
SecurityChecks: []string{types.SecurityCheckVulnerability},
},
wantLogs: []string{
`unknown security check: WRONG-CHECK`,
want: flag.ScanOptions{},
assertion: func(t require.TestingT, err error, msgs ...interface{}) {
require.ErrorContains(t, err, "unknown security check")
},
},
{
name: "without target (args)",
args: []string{},
fields: fields{},
want: flag.ScanOptions{},
name: "without target (args)",
args: []string{},
fields: fields{},
want: flag.ScanOptions{},
assertion: require.NoError,
},
{
name: "with two or more targets (args)",
args: []string{"alpine:latest", "nginx:latest"},
fields: fields{},
want: flag.ScanOptions{},
name: "with two or more targets (args)",
args: []string{"alpine:latest", "nginx:latest"},
fields: fields{},
want: flag.ScanOptions{},
assertion: require.NoError,
},
{
name: "skip two files",
Expand All @@ -79,6 +78,7 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
want: flag.ScanOptions{
SkipFiles: []string{"file1", "file2"},
},
assertion: require.NoError,
},
{
name: "skip two folders",
Expand All @@ -88,6 +88,7 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
want: flag.ScanOptions{
SkipDirs: []string{"dir1", "dir2"},
},
assertion: require.NoError,
},
{
name: "offline scan",
Expand All @@ -97,16 +98,12 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
want: flag.ScanOptions{
OfflineScan: true,
},
assertion: require.NoError,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
level := zap.WarnLevel

core, obs := observer.New(level)
log.Logger = zap.New(core).Sugar()

viper.Set(flag.SkipDirsFlag.ConfigName, tt.fields.skipDirs)
viper.Set(flag.SkipFilesFlag.ConfigName, tt.fields.skipFiles)
viper.Set(flag.OfflineScanFlag.ConfigName, tt.fields.offlineScan)
Expand All @@ -121,15 +118,9 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
SecurityChecks: &flag.SecurityChecksFlag,
}

got := f.ToOptions(tt.args)
got, err := f.ToOptions(tt.args)
tt.assertion(t, err)
assert.Equalf(t, tt.want, got, "ToOptions()")

// Assert log messages
var gotMessages []string
for _, entry := range obs.AllUntimed() {
gotMessages = append(gotMessages, entry.Message)
}
assert.Equal(t, tt.wantLogs, gotMessages, tt.name)
})

}
Expand Down

0 comments on commit 917f388

Please sign in to comment.