Skip to content

Commit

Permalink
perf: use a map to track which advisories should be checked for which…
Browse files Browse the repository at this point in the history
… packages (#216)

* perf: use a map to track which advisories should be checked for which packages

* fix: ensure that vulnerabilities are deduplicated

* refactor: record the number of loaded vulnerabilities

* fix: don't load advisories that don't affect any packages
  • Loading branch information
G-Rath committed Sep 14, 2023
1 parent 1604a55 commit 2e92110
Show file tree
Hide file tree
Showing 13 changed files with 180 additions and 76 deletions.
17 changes: 17 additions & 0 deletions fixtures/configs-extra-dbs/db/osv-1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"id": "OSV-1",
"affected": [
{
"package": {
"ecosystem": "npm",
"name": "request"
}
},
{
"package": {
"ecosystem": "npm",
"name": "@cypress/request"
}
}
]
}
3 changes: 3 additions & 0 deletions fixtures/configs-extra-dbs/db/osv-2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"id": "OSV-2"
}
3 changes: 3 additions & 0 deletions fixtures/configs-extra-dbs/db/osv-3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"id": "OSV-3"
}
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func describeDB(db database.DB) string {
color.YellowString("%d", tt.BatchSize),
)
case *database.ZipDB:
count := len(tt.Vulnerabilities(true))
count := tt.VulnerabilitiesCount

