Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: use a map to track which advisories should be checked for which packages #216

Merged
merged 4 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions fixtures/configs-extra-dbs/db/osv-1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"id": "OSV-1",
"affected": [
{
"package": {
"ecosystem": "npm",
"name": "request"
}
},
{
"package": {
"ecosystem": "npm",
"name": "@cypress/request"
}
}
]
}
3 changes: 3 additions & 0 deletions fixtures/configs-extra-dbs/db/osv-2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"id": "OSV-2"
}
3 changes: 3 additions & 0 deletions fixtures/configs-extra-dbs/db/osv-3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"id": "OSV-3"
}
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func describeDB(db database.DB) string {
color.YellowString("%d", tt.BatchSize),
)
case *database.ZipDB:
count := len(tt.Vulnerabilities(true))
count := tt.VulnerabilitiesCount

return fmt.Sprintf(
"%s %s, including withdrawn - last updated %s",
Expand All @@ -179,7 +179,7 @@ func describeDB(db database.DB) string {
tt.UpdatedAt,
)
case *database.DirDB:
count := len(tt.Vulnerabilities(true))
count := tt.VulnerabilitiesCount

return fmt.Sprintf(
"%s %s, including withdrawn",
Expand Down
4 changes: 2 additions & 2 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1036,12 +1036,12 @@ func TestRun_Configs(t *testing.T) {
wantStdout: `
Loaded the following OSV databases:
api#https://example.com/v1 (using batches of 1000)
dir#file:/fixtures/configs-extra-dbs (0 vulnerabilities, including withdrawn)
dir#file:/fixtures/configs-extra-dbs (3 vulnerabilities, including withdrawn)
zip#https://example.com/osvs/all
fixtures/configs-extra-dbs/yarn.lock: found 0 packages
Using config at fixtures/configs-extra-dbs/.osv-detector.yaml (0 ignores)
Using db api#https://example.com/v1 (using batches of 1000)
Using db dir#file:/fixtures/configs-extra-dbs (0 vulnerabilities, including withdrawn)
Using db dir#file:/fixtures/configs-extra-dbs (3 vulnerabilities, including withdrawn)

no known vulnerabilities found
`,
Expand Down
4 changes: 2 additions & 2 deletions pkg/database/dir.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ var ErrDirPathWrongProtocol = errors.New("directory path must start with \"file:
// load walks the filesystem starting with the working directory within the local path,
// loading all OSVs found along the way.
func (db *DirDB) load() error {
db.vulnerabilities = []OSV{}
db.vulnerabilities = make(map[string][]OSV)

if !strings.HasPrefix(db.LocalPath, "file:") {
return ErrDirPathWrongProtocol
Expand Down Expand Up @@ -78,7 +78,7 @@ func (db *DirDB) load() error {
return nil
}

db.vulnerabilities = append(db.vulnerabilities, pa)
db.addVulnerability(pa)

return nil
})
Expand Down
14 changes: 12 additions & 2 deletions pkg/database/dir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@ import (
func TestNewDirDB(t *testing.T) {
t.Parallel()

osvs := []database.OSV{{ID: "OSV-1"}, {ID: "OSV-2"}, {ID: "GHSA-1234"}}
osvs := []database.OSV{
withDefaultAffected("OSV-1"),
withDefaultAffected("OSV-2"),
{
ID: "GHSA-1234",
Affected: []database.Affected{
{Package: database.Package{Ecosystem: "npm", Name: "request"}},
{Package: database.Package{Ecosystem: "npm", Name: "@cypress/request"}},
},
},
}

db, err := database.NewDirDB(database.Config{URL: "file:/fixtures/db"}, false)

Expand Down Expand Up @@ -69,7 +79,7 @@ func TestNewDirDB_DoesNotExist(t *testing.T) {
func TestNewDirDB_WorkingDirectory(t *testing.T) {
t.Parallel()

osvs := []database.OSV{{ID: "OSV-1"}}
osvs := []database.OSV{withDefaultAffected("OSV-1")}

db, err := database.NewDirDB(database.Config{URL: "file:/fixtures/db", WorkingDirectory: "nested-1"}, false)

Expand Down
16 changes: 15 additions & 1 deletion pkg/database/fixtures/db/file.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
{
"id": "GHSA-1234"
"id": "GHSA-1234",
"affected": [
{
"package": {
"ecosystem": "npm",
"name": "request"
}
},
{
"package": {
"ecosystem": "npm",
"name": "@cypress/request"
}
}
]
}
11 changes: 10 additions & 1 deletion pkg/database/fixtures/db/nested-1/osv-1.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
{
"id": "OSV-1"
"id": "OSV-1",
"affected": [
{
"package": {
"name": "mine",
"ecosystem": "PyPi"
},
"versions": []
}
]
}
11 changes: 10 additions & 1 deletion pkg/database/fixtures/db/nested-2/osv-2.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
{
"id": "OSV-2"
"id": "OSV-2",
"affected": [
{
"package": {
"name": "mine",
"ecosystem": "PyPi"
},
"versions": []
}
]
}
44 changes: 34 additions & 10 deletions pkg/database/mem-check.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,39 @@ import (
// an OSV database that lives in-memory, and can be used by other structs
// that handle loading the vulnerabilities from where ever
type memDB struct {
vulnerabilities []OSV
vulnerabilities map[string][]OSV
VulnerabilitiesCount int
}

func (db *memDB) Vulnerabilities(includeWithdrawn bool) []OSV {
if includeWithdrawn {
return db.vulnerabilities
func (db *memDB) addVulnerability(osv OSV) {
db.VulnerabilitiesCount++

for _, affected := range osv.Affected {
hash := string(affected.Package.Ecosystem) + "-" + affected.Package.NormalizedName()
vulns := db.vulnerabilities[hash]

if vulns == nil {
vulns = []OSV{}
}

db.vulnerabilities[hash] = append(vulns, osv)
}
}

func (db *memDB) Vulnerabilities(includeWithdrawn bool) []OSV {
var vulnerabilities []OSV
ids := make(map[string]struct{})

for _, vulnerability := range db.vulnerabilities {
if vulnerability.Withdrawn == nil {
vulnerabilities = append(vulnerabilities, vulnerability)
for _, vulns := range db.vulnerabilities {
for _, vulnerability := range vulns {
if _, ok := ids[vulnerability.ID]; ok {
continue
}

if (vulnerability.Withdrawn == nil) || includeWithdrawn {
vulnerabilities = append(vulnerabilities, vulnerability)
ids[vulnerability.ID] = struct{}{}
}
}
}

Expand All @@ -29,9 +49,13 @@ func (db *memDB) Vulnerabilities(includeWithdrawn bool) []OSV {
func (db *memDB) VulnerabilitiesAffectingPackage(pkg internal.PackageDetails) Vulnerabilities {
var vulnerabilities Vulnerabilities

for _, vulnerability := range db.Vulnerabilities(false) {
if vulnerability.IsAffected(pkg) && !vulnerabilities.Includes(vulnerability) {
vulnerabilities = append(vulnerabilities, vulnerability)
hash := string(pkg.Ecosystem) + "-" + pkg.Name

if vulns, ok := db.vulnerabilities[hash]; ok {
for _, vulnerability := range vulns {
if vulnerability.Withdrawn == nil && vulnerability.IsAffected(pkg) && !vulnerabilities.Includes(vulnerability) {
vulnerabilities = append(vulnerabilities, vulnerability)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/database/zip.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File) {
return
}

db.vulnerabilities = append(db.vulnerabilities, osv)
db.addVulnerability(osv)
}

// load fetches a zip archive of the OSV database and loads known vulnerabilities
Expand All @@ -162,7 +162,7 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File) {
// so that a new version of the archive is only downloaded if it has been
// modified, per HTTP caching standards.
func (db *ZipDB) load() error {
db.vulnerabilities = []OSV{}
db.vulnerabilities = make(map[string][]OSV)

body, err := db.fetchZip()

Expand Down