Permalink
Browse files

Fix issue with marshal caching vrw (#3785)

We were caching the encoder function enclosing over the passed in
ValueReadWriter. This causes problems when the encoder is used again but
with another ValueReadWriter.

Instead, make the internal encoder function take a ValueReadWriter.
  • Loading branch information...
arv committed Nov 15, 2017
1 parent 76400c6 commit 38a979833530b527e7e3a2a3657264e3f3768a2e
Showing with 102 additions and 198 deletions.
  1. +50 −50 go/marshal/encode.go
  2. +15 −15 go/marshal/encode_type.go
  3. +34 −124 go/marshal/encode_type_test.go
  4. +2 −8 go/ngql/query_test.go
  5. +1 −1 go/util/datetime/date_time_test.go
View
@@ -125,8 +125,8 @@ func MustMarshalOpt(vrw types.ValueReadWriter, v interface{}, opt Opt) types.Val
nt := nomsTags{
set: opt.Set,
}
encoder := typeEncoder(vrw, rv.Type(), map[string]reflect.Type{}, nt)
return encoder(rv)
encoder := typeEncoder(rv.Type(), map[string]reflect.Type{}, nt)
return encoder(rv, vrw)
}
// Marshaler is an interface types can implement to provide their own encoding.
@@ -197,34 +197,34 @@ var emptyInterface = reflect.TypeOf((*interface{})(nil)).Elem()
var marshalerInterface = reflect.TypeOf((*Marshaler)(nil)).Elem()
var structNameMarshalerInterface = reflect.TypeOf((*StructNameMarshaler)(nil)).Elem()
type encoderFunc func(v reflect.Value) types.Value
type encoderFunc func(v reflect.Value, vrw types.ValueReadWriter) types.Value
func boolEncoder(v reflect.Value) types.Value {
func boolEncoder(v reflect.Value, vrw types.ValueReadWriter) types.Value {
return types.Bool(v.Bool())
}
func float64Encoder(v reflect.Value) types.Value {
func float64Encoder(v reflect.Value, vrw types.ValueReadWriter) types.Value {
return types.Number(v.Float())
}
func intEncoder(v reflect.Value) types.Value {
func intEncoder(v reflect.Value, vrw types.ValueReadWriter) types.Value {
return types.Number(float64(v.Int()))
}
func uintEncoder(v reflect.Value) types.Value {
func uintEncoder(v reflect.Value, vrw types.ValueReadWriter) types.Value {
return types.Number(float64(v.Uint()))
}
func stringEncoder(v reflect.Value) types.Value {
func stringEncoder(v reflect.Value, vrw types.ValueReadWriter) types.Value {
return types.String(v.String())
}
func nomsValueEncoder(v reflect.Value) types.Value {
func nomsValueEncoder(v reflect.Value, vrw types.ValueReadWriter) types.Value {
return v.Interface().(types.Value)
}
func marshalerEncoder(vrw types.ValueReadWriter, t reflect.Type) encoderFunc {
return func(v reflect.Value) types.Value {
func marshalerEncoder(t reflect.Type) encoderFunc {
return func(v reflect.Value, vrw types.ValueReadWriter) types.Value {
val, err := v.Interface().(Marshaler).MarshalNoms(vrw)
if err != nil {
panic(&marshalNomsError{err})
@@ -236,9 +236,9 @@ func marshalerEncoder(vrw types.ValueReadWriter, t reflect.Type) encoderFunc {
}
}
func typeEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[string]reflect.Type, tags nomsTags) encoderFunc {
func typeEncoder(t reflect.Type, seenStructs map[string]reflect.Type, tags nomsTags) encoderFunc {
if t.Implements(marshalerInterface) {
return marshalerEncoder(vrw, t)
return marshalerEncoder(t)
}
switch t.Kind() {
@@ -253,22 +253,22 @@ func typeEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[stri
case reflect.String:
return stringEncoder
case reflect.Struct:
return structEncoder(vrw, t, seenStructs)
return structEncoder(t, seenStructs)
case reflect.Slice, reflect.Array:
if shouldEncodeAsSet(t, tags) {
return setFromListEncoder(vrw, t, seenStructs)
return setFromListEncoder(t, seenStructs)
}
return listEncoder(vrw, t, seenStructs)
return listEncoder(t, seenStructs)
case reflect.Map:
if shouldEncodeAsSet(t, tags) {
return setEncoder(vrw, t, seenStructs)
return setEncoder(t, seenStructs)
}
return mapEncoder(vrw, t, seenStructs)
return mapEncoder(t, seenStructs)
case reflect.Interface:
return func(v reflect.Value) types.Value {
return func(v reflect.Value, vrw types.ValueReadWriter) types.Value {
// Get the dynamic type.
v2 := reflect.ValueOf(v.Interface())
return typeEncoder(vrw, v2.Type(), seenStructs, tags)(v2)
return typeEncoder(v2.Type(), seenStructs, tags)(v2, vrw)
}
case reflect.Ptr:
// Allow implementations of types.Value (like *types.Type)
@@ -289,7 +289,7 @@ func getStructName(t reflect.Type) string {
return strings.Title(t.Name())
}
func structEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
func structEncoder(t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
if t.Implements(nomsValueInterface) {
return nomsValueEncoder
}
@@ -302,39 +302,39 @@ func structEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[st
structName := getStructName(t)
seenStructs[t.Name()] = t
fields, knownShape, originalFieldIndex := typeFields(vrw, t, seenStructs, false, false)
fields, knownShape, originalFieldIndex := typeFields(t, seenStructs, false, false)
if knownShape {
fieldNames := make([]string, len(fields))
for i, f := range fields {
fieldNames[i] = f.name
}
structTemplate := types.MakeStructTemplate(structName, fieldNames)
e = func(v reflect.Value) types.Value {
e = func(v reflect.Value, vrw types.ValueReadWriter) types.Value {
values := make(types.ValueSlice, len(fields))
for i, f := range fields {
values[i] = f.encoder(v.FieldByIndex(f.index))
values[i] = f.encoder(v.FieldByIndex(f.index), vrw)
}
return structTemplate.NewStruct(values)
}
} else if originalFieldIndex == nil {
// Slower path: cannot precompute the Noms type since there are Noms collections,
// but at least there are a set number of fields.
e = func(v reflect.Value) types.Value {
e = func(v reflect.Value, vrw types.ValueReadWriter) types.Value {
data := make(types.StructData, len(fields))
for _, f := range fields {
fv := v.FieldByIndex(f.index)
if !fv.IsValid() || f.omitEmpty && isEmptyValue(fv) {
continue
}
data[f.name] = f.encoder(fv)
data[f.name] = f.encoder(fv, vrw)
}
return types.NewStruct(structName, data)
}
} else {
// Slowest path - we are extending some other struct. We need to start with the
// type of that struct and extend.
e = func(v reflect.Value) types.Value {
e = func(v reflect.Value, vrw types.ValueReadWriter) types.Value {
fv := v.FieldByIndex(originalFieldIndex)
ret := fv.Interface().(types.Struct)
if ret.IsZeroValue() {
@@ -345,7 +345,7 @@ func structEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[st
if !fv.IsValid() || f.omitEmpty && isEmptyValue(fv) {
continue
}
ret = ret.Set(f.name, f.encoder(fv))
ret = ret.Set(f.name, f.encoder(fv, vrw))
}
return ret
}
@@ -461,7 +461,7 @@ func validateField(f reflect.StructField, t reflect.Type) {
}
}
func typeFields(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[string]reflect.Type, computeType, embedded bool) (fields fieldSlice, knownShape bool, originalFieldIndex []int) {
func typeFields(t reflect.Type, seenStructs map[string]reflect.Type, computeType, embedded bool) (fields fieldSlice, knownShape bool, originalFieldIndex []int) {
knownShape = true
for i := 0; i < t.NumField(); i++ {
index := make([]int, 1)
@@ -478,7 +478,7 @@ func typeFields(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[strin
}
if f.Anonymous && f.PkgPath == "" && !tags.hasName {
embeddedFields, embeddedKnownShape, embeddedOriginalFieldIndex := typeFields(vrw, f.Type, seenStructs, computeType, true)
embeddedFields, embeddedKnownShape, embeddedOriginalFieldIndex := typeFields(f.Type, seenStructs, computeType, true)
if embeddedOriginalFieldIndex != nil {
originalFieldIndex = append(index, embeddedOriginalFieldIndex...)
}
@@ -495,7 +495,7 @@ func typeFields(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[strin
var nt *types.Type
validateField(f, t)
if computeType {
nt = encodeType(vrw, f.Type, seenStructs, tags)
nt = encodeType(f.Type, seenStructs, tags)
if nt == nil {
knownShape = false
}
@@ -507,7 +507,7 @@ func typeFields(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[strin
fields = append(fields, field{
name: tags.name,
encoder: typeEncoder(vrw, f.Type, seenStructs, tags),
encoder: typeEncoder(f.Type, seenStructs, tags),
index: index,
nomsType: nt,
omitEmpty: tags.omitEmpty,
@@ -522,7 +522,7 @@ func typeFields(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[strin
return
}
func listEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
func listEncoder(t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
e := encoderCache.get(t)
if e != nil {
return e
@@ -533,23 +533,23 @@ func listEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[stri
var init sync.RWMutex
init.Lock()
defer init.Unlock()
e = func(v reflect.Value) types.Value {
e = func(v reflect.Value, vrw types.ValueReadWriter) types.Value {
init.RLock()
defer init.RUnlock()
values := make([]types.Value, v.Len())
for i := 0; i < v.Len(); i++ {
values[i] = elemEncoder(v.Index(i))
values[i] = elemEncoder(v.Index(i), vrw)
}
return types.NewList(vrw, values...)
}
encoderCache.set(t, e)
elemEncoder = typeEncoder(vrw, t.Elem(), seenStructs, nomsTags{})
elemEncoder = typeEncoder(t.Elem(), seenStructs, nomsTags{})
return e
}
// Encode set from array or slice
func setFromListEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
func setFromListEncoder(t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
e := setEncoderCache.get(t)
if e != nil {
return e
@@ -560,22 +560,22 @@ func setFromListEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs m
var init sync.RWMutex
init.Lock()
defer init.Unlock()
e = func(v reflect.Value) types.Value {
e = func(v reflect.Value, vrw types.ValueReadWriter) types.Value {
init.RLock()
defer init.RUnlock()
values := make([]types.Value, v.Len())
for i := 0; i < v.Len(); i++ {
values[i] = elemEncoder(v.Index(i))
values[i] = elemEncoder(v.Index(i), vrw)
}
return types.NewSet(vrw, values...)
}
setEncoderCache.set(t, e)
elemEncoder = typeEncoder(vrw, t.Elem(), seenStructs, nomsTags{})
elemEncoder = typeEncoder(t.Elem(), seenStructs, nomsTags{})
return e
}
func setEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
func setEncoder(t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
e := setEncoderCache.get(t)
if e != nil {
return e
@@ -586,22 +586,22 @@ func setEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[strin
var init sync.RWMutex
init.Lock()
defer init.Unlock()
e = func(v reflect.Value) types.Value {
e = func(v reflect.Value, vrw types.ValueReadWriter) types.Value {
init.RLock()
defer init.RUnlock()
values := make([]types.Value, v.Len(), v.Len())
for i, k := range v.MapKeys() {
values[i] = encoder(k)
values[i] = encoder(k, vrw)
}
return types.NewSet(vrw, values...)
}
setEncoderCache.set(t, e)
encoder = typeEncoder(vrw, t.Key(), seenStructs, nomsTags{})
encoder = typeEncoder(t.Key(), seenStructs, nomsTags{})
return e
}
func mapEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
func mapEncoder(t reflect.Type, seenStructs map[string]reflect.Type) encoderFunc {
e := encoderCache.get(t)
if e != nil {
return e
@@ -613,21 +613,21 @@ func mapEncoder(vrw types.ValueReadWriter, t reflect.Type, seenStructs map[strin
var init sync.RWMutex
init.Lock()
defer init.Unlock()
e = func(v reflect.Value) types.Value {
e = func(v reflect.Value, vrw types.ValueReadWriter) types.Value {
init.RLock()
defer init.RUnlock()
keys := v.MapKeys()
kvs := make([]types.Value, 2*len(keys))
for i, k := range keys {
kvs[2*i] = keyEncoder(k)
kvs[2*i+1] = valueEncoder(v.MapIndex(k))
kvs[2*i] = keyEncoder(k, vrw)
kvs[2*i+1] = valueEncoder(v.MapIndex(k), vrw)
}
return types.NewMap(vrw, kvs...)
}
encoderCache.set(t, e)
keyEncoder = typeEncoder(vrw, t.Key(), seenStructs, nomsTags{})
valueEncoder = typeEncoder(vrw, t.Elem(), seenStructs, nomsTags{})
keyEncoder = typeEncoder(t.Key(), seenStructs, nomsTags{})
valueEncoder = typeEncoder(t.Elem(), seenStructs, nomsTags{})
return e
}
Oops, something went wrong.

0 comments on commit 38a9798

Please sign in to comment.