return fmt.Sprintf(
"%s %s, including withdrawn - last updated %s",
Expand All @@ -179,7 +179,7 @@ func describeDB(db database.DB) string {
tt.UpdatedAt,
)
case *database.DirDB:
count := len(tt.Vulnerabilities(true))
count := tt.VulnerabilitiesCount

return fmt.Sprintf(
"%s %s, including withdrawn",
Expand Down
4 changes: 2 additions & 2 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1036,12 +1036,12 @@ func TestRun_Configs(t *testing.T) {
wantStdout: `
Loaded the following OSV databases:
api#https://example.com/v1 (using batches of 1000)
dir#file:/fixtures/configs-extra-dbs (0 vulnerabilities, including withdrawn)
dir#file:/fixtures/configs-extra-dbs (3 vulnerabilities, including withdrawn)
zip#https://example.com/osvs/all
fixtures/configs-extra-dbs/yarn.lock: found 0 packages
Using config at fixtures/configs-extra-dbs/.osv-detector.yaml (0 ignores)
Using db api#https://example.com/v1 (using batches of 1000)
Using db dir#file:/fixtures/configs-extra-dbs (0 vulnerabilities, including withdrawn)
Using db dir#file:/fixtures/configs-extra-dbs (3 vulnerabilities, including withdrawn)
no known vulnerabilities found
`,
Expand Down
4 changes: 2 additions & 2 deletions pkg/database/dir.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ var ErrDirPathWrongProtocol = errors.New("directory path must start with \"file:
// load walks the filesystem starting with the working directory within the local path,
// loading all OSVs found along the way.
func (db *DirDB) load() error {
db.vulnerabilities = []OSV{}
db.vulnerabilities = make(map[string][]OSV)

if !strings.HasPrefix(db.LocalPath, "file:") {
return ErrDirPathWrongProtocol
Expand Down Expand Up @@ -78,7 +78,7 @@ func (db *DirDB) load() error {
return nil
}

db.vulnerabilities = append(db.vulnerabilities, pa)
db.addVulnerability(pa)

return nil
})
Expand Down
14 changes: 12 additions & 2 deletions pkg/database/dir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@ import (
func TestNewDirDB(t *testing.T) {
t.Parallel()

osvs := []database.OSV{{ID: "OSV-1"}, {ID: "OSV-2"}, {ID: "GHSA-1234"}}
osvs := []database.OSV{
withDefaultAffected("OSV-1"),
withDefaultAffected("OSV-2"),
{
ID: "GHSA-1234",
Affected: []database.Affected{
{Package: database.Package{Ecosystem: "npm", Name: "request"}},
{Package: database.Package{Ecosystem: "npm", Name: "@cypress/request"}},
},
},
}

db, err := database.NewDirDB(database.Config{URL: "file:/fixtures/db"}, false)

Expand Down Expand Up @@ -69,7 +79,7 @@ func TestNewDirDB_DoesNotExist(t *testing.T) {
func TestNewDirDB_WorkingDirectory(t *testing.T) {
t.Parallel()

osvs := []database.OSV{{ID: "OSV-1"}}
osvs := []database.OSV{withDefaultAffected("OSV-1")}

db, err := database.NewDirDB(database.Config{URL: "file:/fixtures/db", WorkingDirectory: "nested-1"}, false)

Expand Down
16 changes: 15 additions & 1 deletion pkg/database/fixtures/db/file.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
{
"id": "GHSA-1234"
"id": "GHSA-1234",
"affected": [
{
"package": {
"ecosystem": "npm",
"name": "request"
}
},
{
"package": {
"ecosystem": "npm",
"name": "@cypress/request"
}
}
]
}
11 changes: 10 additions & 1 deletion pkg/database/fixtures/db/nested-1/osv-1.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
{
"id": "OSV-1"
"id": "OSV-1",
"affected": [
{
"package": {
"name": "mine",
"ecosystem": "PyPi"
},
"versions": []
}
]
}
11 changes: 10 additions & 1 deletion pkg/database/fixtures/db/nested-2/osv-2.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
{
"id": "OSV-2"
"id": "OSV-2",
"affected": [
{
"package": {
"name": "mine",
"ecosystem": "PyPi"
},
"versions": []
}
]
}
44 changes: 34 additions & 10 deletions pkg/database/mem-check.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,39 @@ import (
// an OSV database that lives in-memory, and can be used by other structs
// that handle loading the vulnerabilities from where ever
type memDB struct {
vulnerabilities []OSV
vulnerabilities map[string][]OSV
VulnerabilitiesCount int
}

func (db *memDB) Vulnerabilities(includeWithdrawn bool) []OSV {
if includeWithdrawn {
return db.vulnerabilities
func (db *memDB) addVulnerability(osv OSV) {
db.VulnerabilitiesCount++

for _, affected := range osv.Affected {
hash := string(affected.Package.Ecosystem) + "-" + affected.Package.NormalizedName()
vulns := db.vulnerabilities[hash]

if vulns == nil {
vulns = []OSV{}
}

db.vulnerabilities[hash] = append(vulns, osv)
}
}

func (db *memDB) Vulnerabilities(includeWithdrawn bool) []OSV {
var vulnerabilities []OSV
ids := make(map[string]struct{})

for _, vulnerability := range db.vulnerabilities {
if vulnerability.Withdrawn == nil {
vulnerabilities = append(vulnerabilities, vulnerability)
for _, vulns := range db.vulnerabilities {
for _, vulnerability := range vulns {
if _, ok := ids[vulnerability.ID]; ok {
continue
}

if (vulnerability.Withdrawn == nil) || includeWithdrawn {
vulnerabilities = append(vulnerabilities, vulnerability)
ids[vulnerability.ID] = struct{}{}
}
}
}

Expand All @@ -29,9 +49,13 @@ func (db *memDB) Vulnerabilities(includeWithdrawn bool) []OSV {
func (db *memDB) VulnerabilitiesAffectingPackage(pkg internal.PackageDetails) Vulnerabilities {
var vulnerabilities Vulnerabilities

for _, vulnerability := range db.Vulnerabilities(false) {
if vulnerability.IsAffected(pkg) && !vulnerabilities.Includes(vulnerability) {
vulnerabilities = append(vulnerabilities, vulnerability)
hash := string(pkg.Ecosystem) + "-" + pkg.Name

if vulns, ok := db.vulnerabilities[hash]; ok {
for _, vulnerability := range vulns {
if vulnerability.Withdrawn == nil && vulnerability.IsAffected(pkg) && !vulnerabilities.Includes(vulnerability) {
vulnerabilities = append(vulnerabilities, vulnerability)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/database/zip.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File) {
return
}

db.vulnerabilities = append(db.vulnerabilities, osv)
db.addVulnerability(osv)
}

// load fetches a zip archive of the OSV database and loads known vulnerabilities
Expand All @@ -162,7 +162,7 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File) {
// so that a new version of the archive is only downloaded if it has been
// modified, per HTTP caching standards.
func (db *ZipDB) load() error {
db.vulnerabilities = []OSV{}
db.vulnerabilities = make(map[string][]OSV)

body, err := db.fetchZip()

Expand Down

0 comments on commit 2e92110

Please sign in to comment.