Skip to content

Commit

Permalink
Allow mixed values in arrays
Browse files Browse the repository at this point in the history
Decoding of mixed arrays work:

	[1, 1.0]  # Int and float
	['a', 1]
	['a', {tbl=42}, [[{a=1}]]]

This also adds encoding of inline tables, as otherwise we can't encode
arrays with tables in them.

Fixes 305
  • Loading branch information
arp242 committed Jul 4, 2021
1 parent 01bfc69 commit 7d6f80f
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 154 deletions.
15 changes: 6 additions & 9 deletions decode.go
Expand Up @@ -315,10 +315,8 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
}
return badtype("slice", data)
}
sliceLen := datav.Len()
if sliceLen != rv.Len() {
return e("expected array length %d; got TOML array of length %d",
rv.Len(), sliceLen)
if l := datav.Len(); l != rv.Len() {
return e("expected array length %d; got TOML array of length %d", rv.Len(), l)
}
return md.unifySliceArray(datav, rv)
}
Expand All @@ -340,11 +338,10 @@ func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
}

func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
sliceLen := data.Len()
for i := 0; i < sliceLen; i++ {
v := data.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := md.unify(v, sliceval); err != nil {
l := data.Len()
for i := 0; i < l; i++ {
err := md.unify(data.Index(i).Interface(), indirect(rv.Index(i)))
if err != nil {
return err
}
}
Expand Down
187 changes: 107 additions & 80 deletions encode.go
Expand Up @@ -17,13 +17,11 @@ import (
type tomlEncodeError struct{ error }

var (
errArrayMixedElementTypes = errors.New("toml: cannot encode array with mixed element types")
errArrayNilElement = errors.New("toml: cannot encode array with nil element")
errNonString = errors.New("toml: cannot encode a map with non-string key type")
errAnonNonStruct = errors.New("toml: cannot encode an anonymous field that is not a struct")
errArrayNoTable = errors.New("toml: TOML array element cannot contain a table")
errNoKey = errors.New("toml: top-level values must be Go maps or structs")
errAnything = errors.New("") // used in testing
errArrayNilElement = errors.New("toml: cannot encode array with nil element")
errNonString = errors.New("toml: cannot encode a map with non-string key type")
errAnonNonStruct = errors.New("toml: cannot encode an anonymous field that is not a struct")
errNoKey = errors.New("toml: top-level values must be Go maps or structs")
errAnything = errors.New("") // used in testing
)

var quotedReplacer = strings.NewReplacer(
Expand Down Expand Up @@ -141,7 +139,7 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
// generic structs (or whatever the underlying type of a TextMarshaler is).
switch t := rv.Interface().(type) {
case time.Time, encoding.TextMarshaler:
enc.keyEqElement(key, rv)
enc.writeKeyValue(key, rv, false)
return
// TODO: #76 would make this superfluous after implemented.
case Primitive:
Expand All @@ -156,12 +154,12 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.keyEqElement(key, rv)
enc.writeKeyValue(key, rv, false)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.keyEqElement(key, rv)
enc.writeKeyValue(key, rv, false)
}
case reflect.Interface:
if rv.IsNil() {
Expand All @@ -185,31 +183,31 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
}
}

// eElement encodes any value that can be an array element (primitives and
// arrays).
// eElement encodes any value that can be an array element.
func (enc *Encoder) eElement(rv reflect.Value) {
switch v := rv.Interface().(type) {
case time.Time:
// Using TextMarshaler adds extra quotes, which we don't want.
enc.wf(v.Format(time.RFC3339Nano))
return
case encoding.TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
// Use text marshaler if it's available for this value.
if s, err := v.MarshalText(); err != nil {
encPanic(err)
} else {
enc.writeQuoted(string(s))
}
return
}

switch rv.Kind() {
case reflect.String:
enc.writeQuoted(rv.String())
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
f := rv.Float()
Expand All @@ -231,17 +229,19 @@ func (enc *Encoder) eElement(rv reflect.Value) {
}
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Struct:
enc.eStruct(nil, rv, true)
case reflect.Map:
enc.eMap(nil, rv, true)
case reflect.Interface:
enc.eElement(rv.Elem())
case reflect.String:
enc.writeQuoted(rv.String())
default:
encPanic(fmt.Errorf("unexpected primitive type: %s", rv.Kind()))
encPanic(fmt.Errorf("unexpected primitive type: %T", rv.Interface()))
}
}

// By the TOML spec, all floats must have a decimal with at least one
// number on either side.
// By the TOML spec, all floats must have a decimal with at least one number on
// either side.
func floatAddDecimal(fstr string) string {
if !strings.Contains(fstr, ".") {
return fstr + ".0"
Expand Down Expand Up @@ -278,7 +278,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
enc.newline()
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
enc.eMapOrStruct(key, trv)
enc.eMapOrStruct(key, trv, false)
}
}

Expand All @@ -292,22 +292,22 @@ func (enc *Encoder) eTable(key Key, rv reflect.Value) {
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
}
enc.eMapOrStruct(key, rv)
enc.eMapOrStruct(key, rv, false)
}

func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) {
switch rv := eindirect(rv); rv.Kind() {
case reflect.Map:
enc.eMap(key, rv)
enc.eMap(key, rv, inline)
case reflect.Struct:
enc.eStruct(key, rv)
enc.eStruct(key, rv, inline)
default:
// Should never happen?
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
}
}

func (enc *Encoder) eMap(key Key, rv reflect.Value) {
func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
rt := rv.Type()
if rt.Key().Kind() != reflect.String {
encPanic(errNonString)
Expand All @@ -325,57 +325,76 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value) {
}
}

var writeMapKeys = func(mapKeys []string) {
var writeMapKeys = func(mapKeys []string, trailC bool) {
sort.Strings(mapKeys)
for _, mapKey := range mapKeys {
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(mrv) {
// Don't write anything for nil fields.
for i, mapKey := range mapKeys {
val := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(val) {
continue
}
enc.encode(key.add(mapKey), mrv)

if inline {
enc.writeKeyValue(Key{mapKey}, val, true)
if trailC || i != len(mapKeys)-1 {
enc.wf(", ")
}
} else {
enc.encode(key.add(mapKey), val)
}
}
}
writeMapKeys(mapKeysDirect)
writeMapKeys(mapKeysSub)

if inline {
enc.wf("{")
}
writeMapKeys(mapKeysDirect, len(mapKeysSub) > 0)
writeMapKeys(mapKeysSub, false)
if inline {
enc.wf("}")
}
}

func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table, then all keys under it will be in that
// a field that creates a new table then all keys under it will be in that
// table (not the one we're writing here).
rt := rv.Type()
var fieldsDirect, fieldsSub [][]int
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
//
// Fields is a [][]int: for fieldsDirect this always has one entry (the
// struct index). For fieldsSub it contains two entries: the parent field
// index from tv, and the field indexes for the fields of the sub.
var (
rt = rv.Type()
fieldsDirect, fieldsSub [][]int
addFields func(rt reflect.Type, rv reflect.Value, start []int)
)
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
// skip unexported fields
if f.PkgPath != "" && !f.Anonymous {
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields.
continue
}

frv := rv.Field(i)

// Treat anonymous struct fields with tag names as though they are
// not anonymous, like encoding/json does.
//
// Non-struct anonymous fields use the normal encoding logic.
if f.Anonymous {
t := f.Type
switch t.Kind() {
case reflect.Struct:
// Treat anonymous struct fields with
// tag names as though they are not
// anonymous, like encoding/json does.
if getOptions(f.Tag).name == "" {
addFields(t, frv, append(start, f.Index...))
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct &&
getOptions(f.Tag).name == "" {
if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), append(start, f.Index...))
}
continue
}
// Fall through to the normal field encoding logic below
// for non-struct anonymous fields.
}
}

Expand All @@ -388,35 +407,49 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
}
addFields(rt, rv, nil)

