From fb6e26db2cd6d21992f1925c43f7da95a83880c3 Mon Sep 17 00:00:00 2001 From: Adam S Levy Date: Thu, 23 Jul 2020 16:33:59 -0800 Subject: [PATCH] Add support for url.URL --- bind.go | 8 ++++++++ bind_test.go | 21 +++++++++++++++++++++ url.go | 20 ++++++++++++++++++++ 3 files changed, 49 insertions(+) create mode 100644 url.go diff --git a/bind.go b/bind.go index 28cfe7a..a16f76d 100644 --- a/bind.go +++ b/bind.go @@ -93,6 +93,7 @@ import ( "encoding/json" "flag" "fmt" + "net/url" "reflect" "strings" "time" @@ -396,6 +397,9 @@ func (b bind) bind(fs FlagSet, v interface{}) (err error) { _, isBinder := fieldI.(Binder) _, isFlagValue := fieldI.(flag.Value) + _, isJSONRawMessage := fieldI.(*json.RawMessage) + _, isURL := fieldI.(*url.URL) + isFlagValue = isFlagValue || isJSONRawMessage || isURL isStruct := fieldT.Kind() == reflect.Struct @@ -530,6 +534,8 @@ func bindSTDFlag(fs STDFlagSet, tag flagTag, p interface{}) bool { fs.Var(p, tag.Name, tag.Usage) case *json.RawMessage: fs.Var((*JSONRawMessage)(p), tag.Name, tag.Usage) + case *url.URL: + fs.Var((*URL)(p), tag.Name, tag.Usage) case *bool: val := *p fs.BoolVar(p, tag.Name, val, tag.Usage) @@ -580,6 +586,8 @@ func bindPFlag(fs PFlagSet, tag flagTag, p interface{}, typeName string) bool { f = fs.VarPF(pp, tag.Name, tag.ShortName, tag.Usage) case *json.RawMessage: f = fs.VarPF((*JSONRawMessage)(p), tag.Name, tag.ShortName, tag.Usage) + case *url.URL: + f = fs.VarPF((*URL)(p), tag.Name, tag.ShortName, tag.Usage) case *bool: val := *p fs.BoolVarP(p, tag.Name, tag.ShortName, val, tag.Usage) diff --git a/bind_test.go b/bind_test.go index 3d949b7..0ef9f93 100644 --- a/bind_test.go +++ b/bind_test.go @@ -27,6 +27,7 @@ import ( "fmt" "io" "net/http" + "net/url" "testing" "time" @@ -202,6 +203,9 @@ type ValidTestFlags struct { ExportedInterface interface{} + CustomURLPtr *url.URL + CustomURL url.URL + custom bool } @@ -287,6 +291,10 @@ var tests = []BindTest{ "-struct-b-bool", "-nested-struct-a-bool", "-embedded-struct-b-bool", + "-custom-url", + "http://example.com", + "-custom-url-ptr", + "http://example.com", "-custom", }, ExpF: &ValidTestFlags{ @@ -327,6 +335,11 @@ var tests = []BindTest{ NestedFlat: StructB{true}, StructA: StructA{true, false}, StructB: StructB{true}, + CustomURL: func() url.URL { + u := mustParseURL("http://example.com") + return *u + }(), + CustomURLPtr: mustParseURL("http://example.com"), custom: true, }, }, { @@ -436,3 +449,11 @@ var tests = []BindTest{ }{http.Client{Timeout: 5 * time.Second}}, }, } + +func mustParseURL(rawurl string) *url.URL { + u, err := url.Parse(rawurl) + if err != nil { + panic(err) + } + return u +} diff --git a/url.go b/url.go new file mode 100644 index 0000000..034a62b --- /dev/null +++ b/url.go @@ -0,0 +1,20 @@ +package flagbind + +import "net/url" + +type URL url.URL + +func (u *URL) Set(text string) error { + _u, err := url.Parse(text) + if err != nil { + return err + } + *u = (URL)(*_u) + return nil +} + +func (u URL) String() string { + return (*url.URL)(&u).String() +} + +func (u URL) Type() string { return "URL" }