From b8aca6b0c4616747b07b79e72a4172e1b575acb7 Mon Sep 17 00:00:00 2001 From: Milad Abbasi Date: Fri, 1 Jan 2021 15:07:38 +0330 Subject: [PATCH] Add file and env providers --- .gitignore | 3 - env.go | 184 ++++++++++++++++++++++++ errors.go | 12 +- file.go | 137 ++++++++++++++++++ go.mod | 6 + go.sum | 7 + gonfig.go | 404 +++++++++++------------------------------------------ input.go | 386 ++++++++++++++++++++++++++++++++++++++++++++++++++ snake.go | 20 --- tags.go | 50 +++++-- utils.go | 61 ++++++++ 11 files changed, 913 insertions(+), 357 deletions(-) create mode 100644 env.go create mode 100644 file.go create mode 100644 go.sum create mode 100644 input.go delete mode 100644 snake.go create mode 100644 utils.go diff --git a/.gitignore b/.gitignore index f76c4c1..5bbfe60 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,3 @@ # vendor .DS_Store - -.env* -!.env.example diff --git a/env.go b/env.go new file mode 100644 index 0000000..6560d9d --- /dev/null +++ b/env.go @@ -0,0 +1,184 @@ +package gonfig + +import ( + "errors" + "os" + "strings" + + "github.com/joho/godotenv" +) + +// EnvProvider loads values from environment variables to provided struct +type EnvProvider struct { + // Prefix is used when finding values from environment variables, defaults to "" + EnvPrefix string + + // SnakeCase specifies whether to convert field names to snake_case or not, defaults to true + SnakeCase bool + + // UpperCase specifies whether to convert field names to UPPERCASE or not, defaults to true + UpperCase bool + + // FieldSeparator is used to separate field names, defaults to "_" + FieldSeparator string + + // Source is used to retrieve environment variables + // It can be either a path to a file or empty string, if empty OS will be used + Source string + + // Whether to report error if env file is not found, defaults to false + Required bool +} + +// NewEnvProvider creates a new EnvProvider +func NewEnvProvider() *EnvProvider { + return &EnvProvider{ + EnvPrefix: "", + SnakeCase: true, + UpperCase: true, + FieldSeparator: "_", + Source: "", + Required: false, + } +} + +// NewEnvFileProvider creates a new EnvProvider from .env file +func NewEnvFileProvider(path string) *EnvProvider { + return &EnvProvider{ + EnvPrefix: "", + SnakeCase: true, + UpperCase: true, + FieldSeparator: "_", + Source: path, + Required: false, + } +} + +// Name of provider +func (ep *EnvProvider) Name() string { + return "ENV provider" +} + +// Fill takes struct fields and fills their values +func (ep *EnvProvider) Fill(in *Input) error { + content, err := ep.envMap() + if err != nil { + return err + } + + for _, f := range in.Fields { + value, err := ep.provide(content, f.Tags.Config, f.Path) + if err != nil { + if errors.Is(err, ErrKeyNotFound) { + continue + } + + return err + } + + err = in.setValue(f, value) + if err != nil { + return err + } + + f.IsSet = true + } + + return nil +} + +// envMap returns environment variables map from either OS or file specified by source +// Defaults to operating system env variables +func (ep *EnvProvider) envMap() (map[string]string, error) { + envs := envFromOS() + var fileEnvs map[string]string + var err error + + if ep.Source != "" { + fileEnvs, err = envFromFile(ep.Source) + } + if err != nil { + notExistsErr := errors.Is(err, os.ErrNotExist) + if (notExistsErr && ep.Required) || !notExistsErr { + return nil, err + } + } + + if len(envs) == 0 { + if len(fileEnvs) == 0 { + return nil, nil + } + + envs = make(map[string]string) + } + + for k, v := range fileEnvs { + _, exists := envs[k] + if !exists { + envs[k] = v + } + } + + return envs, nil +} + +// returns environment variables map retrieved from operating system +func envFromOS() map[string]string { + envs := os.Environ() + if len(envs) == 0 { + return nil + } + + envMap := make(map[string]string) + + for _, env := range envs { + keyValue := strings.SplitN(env, "=", 2) + if len(keyValue) < 2 { + continue + } + + envMap[keyValue[0]] = keyValue[1] + } + + return envMap +} + +// returns environment variables map retrieved from specified file +func envFromFile(path string) (map[string]string, error) { + m, err := godotenv.Read(path) + if err != nil { + return nil, err + } + + return m, nil +} + +// provide find a value from env variables based on specified key and path +func (ep *EnvProvider) provide(content map[string]string, key string, path []string) (string, error) { + k := ep.buildKey(key, path) + value, exists := content[k] + if !exists { + return "", ErrKeyNotFound + } + + return value, nil +} + +// buildKey prefix key with EnvPrefix, if not provided, path slice will be used +func (ep *EnvProvider) buildKey(key string, path []string) string { + if key != "" { + return ep.EnvPrefix + key + } + + k := strings.Join(path, ep.FieldSeparator) + if ep.SnakeCase { + k = toSnakeCase(k) + } + if ep.UpperCase { + k = strings.ToUpper(k) + } + + k = ep.EnvPrefix + k + + return k +} diff --git a/errors.go b/errors.go index 187c1b0..b102ea5 100644 --- a/errors.go +++ b/errors.go @@ -10,8 +10,14 @@ var ( // Can not handle specified type ErrUnsupportedType = errors.New("unsupported type") + // Only ".json", ".yml", ".yaml" and ".env" file types are supported + ErrUnsupportedFileExt = errors.New("unsupported file extension") + + // Provider could not find value with specified key + ErrKeyNotFound = errors.New("key not found") + // Field is required but no value provided - ErrMissingValue = errors.New("missing value") + ErrRequiredField = errors.New("field is required") // Could not parse the string value ErrParsing = errors.New("failed parsing") @@ -21,8 +27,10 @@ var ( ) const ( - missingValueErrFormat = `%w: "%v" is required` unsupportedTypeErrFormat = `%w: cannot handle type "%v" at "%v"` + unsupportedFileExtErrFormat = `%w: %v` + decodeFailedErrFormat = `failed to decode: %w` + requiredFieldErrFormat = `%w: no value found for "%v"` unsupportedElementTypeErrFormat = `%w: cannot handle slice/array of "%v" at "%v"` parseErrFormat = `%w at "%v": %v` overflowErrFormat = `%w: "%v" overflows type "%v" at "%v"` diff --git a/file.go b/file.go new file mode 100644 index 0000000..08811dc --- /dev/null +++ b/file.go @@ -0,0 +1,137 @@ +package gonfig + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/BurntSushi/toml" + "gopkg.in/yaml.v2" +) + +// Supported file extensions +const ( + JSON = ".json" + YML = ".yml" + YAML = ".yaml" + ENV = ".env" + TOML = ".toml" +) + +// FileProvider loads values from file to provided struct +type FileProvider struct { + // Path to file + FilePath string + + // File will be decoded based on extension + // .json, .yml(.yaml), .env and .toml file extensions are supported + FileExt string + + // Whether to report error if file is not found, defaults to false + Required bool +} + +// NewFileProvider creates a new FileProvider from specified path +func NewFileProvider(path string) *FileProvider { + return &FileProvider{ + FilePath: path, + FileExt: filepath.Ext(path), + Required: false, + } +} + +// Name of provider +func (fp *FileProvider) Name() string { + return "File provider" +} + +// UnmarshalStruct takes a struct pointer and loads values from provided file into it +func (fp *FileProvider) UnmarshalStruct(i interface{}) error { + return fp.decode(i) +} + +// Fill takes struct fields and and checks if their value is set +func (fp *FileProvider) Fill(in *Input) error { + var content map[string]interface{} + if err := fp.decode(&content); err != nil { + return err + } + + for _, f := range in.Fields { + if f.IsSet { + continue + } + + var key string + switch fp.FileExt { + case JSON: + key = f.Tags.Json + case YML, YAML: + key = f.Tags.Yaml + case TOML: + key = f.Tags.Toml + } + + _, err := fp.provide(content, key, f.Path) + if err == nil { + f.IsSet = true + } + } + + return nil +} + +// decode opens specified file and loads its content to input argument +func (fp *FileProvider) decode(i interface{}) (err error) { + f, err := os.Open(fp.FilePath) + if err != nil { + if os.IsNotExist(err) && !fp.Required { + return nil + } + + return fmt.Errorf("file provider: %w", err) + } + defer func() { + if cerr := f.Close(); cerr != nil && err == nil { + err = cerr + } + }() + + switch fp.FileExt { + case JSON: + err = json.NewDecoder(f).Decode(i) + + case YML, YAML: + err = yaml.NewDecoder(f).Decode(i) + + case TOML: + _, err = toml.DecodeReader(f, i) + + default: + err = fmt.Errorf(unsupportedFileExtErrFormat, ErrUnsupportedFileExt, fp.FileExt) + } + + if err != nil { + return fmt.Errorf(decodeFailedErrFormat, err) + } + + return nil +} + +// provide find a value from file content based on specified key and path +func (fp *FileProvider) provide(content map[string]interface{}, key string, path []string) (string, error) { + return traverseMap(content, fp.buildPath(key, path)) +} + +// buildPath makes a path from key and path slice +func (fp *FileProvider) buildPath(key string, path []string) []string { + newPath := make([]string, len(path)) + copy(newPath, path) + + if key != "" { + newPath[len(newPath)-1] = key + } + + return newPath +} diff --git a/go.mod b/go.mod index 47bbdeb..0536672 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,9 @@ module github.com/milad-abbasi/gonfig go 1.15 + +require ( + github.com/BurntSushi/toml v0.3.1 + github.com/joho/godotenv v1.3.0 + gopkg.in/yaml.v2 v2.4.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..635f88a --- /dev/null +++ b/go.sum @@ -0,0 +1,7 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/gonfig.go b/gonfig.go index 6fe70e7..c386eae 100644 --- a/gonfig.go +++ b/gonfig.go @@ -2,366 +2,124 @@ package gonfig import ( "fmt" - "net/url" - "os" + "path/filepath" "reflect" - "strconv" - "strings" - "time" ) -// TODO: separate this into another struct -type Gonfig struct { - Prefix string - structName string - ce ConfigErrors -} +// Config loads values from specified providers into given struct +type Config struct { + // Providers are applied at the order specified + // If multiple values are provided for a field, last one will get applied + Providers []Provider -func New(prefix string) *Gonfig { - return &Gonfig{ - Prefix: prefix, - } + // Collection of errors during loading values into provided struct + ce ConfigErrors } -// Input must be a non-nil struct pointer -func checkInput(i interface{}) error { - t := reflect.TypeOf(i) - v := reflect.ValueOf(i) +// Provider is used to provide values +// It can implement either Unmarshaler or Filler interface or both +// Name method is used for error messages +type Provider interface { + Name() string +} - if t == nil || - t.Kind() != reflect.Ptr || - v.IsNil() || - t.Elem().Kind() != reflect.Struct { - return &InvalidInputError{ - Type: t, - Value: v, - } - } +// Unmarshaler can be implemented by providers to receive struct pointer and unmarshal values into it +type Unmarshaler interface { + UnmarshalStruct(i interface{}) (err error) +} - return nil +// Filler can be implemented by providers to receive struct fields and set their value +type Filler interface { + Fill(in *Input) (err error) } -func (g *Gonfig) Into(i interface{}) error { - if err := checkInput(i); err != nil { - return err - } +// Load creates a new Config object +func Load() *Config { + return &Config{} +} - v := reflect.ValueOf(i) - g.structName = v.Type().String() - v = v.Elem() +// FromEnv adds an EnvProvider to Providers list +func (c *Config) FromEnv() *Config { + return c.FromEnvWithConfig(NewEnvProvider()) +} - g.populate(v, "", &ConfigTags{}) +// FromEnvWithConfig adds an EnvProvider to Providers list with specified config +func (c *Config) FromEnvWithConfig(ep *EnvProvider) *Config { + c.Providers = append(c.Providers, ep) + return c +} - if len(g.ce) != 0 { - return g.ce +// FromFile adds a FileProvider to Providers list +// In case of .env file, it adds a EnvProvider to the list +func (c *Config) FromFile(path string) *Config { + if filepath.Ext(path) == ENV { + return c.FromEnvWithConfig(NewEnvFileProvider(path)) } - return nil + return c.FromFileWithConfig(NewFileProvider(path)) } -func (g *Gonfig) populate(v reflect.Value, value string, tags *ConfigTags, path ...string) { - if tags.Ignore || !v.CanSet() { - return +// FromRequiredFile adds a FileProvider to Providers list with specified config +// In case of .env file, it adds a EnvProvider to the list +func (c *Config) FromFileWithConfig(fp *FileProvider) *Config { + if fp.FileExt == ENV || filepath.Ext(fp.FilePath) == ENV { + return c.FromEnvWithConfig(NewEnvFileProvider(fp.FilePath)) } - // TODO: it should not called here, if struct => bug! - if v.Kind() != reflect.Struct && value == "" { - var key string - if tags.Config != "" { - key = g.Prefix + tags.Config - } else { - key = g.Prefix + toScreamingSnakeCase(path) - } - - var exists bool - value, exists = os.LookupEnv(key) - if !exists { - if tags.Required { - g.collectError(fmt.Errorf(missingValueErrFormat, ErrMissingValue, g.getPath(path))) - return - } else { - value = tags.Default - } - } + c.Providers = append(c.Providers, fp) + return c +} - if tags.Expand { - value = os.ExpandEnv(value) - } +// Into will apply all specified providers in order declared +// and validate final struct for required and default fields +// If multiple values are provided for a field, last one will get applied +func (c *Config) Into(i interface{}) error { + in, err := NewInput(i) + if err != nil { + return err } - switch v.Kind() { - case reflect.String: - v.SetString(value) - - case reflect.Bool: - b, err := strconv.ParseBool(value) - if err != nil { - g.collectError( - fmt.Errorf( - parseErrFormat, - ErrParsing, g.getPath(path), err, - ), - ) - return - } - - v.SetBool(b) - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - var d time.Duration - var i int64 - var err error - - if isDuration(v) { - d, err = time.ParseDuration(value) - if err != nil { - g.collectError( - fmt.Errorf( - parseErrFormat, - ErrParsing, g.getPath(path), err, - ), - ) - return + for _, p := range c.Providers { + if u, ok := p.(Unmarshaler); ok { + if err := u.UnmarshalStruct(i); err != nil { + c.collectError(err) } - - i = int64(d) - } else { - i, err = strconv.ParseInt(value, 0, 64) - if err != nil { - g.collectError( - fmt.Errorf( - parseErrFormat, - ErrParsing, g.getPath(path), err, - ), - ) - return - } - } - - if v.OverflowInt(i) { - g.collectError( - fmt.Errorf( - overflowErrFormat, - ErrValueOverflow, i, v.Kind(), g.getPath(path), - ), - ) - return - } - - v.SetInt(i) - - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - i, err := strconv.ParseUint(value, 0, 64) - if err != nil { - g.collectError( - fmt.Errorf( - parseErrFormat, - ErrParsing, g.getPath(path), err, - ), - ) - return - } - - if v.OverflowUint(i) { - g.collectError( - fmt.Errorf( - overflowErrFormat, - ErrValueOverflow, i, v.Kind(), g.getPath(path), - ), - ) - return - } - - v.SetUint(i) - - case reflect.Float32, reflect.Float64: - f, err := strconv.ParseFloat(value, v.Type().Bits()) - if err != nil { - g.collectError( - fmt.Errorf( - parseErrFormat, - ErrParsing, g.getPath(path), err, - ), - ) - return - } - - if v.OverflowFloat(f) { - g.collectError( - fmt.Errorf( - overflowErrFormat, - ErrValueOverflow, f, v.Kind(), g.getPath(path), - ), - ) - return } - v.SetFloat(f) - - case reflect.Complex64, reflect.Complex128: - c, err := strconv.ParseComplex(value, v.Type().Bits()) - if err != nil { - g.collectError( - fmt.Errorf( - parseErrFormat, - ErrParsing, g.getPath(path), err, - ), - ) - return - } - - if v.OverflowComplex(c) { - g.collectError( - fmt.Errorf( - overflowErrFormat, - ErrValueOverflow, c, v.Kind(), g.getPath(path), - ), - ) - return - } - - v.SetComplex(c) - - case reflect.Slice, reflect.Array: - switch v.Type().Elem().Kind() { - case reflect.Slice, - reflect.Array, - reflect.Uintptr, - reflect.Chan, - reflect.Func, - reflect.Interface, - reflect.UnsafePointer: - g.collectError( - fmt.Errorf( - unsupportedElementTypeErrFormat, - ErrUnsupportedType, v.Type().Elem().Kind(), g.getPath(path), - ), - ) - return - } - - var items []string - for _, v := range strings.Split(value, tags.Separator) { - item := strings.TrimSpace(v) - if len(item) > 0 { - items = append(items, item) - } - } - if len(items) == 0 { - return - } - - switch v.Kind() { - // FIXME: in case of parse error slice should not get initialized - case reflect.Slice: - size := len(items) - sv := reflect.MakeSlice(reflect.SliceOf(v.Type().Elem()), size, size) - - for i := range items { - g.populate(sv.Index(i), items[i], tags, path...) - } - - v.Set(sv) - - case reflect.Array: - size := v.Len() - if size == 0 { - return + if f, ok := p.(Filler); ok { + if err := f.Fill(in); err != nil { + c.collectError(err) } - - at := reflect.ArrayOf(size, v.Type().Elem()) - av := reflect.New(at).Elem() - - for i := 0; i < size; i++ { - g.populate(av.Index(i), items[i], tags, path...) - } - - v.Set(av) } + } - case reflect.Map: - // TODO - - case reflect.Ptr: - pv := reflect.New(v.Type().Elem()) - g.populate(pv.Elem(), value, tags, path...) - v.Set(pv) - - case reflect.Struct: - if isTime(v) { - format := tags.Format - if format == "" { - format = time.RFC3339 - } - - t, err := time.Parse(format, value) - if err != nil { - g.collectError( - fmt.Errorf( - parseErrFormat, - ErrParsing, g.getPath(path), err, - ), - ) - return + for _, f := range in.Fields { + if !f.IsSet { + if f.Tags.Required { + c.collectError(fmt.Errorf(requiredFieldErrFormat, ErrRequiredField, in.getPath(f.Path))) + } else if f.Tags.Default != "" { + err := in.setValue(f, f.Tags.Default) + if err != nil { + c.collectError(err) + } } - - v.Set(reflect.ValueOf(t)) - return } - if isURL(v) { - u, err := url.Parse(value) + if f.Tags.Expand && f.Value.Kind() == reflect.String { + err := in.setValue(f, f.Value.String()) if err != nil { - g.collectError( - fmt.Errorf( - parseErrFormat, - ErrParsing, g.getPath(path), err, - ), - ) - return + c.collectError(err) } - - v.Set(reflect.ValueOf(*u)) - return - } - - for i := 0; i < v.NumField(); i++ { - currentPath := append(path, v.Type().Field(i).Name) - - g.populate( - v.Field(i), - value, - getTags(v.Type().Field(i).Tag), - currentPath..., - ) } - - default: - g.collectError( - fmt.Errorf( - unsupportedTypeErrFormat, - ErrUnsupportedType, v.Kind(), g.getPath(path), - ), - ) } -} - -func (g *Gonfig) collectError(e error) { - g.ce = append(g.ce, e) -} -func (g *Gonfig) getPath(paths []string) string { - return g.structName + "." + strings.Join(paths, ".") -} - -func isDuration(v reflect.Value) bool { - return v.Type().PkgPath() == "time" && v.Type().Name() == "Duration" -} + if len(c.ce) != 0 { + return c.ce + } -func isTime(v reflect.Value) bool { - return v.Type().PkgPath() == "time" && v.Type().Name() == "Time" + return nil } -func isURL(v reflect.Value) bool { - return v.Type().PkgPath() == "net/url" && v.Type().Name() == "URL" +func (c *Config) collectError(e error) { + c.ce = append(c.ce, e) } diff --git a/input.go b/input.go new file mode 100644 index 0000000..87bb888 --- /dev/null +++ b/input.go @@ -0,0 +1,386 @@ +package gonfig + +import ( + "fmt" + "net/url" + "os" + "reflect" + "strconv" + "strings" + "time" +) + +// Input stores information about given struct +type Input struct { + // Struct name is used for error messages + Name string + + // Fields information + Fields []*Field +} + +// Struct field information +type Field struct { + // Field value + Value reflect.Value + + // Field tags + Tags *ConfigTags + + // Slice of field names from root of struct all the way down to the field + Path []string + + // IsSet specifies whether field value is set by one of the providers + IsSet bool +} + +// NewInput validates and returns a new Input with all settable fields +// Input argument must be a non-nil struct pointer +func NewInput(i interface{}) (*Input, error) { + v := reflect.ValueOf(i) + + if err := checkInput(v); err != nil { + return nil, err + } + + in := Input{ + Name: v.Type().String(), + } + + f := Field{ + Value: v.Elem(), + Tags: &ConfigTags{}, + } + + if err := in.traverseFiled(&f); err != nil { + return nil, err + } + + return &in, nil +} + +// checkInput checks for a non-nil struct pointer +func checkInput(v reflect.Value) error { + if v.Type() == nil || + v.Type().Kind() != reflect.Ptr || + v.IsNil() || + v.Type().Elem().Kind() != reflect.Struct { + return &InvalidInputError{ + Type: v.Type(), + Value: v, + } + } + + return nil +} + +// traverseFiled recursively traverse all fields and collect their information +func (in *Input) traverseFiled(f *Field) error { + if !f.Value.CanSet() || f.Tags.Ignore { + return nil + } + + switch f.Value.Kind() { + case reflect.Struct: + if isTime(f.Value) || isURL(f.Value) { + in.collectField(f) + + return nil + } + + for i := 0; i < f.Value.NumField(); i++ { + nestedField := Field{ + Value: f.Value.Field(i), + Tags: getTags(f.Value.Type().Field(i).Tag), + Path: append(f.Path, f.Value.Type().Field(i).Name), + } + + if err := in.traverseFiled(&nestedField); err != nil { + return err + } + } + + case reflect.Ptr: + pv := reflect.New(f.Value.Type().Elem()) + f.Value.Set(pv) + + pointedField := Field{ + Value: pv.Elem(), + Tags: f.Tags, + Path: f.Path, + } + + return in.traverseFiled(&pointedField) + + case reflect.Slice, reflect.Array: + switch f.Value.Type().Elem().Kind() { + case reflect.Slice, + reflect.Array, + reflect.Uintptr, + reflect.Chan, + reflect.Func, + reflect.Interface, + reflect.UnsafePointer: + return fmt.Errorf( + unsupportedElementTypeErrFormat, + ErrUnsupportedType, f.Value.Type().Elem().Kind(), in.getPath(f.Path), + ) + + default: + in.collectField(f) + } + + case reflect.Uintptr, + reflect.Chan, + reflect.Func, + reflect.Interface, + reflect.UnsafePointer: + return fmt.Errorf( + unsupportedTypeErrFormat, + ErrUnsupportedType, f.Value.Kind(), in.getPath(f.Path), + ) + + default: + in.collectField(f) + } + + return nil +} + +func (in *Input) collectField(f *Field) { + in.Fields = append(in.Fields, f) +} + +// setValue validates and sets the value of a struct field +func (in *Input) setValue(f *Field, value string) error { + if f.Tags.Expand { + value = os.ExpandEnv(value) + } + + switch f.Value.Kind() { + case reflect.String: + f.Value.SetString(value) + + case reflect.Bool: + b, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf( + parseErrFormat, + ErrParsing, in.getPath(f.Path), err, + ) + } + + f.Value.SetBool(b) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var d time.Duration + var i int64 + var err error + + if isDuration(f.Value) { + d, err = time.ParseDuration(value) + if err != nil { + return fmt.Errorf( + parseErrFormat, + ErrParsing, in.getPath(f.Path), err, + ) + } + + i = int64(d) + } else { + i, err = strconv.ParseInt(value, 0, 64) + if err != nil { + return fmt.Errorf( + parseErrFormat, + ErrParsing, in.getPath(f.Path), err, + ) + } + } + + if f.Value.OverflowInt(i) { + return fmt.Errorf( + overflowErrFormat, + ErrValueOverflow, i, f.Value.Kind(), in.getPath(f.Path), + ) + } + + f.Value.SetInt(i) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i, err := strconv.ParseUint(value, 0, 64) + if err != nil { + return fmt.Errorf( + parseErrFormat, + ErrParsing, in.getPath(f.Path), err, + ) + } + + if f.Value.OverflowUint(i) { + return fmt.Errorf( + overflowErrFormat, + ErrValueOverflow, i, f.Value.Kind(), in.getPath(f.Path), + ) + } + + f.Value.SetUint(i) + + case reflect.Float32, reflect.Float64: + fv, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf( + parseErrFormat, + ErrParsing, in.getPath(f.Path), err, + ) + } + + if f.Value.OverflowFloat(fv) { + return fmt.Errorf( + overflowErrFormat, + ErrValueOverflow, fv, f.Value.Kind(), in.getPath(f.Path), + ) + } + + f.Value.SetFloat(fv) + + case reflect.Complex64, reflect.Complex128: + cx, err := strconv.ParseComplex(value, 64) + if err != nil { + return fmt.Errorf( + parseErrFormat, + ErrParsing, in.getPath(f.Path), err, + ) + } + + if f.Value.OverflowComplex(cx) { + return fmt.Errorf( + overflowErrFormat, + ErrValueOverflow, cx, f.Value.Kind(), in.getPath(f.Path), + ) + } + + f.Value.SetComplex(cx) + + case reflect.Slice, reflect.Array: + switch f.Value.Type().Elem().Kind() { + case reflect.Slice, + reflect.Array, + reflect.Uintptr, + reflect.Chan, + reflect.Func, + reflect.Interface, + reflect.UnsafePointer: + return fmt.Errorf( + unsupportedElementTypeErrFormat, + ErrUnsupportedType, f.Value.Type().Elem().Kind(), in.getPath(f.Path), + ) + } + + var items []string + for _, v := range strings.Split(value, f.Tags.Separator) { + item := strings.TrimSpace(v) + if len(item) > 0 { + items = append(items, item) + } + } + if len(items) == 0 { + return nil + } + + switch f.Value.Kind() { + case reflect.Slice: + size := len(items) + sv := reflect.MakeSlice(reflect.SliceOf(f.Value.Type().Elem()), size, size) + + for i := range items { + nestedField := Field{ + Value: sv.Index(i), + Tags: f.Tags, + Path: f.Path, + } + + if err := in.setValue(&nestedField, items[i]); err != nil { + return err + } + } + + f.Value.Set(sv) + + case reflect.Array: + size := f.Value.Len() + if size == 0 { + return nil + } + + at := reflect.ArrayOf(size, f.Value.Type().Elem()) + av := reflect.New(at).Elem() + + for i := 0; i < size; i++ { + nestedField := Field{ + Value: av.Index(i), + Tags: f.Tags, + Path: f.Path, + } + + if err := in.setValue(&nestedField, items[i]); err != nil { + return err + } + } + + f.Value.Set(av) + } + + case reflect.Map: + // TODO + + case reflect.Ptr: + pv := reflect.New(f.Value.Type().Elem()) + f.Value.Set(pv) + pointedField := Field{ + Value: pv.Elem(), + Tags: f.Tags, + Path: f.Path, + } + + return in.setValue(&pointedField, value) + + case reflect.Struct: + if isTime(f.Value) { + t, err := time.Parse(f.Tags.Format, value) + if err != nil { + return fmt.Errorf( + parseErrFormat, + ErrParsing, in.getPath(f.Path), err, + ) + } + + f.Value.Set(reflect.ValueOf(t)) + return nil + } + + if isURL(f.Value) { + u, err := url.Parse(value) + if err != nil { + return fmt.Errorf( + parseErrFormat, + ErrParsing, in.getPath(f.Path), err, + ) + } + + f.Value.Set(reflect.ValueOf(*u)) + return nil + } + + default: + return fmt.Errorf( + unsupportedTypeErrFormat, + ErrUnsupportedType, f.Value.Kind(), in.getPath(f.Path), + ) + } + + return nil +} + +// getPath returns a dot separated string prefixed with struct name +func (in *Input) getPath(paths []string) string { + return in.Name + "." + strings.Join(paths, ".") +} diff --git a/snake.go b/snake.go deleted file mode 100644 index 05bd0c0..0000000 --- a/snake.go +++ /dev/null @@ -1,20 +0,0 @@ -package gonfig - -import ( - "regexp" - "strings" -) - -var ( - firstCapRegex = regexp.MustCompile("([A-Z])([A-Z][a-z])") - allCapRegex = regexp.MustCompile("([a-z0-9])([A-Z])") -) - -func toScreamingSnakeCase(in []string) string { - s := strings.Join(in, "_") - out := firstCapRegex.ReplaceAllString(s, "${1}_${2}") - out = allCapRegex.ReplaceAllString(out, "${1}_${2}") - out = strings.ReplaceAll(out, "-", "_") - - return strings.ToUpper(out) -} diff --git a/tags.go b/tags.go index f5dd661..c702a58 100644 --- a/tags.go +++ b/tags.go @@ -1,40 +1,58 @@ package gonfig -import "reflect" +import ( + "reflect" + "strings" + "time" +) const ( defaultSeparator = " " ignoreCharacter = "-" ) -// all possible useful tags +// Possible tags, all are optional type ConfigTags struct { - // config key name, use "-" to ignore, defaults to field name + // Key to be used by providers to retrieve the needed value, defaults to field name. + // Use "-" to ignore the field. Config string - // default value for field + // json tag for json files + Json string + + // yaml tag for yaml files + Yaml string + + // toml tag for toml files + Toml string + + // Default value for field. Default string - // specify if value should be present, defaults to false + // Specify if value should be present, defaults to false. Required bool - // specify if field should be ignored, defaults to false + // Specify if field should be ignored, defaults to false. Ignore bool - // specify if value should be expanded from env, defaults to false + // Specify if value should be expanded from env, defaults to false. Expand bool - // separator to be used for slice/array items, defaults to " " + // Separator to be used for slice/array items, defaults to " ". Separator string - // format to be used for parsing time strings, defaults to time.RFC3339 + // Format to be used for parsing time strings, defaults to time.RFC3339. Format string } +// Returns default config tags. func getTags(st reflect.StructTag) *ConfigTags { tags := ConfigTags{ Config: st.Get("config"), Default: st.Get("default"), + Json: extractKeyName(st.Get("json")), + Yaml: extractKeyName(st.Get("yaml")), + Toml: extractKeyName(st.Get("toml")), Required: st.Get("required") == "true", Ignore: st.Get("ignore") == "true", Expand: st.Get("expand") == "true", @@ -48,6 +66,20 @@ func getTags(st reflect.StructTag) *ConfigTags { if tags.Separator == "" { tags.Separator = defaultSeparator } + if tags.Format == "" { + tags.Format = time.RFC3339 + } return &tags } + +// It extracts name of the key from file tag, ignoring options +// e.g. calling with "field,omitempty" would return "field" +func extractKeyName(key string) string { + slice := strings.Split(key, ",") + if len(slice) == 0 { + return "" + } + + return slice[0] +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..762cfdc --- /dev/null +++ b/utils.go @@ -0,0 +1,61 @@ +package gonfig + +import ( + "fmt" + "reflect" + "regexp" + "strings" +) + +var ( + firstCapRegex = regexp.MustCompile("([A-Z])([A-Z][a-z])") + allCapRegex = regexp.MustCompile("([a-z0-9])([A-Z])") +) + +// toSnakeCase converts input string into snake_case form +func toSnakeCase(s string) string { + out := firstCapRegex.ReplaceAllString(s, "${1}_${2}") + out = allCapRegex.ReplaceAllString(out, "${1}_${2}") + out = strings.ReplaceAll(out, "-", "_") + + return out +} + +func isDuration(v reflect.Value) bool { + return v.Type().PkgPath() == "time" && v.Type().Name() == "Duration" +} + +func isTime(v reflect.Value) bool { + return v.Type().PkgPath() == "time" && v.Type().Name() == "Time" +} + +func isURL(v reflect.Value) bool { + return v.Type().PkgPath() == "net/url" && v.Type().Name() == "URL" +} + +// traverseMap finds a value in a map based on provided path +func traverseMap(m map[string]interface{}, path []string) (string, error) { + if len(path) == 0 { + return "", ErrKeyNotFound + } + first, path := path[0], path[1:] + + value, exists := m[first] + if !exists { + value, exists = m[strings.ToLower(first)] + if !exists { + return "", ErrKeyNotFound + } + } + + if len(path) == 0 { + return fmt.Sprint(value), nil + } + + nestedMap, ok := value.(map[string]interface{}) + if !ok { + return "", ErrKeyNotFound + } + + return traverseMap(nestedMap, path) +}