Skip to content

Commit

Permalink
Implement a config resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
carlpett authored and alecthomas committed Sep 13, 2017
1 parent 23bcc3c commit 76855bf
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 52 deletions.
19 changes: 15 additions & 4 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Application struct {
terminate func(status int) // See Terminate()
noInterspersed bool // can flags be interspersed with args (or must they come first)
defaultEnvars bool
resolvers []*ConfigResolver
completion bool
helpFlag *Clause
helpCommand *CmdClause
Expand All @@ -46,6 +47,7 @@ func New(name, help string) *Application {
defaultUsage: &UsageContext{
Template: DefaultUsageTemplate,
},
resolvers: make([]*ConfigResolver, 0),
}
a.flagGroup = newFlagGroup()
a.argGroup = newArgGroup()
Expand Down Expand Up @@ -128,6 +130,14 @@ func (a *Application) DefaultEnvars() *Application {
return a
}

// ConfigResolver configures all flags, that are not given on the commandline
// or as environment variables, to be looked up using the config resolver
// given.
func (a *Application) ConfigResolver(c ConfigResolver) *Application {
a.resolvers = append(a.resolvers, &c)
return a
}

// Terminate specifies the termination handler. Defaults to os.Exit(status).
// If nil is passed, a no-op function will be used.
func (a *Application) Terminate(terminate func(int)) *Application {
Expand Down Expand Up @@ -170,6 +180,7 @@ func (a *Application) parseContext(ignoreDefault bool, args []string) (*ParseCon
return nil, err
}
context := tokenize(args, ignoreDefault)
context.Resolvers = a.resolvers
err := parse(context, a)
return context, err
}
Expand Down Expand Up @@ -382,7 +393,7 @@ func (a *Application) setDefaults(context *ParseContext) error {
// Check required flags and set defaults.
for _, flag := range context.flags.long {
if flagElements[flag.name] == nil {
if err := flag.setDefault(); err != nil {
if err := flag.setDefault(context); err != nil {
return err
}
} else {
Expand All @@ -392,7 +403,7 @@ func (a *Application) setDefaults(context *ParseContext) error {

for _, arg := range context.arguments.args {
if argElements[arg.name] == nil {
if err := arg.setDefault(); err != nil {
if err := arg.setDefault(context); err != nil {
return err
}
} else {
Expand All @@ -411,15 +422,15 @@ func (a *Application) validateRequired(context *ParseContext) error {
for _, flag := range context.flags.long {
if flagElements[flag.name] == nil {
// Check required flags were provided.
if flag.needsValue() {
if flag.needsValue(context) {
return TError("required flag --{{.Arg0}} not provided", V{"Arg0": flag.name})
}
}
}

for _, arg := range context.arguments.args {
if argElements[arg.name] == nil {
if arg.needsValue() {
if arg.needsValue(context) {
return TError("required argument '{{.Arg0}}' not provided", V{"Arg0": arg.name})
}
}
Expand Down
74 changes: 60 additions & 14 deletions clause.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ type Clause struct {
actionMixin
completionsMixin

name string
shorthand rune
help string
placeholder string
hidden bool
defaultValues []string
value Value
required bool
envar string
noEnvar bool
name string
shorthand rune
help string
placeholder string
hidden bool
defaultValues []string
value Value
required bool
envar string
noEnvar bool
resolverKey string
noConfigResolver bool
}

func NewClause(name, help string) *Clause {
Expand Down Expand Up @@ -139,6 +141,23 @@ func (c *Clause) NoEnvar() *Clause {
return c
}

// ConfigResolverKey overrides the default value(s) for a flag from a config
// file, if it is set. Several default values can be provided by using new
// lines to separate them.
func (c *Clause) ConfigResolverKey(key string) *Clause {
c.resolverKey = key
c.noConfigResolver = false
return c
}

// NoConfigResolver forces resolver variable defaults to be disabled for this flag.
// Most useful in conjunction with app.ConfigResolver().
func (c *Clause) NoConfigResolver() *Clause {
c.resolverKey = ""
c.noConfigResolver = true
return c
}

// PlaceHolder sets the place-holder string used for flag values in the help. The
// default behaviour is to use the value provided by Default() if provided,
// then fall back on the capitalized flag name.
Expand All @@ -165,9 +184,9 @@ func (c *Clause) Short(name rune) *Clause {
return c
}

func (c *Clause) needsValue() bool {
func (c *Clause) needsValue(context *ParseContext) bool {
haveDefault := len(c.defaultValues) > 0
return c.required && !(haveDefault || c.HasEnvarValue())
return c.required && !(haveDefault || c.HasEnvarValue() || c.HasConfigResolvers(context))
}

func (c *Clause) reset() {
Expand All @@ -176,7 +195,7 @@ func (c *Clause) reset() {
}
}

func (c *Clause) setDefault() error {
func (c *Clause) setDefault(context *ParseContext) error {
if c.HasEnvarValue() {
c.reset()
if v, ok := c.value.(cumulativeValue); !ok || !v.IsCumulative() {
Expand All @@ -189,6 +208,9 @@ func (c *Clause) setDefault() error {
}
}
return nil
} else if c.HasConfigResolvers(context) {
c.reset()
c.value.Set(c.GetConfigResolverValue(context))
} else if len(c.defaultValues) > 0 {
c.reset()
for _, defaultValue := range c.defaultValues {
Expand Down Expand Up @@ -227,9 +249,33 @@ func (c *Clause) GetSplitEnvarValue() []string {
return values
}

func (c *Clause) HasConfigResolvers(context *ParseContext) bool {
if len(context.Resolvers) == 0 || c.noConfigResolver {
return false
}
return true
}

func (c *Clause) GetConfigResolverValue(context *ParseContext) string {
var key string
if c.resolverKey != "" {
key = c.resolverKey
} else {
key = c.name
}

for _, r := range context.Resolvers {
val := (*r).Resolve(key, context)
if val != "" {
return val
}
}

return ""
}

func (c *Clause) SetValue(value Value) {
c.value = value
c.setDefault()
}

// StringMap provides key=value parsing into a map.
Expand Down
85 changes: 71 additions & 14 deletions clause_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,28 +68,85 @@ func TestFloat32(t *testing.T) {
assert.InEpsilon(t, 123.45, *v, 0.001)
}

func TestDefaultScalarValueIsSetBeforeParse(t *testing.T) {
func TestUnicodeShortFlag(t *testing.T) {
app := newTestApp()
f := app.Flag("long", "").Short('ä').Bool()
_, err := app.Parse([]string{"-ä"})
assert.NoError(t, err)
assert.True(t, *f)
}

type TestResolver struct {
vals map[string]string
}

func (r *TestResolver) Resolve(key string, context *ParseContext) string {
return r.vals[key]
}

func TestResolverSimple(t *testing.T) {
app := newTestApp()
v := app.Flag("a", "").Default("123").Int()
assert.Equal(t, *v, 123)
_, err := app.Parse([]string{"--a", "456"})
app.ConfigResolver(&TestResolver{vals: map[string]string{"hello": "world"}})
f := app.Flag("hello", "help").String()
_, err := app.Parse([]string{})
assert.NoError(t, err)
assert.Equal(t, *v, 456)
assert.Equal(t, "world", *f)
}

func TestDefaultCumulativeValueIsSetBeforeParse(t *testing.T) {
func TestResolverSatisfiesRequired(t *testing.T) {
app := newTestApp()
v := app.Flag("a", "").Default("123", "456").Ints()
assert.Equal(t, *v, []int{123, 456})
_, err := app.Parse([]string{"--a", "789", "--a", "123"})
app.ConfigResolver(&TestResolver{vals: map[string]string{"hello": "world"}})
f := app.Flag("hello", "help").Required().String()
_, err := app.Parse([]string{})
assert.NoError(t, err)
assert.Equal(t, *v, []int{789, 123})
assert.Equal(t, "world", *f)
}

func TestUnicodeShortFlag(t *testing.T) {
func TestResolverKeyOverride(t *testing.T) {
app := newTestApp()
f := app.Flag("long", "").Short('ä').Bool()
_, err := app.Parse([]string{"-ä"})
app.ConfigResolver(&TestResolver{vals: map[string]string{"foo": "world"}})
f := app.Flag("hello", "help").ConfigResolverKey("foo").String()
_, err := app.Parse([]string{})
assert.NoError(t, err)
assert.True(t, *f)
assert.Equal(t, "world", *f)
}

func TestResolverDisable(t *testing.T) {
app := newTestApp()
app.ConfigResolver(&TestResolver{vals: map[string]string{"hello": "world"}})
f := app.Flag("hello", "help").NoConfigResolver().String()
_, err := app.Parse([]string{})
assert.NoError(t, err)
assert.Equal(t, "", *f)
}

func TestResolverLowerPriorityThanFlag(t *testing.T) {
app := newTestApp()
app.ConfigResolver(&TestResolver{vals: map[string]string{"hello": "world"}})
f := app.Flag("hello", "help").String()
_, err := app.Parse([]string{"--hello", "there"})
assert.NoError(t, err)
assert.Equal(t, "there", *f)
}

func TestResolverLowerPriorityThanEnvar(t *testing.T) {
os.Setenv("TEST_RESOLVER", "foo")
app := newTestApp()
app.ConfigResolver(&TestResolver{vals: map[string]string{"hello": "world"}})
f := app.Flag("hello", "help").Envar("TEST_RESOLVER").String()
_, err := app.Parse([]string{})
assert.NoError(t, err)
assert.Equal(t, "foo", *f)
}

func TestResolverFallbackWithMultipleResolvers(t *testing.T) {
app := newTestApp()
app.ConfigResolver(&TestResolver{vals: map[string]string{"hello": "world"}})
app.ConfigResolver(&TestResolver{vals: map[string]string{"hello": "there", "foo": "bar"}})
f1 := app.Flag("hello", "help").String()
f2 := app.Flag("foo", "help").String()
_, err := app.Parse([]string{})
assert.NoError(t, err)
assert.Equal(t, "world", *f1)
assert.Equal(t, "bar", *f2)
}
42 changes: 22 additions & 20 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ func (f *FlagGroupModel) FlagSummary() string {
}

type ClauseModel struct {
Name string
Help string
Short rune
Default []string
Envar string
PlaceHolder string
Required bool
Hidden bool
Value Value
Cumulative bool
Name string
Help string
Short rune
Default []string
Envar string
ConfigResolverKey string
PlaceHolder string
Required bool
Hidden bool
Value Value
Cumulative bool
}

func (c *ClauseModel) String() string {
Expand Down Expand Up @@ -251,16 +252,17 @@ func (f *flagGroup) Model() *FlagGroupModel {
func (f *Clause) Model() *ClauseModel {
_, cumulative := f.value.(cumulativeValue)
return &ClauseModel{
Name: f.name,
Help: f.help,
Short: f.shorthand,
Default: f.defaultValues,
Envar: f.envar,
PlaceHolder: f.placeholder,
Required: f.required,
Hidden: f.hidden,
Value: f.value,
Cumulative: cumulative,
Name: f.name,
Help: f.help,
Short: f.shorthand,
Default: f.defaultValues,
Envar: f.envar,
ConfigResolverKey: f.resolverKey,
PlaceHolder: f.placeholder,
Required: f.required,
Hidden: f.hidden,
Value: f.value,
Cumulative: cumulative,
}
}

Expand Down
6 changes: 6 additions & 0 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func (p ParseElements) ArgMap() map[string]*ParseElement {
// any).
type ParseContext struct {
SelectedCommand *CmdClause
Resolvers []*ConfigResolver
ignoreDefault bool
argsOnly bool
peek []*Token
Expand All @@ -132,6 +133,11 @@ type ParseContext struct {
Elements ParseElements
}

// A ConfigResolver retrieves configuration from an external source
type ConfigResolver interface {
Resolve(string, *ParseContext) string
}

// LastCmd returns true if the element is the last (sub)command
// being evaluated.
func (p *ParseContext) LastCmd(element *ParseElement) bool {
Expand Down
4 changes: 4 additions & 0 deletions struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func (c *cmdMixin) fromStruct(clause *CmdClause, v interface{}) error { // nolin
required := tag.Get("required")
hidden := tag.Get("hidden")
env := tag.Get("env")
resolverKey := tag.Get("resolverKey")
enum := tag.Get("enum")
name := strings.ToLower(strings.Join(camelCase(ft.Name), "-"))
if tag.Get("long") != "" {
Expand Down Expand Up @@ -98,6 +99,9 @@ func (c *cmdMixin) fromStruct(clause *CmdClause, v interface{}) error { // nolin
if env != "" {
clause = clause.Envar(env)
}
if resolverKey != "" {
clause = clause.ConfigResolverKey(resolverKey)
}
ptr := field.Addr().Interface()
if ft.Type == reflect.TypeOf(time.Duration(0)) {
clause.DurationVar(ptr.(*time.Duration))
Expand Down

0 comments on commit 76855bf

Please sign in to comment.