Skip to content

Commit

Permalink
Add support for time.Time flags
Browse files Browse the repository at this point in the history
Signed-off-by: Knut Ahlers <knut@ahlers.me>
  • Loading branch information
Luzifer committed Sep 17, 2018
1 parent 913d4f1 commit f4e07b5
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 13 deletions.
110 changes: 97 additions & 13 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,25 @@ import (
validator "gopkg.in/validator.v2"
)

type afterFunc func() error

var (
autoEnv bool
fs *pflag.FlagSet
variableDefaults map[string]string

timeParserFormats = []string{
// Default constants
time.RFC3339Nano, time.RFC3339,
time.RFC1123Z, time.RFC1123,
time.RFC822Z, time.RFC822,
time.RFC850, time.RubyDate, time.UnixDate, time.ANSIC,
"2006-01-02 15:04:05.999999999 -0700 MST",
// More uncommon time formats
"2006-01-02 15:04:05", "2006-01-02 15:04:05Z07:00", // Simplified ISO time format
"01/02/2006 15:04:05", "01/02/2006 15:04:05Z07:00", // US time format
"02.01.2006 15:04:05", "02.01.2006 15:04:05Z07:00", // DE time format
}
)

func init() {
Expand Down Expand Up @@ -61,6 +76,11 @@ func Args() []string {
return fs.Args()
}

// AddTimeParserFormats adds custom formats to parse time.Time fields
func AddTimeParserFormats(f ...string) {
timeParserFormats = append(timeParserFormats, f...)
}

// AutoEnv enables or disables automated env variable guessing. If no `env` struct
// tag was set and AutoEnv is enabled the env variable name is derived from the
// name of the field: `MyFieldName` will get `MY_FIELD_NAME`
Expand Down Expand Up @@ -97,22 +117,37 @@ func parse(in interface{}, args []string) error {
}

fs = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError)
if err := execTags(in, fs); err != nil {
afterFuncs, err := execTags(in, fs)
if err != nil {
return err
}

return fs.Parse(args)
if err := fs.Parse(args); err != nil {
return err
}

if afterFuncs != nil {
for _, f := range afterFuncs {
if err := f(); err != nil {
return err
}
}
}

return nil
}

func execTags(in interface{}, fs *pflag.FlagSet) error {
func execTags(in interface{}, fs *pflag.FlagSet) ([]afterFunc, error) {
if reflect.TypeOf(in).Kind() != reflect.Ptr {
return errors.New("Calling parser with non-pointer")
return nil, errors.New("Calling parser with non-pointer")
}

if reflect.ValueOf(in).Elem().Kind() != reflect.Struct {
return errors.New("Calling parser with pointer to non-struct")
return nil, errors.New("Calling parser with pointer to non-struct")
}

afterFuncs := []afterFunc{}

st := reflect.ValueOf(in).Elem()
for i := 0; i < st.NumField(); i++ {
valField := st.Field(i)
Expand All @@ -134,7 +169,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error {
if value == "" {
v = time.Duration(0)
} else {
return err
return nil, err
}
}

Expand All @@ -148,6 +183,53 @@ func execTags(in interface{}, fs *pflag.FlagSet) error {
valField.Set(reflect.ValueOf(v))
}
continue

case reflect.TypeOf(time.Time{}):
var sVar string

if typeField.Tag.Get("flag") != "" {
if len(parts) == 1 {
fs.StringVar(&sVar, parts[0], value, typeField.Tag.Get("description"))
} else {
fs.StringVarP(&sVar, parts[0], parts[1], value, typeField.Tag.Get("description"))
}
} else {
sVar = value
}

afterFuncs = append(afterFuncs, func(valField reflect.Value, sVar *string) func() error {
return func() error {
if *sVar == "" {
// No time, no problem
return nil
}

// Check whether we could have a timestamp
if ts, err := strconv.ParseInt(*sVar, 10, 64); err == nil {
t := time.Unix(ts, 0)
valField.Set(reflect.ValueOf(t))
return nil
}

// We haven't so lets walk through possible time formats
matched := false
for _, tf := range timeParserFormats {
if t, err := time.Parse(tf, *sVar); err == nil {
matched = true
valField.Set(reflect.ValueOf(t))
return nil
}
}

if !matched {
return fmt.Errorf("Value %q did not match expected time formats", *sVar)
}

return nil
}
}(valField, &sVar))

continue
}

switch typeField.Type.Kind() {
Expand Down Expand Up @@ -180,7 +262,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error {
if value == "" {
vt = 0
} else {
return err
return nil, err
}
}
if typeField.Tag.Get("flag") != "" {
Expand All @@ -195,7 +277,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error {
if value == "" {
vt = 0
} else {
return err
return nil, err
}
}
if typeField.Tag.Get("flag") != "" {
Expand All @@ -210,7 +292,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error {
if value == "" {
vt = 0.0
} else {
return err
return nil, err
}
}
if typeField.Tag.Get("flag") != "" {
Expand All @@ -220,9 +302,11 @@ func execTags(in interface{}, fs *pflag.FlagSet) error {
}

case reflect.Struct:
if err := execTags(valField.Addr().Interface(), fs); err != nil {
return err
afs, err := execTags(valField.Addr().Interface(), fs)
if err != nil {
return nil, err
}
afterFuncs = append(afterFuncs, afs...)

case reflect.Slice:
switch typeField.Type.Elem().Kind() {
Expand All @@ -231,7 +315,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error {
for _, v := range strings.Split(value, ",") {
it, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64)
if err != nil {
return err
return nil, err
}
def = append(def, int(it))
}
Expand All @@ -258,7 +342,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error {
}
}

return nil
return afterFuncs, nil
}

func registerFlagFloat(t reflect.Kind, fs *pflag.FlagSet, field interface{}, parts []string, vt float64, desc string) {
Expand Down
51 changes: 51 additions & 0 deletions time_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package rconfig

import (
"fmt"
"testing"
"time"
)

func TestParseTime(t *testing.T) {
type ts struct {
Test time.Time `flag:"time"`
TestS time.Time `flag:"other-time,o"`
TestDef time.Time `default:"2006-01-02T15:04:05.999999999Z"`
TestDE time.Time `default:"18.09.2018 20:25:31"`
}

var (
err error
args []string
cfg ts
)

for _, tf := range timeParserFormats {
expect := time.Now().Format(tf)

cfg = ts{}
args = []string{
fmt.Sprintf("--time=%s", expect),
"-o", expect,
}

if err = parse(&cfg, args); err != nil {
t.Fatalf("Time format %q did not parse: %s", tf, err)
}

for name, ti := range map[string]time.Time{
"Long flag": cfg.Test,
"Short flag": cfg.TestS,
"Default flag": cfg.TestDef,
"DE flag": cfg.TestDE,
} {
if ti.IsZero() {
t.Errorf("%s did parse to zero with format %q", name, tf)
}
}

if e := cfg.Test.Format(tf); e != expect {
t.Errorf("Parsed time %q did not match expectation %q", e, expect)
}
}
}

0 comments on commit f4e07b5

Please sign in to comment.