Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aliases in CLI flag, env and config #3481

Merged
merged 2 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ require (
github.com/sirupsen/logrus v1.9.0 // indirect
github.com/spdx/tools-golang v0.3.1-0.20230104082527-d6f58551be3f
github.com/spf13/afero v1.9.2 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/cast v1.5.0
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/subosito/gotenv v1.4.1 // indirect
Expand Down
15 changes: 14 additions & 1 deletion pkg/flag/db_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ var (
ConfigName: "db.skip-update",
Value: false,
Usage: "skip updating vulnerability database",
Aliases: []Alias{
{
Name: "skip-update",
Deprecated: true, // --security-update was renamed to --skip-db-update
},
},
}
NoProgressFlag = Flag{
Name: "no-progress",
Expand Down Expand Up @@ -84,7 +90,14 @@ func (f *DBFlagGroup) Name() string {
}

func (f *DBFlagGroup) Flags() []*Flag {
return []*Flag{f.Reset, f.DownloadDBOnly, f.SkipDBUpdate, f.NoProgress, f.DBRepository, f.Light}
return []*Flag{
f.Reset,
f.DownloadDBOnly,
f.SkipDBUpdate,
f.NoProgress,
f.DBRepository,
f.Light,
}
}

func (f *DBFlagGroup) ToOptions() (DBOptions, error) {
Expand Down
17 changes: 14 additions & 3 deletions pkg/flag/kubernetes_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ var (
ConfigName: "kubernetes.context",
Value: "",
Usage: "specify a context to scan",
Aliases: []Alias{
{Name: "ctx"},
},
}
K8sNamespaceFlag = Flag{
Name: "namespace",
Expand All @@ -23,8 +26,11 @@ var (
ComponentsFlag = Flag{
Name: "components",
ConfigName: "kubernetes.components",
Value: []string{"workload", "infra"},
Usage: "specify which components to scan",
Value: []string{
"workload",
"infra",
},
Usage: "specify which components to scan",
}
)

Expand Down Expand Up @@ -56,7 +62,12 @@ func (f *K8sFlagGroup) Name() string {
}

func (f *K8sFlagGroup) Flags() []*Flag {
return []*Flag{f.ClusterContext, f.Namespace, f.KubeConfig, f.Components}
return []*Flag{
f.ClusterContext,
f.Namespace,
f.KubeConfig,
f.Components,
}
}

func (f *K8sFlagGroup) ToOptions() K8sOptions {
Expand Down
143 changes: 104 additions & 39 deletions pkg/flag/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package flag
import (
"fmt"
"io"
"os"
"strings"
"sync"
"time"

"github.com/spf13/cast"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
Expand Down Expand Up @@ -38,6 +41,15 @@ type Flag struct {

// Deprecated represents if the flag is deprecated
Deprecated bool

// Aliases represents aliases
Aliases []Alias
}

type Alias struct {
Name string
ConfigName string
Deprecated bool
}

type FlagGroup interface {
Expand Down Expand Up @@ -142,39 +154,58 @@ func bind(cmd *cobra.Command, flag *Flag) error {
viper.SetDefault(flag.ConfigName, flag.Value)
return nil
}

// Bind CLI flags
if flag.Persistent {
if err := viper.BindPFlag(flag.ConfigName, cmd.PersistentFlags().Lookup(flag.Name)); err != nil {
return err
return xerrors.Errorf("bind flag error: %w", err)
}
} else {
if err := viper.BindPFlag(flag.ConfigName, cmd.Flags().Lookup(flag.Name)); err != nil {
return err
return xerrors.Errorf("bind flag error: %w", err)
}
}
// We don't use viper.AutomaticEnv, so we need to add a prefix manually here.
if err := viper.BindEnv(flag.ConfigName, strings.ToUpper("trivy_"+strings.ReplaceAll(flag.Name, "-", "_"))); err != nil {

// Bind environmental variable
if err := bindEnv(flag); err != nil {
return err
}

return nil
}

func getString(flag *Flag) string {
if flag == nil {
return ""
func bindEnv(flag *Flag) error {
// We don't use viper.AutomaticEnv, so we need to add a prefix manually here.
envName := strings.ToUpper("trivy_" + strings.ReplaceAll(flag.Name, "-", "_"))
if err := viper.BindEnv(flag.ConfigName, envName); err != nil {
return xerrors.Errorf("bind env error: %w", err)
}
return viper.GetString(flag.ConfigName)

// Bind env aliases
for _, alias := range flag.Aliases {
envAlias := strings.ToUpper("trivy_" + strings.ReplaceAll(alias.Name, "-", "_"))
if err := viper.BindEnv(flag.ConfigName, envAlias); err != nil {
return xerrors.Errorf("bind env error: %w", err)
}
if alias.Deprecated {
if _, ok := os.LookupEnv(envAlias); ok {
log.Logger.Warnf("'%s' is deprecated. Use '%s' instead.", envAlias, envName)
}
}
}
return nil
}

func getString(flag *Flag) string {
return cast.ToString(getValue(flag))
}

func getStringSlice(flag *Flag) []string {
if flag == nil {
return nil
}
// viper always returns a string for ENV
// https://github.com/spf13/viper/blob/419fd86e49ef061d0d33f4d1d56d5e2a480df5bb/viper.go#L545-L553
// and uses strings.Field to separate values (whitespace only)
// we need to separate env values with ','
v := viper.GetStringSlice(flag.ConfigName)
v := cast.ToStringSlice(getValue(flag))
switch {
case len(v) == 0: // no strings
return nil
Expand All @@ -185,24 +216,36 @@ func getStringSlice(flag *Flag) []string {
}

func getInt(flag *Flag) int {
if flag == nil {
return 0
}
return viper.GetInt(flag.ConfigName)
return cast.ToInt(getValue(flag))
}

func getBool(flag *Flag) bool {
if flag == nil {
return false
}
return viper.GetBool(flag.ConfigName)
return cast.ToBool(getValue(flag))
}

func getDuration(flag *Flag) time.Duration {
return cast.ToDuration(getValue(flag))
}

func getValue(flag *Flag) any {
if flag == nil {
return 0
return nil
}

// First, looks for aliases in config file (trivy.yaml).
// Note that viper.RegisterAlias cannot be used for this purpose.
var v any
for _, alias := range flag.Aliases {
if alias.ConfigName == "" {
continue
}
v = viper.Get(alias.ConfigName)
if v != nil {
log.Logger.Warnf("'%s' in config file is deprecated. Use '%s' instead.", alias.ConfigName, flag.ConfigName)
return v
}
}
return viper.GetDuration(flag.ConfigName)
return viper.Get(flag.ConfigName)
}

func (f *Flags) groups() []FlagGroup {
Expand Down Expand Up @@ -260,13 +303,17 @@ func (f *Flags) groups() []FlagGroup {
}

func (f *Flags) AddFlags(cmd *cobra.Command) {
aliases := make(flagAliases)
for _, group := range f.groups() {
for _, flag := range group.Flags() {
addFlag(cmd, flag)

// Register flag aliases
aliases.Add(flag)
}
}

cmd.Flags().SetNormalizeFunc(flagNameNormalize)
cmd.Flags().SetNormalizeFunc(aliases.NormalizeFunc())
}

func (f *Flags) Usages(cmd *cobra.Command) string {
Expand Down Expand Up @@ -403,20 +450,38 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF
return opts, nil
}

func flagNameNormalize(f *pflag.FlagSet, name string) pflag.NormalizedName {
switch name {
case "skip-update":
name = SkipDBUpdateFlag.Name
case "policy":
name = ConfigPolicyFlag.Name
case "data":
name = ConfigDataFlag.Name
case "namespaces":
name = PolicyNamespaceFlag.Name
case "ctx":
name = ClusterContextFlag.Name
case "security-checks":
name = ScannersFlag.Name
}
return pflag.NormalizedName(name)
type flagAlias struct {
formalName string
deprecated bool
once sync.Once
}

// flagAliases have aliases for CLI flags
type flagAliases map[string]*flagAlias

func (a flagAliases) Add(flag *Flag) {
if flag == nil {
return
}
for _, alias := range flag.Aliases {
a[alias.Name] = &flagAlias{
formalName: flag.Name,
deprecated: alias.Deprecated,
}
}
}

func (a flagAliases) NormalizeFunc() func(*pflag.FlagSet, string) pflag.NormalizedName {
return func(_ *pflag.FlagSet, name string) pflag.NormalizedName {
if alias, ok := a[name]; ok {
if alias.deprecated {
// NormalizeFunc is called several times
alias.once.Do(func() {
log.Logger.Warnf("'--%s' is deprecated. Use '--%s' instead.", name, alias.formalName)
})
}
name = alias.formalName
}
return pflag.NormalizedName(name)
}
}
9 changes: 9 additions & 0 deletions pkg/flag/rego_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,27 @@ var (
ConfigName: "rego.policy",
Value: []string{},
Usage: "specify paths to the Rego policy files directory, applying config files",
Aliases: []Alias{
{Name: "policy"},
},
}
ConfigDataFlag = Flag{
Name: "config-data",
ConfigName: "rego.data",
Value: []string{},
Usage: "specify paths from which data for the Rego policies will be recursively loaded",
Aliases: []Alias{
{Name: "data"},
},
}
PolicyNamespaceFlag = Flag{
Name: "policy-namespaces",
ConfigName: "rego.namespaces",
Value: []string{},
Usage: "Rego namespaces",
Aliases: []Alias{
{Name: "namespaces"},
},
}
)

Expand Down
7 changes: 7 additions & 0 deletions pkg/flag/scan_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ var (
types.VulnerabilityScanner,
types.SecretScanner,
},
Aliases: []Alias{
{
Name: "security-checks",
ConfigName: "scan.security-checks",
Deprecated: true, // --security-checks was renamed to --scanners
},
},
Usage: "comma-separated list of what security issues to detect (vuln,config,secret,license)",
}
FilePatternsFlag = Flag{
Expand Down