Skip to content

Commit

Permalink
Don't panic with 'omitempty' and uncomparable type
Browse files Browse the repository at this point in the history
Fixes #360
  • Loading branch information
arp242 committed Jul 28, 2022
1 parent 2e74712 commit 8d9ffad
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
10 changes: 7 additions & 3 deletions encode.go
Expand Up @@ -261,7 +261,7 @@ func (enc *Encoder) eElement(rv reflect.Value) {
enc.eElement(reflect.ValueOf(v))
return
}
encPanic(errors.New(fmt.Sprintf("Unable to convert \"%s\" to neither int64 nor float64", n)))
encPanic(fmt.Errorf("unable to convert %q to int64 or float64", n))
}

switch rv.Kind() {
Expand Down Expand Up @@ -504,7 +504,8 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
if opts.name != "" {
keyName = opts.name
}
if opts.omitempty && isEmpty(fieldVal) {

if opts.omitempty && enc.isEmpty(fieldVal) {
continue
}
if opts.omitzero && isZero(fieldVal) {
Expand Down Expand Up @@ -648,11 +649,14 @@ func isZero(rv reflect.Value) bool {
return false
}

func isEmpty(rv reflect.Value) bool {
func (enc *Encoder) isEmpty(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
return rv.Len() == 0
case reflect.Struct:
if !rv.Type().Comparable() {
encPanic(fmt.Errorf("type %q cannot be used with omitempty as it's uncomparable", rv.Type()))
}
return reflect.Zero(rv.Type()).Interface() == rv.Interface()
case reflect.Bool:
return !rv.Bool()
Expand Down
32 changes: 32 additions & 0 deletions encode_test.go
Expand Up @@ -212,6 +212,38 @@ time = 1985-06-18T15:16:17Z
v, expected, nil)
}

func TestEncodeWithOmitEmptyError(t *testing.T) {
type nest struct {
Field []string `toml:"Field,omitempty"`
}

tests := []struct {
in interface{}
wantErr string
}{
{ // Make sure it doesn't panic on uncomparable types; #360
struct {
Values nest `toml:"values,omitempty"`
Empty nest `toml:"empty,omitempty"`
}{Values: nest{[]string{"XXX"}}},
"cannot be used with omitempty as it's uncomparable",
},
}

for _, tt := range tests {
t.Run("", func(t *testing.T) {
buf := new(bytes.Buffer)
err := NewEncoder(buf).Encode(tt.in)
if !errorContains(err, tt.wantErr) {
t.Fatalf("wrong error: %v", err)
}
if buf.String() != "" {
t.Errorf("output not empty:\n%s", buf)
}
})
}
}

func TestEncodeWithOmitZero(t *testing.T) {
type simple struct {
Number int `toml:"number,omitzero"`
Expand Down

0 comments on commit 8d9ffad

Please sign in to comment.