diff --git a/syft/pkg/cataloger/python/parse_wheel_egg_metadata.go b/syft/pkg/cataloger/python/parse_wheel_egg_metadata.go index e0ab4fb8d90..88e98527547 100644 --- a/syft/pkg/cataloger/python/parse_wheel_egg_metadata.go +++ b/syft/pkg/cataloger/python/parse_wheel_egg_metadata.go @@ -93,15 +93,7 @@ func extractRFC5322Fields(locationReader file.LocationReadCloser) (map[string]an key = strings.ReplaceAll(strings.TrimSpace(line[0:i]), "-", "") val := getFieldType(key, strings.TrimSpace(line[i+1:])) - if strSlice, ok := val.([]string); ok { - if fields[key] == nil { - fields[key] = strSlice - } else { - fields[key] = append(fields[key].([]string), strSlice...) - } - } else { - fields[key] = val - } + fields[key] = handleSingleOrMultiField(fields[key], val) } else { log.Warnf("cannot parse field from path: %q from line: %q", locationReader.Path(), line) } @@ -110,6 +102,25 @@ func extractRFC5322Fields(locationReader file.LocationReadCloser) (map[string]an return fields, nil } +func handleSingleOrMultiField(existingValue, val any) any { + strSlice, ok := val.([]string) + if !ok { + return val + } + if existingValue == nil { + return strSlice + } + + switch existingValueTy := existingValue.(type) { + case []string: + return append(existingValueTy, strSlice...) + case string: + return append([]string{existingValueTy}, strSlice...) + } + + return append([]string{fmt.Sprintf("%s", existingValue)}, strSlice...) +} + func getFieldType(key, in string) any { if plural, ok := pluralFields[key]; ok && plural { return []string{in} diff --git a/syft/pkg/cataloger/python/parse_wheel_egg_metadata_test.go b/syft/pkg/cataloger/python/parse_wheel_egg_metadata_test.go index d4fb79a2c70..2ebc771e336 100644 --- a/syft/pkg/cataloger/python/parse_wheel_egg_metadata_test.go +++ b/syft/pkg/cataloger/python/parse_wheel_egg_metadata_test.go @@ -1,11 +1,15 @@ package python import ( + "io" "os" + "strings" "testing" "github.com/go-test/deep" "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/anchore/syft/internal/cmptest" "github.com/anchore/syft/syft/file" @@ -180,3 +184,59 @@ func TestParseWheelEggMetadataInvalid(t *testing.T) { }) } } + +func Test_extractRFC5322Fields(t *testing.T) { + + tests := []struct { + name string + input string + want map[string]any + wantErr require.ErrorAssertionFunc + }{ + { + name: "with valid plural fields", + input: ` +Name: mxnet +Version: 1.8.0 +Requires-Dist: numpy (>=1.16.6) +Requires-Dist: requests (>=2.22.0) +ProvidesExtra: cryptoutils ; extra == 'secure' +ProvidesExtra: socks ; extra == 'secure' +`, + want: map[string]any{ + "Name": "mxnet", + "Version": "1.8.0", + "RequiresDist": []string{"numpy (>=1.16.6)", "requests (>=2.22.0)"}, + "ProvidesExtra": []string{"cryptoutils ; extra == 'secure'", "socks ; extra == 'secure'"}, + }, + }, + { + name: "with invalid plural fields (overwrite)", + input: ` +Name: mxnet +Version: 1.8.0 +Version: 1.9.0 +`, + want: map[string]any{ + "Name": "mxnet", + "Version": "1.9.0", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.wantErr == nil { + tt.wantErr = require.NoError + } + + reader := file.NewLocationReadCloser( + file.NewLocation("/made/up"), + io.NopCloser(strings.NewReader(tt.input)), + ) + + got, err := extractRFC5322Fields(reader) + tt.wantErr(t, err) + assert.Equal(t, tt.want, got) + }) + } +}