Skip to content

Commit

Permalink
refactor: unify package addition and vulnerability scanning (#6579)
Browse files Browse the repository at this point in the history
Signed-off-by: knqyf263 <knqyf263@gmail.com>
  • Loading branch information
knqyf263 committed May 20, 2024
1 parent d465d9d commit cf1a7bf
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 111 deletions.
71 changes: 36 additions & 35 deletions pkg/scanner/langpkg/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ var (
)

type Scanner interface {
Packages(target types.ScanTarget, options types.ScanOptions) types.Results
Scan(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (types.Results, error)
}

Expand All @@ -34,24 +33,7 @@ func NewScanner() Scanner {
return &scanner{}
}

func (s *scanner) Packages(target types.ScanTarget, _ types.ScanOptions) types.Results {
var results types.Results
for _, app := range target.Applications {
if len(app.Packages) == 0 {
continue
}

results = append(results, types.Result{
Target: targetName(app.Type, app.FilePath),
Class: types.ClassLangPkg,
Type: app.Type,
Packages: app.Packages,
})
}
return results
}

func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, _ types.ScanOptions) (types.Results, error) {
func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.ScanOptions) (types.Results, error) {
apps := target.Applications
log.Info("Number of language-specific files", log.Int("num", len(apps)))
if len(apps) == 0 {
Expand All @@ -66,34 +48,53 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, _ types.Sca
}

ctx = log.WithContextPrefix(ctx, string(app.Type))
result := types.Result{
Target: targetName(app.Type, app.FilePath),
Class: types.ClassLangPkg,
Type: app.Type,
}

// Prevent the same log messages from being displayed many times for the same type.
if _, ok := printedTypes[app.Type]; !ok {
log.InfoContext(ctx, "Detecting vulnerabilities...")
printedTypes[app.Type] = struct{}{}
if opts.ListAllPackages {
sort.Sort(app.Packages)
result.Packages = app.Packages
}

log.DebugContext(ctx, "Scanning packages from the file", log.String("file_path", app.FilePath))
vulns, err := library.Detect(ctx, app.Type, app.Packages)
if err != nil {
return nil, xerrors.Errorf("failed vulnerability detection of packages: %w", err)
} else if len(vulns) == 0 {
continue
if opts.Scanners.Enabled(types.VulnerabilityScanner) {
var err error
result.Vulnerabilities, err = s.scanVulnerabilities(ctx, app, printedTypes)
if err != nil {
return nil, err
}
}

results = append(results, types.Result{
Target: targetName(app.Type, app.FilePath),
Vulnerabilities: vulns,
Class: types.ClassLangPkg,
Type: app.Type,
})
if len(result.Packages) == 0 && len(result.Vulnerabilities) == 0 {
continue
}
results = append(results, result)
}
sort.Slice(results, func(i, j int) bool {
return results[i].Target < results[j].Target
})
return results, nil
}

func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes map[ftypes.LangType]struct{}) (
[]types.DetectedVulnerability, error) {

// Prevent the same log messages from being displayed many times for the same type.
if _, ok := printedTypes[app.Type]; !ok {
log.InfoContext(ctx, "Detecting vulnerabilities...")
printedTypes[app.Type] = struct{}{}
}

log.DebugContext(ctx, "Scanning packages for vulnerabilities", log.String("file_path", app.FilePath))
vulns, err := library.Detect(ctx, app.Type, app.Packages)
if err != nil {
return nil, xerrors.Errorf("failed vulnerability detection of libraries: %w", err)
}
return vulns, err
}

func targetName(appType ftypes.LangType, filePath string) string {
if t, ok := PkgTargets[appType]; ok && filePath == "" {
// When the file path is empty, we will overwrite it with the pre-defined value.
Expand Down
66 changes: 18 additions & 48 deletions pkg/scanner/local/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"golang.org/x/xerrors"

dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
ospkgDetector "github.com/aquasecurity/trivy/pkg/detector/ospkg"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/fanal/applier"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
Expand Down Expand Up @@ -105,39 +106,19 @@ func (s Scanner) Scan(ctx context.Context, targetName, artifactKey string, blobK
}

func (s Scanner) ScanTarget(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (types.Results, ftypes.OS, error) {
var eosl bool
var results, pkgResults types.Results
var err error
var results types.Results

// By default, we need to remove dev dependencies from the result
// IncludeDevDeps option allows you not to remove them
excludeDevDeps(target.Applications, options.IncludeDevDeps)

// Fill OS packages and language-specific packages
if options.ListAllPackages {
if res := s.osPkgScanner.Packages(target, options); len(res.Packages) != 0 {
pkgResults = append(pkgResults, res)
}
pkgResults = append(pkgResults, s.langPkgScanner.Packages(target, options)...)
}

// Scan packages for vulnerabilities
if options.Scanners.Enabled(types.VulnerabilityScanner) {
var vulnResults types.Results
vulnResults, eosl, err = s.scanVulnerabilities(ctx, target, options)
if err != nil {
return nil, ftypes.OS{}, xerrors.Errorf("failed to detect vulnerabilities: %w", err)
}
target.OS.Eosl = eosl

// Merge package results into vulnerability results
mergedResults := s.fillPkgsInVulns(pkgResults, vulnResults)

results = append(results, mergedResults...)
} else {
// If vulnerability scanning is not enabled, it just adds package results.
results = append(results, pkgResults...)
// Add packages if needed and scan packages for vulnerabilities
vulnResults, eosl, err := s.scanVulnerabilities(ctx, target, options)
if err != nil {
return nil, ftypes.OS{}, xerrors.Errorf("failed to detect vulnerabilities: %w", err)
}
target.OS.Eosl = eosl
results = append(results, vulnResults...)

// Store misconfigurations
results = append(results, s.misconfsToResults(target.Misconfigurations, options)...)
Expand Down Expand Up @@ -172,17 +153,24 @@ func (s Scanner) ScanTarget(ctx context.Context, target types.ScanTarget, option

func (s Scanner) scanVulnerabilities(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (
types.Results, bool, error) {
if !options.ListAllPackages && !options.Scanners.Enabled(types.VulnerabilityScanner) {
return nil, false, nil
}

var eosl bool
var results types.Results

if slices.Contains(options.VulnType, types.VulnTypeOS) {
vuln, detectedEOSL, err := s.osPkgScanner.Scan(ctx, target, options)
if err != nil {
switch {
case errors.Is(err, ospkgDetector.ErrUnsupportedOS):
// do nothing
case err != nil:
return nil, false, xerrors.Errorf("unable to scan OS packages: %w", err)
} else if vuln.Target != "" {
case vuln.Target != "":
results = append(results, vuln)
eosl = detectedEOSL
}
eosl = detectedEOSL
}

if slices.Contains(options.VulnType, types.VulnTypeLibrary) {
Expand All @@ -196,24 +184,6 @@ func (s Scanner) scanVulnerabilities(ctx context.Context, target types.ScanTarge
return results, eosl, nil
}

func (s Scanner) fillPkgsInVulns(pkgResults, vulnResults types.Results) types.Results {
var results types.Results
if len(pkgResults) == 0 { // '--list-all-pkgs' == false or packages not found
return vulnResults
}
for _, result := range pkgResults {
if r, found := lo.Find(vulnResults, func(r types.Result) bool {
return r.Class == result.Class && r.Target == result.Target && r.Type == result.Type
}); found {
r.Packages = result.Packages
results = append(results, r)
} else { // when package result has no vulnerabilities we still need to add it to result(for 'list-all-pkgs')
results = append(results, result)
}
}
return results
}

func (s Scanner) misconfsToResults(misconfs []ftypes.Misconfiguration, options types.ScanOptions) types.Results {
if !ShouldScanMisconfigOrRbac(options.Scanners) &&
!options.ImageConfigScanners.Enabled(types.MisconfigScanner) {
Expand Down
50 changes: 22 additions & 28 deletions pkg/scanner/ospkg/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ospkg

import (
"context"
"errors"
"fmt"
"sort"
"time"
Expand All @@ -15,7 +14,6 @@ import (
)

type Scanner interface {
Packages(target types.ScanTarget, options types.ScanOptions) types.Result
Scan(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (types.Result, bool, error)
}

Expand All @@ -25,21 +23,7 @@ func NewScanner() Scanner {
return &scanner{}
}

func (s *scanner) Packages(target types.ScanTarget, _ types.ScanOptions) types.Result {
if len(target.Packages) == 0 || !target.OS.Detected() {
return types.Result{}
}

sort.Sort(target.Packages)
return types.Result{
Target: fmt.Sprintf("%s (%s %s)", target.Name, target.OS.Family, target.OS.Name),
Class: types.ClassOSPkg,
Type: target.OS.Family,
Packages: target.Packages,
}
}

func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, _ types.ScanOptions) (types.Result, bool, error) {
func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.ScanOptions) (types.Result, bool, error) {
if !target.OS.Detected() {
log.Debug("Detected OS: unknown")
return types.Result{}, false, nil
Expand All @@ -52,19 +36,29 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, _ types.Sca
target.OS.Name += "-ESM"
}

result := types.Result{
Target: fmt.Sprintf("%s (%s %s)", target.Name, target.OS.Family, target.OS.Name),
Class: types.ClassOSPkg,
Type: target.OS.Family,
}

if opts.ListAllPackages {
sort.Sort(target.Packages)
result.Packages = target.Packages
}

if !opts.Scanners.Enabled(types.VulnerabilityScanner) {
// Return packages only
return result, false, nil
}

vulns, eosl, err := ospkgDetector.Detect(ctx, "", target.OS.Family, target.OS.Name, target.Repository, time.Time{},
target.Packages)
if errors.Is(err, ospkgDetector.ErrUnsupportedOS) {
return types.Result{}, false, nil
} else if err != nil {
return types.Result{}, false, xerrors.Errorf("failed vulnerability detection of OS packages: %w", err)
if err != nil {
// Return a result for those who want to override the error handling.
return result, false, xerrors.Errorf("failed vulnerability detection of OS packages: %w", err)
}
result.Vulnerabilities = vulns

artifactDetail := fmt.Sprintf("%s (%s %s)", target.Name, target.OS.Family, target.OS.Name)
return types.Result{
Target: artifactDetail,
Vulnerabilities: vulns,
Class: types.ClassOSPkg,
Type: target.OS.Family,
}, eosl, nil
return result, eosl, nil
}

0 comments on commit cf1a7bf

Please sign in to comment.