var writeFields = func(fields [][]int) {
writeFields := func(fields [][]int) {
for _, fieldIndex := range fields {
sft := rt.FieldByIndex(fieldIndex)
sf := rv.FieldByIndex(fieldIndex)
if isNil(sf) {
// Don't write anything for nil fields.
fieldType := rt.FieldByIndex(fieldIndex)
fieldVal := rv.FieldByIndex(fieldIndex)

if isNil(fieldVal) { /// Don't write anything for nil fields.
continue
}

opts := getOptions(sft.Tag)
opts := getOptions(fieldType.Tag)
if opts.skip {
continue
}
keyName := sft.Name
keyName := fieldType.Name
if opts.name != "" {
keyName = opts.name
}
if opts.omitempty && isEmpty(sf) {
if opts.omitempty && isEmpty(fieldVal) {
continue
}
if opts.omitzero && isZero(sf) {
if opts.omitzero && isZero(fieldVal) {
continue
}

enc.encode(key.add(keyName), sf)
if inline {
enc.writeKeyValue(Key{keyName}, fieldVal, true)
if fieldIndex[0] != len(fields)-1 {
enc.wf(", ")
}
} else {
enc.encode(key.add(keyName), fieldVal)
}
}
}

if inline {
enc.wf("{")
}
writeFields(fieldsDirect)
writeFields(fieldsSub)
if inline {
enc.wf("}")
}
}

// tomlTypeName returns the TOML type name of the Go value's type. It is
Expand Down Expand Up @@ -487,31 +520,18 @@ func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}

/// Don't allow nil.
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
elem := rv.Index(i)
switch elemType := tomlTypeOfGo(elem); {
case elemType == nil:
if tomlTypeOfGo(rv.Index(i)) == nil {
encPanic(errArrayNilElement)
case !typeEqual(firstType, elemType):
encPanic(errArrayMixedElementTypes)
}
}

// If we have a nested array, then we must make sure that the nested array
// contains ONLY primitives.
//
// This checks arbitrarily nested arrays.
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
nest := tomlArrayType(eindirect(rv.Index(0)))
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
encPanic(errArrayNoTable)
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}
return firstType
}
Expand Down Expand Up @@ -570,13 +590,20 @@ func (enc *Encoder) newline() {
}
}

func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
// Write a key/value pair:
//
// key = <any value>
//
// If inline is true it won't add a newline at the end.
func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
if len(key) == 0 {
encPanic(errNoKey)
}
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
enc.newline()
if !inline {
enc.newline()
}
}

func (enc *Encoder) wf(format string, v ...interface{}) {
Expand Down

0 comments on commit 7d6f80f

Please sign in to comment.