diff --git a/decode.go b/decode.go index 746a1f6c..e106fa97 100644 --- a/decode.go +++ b/decode.go @@ -315,10 +315,8 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error { } return badtype("slice", data) } - sliceLen := datav.Len() - if sliceLen != rv.Len() { - return e("expected array length %d; got TOML array of length %d", - rv.Len(), sliceLen) + if l := datav.Len(); l != rv.Len() { + return e("expected array length %d; got TOML array of length %d", rv.Len(), l) } return md.unifySliceArray(datav, rv) } @@ -340,11 +338,10 @@ func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error { } func (md *MetaData) unifySliceArray(data, rv reflect.Value) error { - sliceLen := data.Len() - for i := 0; i < sliceLen; i++ { - v := data.Index(i).Interface() - sliceval := indirect(rv.Index(i)) - if err := md.unify(v, sliceval); err != nil { + l := data.Len() + for i := 0; i < l; i++ { + err := md.unify(data.Index(i).Interface(), indirect(rv.Index(i))) + if err != nil { return err } } diff --git a/encode.go b/encode.go index 41188412..ee8ce9a1 100644 --- a/encode.go +++ b/encode.go @@ -17,13 +17,11 @@ import ( type tomlEncodeError struct{ error } var ( - errArrayMixedElementTypes = errors.New("toml: cannot encode array with mixed element types") - errArrayNilElement = errors.New("toml: cannot encode array with nil element") - errNonString = errors.New("toml: cannot encode a map with non-string key type") - errAnonNonStruct = errors.New("toml: cannot encode an anonymous field that is not a struct") - errArrayNoTable = errors.New("toml: TOML array element cannot contain a table") - errNoKey = errors.New("toml: top-level values must be Go maps or structs") - errAnything = errors.New("") // used in testing + errArrayNilElement = errors.New("toml: cannot encode array with nil element") + errNonString = errors.New("toml: cannot encode a map with non-string key type") + errAnonNonStruct = errors.New("toml: cannot encode an anonymous field that is not a struct") + errNoKey = errors.New("toml: top-level values must be Go maps or structs") + errAnything = errors.New("") // used in testing ) var quotedReplacer = strings.NewReplacer( @@ -141,7 +139,7 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) { // generic structs (or whatever the underlying type of a TextMarshaler is). switch t := rv.Interface().(type) { case time.Time, encoding.TextMarshaler: - enc.keyEqElement(key, rv) + enc.writeKeyValue(key, rv, false) return // TODO: #76 would make this superfluous after implemented. case Primitive: @@ -156,12 +154,12 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) { reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: - enc.keyEqElement(key, rv) + enc.writeKeyValue(key, rv, false) case reflect.Array, reflect.Slice: if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) { enc.eArrayOfTables(key, rv) } else { - enc.keyEqElement(key, rv) + enc.writeKeyValue(key, rv, false) } case reflect.Interface: if rv.IsNil() { @@ -185,8 +183,7 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) { } } -// eElement encodes any value that can be an array element (primitives and -// arrays). +// eElement encodes any value that can be an array element. func (enc *Encoder) eElement(rv reflect.Value) { switch v := rv.Interface().(type) { case time.Time: @@ -194,7 +191,7 @@ func (enc *Encoder) eElement(rv reflect.Value) { enc.wf(v.Format(time.RFC3339Nano)) return case encoding.TextMarshaler: - // Special case. Use text marshaler if it's available for this value. + // Use text marshaler if it's available for this value. if s, err := v.MarshalText(); err != nil { encPanic(err) } else { @@ -202,14 +199,15 @@ func (enc *Encoder) eElement(rv reflect.Value) { } return } + switch rv.Kind() { + case reflect.String: + enc.writeQuoted(rv.String()) case reflect.Bool: enc.wf(strconv.FormatBool(rv.Bool())) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, - reflect.Int64: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: enc.wf(strconv.FormatInt(rv.Int(), 10)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, - reflect.Uint32, reflect.Uint64: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: enc.wf(strconv.FormatUint(rv.Uint(), 10)) case reflect.Float32: f := rv.Float() @@ -231,17 +229,19 @@ func (enc *Encoder) eElement(rv reflect.Value) { } case reflect.Array, reflect.Slice: enc.eArrayOrSliceElement(rv) + case reflect.Struct: + enc.eStruct(nil, rv, true) + case reflect.Map: + enc.eMap(nil, rv, true) case reflect.Interface: enc.eElement(rv.Elem()) - case reflect.String: - enc.writeQuoted(rv.String()) default: - encPanic(fmt.Errorf("unexpected primitive type: %s", rv.Kind())) + encPanic(fmt.Errorf("unexpected primitive type: %T", rv.Interface())) } } -// By the TOML spec, all floats must have a decimal with at least one -// number on either side. +// By the TOML spec, all floats must have a decimal with at least one number on +// either side. func floatAddDecimal(fstr string) string { if !strings.Contains(fstr, ".") { return fstr + ".0" @@ -278,7 +278,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { enc.newline() enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll()) enc.newline() - enc.eMapOrStruct(key, trv) + enc.eMapOrStruct(key, trv, false) } } @@ -292,22 +292,22 @@ func (enc *Encoder) eTable(key Key, rv reflect.Value) { enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll()) enc.newline() } - enc.eMapOrStruct(key, rv) + enc.eMapOrStruct(key, rv, false) } -func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) { +func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) { switch rv := eindirect(rv); rv.Kind() { case reflect.Map: - enc.eMap(key, rv) + enc.eMap(key, rv, inline) case reflect.Struct: - enc.eStruct(key, rv) + enc.eStruct(key, rv, inline) default: // Should never happen? panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String()) } } -func (enc *Encoder) eMap(key Key, rv reflect.Value) { +func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) { rt := rv.Type() if rt.Key().Kind() != reflect.String { encPanic(errNonString) @@ -325,57 +325,76 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value) { } } - var writeMapKeys = func(mapKeys []string) { + var writeMapKeys = func(mapKeys []string, trailC bool) { sort.Strings(mapKeys) - for _, mapKey := range mapKeys { - mrv := rv.MapIndex(reflect.ValueOf(mapKey)) - if isNil(mrv) { - // Don't write anything for nil fields. + for i, mapKey := range mapKeys { + val := rv.MapIndex(reflect.ValueOf(mapKey)) + if isNil(val) { continue } - enc.encode(key.add(mapKey), mrv) + + if inline { + enc.writeKeyValue(Key{mapKey}, val, true) + if trailC || i != len(mapKeys)-1 { + enc.wf(", ") + } + } else { + enc.encode(key.add(mapKey), val) + } } } - writeMapKeys(mapKeysDirect) - writeMapKeys(mapKeysSub) + + if inline { + enc.wf("{") + } + writeMapKeys(mapKeysDirect, len(mapKeysSub) > 0) + writeMapKeys(mapKeysSub, false) + if inline { + enc.wf("}") + } } -func (enc *Encoder) eStruct(key Key, rv reflect.Value) { +func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) { // Write keys for fields directly under this key first, because if we write - // a field that creates a new table, then all keys under it will be in that + // a field that creates a new table then all keys under it will be in that // table (not the one we're writing here). - rt := rv.Type() - var fieldsDirect, fieldsSub [][]int - var addFields func(rt reflect.Type, rv reflect.Value, start []int) + // + // Fields is a [][]int: for fieldsDirect this always has one entry (the + // struct index). For fieldsSub it contains two entries: the parent field + // index from tv, and the field indexes for the fields of the sub. + var ( + rt = rv.Type() + fieldsDirect, fieldsSub [][]int + addFields func(rt reflect.Type, rv reflect.Value, start []int) + ) addFields = func(rt reflect.Type, rv reflect.Value, start []int) { for i := 0; i < rt.NumField(); i++ { f := rt.Field(i) - // skip unexported fields - if f.PkgPath != "" && !f.Anonymous { + if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields. continue } + frv := rv.Field(i) + + // Treat anonymous struct fields with tag names as though they are + // not anonymous, like encoding/json does. + // + // Non-struct anonymous fields use the normal encoding logic. if f.Anonymous { t := f.Type switch t.Kind() { case reflect.Struct: - // Treat anonymous struct fields with - // tag names as though they are not - // anonymous, like encoding/json does. if getOptions(f.Tag).name == "" { addFields(t, frv, append(start, f.Index...)) continue } case reflect.Ptr: - if t.Elem().Kind() == reflect.Struct && - getOptions(f.Tag).name == "" { + if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" { if !frv.IsNil() { addFields(t.Elem(), frv.Elem(), append(start, f.Index...)) } continue } - // Fall through to the normal field encoding logic below - // for non-struct anonymous fields. } } @@ -388,35 +407,49 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) { } addFields(rt, rv, nil) - var writeFields = func(fields [][]int) { + writeFields := func(fields [][]int) { for _, fieldIndex := range fields { - sft := rt.FieldByIndex(fieldIndex) - sf := rv.FieldByIndex(fieldIndex) - if isNil(sf) { - // Don't write anything for nil fields. + fieldType := rt.FieldByIndex(fieldIndex) + fieldVal := rv.FieldByIndex(fieldIndex) + + if isNil(fieldVal) { /// Don't write anything for nil fields. continue } - opts := getOptions(sft.Tag) + opts := getOptions(fieldType.Tag) if opts.skip { continue } - keyName := sft.Name + keyName := fieldType.Name if opts.name != "" { keyName = opts.name } - if opts.omitempty && isEmpty(sf) { + if opts.omitempty && isEmpty(fieldVal) { continue } - if opts.omitzero && isZero(sf) { + if opts.omitzero && isZero(fieldVal) { continue } - enc.encode(key.add(keyName), sf) + if inline { + enc.writeKeyValue(Key{keyName}, fieldVal, true) + if fieldIndex[0] != len(fields)-1 { + enc.wf(", ") + } + } else { + enc.encode(key.add(keyName), fieldVal) + } } } + + if inline { + enc.wf("{") + } writeFields(fieldsDirect) writeFields(fieldsSub) + if inline { + enc.wf("}") + } } // tomlTypeName returns the TOML type name of the Go value's type. It is @@ -487,31 +520,18 @@ func tomlArrayType(rv reflect.Value) tomlType { if isNil(rv) || !rv.IsValid() || rv.Len() == 0 { return nil } - firstType := tomlTypeOfGo(rv.Index(0)) - if firstType == nil { - encPanic(errArrayNilElement) - } + /// Don't allow nil. rvlen := rv.Len() for i := 1; i < rvlen; i++ { - elem := rv.Index(i) - switch elemType := tomlTypeOfGo(elem); { - case elemType == nil: + if tomlTypeOfGo(rv.Index(i)) == nil { encPanic(errArrayNilElement) - case !typeEqual(firstType, elemType): - encPanic(errArrayMixedElementTypes) } } - // If we have a nested array, then we must make sure that the nested array - // contains ONLY primitives. - // - // This checks arbitrarily nested arrays. - if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) { - nest := tomlArrayType(eindirect(rv.Index(0))) - if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) { - encPanic(errArrayNoTable) - } + firstType := tomlTypeOfGo(rv.Index(0)) + if firstType == nil { + encPanic(errArrayNilElement) } return firstType } @@ -570,13 +590,20 @@ func (enc *Encoder) newline() { } } -func (enc *Encoder) keyEqElement(key Key, val reflect.Value) { +// Write a key/value pair: +// +// key = +// +// If inline is true it won't add a newline at the end. +func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) { if len(key) == 0 { encPanic(errNoKey) } enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1)) enc.eElement(val) - enc.newline() + if !inline { + enc.newline() + } } func (enc *Encoder) wf(format string, v ...interface{}) { diff --git a/encode_test.go b/encode_test.go index 542f31f9..0da64586 100644 --- a/encode_test.go +++ b/encode_test.go @@ -396,27 +396,32 @@ Fun = "why would you do this?" } } -func encodeExpected( - t *testing.T, label string, val interface{}, wantStr string, wantErr error, -) { +func encodeExpected(t *testing.T, label string, val interface{}, want string, wantErr error) { t.Helper() - var buf bytes.Buffer - enc := NewEncoder(&buf) - err := enc.Encode(val) - if err != wantErr { - if wantErr != nil { - if wantErr == errAnything && err != nil { - return + + t.Run(label, func(t *testing.T) { + var buf bytes.Buffer + err := NewEncoder(&buf).Encode(val) + if err != wantErr { + if wantErr != nil { + if wantErr == errAnything && err != nil { + return + } + t.Errorf("want Encode error %v, got %v", wantErr, err) + } else { + t.Errorf("Encode failed: %s", err) } - t.Errorf("%s: want Encode error %v, got %v", label, wantErr, err) - } else { - t.Errorf("%s: Encode failed: %s", label, err) } - } - if err != nil { - return - } - if got := buf.String(); wantStr != got { - t.Errorf("%s\nhave: %s\nwant: %s\n", label, got, wantStr) - } + if err != nil { + return + } + + have := strings.TrimSpace(buf.String()) + want = strings.TrimSpace(want) + if want != have { + t.Errorf("\nhave: %s\nwant: %s\n", have, want) + // v, _ := json.MarshalIndent(val, "", " ") + // t.Log(string(v)) + } + }) } diff --git a/go.mod b/go.mod index 7a6fd52b..fce1851c 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/BurntSushi/toml go 1.13 -require github.com/BurntSushi/toml-test v0.1.1-0.20210624055653-1f6389604dc6 +require github.com/BurntSushi/toml-test v0.1.1-0.20210704062846-269931e74e3f diff --git a/go.sum b/go.sum index 4af5dafc..6905afff 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/BurntSushi/toml v0.3.2-0.20210614224209-34d990aa228d/go.mod h1:2QZjSXA5e+XyFeCAxxtL8Z4StYUsTquL8ODGPR3C3MA= github.com/BurntSushi/toml v0.3.2-0.20210621044154-20a94d639b8e/go.mod h1:t4zg8TkHfP16Vb3x4WKIw7zVYMit5QFtPEO8lOWxzTg= +github.com/BurntSushi/toml v0.3.2-0.20210624061728-01bfc69d1057/go.mod h1:NMj2lD5LfMqcE0w8tnqOsH6944oaqpI1974lrIwerfE= github.com/BurntSushi/toml-test v0.1.1-0.20210620192437-de01089bbf76/go.mod h1:P/PrhmZ37t5llHfDuiouWXtFgqOoQ12SAh9j6EjrBR4= -github.com/BurntSushi/toml-test v0.1.1-0.20210624055653-1f6389604dc6 h1:hRkQ1B9Jtdssyzo0Cr3EgS2WMwPscGNrCDNCeYshAQA= github.com/BurntSushi/toml-test v0.1.1-0.20210624055653-1f6389604dc6/go.mod h1:UAIt+Eo8itMZAAgImXkPGDMYsT1SsJkVdB5TuONl86A= +github.com/BurntSushi/toml-test v0.1.1-0.20210704062846-269931e74e3f h1:2bJvwBZX/Ajv19zGY3hvuHDInegqjxsz9ht9Smlr7Rk= +github.com/BurntSushi/toml-test v0.1.1-0.20210704062846-269931e74e3f/go.mod h1:fnFWrIwqgHsEjVsW3RYCJmDo86oq9eiJ9u6bnqhtm2g= zgo.at/zli v0.0.0-20210619044753-e7020a328e59/go.mod h1:HLAc12TjNGT+VRXr76JnsNE3pbooQtwKWhX+RlDjQ2Y= diff --git a/move_test.go b/move_test.go index c8ba5a75..7136eff1 100644 --- a/move_test.go +++ b/move_test.go @@ -198,12 +198,12 @@ ArrayOfMixedSlices = [[1, 2], ["a", "b"]] wantOutput: "Empty = []\n", }, "(error) slice with element type mismatch (string and integer)": { - input: struct{ Mixed []interface{} }{[]interface{}{1, "a"}}, - wantError: errArrayMixedElementTypes, + input: struct{ Mixed []interface{} }{[]interface{}{1, "a"}}, + wantOutput: "Mixed = [1, \"a\"]\n", }, "(error) slice with element type mismatch (integer and float)": { - input: struct{ Mixed []interface{} }{[]interface{}{1, 2.5}}, - wantError: errArrayMixedElementTypes, + input: struct{ Mixed []interface{} }{[]interface{}{1, 2.5}}, + wantOutput: "Mixed = [1, 2.5]\n", }, "slice with elems of differing Go types, same TOML types": { input: struct { @@ -223,7 +223,7 @@ ArrayOfMixedSlices = [[1, 2], ["a", "b"]] input: struct{ Mixed []interface{} }{ []interface{}{1, []interface{}{2}}, }, - wantError: errArrayMixedElementTypes, + wantOutput: "Mixed = [1, [2]]\n", }, "(error) slice with 1 nil element": { input: struct{ NilElement1 []interface{} }{[]interface{}{nil}}, @@ -379,17 +379,41 @@ ArrayOfMixedSlices = [[1, 2], ["a", "b"]] input: []struct{ Int int }{{1}, {2}, {3}}, wantError: errNoKey, }, - "(error) slice of slice": { + "(error) map no string key": { + input: map[int]string{1: ""}, + wantError: errNonString, + }, + + "tbl-in-arr-struct": { + input: struct { + Arr [][]struct{ A, B, C int } + }{[][]struct{ A, B, C int }{{{1, 2, 3}, {4, 5, 6}}}}, + wantOutput: "Arr = [[{A = 1, B = 2, C = 3}, {A = 4, B = 5, C = 6}]]", + }, + + "tbl-in-arr-map": { + input: map[string]interface{}{ + "arr": []interface{}{[]interface{}{ + map[string]interface{}{ + "a": []interface{}{"hello", "world"}, + "b": []interface{}{1.12, 4.1}, + "c": 1, + "d": map[string]interface{}{"e": "E"}, + "f": struct{ A, B int }{1, 2}, + "g": []struct{ A, B int }{{3, 4}, {5, 6}}, + }, + }}, + }, + wantOutput: `arr = [[{a = ["hello", "world"], b = [1.12, 4.1], c = 1, d = {e = "E"}, f = {A = 1, B = 2}, g = [{A = 3, B = 4}, {A = 5, B = 6}]}]]`, + }, + + "slice of slice": { input: struct { Slices [][]struct{ Int int } }{ [][]struct{ Int int }{{{1}}, {{2}}, {{3}}}, }, - wantError: errArrayNoTable, - }, - "(error) map no string key": { - input: map[int]string{1: ""}, - wantError: errNonString, + wantOutput: "Slices = [[{Int = 1}], [{Int = 2}], [{Int = 3}]]", }, } for label, test := range tests { diff --git a/parse.go b/parse.go index 34f1bf19..1cebd6fd 100644 --- a/parse.go +++ b/parse.go @@ -324,9 +324,13 @@ func (p *parser) valueDatetime(it item) (interface{}, tomlType) { } func (p *parser) valueArray(it item) (interface{}, tomlType) { - array := make([]interface{}, 0) - types := make([]tomlType, 0) + p.setType(p.currentKey, tomlArray) + // p.setType(p.currentKey, typ) + var ( + array []interface{} + types []tomlType + ) for it = p.next(); it.typ != itemArrayEnd; it = p.next() { if it.typ == itemCommentStart { p.expect(itemText) @@ -337,8 +341,7 @@ func (p *parser) valueArray(it item) (interface{}, tomlType) { array = append(array, val) types = append(types, typ) } - return array, p.typeOfArray(types) - + return array, tomlArray } func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tomlType) { @@ -511,11 +514,12 @@ func (p *parser) set(key string, val interface{}, typ tomlType) { // It will make sure that the key hasn't already been defined, account for // implicit key groups. func (p *parser) setValue(key string, value interface{}) { - var tmpHash interface{} - var ok bool - - hash := p.mapping - keyContext := make(Key, 0) + var ( + tmpHash interface{} + ok bool + hash = p.mapping + keyContext Key + ) for _, k := range p.context { keyContext = append(keyContext, k) if tmpHash, ok = hash[k]; !ok { @@ -544,6 +548,11 @@ func (p *parser) setValue(key string, value interface{}) { // // Note that since it has already been defined (as a hash), we don't // want to overwrite it. So our business is done. + if p.isArray(keyContext) { + p.removeImplicit(keyContext) + hash[key] = value + return + } if p.isImplicit(keyContext) { p.removeImplicit(keyContext) return @@ -553,6 +562,7 @@ func (p *parser) setValue(key string, value interface{}) { // key, which is *always* wrong. p.panicf("Key '%s' has already been defined.", keyContext) } + hash[key] = value } @@ -577,6 +587,7 @@ func (p *parser) setType(key string, typ tomlType) { func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = true } func (p *parser) removeImplicit(key Key) { p.implicits[key.String()] = false } func (p *parser) isImplicit(key Key) bool { return p.implicits[key.String()] } +func (p *parser) isArray(key Key) bool { return p.types[key.String()] == tomlArray } func (p *parser) addImplicitContext(key Key) { p.addImplicit(key) p.addContext(key, false) diff --git a/toml_test.go b/toml_test.go index 9e85f5e4..0f0c1d66 100644 --- a/toml_test.go +++ b/toml_test.go @@ -52,8 +52,6 @@ func TestToml(t *testing.T) { "valid/datetime-local-date", "valid/datetime-local-time", "valid/datetime-local", - "valid/array-mix-string-table", - "valid/inline-table-nest", }, } diff --git a/type_check.go b/type_check.go index c73f8afc..d56aa80f 100644 --- a/type_check.go +++ b/type_check.go @@ -68,24 +68,3 @@ func (p *parser) typeOfPrimitive(lexItem item) tomlType { p.bug("Cannot infer primitive type of lex item '%s'.", lexItem) panic("unreachable") } - -// typeOfArray returns a tomlType for an array given a list of types of its -// values. -// -// In the current spec, if an array is homogeneous, then its type is always -// "Array". If the array is not homogeneous, an error is generated. -func (p *parser) typeOfArray(types []tomlType) tomlType { - // Empty arrays are cool. - if len(types) == 0 { - return tomlArray - } - - theType := types[0] - for _, t := range types[1:] { - if !typeEqual(theType, t) { - p.panicf("Array contains values of type '%s' and '%s', but "+ - "arrays must be homogeneous.", theType, t) - } - } - return tomlArray -}