Skip to content

Commit

Permalink
Add tests for SetValue method
Browse files Browse the repository at this point in the history
  • Loading branch information
Milad Abbasi committed Jan 19, 2021
1 parent cdd0f1a commit f9bb8b3
Show file tree
Hide file tree
Showing 4 changed files with 717 additions and 359 deletions.
18 changes: 11 additions & 7 deletions errors.go
Expand Up @@ -13,6 +13,9 @@ var (
// Only ".json", ".yml", ".yaml" and ".env" file types are supported
ErrUnsupportedFileExt = errors.New("unsupported file extension")

// Struct field is unexported
ErrUnSettableField = errors.New("unSettable field")

// Provider could not find value with specified key
ErrKeyNotFound = errors.New("key not found")

Expand All @@ -27,13 +30,14 @@ var (
)

const (
unsupportedTypeErrFormat = `%w: cannot handle type "%v" at "%v"`
unsupportedFileExtErrFormat = `%w: %v`
decodeFailedErrFormat = `failed to decode: %w`
requiredFieldErrFormat = `%w: no value found for "%v"`
unsupportedElementTypeErrFormat = `%w: cannot handle slice/array of "%v" at "%v"`
parseErrFormat = `%w at "%v": %v`
overflowErrFormat = `%w: "%v" overflows type "%v" at "%v"`
unsupportedTypeErrFormat = `%w: %v`
badFieldErrFormat = `bad field "%v": %w`
unsupportedFileExtErrFormat = `%w: %v`
unSettableFieldErrFormat = `%w: %v`
decodeFailedErrFormat = `failed to decode: %w`
requiredFieldErrFormat = `%w: no value found for "%v"`
parseErrFormat = `%w at "%v": %v`
overflowErrFormat = `%w: "%v" overflows type "%v" at "%v"`
)

// An InvalidInputError describes an invalid argument passed to Into function
Expand Down
175 changes: 72 additions & 103 deletions input.go
Expand Up @@ -49,7 +49,7 @@ func NewInput(i interface{}) (*Input, error) {

f := Field{
Value: v.Elem(),
Tags: &ConfigTags{},
Tags: new(ConfigTags),
}

if err := in.traverseField(&f); err != nil {
Expand Down Expand Up @@ -80,14 +80,11 @@ func (in *Input) traverseField(f *Field) error {
return nil
}

switch f.Value.Kind() {
case reflect.Struct:
if isTime(f.Value) || isURL(f.Value) {
in.collectField(f)

return nil
}
if err := in.isSupportedType(f.Value.Type()); err != nil {
return fmt.Errorf(badFieldErrFormat, in.getPath(f.Path), err)
}

if isStruct(f.Value.Type()) {
for i := 0; i < f.Value.NumField(); i++ {
nestedField := Field{
Value: f.Value.Field(i),
Expand All @@ -100,59 +97,70 @@ func (in *Input) traverseField(f *Field) error {
}
}

case reflect.Ptr:
pv := reflect.New(f.Value.Type().Elem())
f.Value.Set(pv)
return nil
}

if f.Value.Kind() == reflect.Ptr && isStruct(f.Value.Type().Elem()) {
if f.Value.IsNil() {
initPtr(f.Value)
}

pointedField := Field{
Value: pv.Elem(),
Value: f.Value.Elem(),
Tags: f.Tags,
Path: f.Path,
}

return in.traverseField(&pointedField)
}

case reflect.Slice, reflect.Array:
switch f.Value.Type().Elem().Kind() {
case reflect.Slice,
reflect.Array,
reflect.Uintptr,
reflect.Chan,
reflect.Func,
reflect.Interface,
reflect.UnsafePointer:
return fmt.Errorf(
unsupportedElementTypeErrFormat,
ErrUnsupportedType, f.Value.Type().Elem().Kind(), in.getPath(f.Path),
)
in.collectField(f)
return nil
}

default:
in.collectField(f)
}
func (in *Input) collectField(f *Field) {
in.Fields = append(in.Fields, f)
}

case reflect.Uintptr,
func (in *Input) isSupportedType(t reflect.Type) error {
switch t.Kind() {
case reflect.Invalid,
reflect.Uintptr,
reflect.Chan,
reflect.Func,
reflect.Interface,
reflect.UnsafePointer:
return fmt.Errorf(
unsupportedTypeErrFormat,
ErrUnsupportedType, f.Value.Kind(), in.getPath(f.Path),
)
return fmt.Errorf(unsupportedTypeErrFormat, ErrUnsupportedType, t.Kind())

case reflect.Slice, reflect.Array:
switch t.Elem().Kind() {
case reflect.Slice, reflect.Array:
return fmt.Errorf(unsupportedTypeErrFormat, ErrUnsupportedType, "multi-dimensional slice/array")

default:
in.collectField(f)
default:
return in.isSupportedType(t.Elem())
}

case reflect.Ptr:
return in.isSupportedType(t.Elem())
}

return nil
}

func (in *Input) collectField(f *Field) {
in.Fields = append(in.Fields, f)
}

// SetValue validates and sets the value of a struct field
// returns error in case of unSettable field or unsupported type
func (in *Input) SetValue(f *Field, value string) error {
if !f.Value.CanSet() {
return fmt.Errorf(
unSettableFieldErrFormat,
ErrUnSettableField, in.getPath(f.Path),
)
}
if err := in.isSupportedType(f.Value.Type()); err != nil {
return fmt.Errorf(badFieldErrFormat, in.getPath(f.Path), err)
}

if f.Tags.Expand {
value = os.ExpandEnv(value)
}
Expand All @@ -165,7 +173,7 @@ func (in *Input) SetValue(f *Field, value string) error {
return in.setBool(f, value)

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if isDuration(f.Value) {
if isDuration(f.Value.Type()) {
return in.setDuration(f, value)
}

Expand All @@ -180,8 +188,11 @@ func (in *Input) SetValue(f *Field, value string) error {
case reflect.Complex64, reflect.Complex128:
return in.setComplex(f, value)

case reflect.Slice, reflect.Array:
return in.setList(f, value)
case reflect.Slice:
return in.setSlice(f, value)

case reflect.Array:
return in.setArray(f, value)

case reflect.Map:
return in.setMap(f, value)
Expand All @@ -190,19 +201,13 @@ func (in *Input) SetValue(f *Field, value string) error {
return in.setPointer(f, value)

case reflect.Struct:
if isTime(f.Value) {
if isTime(f.Value.Type()) {
return in.setTime(f, value)
}

if isURL(f.Value) {
if isURL(f.Value.Type()) {
return in.setUrl(f, value)
}

default:
return fmt.Errorf(
unsupportedTypeErrFormat,
ErrUnsupportedType, f.Value.Kind(), in.getPath(f.Path),
)
}

return nil
Expand Down Expand Up @@ -284,7 +289,7 @@ func (in *Input) setFloat(f *Field, value string) error {
}

func (in *Input) setComplex(f *Field, value string) error {
c, err := strconv.ParseComplex(value, 64)
c, err := strconv.ParseComplex(value, 128)
if err != nil {
return fmt.Errorf(
parseErrFormat,
Expand All @@ -302,44 +307,8 @@ func (in *Input) setComplex(f *Field, value string) error {
return nil
}

func (in *Input) setList(f *Field, value string) error {
switch f.Value.Type().Elem().Kind() {
case reflect.Slice,
reflect.Array,
reflect.Uintptr,
reflect.Chan,
reflect.Func,
reflect.Interface,
reflect.UnsafePointer:
return fmt.Errorf(
unsupportedElementTypeErrFormat,
ErrUnsupportedType, f.Value.Type().Elem().Kind(), in.getPath(f.Path),
)
}

var items []string
for _, v := range strings.Split(value, f.Tags.Separator) {
item := strings.TrimSpace(v)
if len(item) > 0 {
items = append(items, item)
}
}
if len(items) == 0 {
return nil
}

switch f.Value.Kind() {
case reflect.Slice:
return in.setSlice(f, items)

case reflect.Array:
return in.setArray(f, items)
}

return nil
}

func (in *Input) setSlice(f *Field, items []string) error {
func (in *Input) setSlice(f *Field, value string) error {
items := extractItems(value, f.Tags.Separator)
size := len(items)
if size == 0 {
return nil
Expand All @@ -362,18 +331,18 @@ func (in *Input) setSlice(f *Field, items []string) error {
return nil
}

func (in *Input) setArray(f *Field, items []string) error {
func (in *Input) setArray(f *Field, value string) error {
items := extractItems(value, f.Tags.Separator)
size := f.Value.Len()
if size == 0 || len(items) == 0 {
return nil
}

at := reflect.ArrayOf(size, f.Value.Type().Elem())
av := reflect.New(at).Elem()
a := reflect.New(reflect.ArrayOf(size, f.Value.Type().Elem())).Elem()

for i := 0; i < size; i++ {
nestedField := Field{
Value: av.Index(i),
Value: a.Index(i),
Tags: f.Tags,
Path: f.Path,
}
Expand All @@ -383,7 +352,7 @@ func (in *Input) setArray(f *Field, items []string) error {
}
}

f.Value.Set(av)
f.Value.Set(a)
return nil
}

Expand All @@ -393,10 +362,12 @@ func (in *Input) setMap(f *Field, value string) error {
}

func (in *Input) setPointer(f *Field, value string) error {
p := reflect.New(f.Value.Type().Elem())
f.Value.Set(p)
if f.Value.IsNil() {
initPtr(f.Value)
}

pointedField := Field{
Value: p.Elem(),
Value: f.Value.Elem(),
Tags: f.Tags,
Path: f.Path,
}
Expand All @@ -412,12 +383,6 @@ func (in *Input) setDuration(f *Field, value string) error {
ErrParsing, in.getPath(f.Path), err,
)
}
if f.Value.OverflowInt(int64(d)) {
return fmt.Errorf(
overflowErrFormat,
ErrValueOverflow, d, f.Value.Kind(), in.getPath(f.Path),
)
}

f.Value.SetInt(int64(d))
return nil
Expand Down Expand Up @@ -449,6 +414,10 @@ func (in *Input) setUrl(f *Field, value string) error {
return nil
}

func initPtr(v reflect.Value) {
v.Set(reflect.New(v.Type().Elem()))
}

// getPath returns a dot separated string prefixed with struct name
func (in *Input) getPath(paths []string) string {
return in.Name + "." + strings.Join(paths, ".")
Expand Down

0 comments on commit f9bb8b3

Please sign in to comment.