From c394138d38dc45e5b751175bf130b4bc4eee5b1d Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 02:39:35 +0800 Subject: [PATCH 01/21] use option for optional fields --- compiler/fory_compiler/generators/go.py | 26 +- go/fory/codegen/decoder.go | 216 +++++ go/fory/codegen/encoder.go | 192 ++++ go/fory/codegen/utils.go | 55 +- go/fory/field_info.go | 29 + go/fory/optional/optional.go | 249 +++++ go/fory/optional_serializer.go | 271 ++++++ go/fory/struct.go | 865 +++++++++++++++++- go/fory/struct_test.go | 64 +- go/fory/type_def.go | 21 +- go/fory/type_resolver.go | 33 +- integration_tests/idl_tests/cpp/main.cc | 63 ++ integration_tests/idl_tests/generate_idl.py | 2 + .../idl_tests/go/idl_roundtrip_test.go | 240 +++++ .../idl_tests/idl/optional_types.fdl | 61 ++ .../fory/idl_tests/IdlRoundTripTest.java | 76 ++ .../python/src/idl_tests/roundtrip.py | 123 +++ .../idl_tests/rust/tests/idl_roundtrip.rs | 64 ++ 18 files changed, 2608 insertions(+), 42 deletions(-) create mode 100644 go/fory/optional/optional.go create mode 100644 go/fory/optional_serializer.go create mode 100644 integration_tests/idl_tests/idl/optional_types.fdl diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index 6d64471daa..e377f08c33 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -268,6 +268,8 @@ def collect_message_imports(self, message: Message, imports: Set[str]): """Collect imports for a message and its nested types recursively.""" for field in message.fields: self.collect_imports(field.field_type, imports) + if self.field_uses_option(field): + imports.add('optional "github.com/apache/fory/go/fory/optional"') for nested_msg in message.nested_messages: self.collect_message_imports(nested_msg, imports) for nested_union in message.nested_unions: @@ -677,6 +679,18 @@ def get_array_type_tag(self, field: Field) -> Optional[str]: return "type=uint8_array" return None + def field_uses_option(self, field: Field) -> bool: + """Return True if field should use optional.Optional in generated Go code.""" + if not field.optional or field.ref: + return False + if isinstance(field.field_type, PrimitiveType): + base_type = self.PRIMITIVE_MAP[field.field_type.kind] + return base_type not in ("[]byte", "time.Time") + if isinstance(field.field_type, NamedType): + named_type = self.schema.get_type(field.field_type.name) + return isinstance(named_type, Enum) + return False + def generate_type( self, field_type: FieldType, @@ -685,17 +699,26 @@ def generate_type( element_optional: bool = False, element_ref: bool = False, parent_stack: Optional[List[Message]] = None, + use_option: bool = True, ) -> str: """Generate Go type string.""" if isinstance(field_type, PrimitiveType): base_type = self.PRIMITIVE_MAP[field_type.kind] if nullable and base_type not in ("[]byte",): + if use_option and not ref and base_type != "time.Time": + return f"optional.Optional[{base_type}]" return f"*{base_type}" return base_type elif isinstance(field_type, NamedType): type_name = self.resolve_nested_type_name(field_type.name, parent_stack) - if nullable or ref: + if nullable: + if use_option and not ref: + named_type = self.schema.get_type(field_type.name) + if isinstance(named_type, Enum): + return f"optional.Optional[{type_name}]" + return f"*{type_name}" + if ref: return f"*{type_name}" return type_name @@ -707,6 +730,7 @@ def generate_type( False, False, parent_stack, + use_option=False, ) return f"[]{element_type}" diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 6e62550338..15aedb8db2 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -90,6 +90,9 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t// Field: %s (%s)\n", field.GoName, field.Type.String()) fieldAccess := fmt.Sprintf("v.%s", field.GoName) + if field.IsOptional { + return generateOptionReadTyped(buf, field, fieldAccess) + } // Handle special named types first // According to new spec, time types are "other internal types" and use ReadValue @@ -238,6 +241,130 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { return nil } +func generateOptionReadTyped(buf *bytes.Buffer, field *FieldInfo, fieldAccess string) error { + elemType := field.OptionalElem + if elemType == nil { + fmt.Fprintf(buf, "\tctx.ReadValue(reflect.ValueOf(&%s).Elem(), fory.RefModeTracking, true)\n", fieldAccess) + return nil + } + fmt.Fprintf(buf, "\t{\n") + if isReferencableType(elemType) { + fmt.Fprintf(buf, "\t\tif ctx.TrackRef() {\n") + fmt.Fprintf(buf, "\t\t\trefID, refErr := ctx.RefResolver().TryPreserveRefId(buf)\n") + fmt.Fprintf(buf, "\t\t\tif refErr != nil {\n") + fmt.Fprintf(buf, "\t\t\t\treturn refErr\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif refID < int32(fory.NotNullValueFlag) {\n") + fmt.Fprintf(buf, "\t\t\t\tif refID == int32(fory.NullFlag) {\n") + fmt.Fprintf(buf, "\t\t\t\t\t%s.Has = false\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t\t\treturn nil\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\tobj := ctx.RefResolver().GetReadObject(refID)\n") + fmt.Fprintf(buf, "\t\t\t\tif obj.IsValid() {\n") + fmt.Fprintf(buf, "\t\t\t\t\ttarget := reflect.ValueOf(&%s.Value).Elem()\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t\t\tif obj.Type().AssignableTo(target.Type()) {\n") + fmt.Fprintf(buf, "\t\t\t\t\t\ttarget.Set(obj)\n") + fmt.Fprintf(buf, "\t\t\t\t\t\t%s.Has = true\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t\t\t\treturn nil\n") + fmt.Fprintf(buf, "\t\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\t%s.Has = false\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t\treturn nil\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t%s.Has = true\n", fieldAccess) + if err := generateOptionValueRead(buf, elemType, fmt.Sprintf("%s.Value", fieldAccess)); err != nil { + return err + } + fmt.Fprintf(buf, "\t\t\tif refID >= 0 {\n") + fmt.Fprintf(buf, "\t\t\t\tctx.RefResolver().SetReadObject(refID, reflect.ValueOf(%s.Value))\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\treturn nil\n") + fmt.Fprintf(buf, "\t\t}\n") + } + fmt.Fprintf(buf, "\t\tflag := buf.ReadInt8(err)\n") + fmt.Fprintf(buf, "\t\tif flag == fory.NullFlag {\n") + fmt.Fprintf(buf, "\t\t\t%s.Has = false\n", fieldAccess) + fmt.Fprintf(buf, "\t\t} else {\n") + fmt.Fprintf(buf, "\t\t\t%s.Has = true\n", fieldAccess) + if err := generateOptionValueRead(buf, elemType, fmt.Sprintf("%s.Value", fieldAccess)); err != nil { + return err + } + fmt.Fprintf(buf, "\t\t}\n") + fmt.Fprintf(buf, "\t}\n") + return nil +} + +func generateOptionValueRead(buf *bytes.Buffer, elemType types.Type, valueExpr string) error { + // Handle special named types first + if named, ok := elemType.(*types.Named); ok { + typeStr := named.String() + switch typeStr { + case "time.Time", "github.com/apache/fory/go/fory.Date": + fmt.Fprintf(buf, "\t\t\tctx.ReadValue(reflect.ValueOf(&%s).Elem(), fory.RefModeNone, true)\n", valueExpr) + return nil + } + if _, ok := named.Underlying().(*types.Struct); ok { + fmt.Fprintf(buf, "\t\t\tctx.ReadValue(reflect.ValueOf(&%s).Elem(), fory.RefModeNone, true)\n", valueExpr) + return nil + } + } + + if basic, ok := elemType.Underlying().(*types.Basic); ok { + switch basic.Kind() { + case types.Bool: + fmt.Fprintf(buf, "\t\t\t%s = buf.ReadBool(err)\n", valueExpr) + case types.Int8: + fmt.Fprintf(buf, "\t\t\t%s = buf.ReadInt8(err)\n", valueExpr) + case types.Int16: + fmt.Fprintf(buf, "\t\t\t%s = buf.ReadInt16(err)\n", valueExpr) + case types.Int32: + fmt.Fprintf(buf, "\t\t\t%s = buf.ReadVarint32(err)\n", valueExpr) + case types.Int, types.Int64: + fmt.Fprintf(buf, "\t\t\t%s = buf.ReadVarint64(err)\n", valueExpr) + case types.Uint8: + fmt.Fprintf(buf, "\t\t\t%s = buf.ReadByte(err)\n", valueExpr) + case types.Uint16: + fmt.Fprintf(buf, "\t\t\t%s = uint16(buf.ReadInt16(err))\n", valueExpr) + case types.Uint32: + fmt.Fprintf(buf, "\t\t\t%s = uint32(buf.ReadInt32(err))\n", valueExpr) + case types.Uint, types.Uint64: + fmt.Fprintf(buf, "\t\t\t%s = uint64(buf.ReadInt64(err))\n", valueExpr) + case types.Float32: + fmt.Fprintf(buf, "\t\t\t%s = buf.ReadFloat32(err)\n", valueExpr) + case types.Float64: + fmt.Fprintf(buf, "\t\t\t%s = buf.ReadFloat64(err)\n", valueExpr) + case types.String: + fmt.Fprintf(buf, "\t\t\t%s = ctx.ReadString()\n", valueExpr) + default: + fmt.Fprintf(buf, "\t\t\t// TODO: unsupported basic type %s\n", basic.String()) + } + return nil + } + + if slice, ok := elemType.(*types.Slice); ok { + return generateSliceReadInlineNoNull(buf, slice, valueExpr) + } + if mapType, ok := elemType.(*types.Map); ok { + return generateMapReadInlineNoNull(buf, mapType, valueExpr) + } + + unwrappedType := types.Unalias(elemType) + if iface, ok := unwrappedType.(*types.Interface); ok { + if iface.Empty() { + fmt.Fprintf(buf, "\t\t\tctx.ReadValue(reflect.ValueOf(&%s).Elem(), fory.RefModeNone, true)\n", valueExpr) + return nil + } + } + + if _, ok := elemType.Underlying().(*types.Struct); ok { + fmt.Fprintf(buf, "\t\t\tctx.ReadValue(reflect.ValueOf(&%s).Elem(), fory.RefModeNone, true)\n", valueExpr) + return nil + } + + fmt.Fprintf(buf, "\t\t\tctx.ReadValue(reflect.ValueOf(&%s).Elem(), fory.RefModeNone, true)\n", valueExpr) + return nil +} + // Note: generateSliceRead is no longer used since we use WriteReferencable/ReadValue for slice fields // generateSliceRead generates code to deserialize a slice according to the list format func generateSliceRead(buf *bytes.Buffer, sliceType *types.Slice, fieldAccess string) error { @@ -414,6 +541,68 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc return nil } +func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fieldAccess string) error { + elemType := sliceType.Elem() + indent := "\t\t\t" + + unwrappedElem := types.Unalias(elemType) + if iface, ok := unwrappedElem.(*types.Interface); ok && iface.Empty() { + fmt.Fprintf(buf, "%s// Dynamic slice []any handling - no null flag\n", indent) + fmt.Fprintf(buf, "%ssliceLen := int(buf.ReadVaruint32(err))\n", indent) + fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) + fmt.Fprintf(buf, "%s\t%s = make([]any, 0)\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s} else {\n", indent) + fmt.Fprintf(buf, "%s\t_ = buf.ReadInt8(err) // collection flags\n", indent) + fmt.Fprintf(buf, "%s\t%s = make([]any, sliceLen)\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s\tfor i := range %s {\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s\t\tctx.ReadValue(reflect.ValueOf(&%s[i]).Elem(), fory.RefModeTracking, true)\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) + return nil + } + + if isPrimitiveSliceElemType(elemType) { + return generatePrimitiveSliceReadInlineNoNull(buf, sliceType, fieldAccess, indent) + } + + elemIsReferencable := isReferencableType(elemType) + fmt.Fprintf(buf, "%ssliceLen := int(buf.ReadVaruint32(err))\n", indent) + fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) + fmt.Fprintf(buf, "%s\t%s = make(%s, 0)\n", indent, fieldAccess, sliceType.String()) + fmt.Fprintf(buf, "%s} else {\n", indent) + if err := writeSliceReadElements(buf, sliceType, elemType, fieldAccess, elemIsReferencable, indent+"\t"); err != nil { + return err + } + fmt.Fprintf(buf, "%s}\n", indent) + return nil +} + +func generatePrimitiveSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fieldAccess string, indent string) error { + elemType := sliceType.Elem() + basic := elemType.Underlying().(*types.Basic) + switch basic.Kind() { + case types.Bool: + fmt.Fprintf(buf, "%s%s = fory.ReadBoolSlice(buf, err)\n", indent, fieldAccess) + case types.Int8: + fmt.Fprintf(buf, "%s%s = fory.ReadInt8Slice(buf, err)\n", indent, fieldAccess) + case types.Uint8: + fmt.Fprintf(buf, "%s%s = fory.ReadByteSlice(buf, err)\n", indent, fieldAccess) + case types.Int16: + fmt.Fprintf(buf, "%s%s = fory.ReadInt16Slice(buf, err)\n", indent, fieldAccess) + case types.Int32: + fmt.Fprintf(buf, "%s%s = fory.ReadInt32Slice(buf, err)\n", indent, fieldAccess) + case types.Int64: + fmt.Fprintf(buf, "%s%s = fory.ReadInt64Slice(buf, err)\n", indent, fieldAccess) + case types.Float32: + fmt.Fprintf(buf, "%s%s = fory.ReadFloat32Slice(buf, err)\n", indent, fieldAccess) + case types.Float64: + fmt.Fprintf(buf, "%s%s = fory.ReadFloat64Slice(buf, err)\n", indent, fieldAccess) + default: + return fmt.Errorf("unsupported primitive type for ARRAY protocol: %s", basic.String()) + } + return nil +} + // writeSliceReadElements generates the element reading code for a slice with specified indentation func writeSliceReadElements(buf *bytes.Buffer, sliceType *types.Slice, elemType types.Type, fieldAccess string, elemIsReferencable bool, indent string) error { // ReadData collection header @@ -760,6 +949,33 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st return nil } +func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAccess string) error { + keyType := mapType.Key() + valueType := mapType.Elem() + + keyIsInterface := false + valueIsInterface := false + unwrappedKey := types.Unalias(keyType) + unwrappedValue := types.Unalias(valueType) + if iface, ok := unwrappedKey.(*types.Interface); ok && iface.Empty() { + keyIsInterface = true + } + if iface, ok := unwrappedValue.(*types.Interface); ok && iface.Empty() { + valueIsInterface = true + } + + indent := "\t\t\t" + fmt.Fprintf(buf, "%smapLen := int(buf.ReadVaruint32(err))\n", indent) + fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent) + fmt.Fprintf(buf, "%s\t%s = make(%s)\n", indent, fieldAccess, mapType.String()) + fmt.Fprintf(buf, "%s} else {\n", indent) + if err := writeMapReadChunks(buf, mapType, fieldAccess, keyType, valueType, keyIsInterface, valueIsInterface, indent+"\t"); err != nil { + return err + } + fmt.Fprintf(buf, "%s}\n", indent) + return nil +} + // writeMapReadChunks generates the map chunk reading code with specified indentation func writeMapReadChunks(buf *bytes.Buffer, mapType *types.Map, fieldAccess string, keyType, valueType types.Type, keyIsInterface, valueIsInterface bool, indent string) error { fmt.Fprintf(buf, "%s%s = make(%s, mapLen)\n", indent, fieldAccess, mapType.String()) diff --git a/go/fory/codegen/encoder.go b/go/fory/codegen/encoder.go index d3552681fa..b2178d3d3d 100644 --- a/go/fory/codegen/encoder.go +++ b/go/fory/codegen/encoder.go @@ -76,6 +76,9 @@ func generateFieldWriteTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t// Field: %s (%s)\n", field.GoName, field.Type.String()) fieldAccess := fmt.Sprintf("v.%s", field.GoName) + if field.IsOptional { + return generateOptionWriteTyped(buf, field, fieldAccess) + } // Handle special named types first // According to new spec, time types are "other internal types" and need WriteValue @@ -224,6 +227,108 @@ func generateFieldWriteTyped(buf *bytes.Buffer, field *FieldInfo) error { return nil } +func generateOptionWriteTyped(buf *bytes.Buffer, field *FieldInfo, fieldAccess string) error { + elemType := field.OptionalElem + if elemType == nil { + fmt.Fprintf(buf, "\tctx.WriteValue(reflect.ValueOf(%s), fory.RefModeTracking, true)\n", fieldAccess) + return nil + } + fmt.Fprintf(buf, "\tif !%s.Has {\n", fieldAccess) + fmt.Fprintf(buf, "\t\tbuf.WriteInt8(fory.NullFlag)\n") + fmt.Fprintf(buf, "\t} else {\n") + if isReferencableType(elemType) { + fmt.Fprintf(buf, "\t\tif ctx.TrackRef() {\n") + fmt.Fprintf(buf, "\t\t\trefWritten, err := ctx.RefResolver().WriteRefOrNull(buf, reflect.ValueOf(%s.Value))\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\tif err != nil {\n") + fmt.Fprintf(buf, "\t\t\t\treturn err\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif refWritten {\n") + fmt.Fprintf(buf, "\t\t\t\treturn nil\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t} else {\n") + fmt.Fprintf(buf, "\t\t\tbuf.WriteInt8(fory.NotNullValueFlag)\n") + fmt.Fprintf(buf, "\t\t}\n") + } else { + fmt.Fprintf(buf, "\t\tbuf.WriteInt8(fory.NotNullValueFlag)\n") + } + if err := generateOptionValueWrite(buf, elemType, fmt.Sprintf("%s.Value", fieldAccess)); err != nil { + return err + } + fmt.Fprintf(buf, "\t}\n") + return nil +} + +func generateOptionValueWrite(buf *bytes.Buffer, elemType types.Type, valueExpr string) error { + // Handle special named types first + if named, ok := elemType.(*types.Named); ok { + typeStr := named.String() + switch typeStr { + case "time.Time", "github.com/apache/fory/go/fory.Date": + fmt.Fprintf(buf, "\t\tctx.WriteValue(reflect.ValueOf(%s), fory.RefModeNone, true)\n", valueExpr) + return nil + } + if _, ok := named.Underlying().(*types.Struct); ok { + fmt.Fprintf(buf, "\t\tctx.WriteValue(reflect.ValueOf(%s), fory.RefModeNone, true)\n", valueExpr) + return nil + } + } + + if basic, ok := elemType.Underlying().(*types.Basic); ok { + switch basic.Kind() { + case types.Bool: + fmt.Fprintf(buf, "\t\tbuf.WriteBool(%s)\n", valueExpr) + case types.Int8: + fmt.Fprintf(buf, "\t\tbuf.WriteByte_(byte(%s))\n", valueExpr) + case types.Int16: + fmt.Fprintf(buf, "\t\tbuf.WriteInt16(%s)\n", valueExpr) + case types.Int32: + fmt.Fprintf(buf, "\t\tbuf.WriteVarint32(%s)\n", valueExpr) + case types.Int, types.Int64: + fmt.Fprintf(buf, "\t\tbuf.WriteVarint64(%s)\n", valueExpr) + case types.Uint8: + fmt.Fprintf(buf, "\t\tbuf.WriteByte_(%s)\n", valueExpr) + case types.Uint16: + fmt.Fprintf(buf, "\t\tbuf.WriteInt16(int16(%s))\n", valueExpr) + case types.Uint32: + fmt.Fprintf(buf, "\t\tbuf.WriteInt32(int32(%s))\n", valueExpr) + case types.Uint, types.Uint64: + fmt.Fprintf(buf, "\t\tbuf.WriteInt64(int64(%s))\n", valueExpr) + case types.Float32: + fmt.Fprintf(buf, "\t\tbuf.WriteFloat32(%s)\n", valueExpr) + case types.Float64: + fmt.Fprintf(buf, "\t\tbuf.WriteFloat64(%s)\n", valueExpr) + case types.String: + fmt.Fprintf(buf, "\t\tctx.WriteString(%s)\n", valueExpr) + default: + fmt.Fprintf(buf, "\t\t// TODO: unsupported basic type %s\n", basic.String()) + } + return nil + } + + if slice, ok := elemType.(*types.Slice); ok { + return generateSliceWriteInlineNoNull(buf, slice, valueExpr) + } + if mapType, ok := elemType.(*types.Map); ok { + return generateMapWriteInlineNoNull(buf, mapType, valueExpr) + } + + unwrappedType := types.Unalias(elemType) + if iface, ok := unwrappedType.(*types.Interface); ok { + if iface.Empty() { + fmt.Fprintf(buf, "\t\tctx.WriteValue(reflect.ValueOf(%s), fory.RefModeNone, true)\n", valueExpr) + return nil + } + } + + if _, ok := elemType.Underlying().(*types.Struct); ok { + fmt.Fprintf(buf, "\t\tctx.WriteValue(reflect.ValueOf(%s), fory.RefModeNone, true)\n", valueExpr) + return nil + } + + fmt.Fprintf(buf, "\t\tctx.WriteValue(reflect.ValueOf(%s), fory.RefModeNone, true)\n", valueExpr) + return nil +} + // generateElementTypeIDWrite generates code to write the element type ID for slice serialization func generateElementTypeIDWrite(buf *bytes.Buffer, elemType types.Type) error { // Handle basic types @@ -377,6 +482,63 @@ func generateSliceWriteInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAc return nil } +func generateSliceWriteInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fieldAccess string) error { + elemType := sliceType.Elem() + indent := "\t\t" + + // Dynamic slice []any handling (no null flag) + unwrappedElem := types.Unalias(elemType) + if iface, ok := unwrappedElem.(*types.Interface); ok && iface.Empty() { + fmt.Fprintf(buf, "%s// Dynamic slice []any handling - no null flag\n", indent) + fmt.Fprintf(buf, "%ssliceLen := 0\n", indent) + fmt.Fprintf(buf, "%sif %s != nil {\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s\tsliceLen = len(%s)\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s}\n", indent) + fmt.Fprintf(buf, "%sbuf.WriteVaruint32(uint32(sliceLen))\n", indent) + fmt.Fprintf(buf, "%sif sliceLen > 0 {\n", indent) + fmt.Fprintf(buf, "%s\tbuf.WriteInt8(1) // CollectionTrackingRef only\n", indent) + fmt.Fprintf(buf, "%s\tfor _, elem := range %s {\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s\t\tctx.WriteValue(reflect.ValueOf(elem), fory.RefModeTracking, true)\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) + return nil + } + + // Primitive slice - use ARRAY protocol helpers without null flag + if isPrimitiveSliceElemType(elemType) { + basic := elemType.Underlying().(*types.Basic) + return writePrimitiveSliceCall(buf, basic, fieldAccess, indent) + } + + // Non-primitive slices use LIST protocol, no null flag + elemIsReferencable := isReferencableType(elemType) + fmt.Fprintf(buf, "%ssliceLen := 0\n", indent) + fmt.Fprintf(buf, "%sif %s != nil {\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s\tsliceLen = len(%s)\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s}\n", indent) + fmt.Fprintf(buf, "%sbuf.WriteVaruint32(uint32(sliceLen))\n", indent) + fmt.Fprintf(buf, "%sif sliceLen > 0 {\n", indent) + fmt.Fprintf(buf, "%s\tcollectFlag := 12 // CollectionIsSameType | CollectionIsDeclElementType\n", indent) + if elemIsReferencable { + fmt.Fprintf(buf, "%s\tif ctx.TrackRef() {\n", indent) + fmt.Fprintf(buf, "%s\t\tcollectFlag |= 1 // CollectionTrackingRef\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + } + fmt.Fprintf(buf, "%s\tbuf.WriteInt8(int8(collectFlag))\n", indent) + fmt.Fprintf(buf, "%s\tfor _, elem := range %s {\n", indent, fieldAccess) + if elemIsReferencable { + fmt.Fprintf(buf, "%s\t\tif ctx.TrackRef() {\n", indent) + fmt.Fprintf(buf, "%s\t\t\tbuf.WriteInt8(-1) // NotNullValueFlag for element\n", indent) + fmt.Fprintf(buf, "%s\t\t}\n", indent) + } + if err := generateSliceElementWriteInlineIndented(buf, elemType, "elem", indent+"\t\t"); err != nil { + return err + } + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) + return nil +} + // isPrimitiveSliceElemType checks if the element type is a primitive type that uses ARRAY protocol func isPrimitiveSliceElemType(elemType types.Type) bool { if basic, ok := elemType.Underlying().(*types.Basic); ok { @@ -512,6 +674,36 @@ func generateMapWriteInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess s return nil } +func generateMapWriteInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAccess string) error { + keyType := mapType.Key() + valueType := mapType.Elem() + + keyIsInterface := false + valueIsInterface := false + unwrappedKey := types.Unalias(keyType) + unwrappedValue := types.Unalias(valueType) + if iface, ok := unwrappedKey.(*types.Interface); ok && iface.Empty() { + keyIsInterface = true + } + if iface, ok := unwrappedValue.(*types.Interface); ok && iface.Empty() { + valueIsInterface = true + } + + indent := "\t\t" + fmt.Fprintf(buf, "%smapLen := 0\n", indent) + fmt.Fprintf(buf, "%sif %s != nil {\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s\tmapLen = len(%s)\n", indent, fieldAccess) + fmt.Fprintf(buf, "%s}\n", indent) + fmt.Fprintf(buf, "%sbuf.WriteVaruint32(uint32(mapLen))\n", indent) + + // Write map chunks without null flag (xlang style) + if err := writeMapChunksCode(buf, keyType, valueType, fieldAccess, keyIsInterface, valueIsInterface, indent); err != nil { + return err + } + + return nil +} + // writeMapChunksCode generates the map chunk writing code with specified indentation func writeMapChunksCode(buf *bytes.Buffer, keyType, valueType types.Type, fieldAccess string, keyIsInterface, valueIsInterface bool, indent string) error { // WriteData chunks for non-empty map diff --git a/go/fory/codegen/utils.go b/go/fory/codegen/utils.go index 8fefa0230e..d18ac944f7 100644 --- a/go/fory/codegen/utils.go +++ b/go/fory/codegen/utils.go @@ -18,6 +18,7 @@ package codegen import ( + "fmt" "go/types" "sort" "unicode" @@ -33,9 +34,11 @@ type FieldInfo struct { Index int // Original field index in struct IsPrimitive bool // Whether it's a Fory primitive type IsPointer bool // Whether it's a pointer type + IsOptional bool // Whether it's a fory optional.Optional[T] Nullable bool // Whether the field can be null (pointer types) TypeID string // Fory TypeID for sorting PrimitiveSize int // Size for primitive type sorting + OptionalElem types.Type // Element type for optional.Optional[T] } // StructInfo contains metadata about a struct to generate code for @@ -56,10 +59,33 @@ func toSnakeCase(s string) string { return string(result) } +func getOptionalElementType(t types.Type) (types.Type, bool) { + t = types.Unalias(t) + named, ok := t.(*types.Named) + if !ok { + return nil, false + } + obj := named.Obj() + if obj == nil || obj.Name() != "Optional" { + return nil, false + } + if obj.Pkg() == nil || obj.Pkg().Path() != "github.com/apache/fory/go/fory/optional" { + return nil, false + } + typeArgs := named.TypeArgs() + if typeArgs == nil || typeArgs.Len() != 1 { + return nil, false + } + return typeArgs.At(0), true +} + // isSupportedFieldType checks if a field type is supported func isSupportedFieldType(t types.Type) bool { // Unwrap alias types (e.g., 'any' is an alias for 'interface{}') t = types.Unalias(t) + if elem, ok := getOptionalElementType(t); ok { + t = elem + } // Handle pointer types if ptr, ok := t.(*types.Pointer); ok { @@ -116,6 +142,9 @@ func isSupportedFieldType(t types.Type) bool { func isPrimitiveType(t types.Type) bool { // Unwrap alias types t = types.Unalias(t) + if elem, ok := getOptionalElementType(t); ok { + t = elem + } // Handle pointer types if ptr, ok := t.(*types.Pointer); ok { @@ -142,6 +171,9 @@ func isPrimitiveType(t types.Type) bool { func getTypeID(t types.Type) string { // Unwrap alias types t = types.Unalias(t) + if elem, ok := getOptionalElementType(t); ok { + t = elem + } // Handle pointer types if ptr, ok := t.(*types.Pointer); ok { @@ -250,6 +282,9 @@ func getTypeID(t types.Type) string { func getPrimitiveSize(t types.Type) int { // Unwrap alias types t = types.Unalias(t) + if elem, ok := getOptionalElementType(t); ok { + t = elem + } // Handle pointer types if ptr, ok := t.(*types.Pointer); ok { @@ -483,6 +518,22 @@ func analyzeField(field *types.Var, index int) (*FieldInfo, error) { return nil, nil // Skip unsupported types } + optionalElem, isOptional := getOptionalElementType(fieldType) + if isOptional && optionalElem != nil { + base := optionalElem + for { + if ptr, ok := base.(*types.Pointer); ok { + base = ptr.Elem() + continue + } + break + } + switch base.Underlying().(type) { + case *types.Struct, *types.Slice, *types.Map: + return nil, fmt.Errorf("field %s: optional.Optional is not supported for struct/slice/map", goName) + } + } + // Analyze type information isPrimitive := isPrimitiveType(fieldType) isPointer := false @@ -505,8 +556,10 @@ func analyzeField(field *types.Var, index int) (*FieldInfo, error) { Index: index, IsPrimitive: isPrimitive, IsPointer: isPointer, - Nullable: isPointer, // Pointer types are nullable, slices/maps are non-nullable in xlang mode + IsOptional: isOptional, + Nullable: isPointer || isOptional, // Pointer and optional types are nullable in xlang mode TypeID: typeID, PrimitiveSize: primitiveSize, + OptionalElem: optionalElem, }, nil } diff --git a/go/fory/field_info.go b/go/fory/field_info.go index 151e300428..7bff20735f 100644 --- a/go/fory/field_info.go +++ b/go/fory/field_info.go @@ -43,6 +43,10 @@ type FieldMeta struct { FieldIndex int // -1 if field doesn't exist in current struct (for compatible mode) FieldDef FieldDef // original FieldDef from remote TypeDef (for compatible mode skip) + // Optional fields (fory/optional.Optional[T]) + IsOptional bool + OptionalInfo optionalInfo + // Pre-computed sizes (for fixed primitives) FixedSize int // 0 if not fixed-size, else 1/2/4/8 @@ -147,6 +151,10 @@ func GroupFields(fields []FieldInfo) FieldGroup { // Categorize fields for i := range fields { field := &fields[i] + if field.Meta.IsOptional { + g.RemainingFields = append(g.RemainingFields, *field) + continue + } if isFixedSizePrimitive(field.DispatchId, field.Meta.Nullable) { // Non-nullable fixed-size primitives only field.Meta.FixedSize = getFixedSizeByDispatchId(field.DispatchId) @@ -445,6 +453,9 @@ func isUnionType(t reflect.Type) bool { if t == nil { return false } + if info, ok := getOptionalInfo(t); ok { + t = info.valueType + } if t.Kind() == reflect.Ptr { t = t.Elem() } @@ -462,6 +473,9 @@ func isStructField(t reflect.Type) bool { if t == nil { return false } + if info, ok := getOptionalInfo(t); ok { + t = info.valueType + } if isUnionType(t) { return false } @@ -748,6 +762,12 @@ func typesCompatible(actual, expected reflect.Type) bool { if actual == nil || expected == nil { return false } + if info, ok := getOptionalInfo(actual); ok { + actual = info.valueType + } + if info, ok := getOptionalInfo(expected); ok { + expected = info.valueType + } if actual == expected { return true } @@ -783,6 +803,12 @@ func elementTypesCompatible(actual, expected reflect.Type) bool { if actual == nil || expected == nil { return false } + if info, ok := getOptionalInfo(actual); ok { + actual = info.valueType + } + if info, ok := getOptionalInfo(expected); ok { + expected = info.valueType + } if actual == expected || actual.AssignableTo(expected) || expected.AssignableTo(actual) { return true } @@ -796,6 +822,9 @@ func elementTypesCompatible(actual, expected reflect.Type) bool { // This is used when the type is not registered in typesInfo // Note: Uses VARINT32/VARINT64/VAR_UINT32/VAR_UINT64 to match Java xlang mode and Rust func typeIdFromKind(type_ reflect.Type) TypeId { + if info, ok := getOptionalInfo(type_); ok { + return typeIdFromKind(info.valueType) + } switch type_.Kind() { case reflect.Bool: return BOOL diff --git a/go/fory/optional/optional.go b/go/fory/optional/optional.go new file mode 100644 index 0000000000..d66bcba072 --- /dev/null +++ b/go/fory/optional/optional.go @@ -0,0 +1,249 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package optional + +// Optional represents an optional value without pointer indirection. +type Optional[T any] struct { + Value T + Has bool +} + +// Some returns an Optional containing a value. +func Some[T any](v T) Optional[T] { + return Optional[T]{Value: v, Has: true} +} + +// None returns an empty Optional. +func None[T any]() Optional[T] { + return Optional[T]{} +} + +// FromPtr converts a pointer to an Optional. +func FromPtr[T any](v *T) Optional[T] { + if v == nil { + return None[T]() + } + return Some(*v) +} + +// Ptr returns a pointer to the contained value or nil. +func (o Optional[T]) Ptr() *T { + if !o.Has { + return nil + } + v := o.Value + return &v +} + +// IsSome reports whether the optional contains a value. +func (o Optional[T]) IsSome() bool { return o.Has } + +// IsNone reports whether the optional is empty. +func (o Optional[T]) IsNone() bool { return !o.Has } + +// Expect returns the contained value or panics with the provided message. +func (o Optional[T]) Expect(message string) T { + if o.Has { + return o.Value + } + panic(message) +} + +// Unwrap returns the contained value or panics. +func (o Optional[T]) Unwrap() T { + if o.Has { + return o.Value + } + panic("optional: unwrap on None") +} + +// UnwrapOr returns the contained value or a default. +func (o Optional[T]) UnwrapOr(defaultValue T) T { + if o.Has { + return o.Value + } + return defaultValue +} + +// UnwrapOrDefault returns the contained value or the zero value. +func (o Optional[T]) UnwrapOrDefault() T { + if o.Has { + return o.Value + } + var zero T + return zero +} + +// UnwrapOrElse returns the contained value or computes a default. +func (o Optional[T]) UnwrapOrElse(defaultFn func() T) T { + if o.Has { + return o.Value + } + return defaultFn() +} + +// Map maps an Optional[T] to Optional[U] by applying a function. +func Map[T, U any](o Optional[T], f func(T) U) Optional[U] { + if o.Has { + return Some(f(o.Value)) + } + return None[U]() +} + +// MapOr applies a function to the contained value or returns a default. +func MapOr[T, U any](o Optional[T], defaultValue U, f func(T) U) U { + if o.Has { + return f(o.Value) + } + return defaultValue +} + +// MapOrElse applies a function to the contained value or computes a default. +func MapOrElse[T, U any](o Optional[T], defaultFn func() U, f func(T) U) U { + if o.Has { + return f(o.Value) + } + return defaultFn() +} + +// And returns None if either option is None, otherwise returns the second option. +func And[T, U any](o Optional[T], other Optional[U]) Optional[U] { + if o.Has { + return other + } + return None[U]() +} + +// AndThen returns None if this option is None, otherwise calls f and returns its result. +func AndThen[T, U any](o Optional[T], f func(T) Optional[U]) Optional[U] { + if o.Has { + return f(o.Value) + } + return None[U]() +} + +// Or returns the option if it is Some, otherwise returns other. +func (o Optional[T]) Or(other Optional[T]) Optional[T] { + if o.Has { + return o + } + return other +} + +// OrElse returns the option if it is Some, otherwise returns the result of f. +func (o Optional[T]) OrElse(f func() Optional[T]) Optional[T] { + if o.Has { + return o + } + return f() +} + +// Filter returns None if the predicate returns false. +func (o Optional[T]) Filter(predicate func(T) bool) Optional[T] { + if o.Has && predicate(o.Value) { + return o + } + return None[T]() +} + +// Result represents a simplified Result type for OkOr helpers. +type Result[T any] struct { + Value T + Err error +} + +// OkOr transforms the option into a Result, using err if None. +func (o Optional[T]) OkOr(err error) Result[T] { + if o.Has { + return Result[T]{Value: o.Value} + } + return Result[T]{Err: err} +} + +// OkOrElse transforms the option into a Result, using a function to produce the error. +func (o Optional[T]) OkOrElse(errFn func() error) Result[T] { + if o.Has { + return Result[T]{Value: o.Value} + } + return Result[T]{Err: errFn()} +} + +// Take takes the value out, leaving None in its place. +func (o *Optional[T]) Take() Optional[T] { + if o == nil || !o.Has { + return None[T]() + } + v := o.Value + o.Has = false + var zero T + o.Value = zero + return Some(v) +} + +// Set sets the option to Some(value). +func (o *Optional[T]) Set(v T) { + if o == nil { + return + } + o.Value = v + o.Has = true +} + +// Flatten transforms Optional[Optional[T]] into Optional[T]. +func Flatten[T any](o Optional[Optional[T]]) Optional[T] { + if !o.Has { + return None[T]() + } + return o.Value +} + +// Int8 wraps an int8 value in Optional. +func Int8(v int8) Optional[int8] { return Some(v) } + +// Int16 wraps an int16 value in Optional. +func Int16(v int16) Optional[int16] { return Some(v) } + +// Int32 wraps an int32 value in Optional. +func Int32(v int32) Optional[int32] { return Some(v) } + +// Int64 wraps an int64 value in Optional. +func Int64(v int64) Optional[int64] { return Some(v) } + +// Int wraps an int value in Optional. +func Int(v int) Optional[int] { return Some(v) } + +// Uint8 wraps a uint8 value in Optional. +func Uint8(v uint8) Optional[uint8] { return Some(v) } + +// Uint16 wraps a uint16 value in Optional. +func Uint16(v uint16) Optional[uint16] { return Some(v) } + +// Uint32 wraps a uint32 value in Optional. +func Uint32(v uint32) Optional[uint32] { return Some(v) } + +// Uint64 wraps a uint64 value in Optional. +func Uint64(v uint64) Optional[uint64] { return Some(v) } + +// Uint wraps a uint value in Optional. +func Uint(v uint) Optional[uint] { return Some(v) } + +// String wraps a string value in Optional. +func String(v string) Optional[string] { return Some(v) } + +// Bool wraps a bool value in Optional. +func Bool(v bool) Optional[bool] { return Some(v) } diff --git a/go/fory/optional_serializer.go b/go/fory/optional_serializer.go new file mode 100644 index 0000000000..b96a31a412 --- /dev/null +++ b/go/fory/optional_serializer.go @@ -0,0 +1,271 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "fmt" + "reflect" + "strings" +) + +const optionalPkgPath = "github.com/apache/fory/go/fory/optional" + +// optionalInfo describes the Optional[T] layout for fast access. +type optionalInfo struct { + valueType reflect.Type + valueOffset uintptr + hasOffset uintptr +} + +func getOptionalInfo(type_ reflect.Type) (optionalInfo, bool) { + if type_ == nil { + return optionalInfo{}, false + } + if type_.Kind() == reflect.Ptr { + return optionalInfo{}, false + } + if type_.Kind() != reflect.Struct { + return optionalInfo{}, false + } + if type_.PkgPath() != optionalPkgPath { + return optionalInfo{}, false + } + name := type_.Name() + if name != "Optional" && !strings.HasPrefix(name, "Optional[") { + return optionalInfo{}, false + } + valueField, ok := type_.FieldByName("Value") + if !ok { + return optionalInfo{}, false + } + hasField, ok := type_.FieldByName("Has") + if !ok || hasField.Type.Kind() != reflect.Bool { + return optionalInfo{}, false + } + return optionalInfo{ + valueType: valueField.Type, + valueOffset: valueField.Offset, + hasOffset: hasField.Offset, + }, true +} + +func validateOptionalValueType(valueType reflect.Type) error { + if valueType == nil { + return fmt.Errorf("optional value type is nil") + } + base := valueType + for base.Kind() == reflect.Ptr { + base = base.Elem() + } + switch base.Kind() { + case reflect.Struct, reflect.Slice, reflect.Map: + return fmt.Errorf("optional.Optional[%s] is not supported for struct/slice/map", valueType.String()) + default: + return nil + } +} + +func isOptionalType(type_ reflect.Type) bool { + _, ok := getOptionalInfo(type_) + return ok +} + +func unwrapOptionalType(type_ reflect.Type) (reflect.Type, bool) { + info, ok := getOptionalInfo(type_) + if !ok { + return type_, false + } + return info.valueType, true +} + +func optionalHasValue(value reflect.Value, info optionalInfo) bool { + if value.Kind() == reflect.Ptr { + if value.IsNil() { + return false + } + value = value.Elem() + } + return value.FieldByName("Has").Bool() +} + +// optionalSerializer handles Optional[T] values by writing null flags and delegating to the element serializer. +type optionalSerializer struct { + optionalType reflect.Type + valueType reflect.Type + valueIndex int + hasIndex int + valueSerializer Serializer +} + +func newOptionalSerializer(optionalType reflect.Type, info optionalInfo, valueSerializer Serializer) *optionalSerializer { + valueField, _ := optionalType.FieldByName("Value") + hasField, _ := optionalType.FieldByName("Has") + return &optionalSerializer{ + optionalType: optionalType, + valueType: info.valueType, + valueIndex: valueField.Index[0], + hasIndex: hasField.Index[0], + valueSerializer: valueSerializer, + } +} + +func (s *optionalSerializer) unwrap(value reflect.Value) reflect.Value { + if value.Kind() == reflect.Ptr { + return value.Elem() + } + return value +} + +func (s *optionalSerializer) has(value reflect.Value) bool { + value = s.unwrap(value) + return value.Field(s.hasIndex).Bool() +} + +func (s *optionalSerializer) valueField(value reflect.Value) reflect.Value { + value = s.unwrap(value) + return value.Field(s.valueIndex) +} + +func (s *optionalSerializer) setHas(value reflect.Value, has bool) { + value = s.unwrap(value) + value.Field(s.hasIndex).SetBool(has) +} + +func (s *optionalSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + if !s.has(value) { + s.writeNull(ctx, refMode, writeType) + return + } + valueField := s.valueField(value) + s.writeValue(ctx, refMode, writeType, valueField) +} + +func (s *optionalSerializer) writeNull(ctx *WriteContext, refMode RefMode, writeType bool) { + switch refMode { + case RefModeTracking, RefModeNullOnly: + ctx.Buffer().WriteInt8(NullFlag) + return + case RefModeNone: + // For RefModeNone, write zero value data without any flag. + zero := reflect.New(s.valueType).Elem() + if writeType { + info, err := ctx.TypeResolver().getTypeInfo(zero, true) + if err != nil { + ctx.SetError(FromError(err)) + return + } + ctx.TypeResolver().WriteTypeInfo(ctx.Buffer(), info, ctx.Err()) + } + s.valueSerializer.WriteData(ctx, zero) + } +} + +func (s *optionalSerializer) writeValue(ctx *WriteContext, refMode RefMode, writeType bool, valueField reflect.Value) { + switch refMode { + case RefModeTracking: + refWritten, err := ctx.RefResolver().WriteRefOrNull(ctx.Buffer(), valueField) + if err != nil { + ctx.SetError(FromError(err)) + return + } + if refWritten { + return + } + case RefModeNullOnly: + ctx.Buffer().WriteInt8(NotNullValueFlag) + case RefModeNone: + // No ref/null flag written. + } + if writeType { + info, err := ctx.TypeResolver().getTypeInfo(valueField, true) + if err != nil { + ctx.SetError(FromError(err)) + return + } + ctx.TypeResolver().WriteTypeInfo(ctx.Buffer(), info, ctx.Err()) + } + s.valueSerializer.WriteData(ctx, valueField) +} + +func (s *optionalSerializer) WriteData(ctx *WriteContext, value reflect.Value) { + // WriteData assumes the value is present and writes data only (no null/ref flags). + valueField := s.valueField(value) + s.valueSerializer.WriteData(ctx, valueField) +} + +func (s *optionalSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + buf := ctx.Buffer() + switch refMode { + case RefModeTracking: + refID, refErr := ctx.RefResolver().TryPreserveRefId(buf) + if refErr != nil { + ctx.SetError(FromError(refErr)) + return + } + if refID < int32(NotNullValueFlag) { + if refID == int32(NullFlag) { + s.setHas(value, false) + return + } + refObj := ctx.RefResolver().GetReadObject(refID) + if refObj.IsValid() { + valueField := s.valueField(value) + if refObj.Type().AssignableTo(valueField.Type()) { + valueField.Set(refObj) + s.setHas(value, true) + return + } + } + } + case RefModeNullOnly: + flag := buf.ReadInt8(ctx.Err()) + if flag == NullFlag { + s.setHas(value, false) + return + } + case RefModeNone: + // No null flag. + } + if readType { + typeID := buf.ReadVaruint32Small7(ctx.Err()) + if ctx.HasError() { + return + } + internalTypeID := TypeId(typeID & 0xFF) + if IsNamespacedType(TypeId(typeID)) || internalTypeID == COMPATIBLE_STRUCT || internalTypeID == STRUCT { + typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(buf, typeID, ctx.Err()) + if structSer, ok := typeInfo.Serializer.(*structSerializer); ok && len(structSer.fieldDefs) > 0 { + valueField := s.valueField(value) + s.setHas(value, true) + structSer.ReadData(ctx, valueField) + return + } + } + } + s.ReadData(ctx, value) +} + +func (s *optionalSerializer) ReadData(ctx *ReadContext, value reflect.Value) { + valueField := s.valueField(value) + s.setHas(value, true) + s.valueSerializer.ReadData(ctx, valueField) +} + +func (s *optionalSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} diff --git a/go/fory/struct.go b/go/fory/struct.go index deea9d2ec9..3110a76717 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -152,6 +152,14 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { } fieldType := field.Type + optionalInfo, isOptional := getOptionalInfo(fieldType) + baseType := fieldType + if isOptional { + if err := validateOptionalValueType(optionalInfo.valueType); err != nil { + return fmt.Errorf("field %s: %w", field.Name, err) + } + baseType = optionalInfo.valueType + } var fieldSerializer Serializer // For any fields, don't get a serializer - use WriteValue/ReadValue instead @@ -179,9 +187,9 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { // Override TypeId based on compress/encoding tags for integer types // This matches the logic in type_def.go:buildFieldDefs - baseKind := fieldType.Kind() + baseKind := baseType.Kind() if baseKind == reflect.Ptr { - baseKind = fieldType.Elem().Kind() + baseKind = baseType.Elem().Kind() } switch baseKind { case reflect.Uint32: @@ -245,15 +253,15 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { if typeResolver.fory.config.IsXlang { // xlang mode: only pointer types are nullable by default per xlang spec // Slices and maps are NOT nullable - they serialize as empty when nil - nullableFlag = fieldType.Kind() == reflect.Ptr + nullableFlag = isOptional || fieldType.Kind() == reflect.Ptr } else { // Native mode: Go's natural semantics - all nil-able types are nullable - nullableFlag = fieldType.Kind() == reflect.Ptr || + nullableFlag = isOptional || fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Map || fieldType.Kind() == reflect.Interface } - if foryTag.NullableSet { + if foryTag.NullableSet && !isOptional { // Override nullable flag if explicitly set in fory tag nullableFlag = foryTag.Nullable } @@ -279,7 +287,7 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { refMode = RefModeNullOnly } // Pre-compute WriteType: true for struct fields in compatible mode - writeType := typeResolver.Compatible() && isStructField(fieldType) + writeType := typeResolver.Compatible() && isStructField(baseType) // Pre-compute DispatchId, with special handling for enum fields and pointer-to-numeric var dispatchId DispatchId @@ -323,6 +331,8 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { FieldIndex: i, WriteType: writeType, HasGenerics: isCollectionType(fieldTypeId), // Container fields have declared element types + IsOptional: isOptional, + OptionalInfo: optionalInfo, TagID: foryTag.ID, HasForyTag: foryTag.HasTag, TagRefSet: foryTag.RefSet, @@ -670,6 +680,15 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err fieldType = remoteType } + optionalInfo, isOptional := getOptionalInfo(fieldType) + baseType := fieldType + if isOptional { + if err := validateOptionalValueType(optionalInfo.valueType); err != nil { + return fmt.Errorf("field %s: %w", def.name, err) + } + baseType = optionalInfo.valueType + } + // Get TypeId from FieldType's TypeId method fieldTypeId := def.fieldType.TypeId() // Pre-compute RefMode based on FieldDef flags (trackingRef and nullable) @@ -680,15 +699,19 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err refMode = RefModeNullOnly } // Pre-compute WriteType: true for struct fields in compatible mode - writeType := typeResolver.Compatible() && isStructField(fieldType) + writeType := typeResolver.Compatible() && isStructField(baseType) // Pre-compute DispatchId, with special handling for pointer-to-numeric and enum fields // IMPORTANT: For compatible mode reading, we must use the REMOTE nullable flag // to determine DispatchId, because Java wrote data with its nullable semantics. var dispatchId DispatchId localKind := fieldType.Kind() + baseKind := localKind + if isOptional { + baseKind = baseType.Kind() + } localIsPtr := localKind == reflect.Ptr - localIsPrimitive := isPrimitiveDispatchKind(localKind) || (localIsPtr && isPrimitiveDispatchKind(fieldType.Elem().Kind())) + localIsPrimitive := isPrimitiveDispatchKind(baseKind) || (localIsPtr && isPrimitiveDispatchKind(fieldType.Elem().Kind())) if localIsPrimitive { if localIsPtr { @@ -706,11 +729,11 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err dispatchId = getDispatchIdFromTypeId(fieldTypeId, true) } else { // Local is T, remote is NOT nullable - use primitive DispatchId - dispatchId = GetDispatchId(fieldType) + dispatchId = GetDispatchId(baseType) } } } else { - dispatchId = GetDispatchId(fieldType) + dispatchId = GetDispatchId(baseType) } if fieldSerializer != nil { if _, ok := fieldSerializer.(*enumSerializer); ok { @@ -735,16 +758,18 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err IsPtr: fieldType != nil && fieldType.Kind() == reflect.Ptr, Serializer: fieldSerializer, Meta: &FieldMeta{ - Name: fieldName, - Type: fieldType, - TypeId: fieldTypeId, - Nullable: def.nullable, // Use remote nullable flag - FieldIndex: fieldIndex, - FieldDef: def, // Save original FieldDef for skipping - WriteType: writeType, - HasGenerics: isCollectionType(fieldTypeId), // Container fields have declared element types - TagID: def.tagID, - HasForyTag: def.tagID >= 0, + Name: fieldName, + Type: fieldType, + TypeId: fieldTypeId, + Nullable: def.nullable, // Use remote nullable flag + FieldIndex: fieldIndex, + FieldDef: def, // Save original FieldDef for skipping + WriteType: writeType, + HasGenerics: isCollectionType(fieldTypeId), // Container fields have declared element types + IsOptional: isOptional, + OptionalInfo: optionalInfo, + TagID: def.tagID, + HasForyTag: def.tagID >= 0, }, } fields = append(fields, fieldInfo) @@ -777,8 +802,8 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err // Local nullable is determined by whether the Go field is a pointer type if i < len(s.fieldDefs) && field.Meta.FieldIndex >= 0 { remoteNullable := s.fieldDefs[i].nullable - // Check if local Go field is a pointer type (can be nil = nullable) - localNullable := field.IsPtr + // Check if local Go field is nullable (pointer or Option) + localNullable := field.IsPtr || field.Meta.IsOptional if remoteNullable != localNullable { s.typeDefDiffers = true break @@ -824,9 +849,13 @@ func (s *structSerializer) computeHash() int32 { if isUserDefinedType(int16(typeId)) { typeId = UNKNOWN } + fieldTypeForHash := field.Meta.Type + if field.Meta.IsOptional { + fieldTypeForHash = field.Meta.OptionalInfo.valueType + } // For fixed-size arrays with primitive elements, use primitive array type IDs - if field.Meta.Type.Kind() == reflect.Array { - elemKind := field.Meta.Type.Elem().Kind() + if fieldTypeForHash.Kind() == reflect.Array { + elemKind := fieldTypeForHash.Elem().Kind() switch elemKind { case reflect.Int8: typeId = INT8_ARRAY @@ -851,13 +880,13 @@ func (s *structSerializer) computeHash() int32 { default: typeId = LIST } - } else if field.Meta.Type.Kind() == reflect.Slice { + } else if fieldTypeForHash.Kind() == reflect.Slice { if !isPrimitiveArrayType(int16(typeId)) && typeId != BINARY { typeId = LIST } - } else if field.Meta.Type.Kind() == reflect.Map { + } else if fieldTypeForHash.Kind() == reflect.Map { // fory.Set[T] is defined as map[T]struct{} - check for struct{} elem type - if isSetReflectType(field.Meta.Type) { + if isSetReflectType(fieldTypeForHash) { typeId = SET } else { typeId = MAP @@ -869,13 +898,17 @@ func (s *structSerializer) computeHash() int32 { // - Default: false for ALL fields (xlang default - aligned with all languages) // - Primitives are always non-nullable // - Can be overridden by explicit fory tag - nullable := false // Default to nullable=false for xlang mode - if field.Meta.TagNullableSet { + nullable := field.Meta.IsOptional // Optional fields are nullable by default + if field.Meta.TagNullableSet && !field.Meta.IsOptional { // Use explicit tag value if set nullable = field.Meta.TagNullable } // Primitives are never nullable, regardless of tag - if isNonNullablePrimitiveKind(field.Meta.Type.Kind()) && !isEnumField { + fieldTypeForNullable := field.Meta.Type + if field.Meta.IsOptional { + fieldTypeForNullable = field.Meta.OptionalInfo.valueType + } + if isNonNullablePrimitiveKind(fieldTypeForNullable.Kind()) && !isEnumField { nullable = false } @@ -1255,6 +1288,20 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { // writeRemainingField writes a non-primitive field (string, slice, map, struct, enum) func (s *structSerializer) writeRemainingField(ctx *WriteContext, ptr unsafe.Pointer, field *FieldInfo, value reflect.Value) { buf := ctx.Buffer() + if field.Meta.IsOptional { + if ptr != nil { + if writeOptionFast(ctx, field, unsafe.Add(ptr, field.Offset)) { + return + } + } + fieldValue := value.Field(field.Meta.FieldIndex) + if field.Serializer != nil { + field.Serializer.Write(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, fieldValue) + } else { + ctx.WriteValue(fieldValue, RefModeTracking, true) + } + return + } // Fast path dispatch using pre-computed DispatchId // ptr must be valid (addressable value) if ptr != nil { @@ -1718,6 +1765,438 @@ func (s *structSerializer) writeRemainingField(ctx *WriteContext, ptr unsafe.Poi } } +func writeOptionFast(ctx *WriteContext, field *FieldInfo, optPtr unsafe.Pointer) bool { + buf := ctx.Buffer() + has := *(*bool)(unsafe.Add(optPtr, field.Meta.OptionalInfo.hasOffset)) + valuePtr := unsafe.Add(optPtr, field.Meta.OptionalInfo.valueOffset) + switch field.DispatchId { + case StringDispatchId: + if field.RefMode != RefModeNone { + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + } else if !has { + ctx.WriteString("") + return true + } + if has { + ctx.WriteString(*(*string)(valuePtr)) + } else { + ctx.WriteString("") + } + return true + case NullableTaggedInt64DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteTaggedInt64(*(*int64)(valuePtr)) + } else { + buf.WriteTaggedInt64(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteTaggedInt64(*(*int64)(valuePtr)) + return true + case NullableTaggedUint64DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteTaggedUint64(*(*uint64)(valuePtr)) + } else { + buf.WriteTaggedUint64(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteTaggedUint64(*(*uint64)(valuePtr)) + return true + case NullableBoolDispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteBool(*(*bool)(valuePtr)) + } else { + buf.WriteBool(false) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteBool(*(*bool)(valuePtr)) + return true + case NullableInt8DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteInt8(*(*int8)(valuePtr)) + } else { + buf.WriteInt8(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteInt8(*(*int8)(valuePtr)) + return true + case NullableUint8DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteUint8(*(*uint8)(valuePtr)) + } else { + buf.WriteUint8(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(*(*uint8)(valuePtr)) + return true + case NullableInt16DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteInt16(*(*int16)(valuePtr)) + } else { + buf.WriteInt16(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteInt16(*(*int16)(valuePtr)) + return true + case NullableUint16DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteUint16(*(*uint16)(valuePtr)) + } else { + buf.WriteUint16(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint16(*(*uint16)(valuePtr)) + return true + case NullableInt32DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteInt32(*(*int32)(valuePtr)) + } else { + buf.WriteInt32(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteInt32(*(*int32)(valuePtr)) + return true + case NullableUint32DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteUint32(*(*uint32)(valuePtr)) + } else { + buf.WriteUint32(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint32(*(*uint32)(valuePtr)) + return true + case NullableInt64DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteInt64(*(*int64)(valuePtr)) + } else { + buf.WriteInt64(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteInt64(*(*int64)(valuePtr)) + return true + case NullableUint64DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteUint64(*(*uint64)(valuePtr)) + } else { + buf.WriteUint64(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint64(*(*uint64)(valuePtr)) + return true + case NullableFloat32DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteFloat32(*(*float32)(valuePtr)) + } else { + buf.WriteFloat32(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteFloat32(*(*float32)(valuePtr)) + return true + case NullableFloat64DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteFloat64(*(*float64)(valuePtr)) + } else { + buf.WriteFloat64(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteFloat64(*(*float64)(valuePtr)) + return true + case NullableVarint32DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteVarint32(*(*int32)(valuePtr)) + } else { + buf.WriteVarint32(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteVarint32(*(*int32)(valuePtr)) + return true + case NullableVarUint32DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteVaruint32(*(*uint32)(valuePtr)) + } else { + buf.WriteVaruint32(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteVaruint32(*(*uint32)(valuePtr)) + return true + case NullableVarint64DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteVarint64(*(*int64)(valuePtr)) + } else { + buf.WriteVarint64(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteVarint64(*(*int64)(valuePtr)) + return true + case NullableVarUint64DispatchId: + if field.RefMode == RefModeNone { + if has { + buf.WriteVaruint64(*(*uint64)(valuePtr)) + } else { + buf.WriteVaruint64(0) + } + return true + } + if !has { + buf.WriteInt8(NullFlag) + return true + } + buf.WriteInt8(NotNullValueFlag) + buf.WriteVaruint64(*(*uint64)(valuePtr)) + return true + case PrimitiveBoolDispatchId: + if has { + buf.WriteBool(*(*bool)(valuePtr)) + } else { + buf.WriteBool(false) + } + return true + case PrimitiveInt8DispatchId: + if has { + buf.WriteInt8(*(*int8)(valuePtr)) + } else { + buf.WriteInt8(0) + } + return true + case PrimitiveUint8DispatchId: + if has { + buf.WriteUint8(*(*uint8)(valuePtr)) + } else { + buf.WriteUint8(0) + } + return true + case PrimitiveInt16DispatchId: + if has { + buf.WriteInt16(*(*int16)(valuePtr)) + } else { + buf.WriteInt16(0) + } + return true + case PrimitiveUint16DispatchId: + if has { + buf.WriteUint16(*(*uint16)(valuePtr)) + } else { + buf.WriteUint16(0) + } + return true + case PrimitiveInt32DispatchId: + if has { + buf.WriteInt32(*(*int32)(valuePtr)) + } else { + buf.WriteInt32(0) + } + return true + case PrimitiveVarint32DispatchId: + if has { + buf.WriteVarint32(*(*int32)(valuePtr)) + } else { + buf.WriteVarint32(0) + } + return true + case PrimitiveInt64DispatchId: + if has { + buf.WriteInt64(*(*int64)(valuePtr)) + } else { + buf.WriteInt64(0) + } + return true + case PrimitiveVarint64DispatchId: + if has { + buf.WriteVarint64(*(*int64)(valuePtr)) + } else { + buf.WriteVarint64(0) + } + return true + case PrimitiveIntDispatchId: + if has { + buf.WriteVarint64(int64(*(*int)(valuePtr))) + } else { + buf.WriteVarint64(0) + } + return true + case PrimitiveUint32DispatchId: + if has { + buf.WriteUint32(*(*uint32)(valuePtr)) + } else { + buf.WriteUint32(0) + } + return true + case PrimitiveVarUint32DispatchId: + if has { + buf.WriteVaruint32(*(*uint32)(valuePtr)) + } else { + buf.WriteVaruint32(0) + } + return true + case PrimitiveUint64DispatchId: + if has { + buf.WriteUint64(*(*uint64)(valuePtr)) + } else { + buf.WriteUint64(0) + } + return true + case PrimitiveVarUint64DispatchId: + if has { + buf.WriteVaruint64(*(*uint64)(valuePtr)) + } else { + buf.WriteVaruint64(0) + } + return true + case PrimitiveUintDispatchId: + if has { + buf.WriteVaruint64(uint64(*(*uint)(valuePtr))) + } else { + buf.WriteVaruint64(0) + } + return true + case PrimitiveTaggedInt64DispatchId: + if has { + buf.WriteTaggedInt64(*(*int64)(valuePtr)) + } else { + buf.WriteTaggedInt64(0) + } + return true + case PrimitiveTaggedUint64DispatchId: + if has { + buf.WriteTaggedUint64(*(*uint64)(valuePtr)) + } else { + buf.WriteTaggedUint64(0) + } + return true + case PrimitiveFloat32DispatchId: + if has { + buf.WriteFloat32(*(*float32)(valuePtr)) + } else { + buf.WriteFloat32(0) + } + return true + case PrimitiveFloat64DispatchId: + if has { + buf.WriteFloat64(*(*float64)(valuePtr)) + } else { + buf.WriteFloat64(0) + } + return true + default: + return false + } +} + func (s *structSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() @@ -2035,6 +2514,20 @@ func (s *structSerializer) ReadData(ctx *ReadContext, value reflect.Value) { func (s *structSerializer) readRemainingField(ctx *ReadContext, ptr unsafe.Pointer, field *FieldInfo, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() + if field.Meta.IsOptional { + if ptr != nil { + if readOptionFast(ctx, field, unsafe.Add(ptr, field.Offset)) { + return + } + } + fieldValue := value.Field(field.Meta.FieldIndex) + if field.Serializer != nil { + field.Serializer.Read(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, fieldValue) + } else { + ctx.ReadValue(fieldValue, RefModeTracking, true) + } + return + } // Fast path dispatch using pre-computed DispatchId // ptr must be valid (addressable value) if ptr != nil { @@ -2362,6 +2855,309 @@ func (s *structSerializer) readRemainingField(ctx *ReadContext, ptr unsafe.Point } } +func readOptionFast(ctx *ReadContext, field *FieldInfo, optPtr unsafe.Pointer) bool { + buf := ctx.Buffer() + err := ctx.Err() + hasPtr := (*bool)(unsafe.Add(optPtr, field.Meta.OptionalInfo.hasOffset)) + valuePtr := unsafe.Add(optPtr, field.Meta.OptionalInfo.valueOffset) + switch field.DispatchId { + case StringDispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*string)(valuePtr) = "" + return true + } + } + *hasPtr = true + *(*string)(valuePtr) = ctx.ReadString() + return true + case NullableTaggedInt64DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*int64)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*int64)(valuePtr) = buf.ReadTaggedInt64(err) + return true + case NullableTaggedUint64DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*uint64)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*uint64)(valuePtr) = buf.ReadTaggedUint64(err) + return true + case NullableBoolDispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*bool)(valuePtr) = false + return true + } + } + *hasPtr = true + *(*bool)(valuePtr) = buf.ReadBool(err) + return true + case NullableInt8DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*int8)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*int8)(valuePtr) = buf.ReadInt8(err) + return true + case NullableUint8DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*uint8)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*uint8)(valuePtr) = buf.ReadUint8(err) + return true + case NullableInt16DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*int16)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*int16)(valuePtr) = buf.ReadInt16(err) + return true + case NullableUint16DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*uint16)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*uint16)(valuePtr) = buf.ReadUint16(err) + return true + case NullableInt32DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*int32)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*int32)(valuePtr) = buf.ReadInt32(err) + return true + case NullableUint32DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*uint32)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*uint32)(valuePtr) = buf.ReadUint32(err) + return true + case NullableInt64DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*int64)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*int64)(valuePtr) = buf.ReadInt64(err) + return true + case NullableUint64DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*uint64)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*uint64)(valuePtr) = buf.ReadUint64(err) + return true + case NullableFloat32DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*float32)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*float32)(valuePtr) = buf.ReadFloat32(err) + return true + case NullableFloat64DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*float64)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*float64)(valuePtr) = buf.ReadFloat64(err) + return true + case NullableVarint32DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*int32)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*int32)(valuePtr) = buf.ReadVarint32(err) + return true + case NullableVarUint32DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*uint32)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*uint32)(valuePtr) = buf.ReadVaruint32(err) + return true + case NullableVarint64DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*int64)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*int64)(valuePtr) = buf.ReadVarint64(err) + return true + case NullableVarUint64DispatchId: + if field.RefMode != RefModeNone { + flag := buf.ReadInt8(err) + if flag == NullFlag { + *hasPtr = false + *(*uint64)(valuePtr) = 0 + return true + } + } + *hasPtr = true + *(*uint64)(valuePtr) = buf.ReadVaruint64(err) + return true + case PrimitiveBoolDispatchId: + *hasPtr = true + *(*bool)(valuePtr) = buf.ReadBool(err) + return true + case PrimitiveInt8DispatchId: + *hasPtr = true + *(*int8)(valuePtr) = buf.ReadInt8(err) + return true + case PrimitiveUint8DispatchId: + *hasPtr = true + *(*uint8)(valuePtr) = buf.ReadUint8(err) + return true + case PrimitiveInt16DispatchId: + *hasPtr = true + *(*int16)(valuePtr) = buf.ReadInt16(err) + return true + case PrimitiveUint16DispatchId: + *hasPtr = true + *(*uint16)(valuePtr) = buf.ReadUint16(err) + return true + case PrimitiveInt32DispatchId: + *hasPtr = true + *(*int32)(valuePtr) = buf.ReadInt32(err) + return true + case PrimitiveVarint32DispatchId: + *hasPtr = true + *(*int32)(valuePtr) = buf.ReadVarint32(err) + return true + case PrimitiveInt64DispatchId: + *hasPtr = true + *(*int64)(valuePtr) = buf.ReadInt64(err) + return true + case PrimitiveVarint64DispatchId: + *hasPtr = true + *(*int64)(valuePtr) = buf.ReadVarint64(err) + return true + case PrimitiveIntDispatchId: + *hasPtr = true + *(*int)(valuePtr) = int(buf.ReadVarint64(err)) + return true + case PrimitiveUint32DispatchId: + *hasPtr = true + *(*uint32)(valuePtr) = buf.ReadUint32(err) + return true + case PrimitiveVarUint32DispatchId: + *hasPtr = true + *(*uint32)(valuePtr) = buf.ReadVaruint32(err) + return true + case PrimitiveUint64DispatchId: + *hasPtr = true + *(*uint64)(valuePtr) = buf.ReadUint64(err) + return true + case PrimitiveVarUint64DispatchId: + *hasPtr = true + *(*uint64)(valuePtr) = buf.ReadVaruint64(err) + return true + case PrimitiveUintDispatchId: + *hasPtr = true + *(*uint)(valuePtr) = uint(buf.ReadVaruint64(err)) + return true + case PrimitiveTaggedInt64DispatchId: + *hasPtr = true + *(*int64)(valuePtr) = buf.ReadTaggedInt64(err) + return true + case PrimitiveTaggedUint64DispatchId: + *hasPtr = true + *(*uint64)(valuePtr) = buf.ReadTaggedUint64(err) + return true + case PrimitiveFloat32DispatchId: + *hasPtr = true + *(*float32)(valuePtr) = buf.ReadFloat32(err) + return true + case PrimitiveFloat64DispatchId: + *hasPtr = true + *(*float64)(valuePtr) = buf.ReadFloat64(err) + return true + default: + return false + } +} + // readFieldsInOrder reads fields in the order they appear in s.fields (TypeDef order) // This is used in compatible mode where Java writes fields in TypeDef order // Precondition: value.CanAddr() must be true (checked by caller) @@ -2374,6 +3170,15 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val s.skipField(ctx, field) return } + if field.Meta.IsOptional { + fieldValue := value.Field(field.Meta.FieldIndex) + if field.Serializer != nil { + field.Serializer.Read(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, fieldValue) + } else { + ctx.ReadValue(fieldValue, RefModeTracking, true) + } + return + } // Fast path for fixed-size primitive types (no ref flag from remote schema) if isFixedSizePrimitive(field.DispatchId, field.Meta.Nullable) { diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go index 1fcdff63d5..c6c0a48388 100644 --- a/go/fory/struct_test.go +++ b/go/fory/struct_test.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/apache/fory/go/fory/optional" "github.com/stretchr/testify/require" ) @@ -73,6 +74,65 @@ func TestUnsignedTypeSerialization(t *testing.T) { } } +func TestOptionFieldSerialization(t *testing.T) { + type Nested struct { + Name string + } + type OptionStruct struct { + OptInt optional.Optional[int32] + OptZero optional.Optional[int32] + OptString optional.Optional[string] + OptBool optional.Optional[bool] + } + + f := New(WithXlang(true), WithCompatible(false)) + require.NoError(t, f.RegisterStruct(OptionStruct{}, 1100)) + + obj := OptionStruct{ + OptInt: optional.Some[int32](123), + OptZero: optional.Some[int32](0), + OptString: optional.Some("hello"), + OptBool: optional.Some(true), + } + + data, err := f.Serialize(obj) + require.NoError(t, err) + + var result any + err = f.Deserialize(data, &result) + require.NoError(t, err) + + out := result.(*OptionStruct) + require.True(t, out.OptInt.Has) + require.Equal(t, int32(123), out.OptInt.Value) + require.True(t, out.OptZero.Has) + require.Equal(t, int32(0), out.OptZero.Value) + require.True(t, out.OptString.Has) + require.Equal(t, "hello", out.OptString.Value) + require.True(t, out.OptBool.Has) + require.Equal(t, true, out.OptBool.Value) +} + +func TestOptionFieldUnsupportedTypes(t *testing.T) { + type Nested struct { + Name string + } + type OptionStruct struct { + OptStruct optional.Optional[Nested] + } + type OptionSlice struct { + OptSlice optional.Optional[[]int] + } + type OptionMap struct { + OptMap optional.Optional[map[string]int] + } + + f := New(WithXlang(true), WithCompatible(false)) + require.Error(t, f.RegisterStruct(OptionStruct{}, 1101)) + require.Error(t, f.RegisterStruct(OptionSlice{}, 1102)) + require.Error(t, f.RegisterStruct(OptionMap{}, 1103)) +} + // Test struct for compatible mode tests (must be named struct at package level) type SetFieldsStruct struct { SetField Set[string] @@ -200,14 +260,14 @@ func TestSetFieldTypeId(t *testing.T) { field.Meta.Name, field.Meta.Type, field.Meta.TypeId, field.Serializer) if field.Meta.Name == "set_field" { - require.Equal(t, SET, field.Meta.TypeId, "SetField should have TypeId=SET(21)") + require.Equal(t, TypeId(SET), field.Meta.TypeId, "SetField should have TypeId=SET(21)") require.NotNil(t, field.Serializer, "SetField serializer should not be nil") _, isSetSerializer := field.Serializer.(setSerializer) require.True(t, isSetSerializer, "SetField serializer should be setSerializer") } if field.Meta.Name == "map_field" { - require.Equal(t, MAP, field.Meta.TypeId, "MapField should have TypeId=MAP(22)") + require.Equal(t, TypeId(MAP), field.Meta.TypeId, "MapField should have TypeId=MAP(22)") require.NotNil(t, field.Serializer, "MapField serializer should not be nil") } } diff --git a/go/fory/type_def.go b/go/fory/type_def.go index ce02e1e0f0..363f46091d 100644 --- a/go/fory/type_def.go +++ b/go/fory/type_def.go @@ -427,6 +427,13 @@ func buildFieldDefs(fory *Fory, value reflect.Value) ([]FieldDef, error) { nameEncoding := fory.typeResolver.typeNameEncoder.ComputeEncodingWith(fieldName, fieldNameEncodings) + fieldType := field.Type + optionalInfo, isOptional := getOptionalInfo(fieldType) + baseType := fieldType + if isOptional { + baseType = optionalInfo.valueType + } + ft, err := buildFieldType(fory, fieldValue) if err != nil { return nil, fmt.Errorf("failed to build field type for field %s: %w", fieldName, err) @@ -434,10 +441,10 @@ func buildFieldDefs(fory *Fory, value reflect.Value) ([]FieldDef, error) { // Apply encoding override from struct tags if set // This works for both direct types and pointer-wrapped types - baseKind := field.Type.Kind() + baseKind := baseType.Kind() // Handle pointer types - get the element kind if baseKind == reflect.Ptr { - baseKind = field.Type.Elem().Kind() + baseKind = baseType.Elem().Kind() } // Check if we need to override the TypeID based on compress/encoding tags @@ -512,16 +519,16 @@ func buildFieldDefs(fory *Fory, value reflect.Value) ([]FieldDef, error) { if fory.config.IsXlang { // xlang mode: only pointer types are nullable by default per xlang spec // Slices and maps are NOT nullable - they serialize as empty when nil - nullableFlag = field.Type.Kind() == reflect.Ptr + nullableFlag = isOptional || field.Type.Kind() == reflect.Ptr } else { // Native mode: Go's natural semantics - all nil-able types are nullable - nullableFlag = field.Type.Kind() == reflect.Ptr || + nullableFlag = isOptional || field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Map || field.Type.Kind() == reflect.Interface } // Override nullable flag if explicitly set in fory tag - if foryTag.NullableSet { + if foryTag.NullableSet && !isOptional { nullableFlag = foryTag.Nullable } // Primitives are never nullable, regardless of tag @@ -968,6 +975,10 @@ func (d *DynamicFieldType) getTypeInfoWithResolver(resolver *TypeResolver) (Type // buildFieldType builds field type from reflect.Type, handling collection, map recursively func buildFieldType(fory *Fory, fieldValue reflect.Value) (FieldType, error) { fieldType := fieldValue.Type() + if info, ok := getOptionalInfo(fieldType); ok { + fieldType = info.valueType + fieldValue = reflect.Zero(fieldType) + } // Handle Interface type, we can't determine the actual type here, so leave it as dynamic type if fieldType.Kind() == reflect.Interface { return NewDynamicFieldType(UNKNOWN), nil diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index f95b4ba990..1643027e10 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -436,15 +436,22 @@ func (r *TypeResolver) RegisterStruct(type_ reflect.Type, fullTypeID uint32) err tag := type_.Name() serializer := newStructSerializer(type_, tag) r.typeToSerializers[type_] = serializer + if err := serializer.initialize(r); err != nil { + delete(r.typeToSerializers, type_) + return err + } r.typeToTypeInfo[type_] = "@" + tag r.typeInfoToType["@"+tag] = type_ // Create pointer serializer ptrType := reflect.PtrTo(type_) - ptrSerializer := &ptrToValueSerializer{ - valueSerializer: serializer, + ptrSerializer, ok := r.typeToSerializers[ptrType] + if !ok { + ptrSerializer = &ptrToValueSerializer{ + valueSerializer: serializer, + } + r.typeToSerializers[ptrType] = ptrSerializer } - r.typeToSerializers[ptrType] = ptrSerializer r.typeTagToSerializers[tag] = ptrSerializer r.typeToTypeInfo[ptrType] = "*@" + tag r.typeInfoToType["*@"+tag] = ptrType @@ -845,6 +852,9 @@ func (r *TypeResolver) getSerializerByType(type_ reflect.Type, mapInStruct bool) // getTypeIdByType returns the TypeId for a given type, or 0 if not found in typesInfo. // This is used to get the type ID without calling Serializer.TypeId(). func (r *TypeResolver) getTypeIdByType(type_ reflect.Type) TypeId { + if info, ok := getOptionalInfo(type_); ok { + type_ = info.valueType + } if info, ok := r.typesInfo[type_]; ok { return TypeId(info.TypeID & 0xFF) // Extract base type ID } @@ -1398,6 +1408,23 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI } func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s Serializer, err error) { + if info, ok := getOptionalInfo(type_); ok { + optionalType := type_ + if optionalType.Kind() == reflect.Ptr { + optionalType = optionalType.Elem() + } + if err := validateOptionalValueType(info.valueType); err != nil { + return nil, err + } + valueSerializer, err := r.getSerializerByType(info.valueType, false) + if err != nil { + return nil, err + } + if valueSerializer == nil { + return nil, fmt.Errorf("no serializer found for optional element type %s", info.valueType) + } + return newOptionalSerializer(optionalType, info, valueSerializer), nil + } kind := type_.Kind() switch kind { case reflect.Ptr: diff --git a/integration_tests/idl_tests/cpp/main.cc b/integration_tests/idl_tests/cpp/main.cc index 9b87409443..25a544e8b8 100644 --- a/integration_tests/idl_tests/cpp/main.cc +++ b/integration_tests/idl_tests/cpp/main.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -28,6 +29,7 @@ #include "complex_fbs.h" #include "fory/serialization/fory.h" #include "monster.h" +#include "optional_types.h" namespace { @@ -68,6 +70,7 @@ fory::Result RunRoundTrip() { addressbook::RegisterTypes(fory); monster::RegisterTypes(fory); complex_fbs::RegisterTypes(fory); + optional_types::RegisterTypes(fory); addressbook::Person::PhoneNumber mobile; mobile.set_number("555-0100"); @@ -201,6 +204,53 @@ fory::Result RunRoundTrip() { fory::Error::invalid("flatbuffers container roundtrip mismatch")); } + optional_types::AllOptionalTypes all_types; + all_types.set_bool_value(true); + all_types.set_int8_value(12); + all_types.set_int16_value(1234); + all_types.set_int32_value(-123456); + all_types.set_fixed_int32_value(-123456); + all_types.set_varint32_value(-12345); + all_types.set_int64_value(-123456789); + all_types.set_fixed_int64_value(-123456789); + all_types.set_varint64_value(-987654321); + all_types.set_tagged_int64_value(123456789); + all_types.set_uint8_value(200); + all_types.set_uint16_value(60000); + all_types.set_uint32_value(1234567890); + all_types.set_fixed_uint32_value(1234567890); + all_types.set_var_uint32_value(1234567890); + all_types.set_uint64_value(9876543210ULL); + all_types.set_fixed_uint64_value(9876543210ULL); + all_types.set_var_uint64_value(12345678901ULL); + all_types.set_tagged_uint64_value(2222222222ULL); + all_types.set_float16_value(1.5F); + all_types.set_float32_value(2.5F); + all_types.set_float64_value(3.5); + all_types.set_string_value("optional"); + all_types.set_bytes_value({static_cast(1), static_cast(2), + static_cast(3)}); + all_types.set_date_value(fory::serialization::LocalDate(19724)); + all_types.set_timestamp_value( + fory::serialization::Timestamp(std::chrono::seconds(1704164645))); + all_types.set_int32_list({1, 2, 3}); + all_types.set_string_list({"alpha", "beta"}); + all_types.set_int64_map({{"alpha", 10}, {"beta", 20}}); + + optional_types::OptionalHolder holder; + *holder.mutable_all_types() = all_types; + holder.set_choice(optional_types::OptionalUnion::note("optional")); + + FORY_TRY(optional_bytes, fory.serialize(holder)); + FORY_TRY(optional_roundtrip, + fory.deserialize( + optional_bytes.data(), optional_bytes.size())); + + if (!(optional_roundtrip == holder)) { + return fory::Unexpected( + fory::Error::invalid("optional types roundtrip mismatch")); + } + const char *data_file = std::getenv("DATA_FILE"); if (data_file != nullptr && data_file[0] != '\0') { FORY_TRY(payload, ReadFile(data_file)); @@ -252,6 +302,19 @@ fory::Result RunRoundTrip() { FORY_RETURN_IF_ERROR(WriteFile(container_file, peer_bytes)); } + const char *optional_file = std::getenv("DATA_FILE_OPTIONAL_TYPES"); + if (optional_file != nullptr && optional_file[0] != '\0') { + FORY_TRY(payload, ReadFile(optional_file)); + FORY_TRY(peer_holder, fory.deserialize( + payload.data(), payload.size())); + if (!(peer_holder == holder)) { + return fory::Unexpected( + fory::Error::invalid("peer optional payload mismatch")); + } + FORY_TRY(peer_bytes, fory.serialize(peer_holder)); + FORY_RETURN_IF_ERROR(WriteFile(optional_file, peer_bytes)); + } + return fory::Result(); } diff --git a/integration_tests/idl_tests/generate_idl.py b/integration_tests/idl_tests/generate_idl.py index 52164a54ee..c4a5a68fbe 100755 --- a/integration_tests/idl_tests/generate_idl.py +++ b/integration_tests/idl_tests/generate_idl.py @@ -26,6 +26,7 @@ IDL_DIR = Path(__file__).resolve().parent SCHEMAS = [ IDL_DIR / "idl" / "addressbook.fdl", + IDL_DIR / "idl" / "optional_types.fdl", IDL_DIR / "idl" / "monster.fbs", IDL_DIR / "idl" / "complex_fbs.fbs", ] @@ -41,6 +42,7 @@ GO_OUTPUT_OVERRIDES = { "monster.fbs": IDL_DIR / "go" / "monster", "complex_fbs.fbs": IDL_DIR / "go" / "complex_fbs", + "optional_types.fdl": IDL_DIR / "go" / "optional_types", } diff --git a/integration_tests/idl_tests/go/idl_roundtrip_test.go b/integration_tests/idl_tests/go/idl_roundtrip_test.go index 3e4cd4e7ce..510000a57c 100644 --- a/integration_tests/idl_tests/go/idl_roundtrip_test.go +++ b/integration_tests/idl_tests/go/idl_roundtrip_test.go @@ -21,10 +21,13 @@ import ( "os" "reflect" "testing" + "time" fory "github.com/apache/fory/go/fory" + "github.com/apache/fory/go/fory/optional" complexfbs "github.com/apache/fory/integration_tests/idl_tests/go/complex_fbs" monster "github.com/apache/fory/integration_tests/idl_tests/go/monster" + optionaltypes "github.com/apache/fory/integration_tests/idl_tests/go/optional_types" ) func buildAddressBook() AddressBook { @@ -74,6 +77,9 @@ func TestAddressBookRoundTrip(t *testing.T) { if err := complexfbs.RegisterTypes(f); err != nil { t.Fatalf("register flatbuffers types: %v", err) } + if err := optionaltypes.RegisterTypes(f); err != nil { + t.Fatalf("register optional types: %v", err) + } book := buildAddressBook() runLocalRoundTrip(t, f, book) @@ -90,6 +96,10 @@ func TestAddressBookRoundTrip(t *testing.T) { container := buildContainer() runLocalContainerRoundTrip(t, f, container) runFileContainerRoundTrip(t, f, container) + + holder := buildOptionalHolder() + runLocalOptionalRoundTrip(t, f, holder) + runFileOptionalRoundTrip(t, f, holder) } func runLocalRoundTrip(t *testing.T, f *fory.Fory, book AddressBook) { @@ -334,3 +344,233 @@ func runFileContainerRoundTrip(t *testing.T, f *fory.Fory, container complexfbs. t.Fatalf("write data file: %v", err) } } + +func buildOptionalHolder() optionaltypes.OptionalHolder { + dateValue := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC) + timestampValue := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC) + allTypes := &optionaltypes.AllOptionalTypes{ + BoolValue: optional.Some(true), + Int8Value: optional.Some(int8(12)), + Int16Value: optional.Some(int16(1234)), + Int32Value: optional.Some(int32(-123456)), + FixedInt32Value: optional.Some(int32(-123456)), + Varint32Value: optional.Some(int32(-12345)), + Int64Value: optional.Some(int64(-123456789)), + FixedInt64Value: optional.Some(int64(-123456789)), + Varint64Value: optional.Some(int64(-987654321)), + TaggedInt64Value: optional.Some(int64(123456789)), + Uint8Value: optional.Some(uint8(200)), + Uint16Value: optional.Some(uint16(60000)), + Uint32Value: optional.Some(uint32(1234567890)), + FixedUint32Value: optional.Some(uint32(1234567890)), + VarUint32Value: optional.Some(uint32(1234567890)), + Uint64Value: optional.Some(uint64(9876543210)), + FixedUint64Value: optional.Some(uint64(9876543210)), + VarUint64Value: optional.Some(uint64(12345678901)), + TaggedUint64Value: optional.Some(uint64(2222222222)), + Float16Value: optional.Some(float32(1.5)), + Float32Value: optional.Some(float32(2.5)), + Float64Value: optional.Some(3.5), + StringValue: optional.Some("optional"), + BytesValue: []byte{1, 2, 3}, + DateValue: &dateValue, + TimestampValue: ×tampValue, + Int32List: []int32{1, 2, 3}, + StringList: []string{"alpha", "beta"}, + Int64Map: map[string]int64{"alpha": 10, "beta": 20}, + } + unionValue := optionaltypes.NoteOptionalUnion("optional") + return optionaltypes.OptionalHolder{ + AllTypes: allTypes, + Choice: &unionValue, + } +} + +func runLocalOptionalRoundTrip(t *testing.T, f *fory.Fory, holder optionaltypes.OptionalHolder) { + data, err := f.Serialize(holder) + if err != nil { + t.Fatalf("serialize: %v", err) + } + + var out optionaltypes.OptionalHolder + if err := f.Deserialize(data, &out); err != nil { + t.Fatalf("deserialize: %v", err) + } + + assertOptionalHolderEqual(t, holder, out) +} + +func runFileOptionalRoundTrip(t *testing.T, f *fory.Fory, holder optionaltypes.OptionalHolder) { + dataFile := os.Getenv("DATA_FILE_OPTIONAL_TYPES") + if dataFile == "" { + return + } + payload, err := os.ReadFile(dataFile) + if err != nil { + t.Fatalf("read data file: %v", err) + } + + var decoded optionaltypes.OptionalHolder + if err := f.Deserialize(payload, &decoded); err != nil { + t.Fatalf("deserialize peer payload: %v", err) + } + assertOptionalHolderEqual(t, holder, decoded) + + out, err := f.Serialize(decoded) + if err != nil { + t.Fatalf("serialize peer payload: %v", err) + } + if err := os.WriteFile(dataFile, out, 0o644); err != nil { + t.Fatalf("write data file: %v", err) + } +} + +func assertOptionalHolderEqual(t *testing.T, expected, actual optionaltypes.OptionalHolder) { + t.Helper() + if expected.AllTypes == nil || actual.AllTypes == nil { + if expected.AllTypes != actual.AllTypes { + t.Fatalf("optional holder all_types mismatch: %#v != %#v", expected.AllTypes, actual.AllTypes) + } + } else { + assertOptionalTypesEqual(t, expected.AllTypes, actual.AllTypes) + } + if expected.Choice == nil || actual.Choice == nil { + if expected.Choice != actual.Choice { + t.Fatalf("optional holder choice mismatch: %#v != %#v", expected.Choice, actual.Choice) + } + } else { + assertOptionalUnionEqual(t, *expected.Choice, *actual.Choice) + } +} + +func assertOptionalUnionEqual(t *testing.T, expected, actual optionaltypes.OptionalUnion) { + t.Helper() + if expected.Case() != actual.Case() { + t.Fatalf("optional union case mismatch: %v != %v", expected.Case(), actual.Case()) + } + switch expected.Case() { + case optionaltypes.OptionalUnionCaseNote: + expValue, _ := expected.AsNote() + actValue, _ := actual.AsNote() + if expValue != actValue { + t.Fatalf("optional union note mismatch: %v != %v", expValue, actValue) + } + case optionaltypes.OptionalUnionCaseCode: + expValue, _ := expected.AsCode() + actValue, _ := actual.AsCode() + if expValue != actValue { + t.Fatalf("optional union code mismatch: %v != %v", expValue, actValue) + } + case optionaltypes.OptionalUnionCasePayload: + expValue, _ := expected.AsPayload() + actValue, _ := actual.AsPayload() + if expValue == nil || actValue == nil { + if expValue != actValue { + t.Fatalf("optional union payload mismatch: %#v != %#v", expValue, actValue) + } + return + } + assertOptionalTypesEqual(t, expValue, actValue) + default: + t.Fatalf("unexpected optional union case: %v", expected.Case()) + } +} + +func assertOptionalTypesEqual(t *testing.T, expected, actual *optionaltypes.AllOptionalTypes) { + t.Helper() + if expected.BoolValue != actual.BoolValue { + t.Fatalf("bool_value mismatch: %#v != %#v", expected.BoolValue, actual.BoolValue) + } + if expected.Int8Value != actual.Int8Value { + t.Fatalf("int8_value mismatch: %#v != %#v", expected.Int8Value, actual.Int8Value) + } + if expected.Int16Value != actual.Int16Value { + t.Fatalf("int16_value mismatch: %#v != %#v", expected.Int16Value, actual.Int16Value) + } + if expected.Int32Value != actual.Int32Value { + t.Fatalf("int32_value mismatch: %#v != %#v", expected.Int32Value, actual.Int32Value) + } + if expected.FixedInt32Value != actual.FixedInt32Value { + t.Fatalf("fixed_int32_value mismatch: %#v != %#v", expected.FixedInt32Value, actual.FixedInt32Value) + } + if expected.Varint32Value != actual.Varint32Value { + t.Fatalf("varint32_value mismatch: %#v != %#v", expected.Varint32Value, actual.Varint32Value) + } + if expected.Int64Value != actual.Int64Value { + t.Fatalf("int64_value mismatch: %#v != %#v", expected.Int64Value, actual.Int64Value) + } + if expected.FixedInt64Value != actual.FixedInt64Value { + t.Fatalf("fixed_int64_value mismatch: %#v != %#v", expected.FixedInt64Value, actual.FixedInt64Value) + } + if expected.Varint64Value != actual.Varint64Value { + t.Fatalf("varint64_value mismatch: %#v != %#v", expected.Varint64Value, actual.Varint64Value) + } + if expected.TaggedInt64Value != actual.TaggedInt64Value { + t.Fatalf("tagged_int64_value mismatch: %#v != %#v", expected.TaggedInt64Value, actual.TaggedInt64Value) + } + if expected.Uint8Value != actual.Uint8Value { + t.Fatalf("uint8_value mismatch: %#v != %#v", expected.Uint8Value, actual.Uint8Value) + } + if expected.Uint16Value != actual.Uint16Value { + t.Fatalf("uint16_value mismatch: %#v != %#v", expected.Uint16Value, actual.Uint16Value) + } + if expected.Uint32Value != actual.Uint32Value { + t.Fatalf("uint32_value mismatch: %#v != %#v", expected.Uint32Value, actual.Uint32Value) + } + if expected.FixedUint32Value != actual.FixedUint32Value { + t.Fatalf("fixed_uint32_value mismatch: %#v != %#v", expected.FixedUint32Value, actual.FixedUint32Value) + } + if expected.VarUint32Value != actual.VarUint32Value { + t.Fatalf("var_uint32_value mismatch: %#v != %#v", expected.VarUint32Value, actual.VarUint32Value) + } + if expected.Uint64Value != actual.Uint64Value { + t.Fatalf("uint64_value mismatch: %#v != %#v", expected.Uint64Value, actual.Uint64Value) + } + if expected.FixedUint64Value != actual.FixedUint64Value { + t.Fatalf("fixed_uint64_value mismatch: %#v != %#v", expected.FixedUint64Value, actual.FixedUint64Value) + } + if expected.VarUint64Value != actual.VarUint64Value { + t.Fatalf("var_uint64_value mismatch: %#v != %#v", expected.VarUint64Value, actual.VarUint64Value) + } + if expected.TaggedUint64Value != actual.TaggedUint64Value { + t.Fatalf("tagged_uint64_value mismatch: %#v != %#v", expected.TaggedUint64Value, actual.TaggedUint64Value) + } + if expected.Float16Value != actual.Float16Value { + t.Fatalf("float16_value mismatch: %#v != %#v", expected.Float16Value, actual.Float16Value) + } + if expected.Float32Value != actual.Float32Value { + t.Fatalf("float32_value mismatch: %#v != %#v", expected.Float32Value, actual.Float32Value) + } + if expected.Float64Value != actual.Float64Value { + t.Fatalf("float64_value mismatch: %#v != %#v", expected.Float64Value, actual.Float64Value) + } + if expected.StringValue != actual.StringValue { + t.Fatalf("string_value mismatch: %#v != %#v", expected.StringValue, actual.StringValue) + } + if !reflect.DeepEqual(expected.BytesValue, actual.BytesValue) { + t.Fatalf("bytes_value mismatch: %#v != %#v", expected.BytesValue, actual.BytesValue) + } + if expected.DateValue == nil || actual.DateValue == nil { + if expected.DateValue != actual.DateValue { + t.Fatalf("date_value mismatch: %#v != %#v", expected.DateValue, actual.DateValue) + } + } else if !expected.DateValue.Equal(*actual.DateValue) { + t.Fatalf("date_value mismatch: %v != %v", expected.DateValue, actual.DateValue) + } + if expected.TimestampValue == nil || actual.TimestampValue == nil { + if expected.TimestampValue != actual.TimestampValue { + t.Fatalf("timestamp_value mismatch: %#v != %#v", expected.TimestampValue, actual.TimestampValue) + } + } else if !expected.TimestampValue.Equal(*actual.TimestampValue) { + t.Fatalf("timestamp_value mismatch: %v != %v", expected.TimestampValue, actual.TimestampValue) + } + if !reflect.DeepEqual(expected.Int32List, actual.Int32List) { + t.Fatalf("int32_list mismatch: %#v != %#v", expected.Int32List, actual.Int32List) + } + if !reflect.DeepEqual(expected.StringList, actual.StringList) { + t.Fatalf("string_list mismatch: %#v != %#v", expected.StringList, actual.StringList) + } + if !reflect.DeepEqual(expected.Int64Map, actual.Int64Map) { + t.Fatalf("int64_map mismatch: %#v != %#v", expected.Int64Map, actual.Int64Map) + } +} diff --git a/integration_tests/idl_tests/idl/optional_types.fdl b/integration_tests/idl_tests/idl/optional_types.fdl new file mode 100644 index 0000000000..a59baec9b3 --- /dev/null +++ b/integration_tests/idl_tests/idl/optional_types.fdl @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package optional_types; + +message AllOptionalTypes [id=120] { + optional bool bool_value = 1; + optional int8 int8_value = 2; + optional int16 int16_value = 3; + optional int32 int32_value = 4; + optional fixed_int32 fixed_int32_value = 5; + optional varint32 varint32_value = 6; + optional int64 int64_value = 7; + optional fixed_int64 fixed_int64_value = 8; + optional varint64 varint64_value = 9; + optional tagged_int64 tagged_int64_value = 10; + optional uint8 uint8_value = 11; + optional uint16 uint16_value = 12; + optional uint32 uint32_value = 13; + optional fixed_uint32 fixed_uint32_value = 14; + optional var_uint32 var_uint32_value = 15; + optional uint64 uint64_value = 16; + optional fixed_uint64 fixed_uint64_value = 17; + optional var_uint64 var_uint64_value = 18; + optional tagged_uint64 tagged_uint64_value = 19; + optional float16 float16_value = 20; + optional float32 float32_value = 21; + optional float64 float64_value = 22; + optional string string_value = 23; + optional bytes bytes_value = 24; + optional date date_value = 25; + optional timestamp timestamp_value = 26; + optional repeated int32 int32_list = 27; + optional repeated string string_list = 28; + optional map int64_map = 29; +} + +union OptionalUnion [id=121] { + string note = 1; + int32 code = 2; + AllOptionalTypes payload = 3; +} + +message OptionalHolder [id=122] { + AllOptionalTypes all_types = 1; + optional OptionalUnion choice = 2; +} diff --git a/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java b/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java index dd67f51538..9695559c27 100644 --- a/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java +++ b/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java @@ -35,6 +35,10 @@ import complex_fbs.Payload; import complex_fbs.ScalarPack; import complex_fbs.Status; +import optional_types.AllOptionalTypes; +import optional_types.OptionalHolder; +import optional_types.OptionalTypesForyRegistration; +import optional_types.OptionalUnion; import monster.Color; import monster.Monster; import monster.MonsterForyRegistration; @@ -44,6 +48,8 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; +import java.time.Instant; +import java.time.LocalDate; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -117,6 +123,35 @@ public void testPrimitiveTypesRoundTrip() throws Exception { } } + @Test + public void testOptionalTypesRoundTrip() throws Exception { + Fory fory = Fory.builder().withLanguage(Language.XLANG).build(); + OptionalTypesForyRegistration.register(fory); + + OptionalHolder holder = buildOptionalHolder(); + byte[] bytes = fory.serialize(holder); + Object decoded = fory.deserialize(bytes); + + Assert.assertTrue(decoded instanceof OptionalHolder); + Assert.assertEquals(decoded, holder); + + for (String peer : resolvePeers()) { + Path dataFile = Files.createTempFile("idl-optional-" + peer + "-", ".bin"); + dataFile.toFile().deleteOnExit(); + Files.write(dataFile, bytes); + + Map env = new HashMap<>(); + env.put("DATA_FILE_OPTIONAL_TYPES", dataFile.toAbsolutePath().toString()); + PeerCommand command = buildPeerCommand(peer, env); + runPeer(command, peer); + + byte[] peerBytes = Files.readAllBytes(dataFile); + Object roundTrip = fory.deserialize(peerBytes); + Assert.assertTrue(roundTrip instanceof OptionalHolder); + Assert.assertEquals(roundTrip, holder); + } + } + @Test public void testFlatbuffersRoundTrip() throws Exception { Fory fory = Fory.builder().withLanguage(Language.XLANG).build(); @@ -347,6 +382,47 @@ private Monster buildMonster() { return monster; } + private OptionalHolder buildOptionalHolder() { + AllOptionalTypes allTypes = new AllOptionalTypes(); + allTypes.setBoolValue(true); + allTypes.setInt8Value((byte) 12); + allTypes.setInt16Value((short) 1234); + allTypes.setInt32Value(-123456); + allTypes.setFixedInt32Value(-123456); + allTypes.setVarint32Value(-12345); + allTypes.setInt64Value(-123456789L); + allTypes.setFixedInt64Value(-123456789L); + allTypes.setVarint64Value(-987654321L); + allTypes.setTaggedInt64Value(123456789L); + allTypes.setUint8Value((byte) 200); + allTypes.setUint16Value((short) 60000); + allTypes.setUint32Value(1234567890); + allTypes.setFixedUint32Value(1234567890); + allTypes.setVarUint32Value(1234567890); + allTypes.setUint64Value(9876543210L); + allTypes.setFixedUint64Value(9876543210L); + allTypes.setVarUint64Value(12345678901L); + allTypes.setTaggedUint64Value(2222222222L); + allTypes.setFloat16Value(1.5f); + allTypes.setFloat32Value(2.5f); + allTypes.setFloat64Value(3.5); + allTypes.setStringValue("optional"); + allTypes.setBytesValue(new byte[] {1, 2, 3}); + allTypes.setDateValue(LocalDate.of(2024, 1, 2)); + allTypes.setTimestampValue(Instant.parse("2024-01-02T03:04:05Z")); + allTypes.setInt32List(new int[] {1, 2, 3}); + allTypes.setStringList(Arrays.asList("alpha", "beta")); + Map int64Map = new HashMap<>(); + int64Map.put("alpha", 10L); + int64Map.put("beta", 20L); + allTypes.setInt64Map(int64Map); + + OptionalHolder holder = new OptionalHolder(); + holder.setAllTypes(allTypes); + holder.setChoice(OptionalUnion.ofNote("optional")); + return holder; + } + private Container buildContainer() { ScalarPack pack = new ScalarPack(); pack.setB((byte) -8); diff --git a/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py b/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py index 33c1c82a0b..c57dbbe4e6 100644 --- a/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py +++ b/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py @@ -17,12 +17,14 @@ from __future__ import annotations +import datetime import os from pathlib import Path import addressbook import complex_fbs import monster +import optional_types import numpy as np import pyfory @@ -101,6 +103,42 @@ def build_primitive_types() -> "addressbook.PrimitiveTypes": ) +def build_optional_holder() -> "optional_types.OptionalHolder": + all_types = optional_types.AllOptionalTypes( + bool_value=True, + int8_value=pyfory.int8(12), + int16_value=pyfory.int16(1234), + int32_value=pyfory.int32(-123456), + fixed_int32_value=pyfory.fixed_int32(-123456), + varint32_value=pyfory.int32(-12345), + int64_value=pyfory.int64(-123456789), + fixed_int64_value=pyfory.fixed_int64(-123456789), + varint64_value=pyfory.int64(-987654321), + tagged_int64_value=pyfory.tagged_int64(123456789), + uint8_value=pyfory.uint8(200), + uint16_value=pyfory.uint16(60000), + uint32_value=pyfory.uint32(1234567890), + fixed_uint32_value=pyfory.fixed_uint32(1234567890), + var_uint32_value=pyfory.uint32(1234567890), + uint64_value=pyfory.uint64(9876543210), + fixed_uint64_value=pyfory.fixed_uint64(9876543210), + var_uint64_value=pyfory.uint64(12345678901), + tagged_uint64_value=pyfory.tagged_uint64(2222222222), + float16_value=pyfory.float32(1.5), + float32_value=pyfory.float32(2.5), + float64_value=pyfory.float64(3.5), + string_value="optional", + bytes_value=b"\x01\x02\x03", + date_value=datetime.date(2024, 1, 2), + timestamp_value=datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.timezone.utc), + int32_list=np.array([1, 2, 3], dtype=np.int32), + string_list=["alpha", "beta"], + int64_map={"alpha": pyfory.int64(10), "beta": pyfory.int64(20)}, + ) + union_value = optional_types.OptionalUnion.note("optional") + return optional_types.OptionalHolder(all_types=all_types, choice=union_value) + + def build_monster() -> "monster.Monster": pos = monster.Vec3(x=1.0, y=2.0, z=3.0) return monster.Monster( @@ -231,11 +269,92 @@ def file_roundtrip_primitives( Path(data_file).write_bytes(fory.serialize(decoded)) +def assert_optional_types_equal( + decoded: "optional_types.AllOptionalTypes", + expected: "optional_types.AllOptionalTypes", +) -> None: + assert decoded.bool_value == expected.bool_value + assert decoded.int8_value == expected.int8_value + assert decoded.int16_value == expected.int16_value + assert decoded.int32_value == expected.int32_value + assert decoded.fixed_int32_value == expected.fixed_int32_value + assert decoded.varint32_value == expected.varint32_value + assert decoded.int64_value == expected.int64_value + assert decoded.fixed_int64_value == expected.fixed_int64_value + assert decoded.varint64_value == expected.varint64_value + assert decoded.tagged_int64_value == expected.tagged_int64_value + assert decoded.uint8_value == expected.uint8_value + assert decoded.uint16_value == expected.uint16_value + assert decoded.uint32_value == expected.uint32_value + assert decoded.fixed_uint32_value == expected.fixed_uint32_value + assert decoded.var_uint32_value == expected.var_uint32_value + assert decoded.uint64_value == expected.uint64_value + assert decoded.fixed_uint64_value == expected.fixed_uint64_value + assert decoded.var_uint64_value == expected.var_uint64_value + assert decoded.tagged_uint64_value == expected.tagged_uint64_value + assert decoded.float16_value == expected.float16_value + assert decoded.float32_value == expected.float32_value + assert decoded.float64_value == expected.float64_value + assert decoded.string_value == expected.string_value + assert decoded.bytes_value == expected.bytes_value + assert decoded.date_value == expected.date_value + assert decoded.timestamp_value == expected.timestamp_value + if expected.int32_list is None: + assert decoded.int32_list is None + else: + np.testing.assert_array_equal(decoded.int32_list, expected.int32_list) + assert decoded.string_list == expected.string_list + assert decoded.int64_map == expected.int64_map + + +def assert_optional_holder_equal( + decoded: "optional_types.OptionalHolder", + expected: "optional_types.OptionalHolder", +) -> None: + assert decoded.all_types is not None + assert expected.all_types is not None + assert_optional_types_equal(decoded.all_types, expected.all_types) + assert decoded.choice is not None + assert expected.choice is not None + assert decoded.choice.case() == expected.choice.case() + if decoded.choice.is_payload(): + assert_optional_types_equal( + decoded.choice.payload_value(), expected.choice.payload_value() + ) + elif decoded.choice.is_note(): + assert decoded.choice.note_value() == expected.choice.note_value() + else: + assert decoded.choice.code_value() == expected.choice.code_value() + + +def local_roundtrip_optional_types( + fory: pyfory.Fory, holder: "optional_types.OptionalHolder" +) -> None: + data = fory.serialize(holder) + decoded = fory.deserialize(data) + assert isinstance(decoded, optional_types.OptionalHolder) + assert_optional_holder_equal(decoded, holder) + + +def file_roundtrip_optional_types( + fory: pyfory.Fory, holder: "optional_types.OptionalHolder" +) -> None: + data_file = os.environ.get("DATA_FILE_OPTIONAL_TYPES") + if not data_file: + return + payload = Path(data_file).read_bytes() + decoded = fory.deserialize(payload) + assert isinstance(decoded, optional_types.OptionalHolder) + assert_optional_holder_equal(decoded, holder) + Path(data_file).write_bytes(fory.serialize(decoded)) + + def main() -> int: fory = pyfory.Fory(xlang=True) addressbook.register_addressbook_types(fory) monster.register_monster_types(fory) complex_fbs.register_complex_fbs_types(fory) + optional_types.register_optional_types_types(fory) book = build_address_book() local_roundtrip(fory, book) @@ -252,6 +371,10 @@ def main() -> int: container = build_container() local_roundtrip_container(fory, container) file_roundtrip_container(fory, container) + + holder = build_optional_holder() + local_roundtrip_optional_types(fory, holder) + file_roundtrip_optional_types(fory, holder) return 0 diff --git a/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs b/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs index dfe3572bf6..e62178965b 100644 --- a/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs +++ b/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs @@ -18,6 +18,7 @@ use std::collections::HashMap; use std::{env, fs}; +use chrono::NaiveDate; use fory::Fory; use idl_tests::addressbook::{ self, @@ -26,6 +27,7 @@ use idl_tests::addressbook::{ }; use idl_tests::complex_fbs::{self, Container, Note, Payload, ScalarPack, Status}; use idl_tests::monster::{self, Color, Monster, Vec3}; +use idl_tests::optional_types::{self, AllOptionalTypes, OptionalHolder, OptionalUnion}; fn build_address_book() -> AddressBook { let mobile = PhoneNumber { @@ -139,12 +141,57 @@ fn build_container() -> Container { } } +fn build_optional_holder() -> OptionalHolder { + let all_types = AllOptionalTypes { + bool_value: Some(true), + int8_value: Some(12), + int16_value: Some(1234), + int32_value: Some(-123456), + fixed_int32_value: Some(-123456), + varint32_value: Some(-12345), + int64_value: Some(-123456789), + fixed_int64_value: Some(-123456789), + varint64_value: Some(-987654321), + tagged_int64_value: Some(123456789), + uint8_value: Some(200), + uint16_value: Some(60000), + uint32_value: Some(1234567890), + fixed_uint32_value: Some(1234567890), + var_uint32_value: Some(1234567890), + uint64_value: Some(9876543210), + fixed_uint64_value: Some(9876543210), + var_uint64_value: Some(12345678901), + tagged_uint64_value: Some(2222222222), + float16_value: Some(1.5), + float32_value: Some(2.5), + float64_value: Some(3.5), + string_value: Some("optional".to_string()), + bytes_value: Some(vec![1, 2, 3]), + date_value: Some(NaiveDate::from_ymd_opt(2024, 1, 2).unwrap()), + timestamp_value: Some( + NaiveDate::from_ymd_opt(2024, 1, 2) + .unwrap() + .and_hms_opt(3, 4, 5) + .expect("timestamp"), + ), + int32_list: Some(vec![1, 2, 3]), + string_list: Some(vec!["alpha".to_string(), "beta".to_string()]), + int64_map: Some(HashMap::from([("alpha".to_string(), 10), ("beta".to_string(), 20)])), + }; + + OptionalHolder { + all_types: Some(all_types.clone()), + choice: Some(OptionalUnion::Note("optional".to_string())), + } +} + #[test] fn test_address_book_roundtrip() { let mut fory = Fory::default().xlang(true); addressbook::register_types(&mut fory).expect("register types"); monster::register_types(&mut fory).expect("register monster types"); complex_fbs::register_types(&mut fory).expect("register flatbuffers types"); + optional_types::register_types(&mut fory).expect("register optional types"); let book = build_address_book(); let bytes = fory.serialize(&book).expect("serialize"); @@ -214,4 +261,21 @@ fn test_address_book_roundtrip() { .expect("serialize peer payload"); fs::write(data_file, encoded).expect("write data file"); } + + let holder = build_optional_holder(); + let bytes = fory.serialize(&holder).expect("serialize"); + let roundtrip: OptionalHolder = fory.deserialize(&bytes).expect("deserialize"); + assert_eq!(holder, roundtrip); + + if let Ok(data_file) = env::var("DATA_FILE_OPTIONAL_TYPES") { + let payload = fs::read(&data_file).expect("read data file"); + let peer_holder: OptionalHolder = fory + .deserialize(&payload) + .expect("deserialize peer payload"); + assert_eq!(holder, peer_holder); + let encoded = fory + .serialize(&peer_holder) + .expect("serialize peer payload"); + fs::write(data_file, encoded).expect("write data file"); + } } From dc3b5d036920e8feb099de6fffc609c1347629a2 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 14:02:51 +0800 Subject: [PATCH 02/21] feat(go): add optional field support --- .gitignore | 6 +- compiler/fory_compiler/generators/go.py | 6 +- cpp/fory/meta/field.h | 85 +++++++++++++++++++ cpp/fory/serialization/temporal_serializers.h | 22 ++--- cpp/fory/serialization/xlang_test_main.cc | 2 +- go/fory/field_info.go | 6 ++ go/fory/struct.go | 30 ++++++- go/fory/type_resolver.go | 5 ++ .../idl_tests/go/idl_roundtrip_test.go | 8 +- .../python/src/idl_tests/roundtrip.py | 46 +++++----- integration_tests/idl_tests/rust/Cargo.lock | 1 + integration_tests/idl_tests/rust/Cargo.toml | 1 + integration_tests/idl_tests/rust/src/lib.rs | 1 + python/pyfory/serializer.py | 6 ++ python/pyfory/struct.py | 3 + 15 files changed, 185 insertions(+), 43 deletions(-) diff --git a/.gitignore b/.gitignore index 35b82d6f12..e8ad03cffa 100644 --- a/.gitignore +++ b/.gitignore @@ -43,15 +43,19 @@ integration_tests/idl_tests/cpp/generated/ integration_tests/idl_tests/go/addressbook*.go integration_tests/idl_tests/go/complex_fbs/ integration_tests/idl_tests/go/monster/ +integration_tests/idl_tests/go/optional_types/ integration_tests/idl_tests/java/src/main/java/addressbook/ integration_tests/idl_tests/java/src/main/java/complex_fbs/ integration_tests/idl_tests/java/src/main/java/monster/ +integration_tests/idl_tests/java/src/main/java/optional_types/ integration_tests/idl_tests/python/src/addressbook.py integration_tests/idl_tests/python/src/complex_fbs.py integration_tests/idl_tests/python/src/monster.py +integration_tests/idl_tests/python/src/optional_types.py integration_tests/idl_tests/rust/src/addressbook.rs integration_tests/idl_tests/rust/src/complex_fbs.rs integration_tests/idl_tests/rust/src/monster.rs +integration_tests/idl_tests/rust/src/optional_types.rs javascript/**/dist/ javascript/**/node_modules/ javascript/**/build @@ -115,4 +119,4 @@ examples/cpp/cmake_example/build **/benchmark_*.png **/results/ benchmarks/**/report/ -ignored/** \ No newline at end of file +ignored/** diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index e377f08c33..3afe3e395a 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -182,7 +182,7 @@ def message_has_unions(self, message: Message) -> bool: PrimitiveKind.FLOAT64: "float64", PrimitiveKind.STRING: "string", PrimitiveKind.BYTES: "[]byte", - PrimitiveKind.DATE: "time.Time", + PrimitiveKind.DATE: "fory.Date", PrimitiveKind.TIMESTAMP: "time.Time", } @@ -685,7 +685,7 @@ def field_uses_option(self, field: Field) -> bool: return False if isinstance(field.field_type, PrimitiveType): base_type = self.PRIMITIVE_MAP[field.field_type.kind] - return base_type not in ("[]byte", "time.Time") + return base_type not in ("[]byte", "time.Time", "fory.Date") if isinstance(field.field_type, NamedType): named_type = self.schema.get_type(field.field_type.name) return isinstance(named_type, Enum) @@ -705,7 +705,7 @@ def generate_type( if isinstance(field_type, PrimitiveType): base_type = self.PRIMITIVE_MAP[field_type.kind] if nullable and base_type not in ("[]byte",): - if use_option and not ref and base_type != "time.Time": + if use_option and not ref and base_type not in ("time.Time", "fory.Date"): return f"optional.Optional[{base_type}]" return f"*{base_type}" return base_type diff --git a/cpp/fory/meta/field.h b/cpp/fory/meta/field.h index 3998ab186b..471f92e04d 100644 --- a/cpp/fory/meta/field.h +++ b/cpp/fory/meta/field.h @@ -1035,6 +1035,91 @@ struct GetFieldTagEntry; + std::chrono::microseconds>; /// LocalDate: naive date without timezone as days since Unix epoch struct LocalDate { @@ -131,7 +131,7 @@ template <> struct Serializer { // ============================================================================ /// Serializer for Timestamp -/// Per xlang spec: serialized as int64 nanosecond count since Unix epoch +/// Per xlang spec: serialized as int64 microsecond count since Unix epoch template <> struct Serializer { static constexpr TypeId type_id = TypeId::TIMESTAMP; @@ -161,8 +161,8 @@ template <> struct Serializer { } static inline void write_data(const Timestamp ×tamp, WriteContext &ctx) { - int64_t nanos = timestamp.time_since_epoch().count(); - ctx.write_bytes(&nanos, sizeof(int64_t)); + int64_t micros = timestamp.time_since_epoch().count(); + ctx.write_bytes(µs, sizeof(int64_t)); } static inline void write_data_generic(const Timestamp ×tamp, @@ -174,26 +174,26 @@ template <> struct Serializer { bool read_type) { bool has_value = read_null_only_flag(ctx, ref_mode); if (ctx.has_error() || !has_value) { - return Timestamp(Duration(0)); + return Timestamp(std::chrono::microseconds(0)); } if (read_type) { uint32_t type_id_read = ctx.read_varuint32(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return Timestamp(Duration(0)); + return Timestamp(std::chrono::microseconds(0)); } if (type_id_read != static_cast(type_id)) { ctx.set_error( Error::type_mismatch(type_id_read, static_cast(type_id))); - return Timestamp(Duration(0)); + return Timestamp(std::chrono::microseconds(0)); } } return read_data(ctx); } static inline Timestamp read_data(ReadContext &ctx) { - int64_t nanos; - ctx.read_bytes(&nanos, sizeof(int64_t), ctx.error()); - return Timestamp(Duration(nanos)); + int64_t micros; + ctx.read_bytes(µs, sizeof(int64_t), ctx.error()); + return Timestamp(std::chrono::microseconds(micros)); } static inline Timestamp read_with_type_info(ReadContext &ctx, diff --git a/cpp/fory/serialization/xlang_test_main.cc b/cpp/fory/serialization/xlang_test_main.cc index f13823268f..0bd8da85b5 100644 --- a/cpp/fory/serialization/xlang_test_main.cc +++ b/cpp/fory/serialization/xlang_test_main.cc @@ -1276,7 +1276,7 @@ void RunTestCrossLanguageSerializer(const std::string &data_file) { std::map str_map = {{"hello", "world"}, {"foo", "bar"}}; LocalDate day(18954); // 2021-11-23 - Timestamp instant(std::chrono::nanoseconds(100000000)); + Timestamp instant(std::chrono::seconds(100)); std::vector copy = bytes; Buffer buffer = MakeBuffer(copy); diff --git a/go/fory/field_info.go b/go/fory/field_info.go index 7bff20735f..bd841910ee 100644 --- a/go/fory/field_info.go +++ b/go/fory/field_info.go @@ -825,6 +825,12 @@ func typeIdFromKind(type_ reflect.Type) TypeId { if info, ok := getOptionalInfo(type_); ok { return typeIdFromKind(info.valueType) } + if type_ == dateType { + return LOCAL_DATE + } + if type_ == timestampType { + return TIMESTAMP + } switch type_.Kind() { case reflect.Bool: return BOOL diff --git a/go/fory/struct.go b/go/fory/struct.go index 3110a76717..ac945a0a54 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -908,7 +908,7 @@ func (s *structSerializer) computeHash() int32 { if field.Meta.IsOptional { fieldTypeForNullable = field.Meta.OptionalInfo.valueType } - if isNonNullablePrimitiveKind(fieldTypeForNullable.Kind()) && !isEnumField { + if !field.Meta.IsOptional && isNonNullablePrimitiveKind(fieldTypeForNullable.Kind()) && !isEnumField { nullable = false } @@ -1414,30 +1414,50 @@ func (s *structSerializer) writeRemainingField(ctx *WriteContext, ptr unsafe.Poi if field.RefMode == RefModeTracking { break } + if field.Meta.HasGenerics && field.Serializer != nil { + field.Serializer.Write(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, value.Field(field.Meta.FieldIndex)) + return + } ctx.WriteStringStringMap(*(*map[string]string)(fieldPtr), field.RefMode, false) return case StringInt64MapDispatchId: if field.RefMode == RefModeTracking { break } + if field.Meta.HasGenerics && field.Serializer != nil { + field.Serializer.Write(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, value.Field(field.Meta.FieldIndex)) + return + } ctx.WriteStringInt64Map(*(*map[string]int64)(fieldPtr), field.RefMode, false) return case StringInt32MapDispatchId: if field.RefMode == RefModeTracking { break } + if field.Meta.HasGenerics && field.Serializer != nil { + field.Serializer.Write(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, value.Field(field.Meta.FieldIndex)) + return + } ctx.WriteStringInt32Map(*(*map[string]int32)(fieldPtr), field.RefMode, false) return case StringIntMapDispatchId: if field.RefMode == RefModeTracking { break } + if field.Meta.HasGenerics && field.Serializer != nil { + field.Serializer.Write(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, value.Field(field.Meta.FieldIndex)) + return + } ctx.WriteStringIntMap(*(*map[string]int)(fieldPtr), field.RefMode, false) return case StringFloat64MapDispatchId: if field.RefMode == RefModeTracking { break } + if field.Meta.HasGenerics && field.Serializer != nil { + field.Serializer.Write(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, value.Field(field.Meta.FieldIndex)) + return + } ctx.WriteStringFloat64Map(*(*map[string]float64)(fieldPtr), field.RefMode, false) return case StringBoolMapDispatchId: @@ -1446,12 +1466,20 @@ func (s *structSerializer) writeRemainingField(ctx *WriteContext, ptr unsafe.Poi if field.RefMode == RefModeTracking { break } + if field.Meta.HasGenerics && field.Serializer != nil { + field.Serializer.Write(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, value.Field(field.Meta.FieldIndex)) + return + } ctx.WriteStringBoolMap(*(*map[string]bool)(fieldPtr), field.RefMode, false) return case IntIntMapDispatchId: if field.RefMode == RefModeTracking { break } + if field.Meta.HasGenerics && field.Serializer != nil { + field.Serializer.Write(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, value.Field(field.Meta.FieldIndex)) + return + } ctx.WriteIntIntMap(*(*map[int]int)(fieldPtr), field.RefMode, false) return case NullableTaggedInt64DispatchId: diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index 1643027e10..91b9264220 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -858,6 +858,11 @@ func (r *TypeResolver) getTypeIdByType(type_ reflect.Type) TypeId { if info, ok := r.typesInfo[type_]; ok { return TypeId(info.TypeID & 0xFF) // Extract base type ID } + if type_ != nil && type_.Kind() == reflect.Ptr { + if info, ok := r.typesInfo[type_.Elem()]; ok { + return TypeId(info.TypeID & 0xFF) + } + } return 0 } diff --git a/integration_tests/idl_tests/go/idl_roundtrip_test.go b/integration_tests/idl_tests/go/idl_roundtrip_test.go index 510000a57c..738a1d24ad 100644 --- a/integration_tests/idl_tests/go/idl_roundtrip_test.go +++ b/integration_tests/idl_tests/go/idl_roundtrip_test.go @@ -346,7 +346,7 @@ func runFileContainerRoundTrip(t *testing.T, f *fory.Fory, container complexfbs. } func buildOptionalHolder() optionaltypes.OptionalHolder { - dateValue := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC) + dateValue := fory.Date{Year: 2024, Month: time.January, Day: 2} timestampValue := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC) allTypes := &optionaltypes.AllOptionalTypes{ BoolValue: optional.Some(true), @@ -554,8 +554,10 @@ func assertOptionalTypesEqual(t *testing.T, expected, actual *optionaltypes.AllO if expected.DateValue != actual.DateValue { t.Fatalf("date_value mismatch: %#v != %#v", expected.DateValue, actual.DateValue) } - } else if !expected.DateValue.Equal(*actual.DateValue) { - t.Fatalf("date_value mismatch: %v != %v", expected.DateValue, actual.DateValue) + } else if expected.DateValue.Year != actual.DateValue.Year || + expected.DateValue.Month != actual.DateValue.Month || + expected.DateValue.Day != actual.DateValue.Day { + t.Fatalf("date_value mismatch: %#v != %#v", expected.DateValue, actual.DateValue) } if expected.TimestampValue == nil || actual.TimestampValue == nil { if expected.TimestampValue != actual.TimestampValue { diff --git a/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py b/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py index c57dbbe4e6..9b00f8419a 100644 --- a/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py +++ b/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py @@ -106,34 +106,34 @@ def build_primitive_types() -> "addressbook.PrimitiveTypes": def build_optional_holder() -> "optional_types.OptionalHolder": all_types = optional_types.AllOptionalTypes( bool_value=True, - int8_value=pyfory.int8(12), - int16_value=pyfory.int16(1234), - int32_value=pyfory.int32(-123456), - fixed_int32_value=pyfory.fixed_int32(-123456), - varint32_value=pyfory.int32(-12345), - int64_value=pyfory.int64(-123456789), - fixed_int64_value=pyfory.fixed_int64(-123456789), - varint64_value=pyfory.int64(-987654321), - tagged_int64_value=pyfory.tagged_int64(123456789), - uint8_value=pyfory.uint8(200), - uint16_value=pyfory.uint16(60000), - uint32_value=pyfory.uint32(1234567890), - fixed_uint32_value=pyfory.fixed_uint32(1234567890), - var_uint32_value=pyfory.uint32(1234567890), - uint64_value=pyfory.uint64(9876543210), - fixed_uint64_value=pyfory.fixed_uint64(9876543210), - var_uint64_value=pyfory.uint64(12345678901), - tagged_uint64_value=pyfory.tagged_uint64(2222222222), - float16_value=pyfory.float32(1.5), - float32_value=pyfory.float32(2.5), - float64_value=pyfory.float64(3.5), + int8_value=12, + int16_value=1234, + int32_value=-123456, + fixed_int32_value=-123456, + varint32_value=-12345, + int64_value=-123456789, + fixed_int64_value=-123456789, + varint64_value=-987654321, + tagged_int64_value=123456789, + uint8_value=200, + uint16_value=60000, + uint32_value=1234567890, + fixed_uint32_value=1234567890, + var_uint32_value=1234567890, + uint64_value=9876543210, + fixed_uint64_value=9876543210, + var_uint64_value=12345678901, + tagged_uint64_value=2222222222, + float16_value=1.5, + float32_value=2.5, + float64_value=3.5, string_value="optional", bytes_value=b"\x01\x02\x03", date_value=datetime.date(2024, 1, 2), - timestamp_value=datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.timezone.utc), + timestamp_value=datetime.datetime.fromtimestamp(1704164645), int32_list=np.array([1, 2, 3], dtype=np.int32), string_list=["alpha", "beta"], - int64_map={"alpha": pyfory.int64(10), "beta": pyfory.int64(20)}, + int64_map={"alpha": 10, "beta": 20}, ) union_value = optional_types.OptionalUnion.note("optional") return optional_types.OptionalHolder(all_types=all_types, choice=union_value) diff --git a/integration_tests/idl_tests/rust/Cargo.lock b/integration_tests/idl_tests/rust/Cargo.lock index e3e33b11be..e569aad634 100644 --- a/integration_tests/idl_tests/rust/Cargo.lock +++ b/integration_tests/idl_tests/rust/Cargo.lock @@ -143,6 +143,7 @@ dependencies = [ name = "idl_tests" version = "0.1.0" dependencies = [ + "chrono", "fory", "fory-core", ] diff --git a/integration_tests/idl_tests/rust/Cargo.toml b/integration_tests/idl_tests/rust/Cargo.toml index 3699505780..2c476ccaa7 100644 --- a/integration_tests/idl_tests/rust/Cargo.toml +++ b/integration_tests/idl_tests/rust/Cargo.toml @@ -22,5 +22,6 @@ edition = "2021" license = "Apache-2.0" [dependencies] +chrono = "0.4" fory = { path = "../../../rust/fory" } fory-core = { path = "../../../rust/fory-core" } diff --git a/integration_tests/idl_tests/rust/src/lib.rs b/integration_tests/idl_tests/rust/src/lib.rs index b993e8f0eb..b1e1e8883c 100644 --- a/integration_tests/idl_tests/rust/src/lib.rs +++ b/integration_tests/idl_tests/rust/src/lib.rs @@ -18,3 +18,4 @@ pub mod addressbook; pub mod complex_fbs; pub mod monster; +pub mod optional_types; diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index c5e3457491..b3d3ef53de 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -560,6 +560,12 @@ def read(self, buffer): return fory_buf return fory_buf.to_pybytes() + def xwrite(self, buffer, value): + buffer.write_bytes_and_size(value) + + def xread(self, buffer): + return buffer.read_bytes_and_size() + class BytesBufferObject(BufferObject): __slots__ = ("binary",) diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py index 0dfcb6177d..7d3d1a4f19 100644 --- a/python/pyfory/struct.py +++ b/python/pyfory/struct.py @@ -1330,6 +1330,9 @@ def compute_struct_fingerprint(type_resolver, field_names, serializers, nullable nullable_flag = "1" if nullable_map.get(field_name, False) else "0" else: type_id = type_resolver.get_typeinfo(serializer.type_).type_id & 0xFF + # For xlang, user-defined types use UNKNOWN in fingerprint to match other languages. + if not type_resolver.fory.is_py and type_id >= TypeId.BOUND: + type_id = TypeId.UNKNOWN if type_id in {TypeId.TYPED_UNION, TypeId.NAMED_UNION}: type_id = TypeId.UNION is_nullable = nullable_map.get(field_name, False) From 21e68157c64d5e384595c5136e096a7a8b827b13 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 14:42:31 +0800 Subject: [PATCH 03/21] fix(python): restore bytes xlang out-of-band path --- compiler/fory_compiler/generators/go.py | 6 +++++- javascript/packages/hps/src/fastcall.cc | 5 +++-- python/pyfory/serializer.py | 6 ------ 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index 3afe3e395a..def3658499 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -705,7 +705,11 @@ def generate_type( if isinstance(field_type, PrimitiveType): base_type = self.PRIMITIVE_MAP[field_type.kind] if nullable and base_type not in ("[]byte",): - if use_option and not ref and base_type not in ("time.Time", "fory.Date"): + if ( + use_option + and not ref + and base_type not in ("time.Time", "fory.Date") + ): return f"optional.Optional[{base_type}]" return f"*{base_type}" return base_type diff --git a/javascript/packages/hps/src/fastcall.cc b/javascript/packages/hps/src/fastcall.cc index 66ffd2bddc..94e8c0ad78 100644 --- a/javascript/packages/hps/src/fastcall.cc +++ b/javascript/packages/hps/src/fastcall.cc @@ -110,8 +110,9 @@ static void serializeString(const v8::FunctionCallbackInfo &args) { int flags = String::HINT_MANY_WRITES_EXPECTED | String::NO_NULL_TERMINATION | String::REPLACE_INVALID_UTF8; if (is_one_byte) { - offset += writeVarUint32(dst_data, offset, - (str->Length() << 2) | Encoding::LATIN1); // length + offset += + writeVarUint32(dst_data, offset, + (str->Length() << 2) | Encoding::LATIN1); // length offset += str->WriteOneByte(isolate, dst_data + offset, 0, str->Length(), flags); } else { diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index b3d3ef53de..c5e3457491 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -560,12 +560,6 @@ def read(self, buffer): return fory_buf return fory_buf.to_pybytes() - def xwrite(self, buffer, value): - buffer.write_bytes_and_size(value) - - def xread(self, buffer): - return buffer.read_bytes_and_size() - class BytesBufferObject(BufferObject): __slots__ = ("binary",) From e7466f0647e6ca89201b3e754f52c1ea05d46f4a Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 15:29:21 +0800 Subject: [PATCH 04/21] update DATE/TIMESTAMP --- compiler/README.md | 2 +- compiler/fory_compiler/generators/cpp.py | 2 +- compiler/fory_compiler/generators/go.py | 2 +- compiler/fory_compiler/generators/java.py | 2 +- cpp/README.md | 2 +- cpp/fory/meta/field.h | 802 +++++++++++++++++- cpp/fory/meta/preprocessor.h | 20 +- cpp/fory/row/schema.cc | 2 +- cpp/fory/row/schema.h | 2 +- cpp/fory/serialization/skip.cc | 22 +- cpp/fory/serialization/struct_serializer.h | 2 +- cpp/fory/serialization/temporal_serializers.h | 81 +- cpp/fory/serialization/type_resolver.cc | 4 +- cpp/fory/serialization/xlang_test_main.cc | 8 +- cpp/fory/type/type.h | 2 +- dart/example/example.g.dart | 2 +- dart/packages/fory/example/example.g.dart | 2 +- .../fory/lib/src/const/dart_type.dart | 2 +- .../packages/fory/lib/src/const/obj_type.dart | 4 +- .../src/serializer/time/date_serializer.dart | 2 +- .../serializer/time/timestamp_serializer.dart | 16 +- docs/compiler/type-system.md | 14 +- docs/guide/cpp/cross-language.md | 2 +- docs/guide/cpp/supported-types.md | 6 +- docs/guide/cpp/type-registration.md | 2 +- .../specification/xlang_serialization_spec.md | 29 +- docs/specification/xlang_type_mapping.md | 2 +- go/fory/codegen/decoder.go | 5 +- go/fory/codegen/encoder.go | 2 +- go/fory/codegen/utils.go | 6 +- go/fory/field_info.go | 2 +- go/fory/skip.go | 5 +- go/fory/time.go | 10 +- go/fory/type_resolver.go | 2 +- go/fory/types.go | 8 +- go/fory/util.go | 18 +- go/fory/util_test.go | 9 +- integration_tests/idl_tests/cpp/main.cc | 2 +- .../apache/fory/resolver/XtypeResolver.java | 2 +- .../fory/serializer/TimeSerializers.java | 16 +- .../main/java/org/apache/fory/type/Types.java | 2 +- .../apache/fory/format/type/DataTypes.java | 2 +- .../fory/format/type/SchemaEncoder.java | 2 +- javascript/packages/fory/lib/gen/datetime.ts | 32 +- javascript/packages/fory/lib/type.ts | 2 +- python/pyfory/_serializer.py | 14 +- python/pyfory/format/encoder.pxi | 2 +- python/pyfory/format/infer.py | 2 +- python/pyfory/format/row.pxi | 2 +- python/pyfory/format/schema.pxi | 2 +- python/pyfory/format/schema.py | 4 +- python/pyfory/format/tests/test_infer.py | 2 +- python/pyfory/includes/libformat.pxd | 2 +- python/pyfory/includes/libserialization.pxd | 2 +- python/pyfory/primitive.pxi | 25 +- python/pyfory/registry.py | 2 +- python/pyfory/tests/test_serializer.py | 2 +- python/pyfory/types.py | 2 +- rust/fory-core/src/resolver/type_resolver.rs | 2 +- rust/fory-core/src/serializer/datetime.rs | 24 +- rust/fory-core/src/serializer/skip.rs | 4 +- rust/fory-core/src/types.rs | 8 +- rust/fory-derive/src/object/util.rs | 4 +- 63 files changed, 1087 insertions(+), 183 deletions(-) diff --git a/compiler/README.md b/compiler/README.md index 452bce9132..a6f9959e84 100644 --- a/compiler/README.md +++ b/compiler/README.md @@ -196,7 +196,7 @@ message Config { ... } // Registered as "package.Config" | `float64` | `double` | `pyfory.float64` | `float64` | `f64` | `double` | | `string` | `String` | `str` | `string` | `String` | `std::string` | | `bytes` | `byte[]` | `bytes` | `[]byte` | `Vec` | `std::vector` | -| `date` | `LocalDate` | `datetime.date` | `time.Time` | `chrono::NaiveDate` | `fory::LocalDate` | +| `date` | `LocalDate` | `datetime.date` | `time.Time` | `chrono::NaiveDate` | `fory::Date` | | `timestamp` | `Instant` | `datetime.datetime` | `time.Time` | `chrono::NaiveDateTime` | `fory::Timestamp` | ### Collection Types diff --git a/compiler/fory_compiler/generators/cpp.py b/compiler/fory_compiler/generators/cpp.py index ffe3b9bf3d..c69a9647cb 100644 --- a/compiler/fory_compiler/generators/cpp.py +++ b/compiler/fory_compiler/generators/cpp.py @@ -63,7 +63,7 @@ class CppGenerator(BaseGenerator): PrimitiveKind.FLOAT64: "double", PrimitiveKind.STRING: "std::string", PrimitiveKind.BYTES: "std::vector", - PrimitiveKind.DATE: "fory::serialization::LocalDate", + PrimitiveKind.DATE: "fory::serialization::Date", PrimitiveKind.TIMESTAMP: "fory::serialization::Timestamp", } diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index def3658499..d0974c73cb 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -460,7 +460,7 @@ def get_union_case_type_id_expr( PrimitiveKind.FLOAT64: "fory.FLOAT64", PrimitiveKind.STRING: "fory.STRING", PrimitiveKind.BYTES: "fory.BINARY", - PrimitiveKind.DATE: "fory.LOCAL_DATE", + PrimitiveKind.DATE: "fory.DATE", PrimitiveKind.TIMESTAMP: "fory.TIMESTAMP", } return primitive_type_ids.get(kind, "fory.UNKNOWN") diff --git a/compiler/fory_compiler/generators/java.py b/compiler/fory_compiler/generators/java.py index beec788ba7..b005da4d23 100644 --- a/compiler/fory_compiler/generators/java.py +++ b/compiler/fory_compiler/generators/java.py @@ -645,7 +645,7 @@ def get_union_case_type_id_expr( PrimitiveKind.FLOAT64: "Types.FLOAT64", PrimitiveKind.STRING: "Types.STRING", PrimitiveKind.BYTES: "Types.BINARY", - PrimitiveKind.DATE: "Types.LOCAL_DATE", + PrimitiveKind.DATE: "Types.DATE", PrimitiveKind.TIMESTAMP: "Types.TIMESTAMP", } return primitive_type_ids.get(kind, "Types.UNKNOWN") diff --git a/cpp/README.md b/cpp/README.md index 60a084398f..ba0ae70e84 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -333,7 +333,7 @@ cpp/fory/ │ ├── collection_serializer.h # vector, set serializers │ ├── map_serializer.h # map serializers │ ├── smart_ptr_serializers.h # optional, shared_ptr, unique_ptr -│ ├── temporal_serializers.h # Duration, Timestamp, LocalDate +│ ├── temporal_serializers.h # Duration, Timestamp, Date │ ├── variant_serializer.h # std::variant support │ ├── type_resolver.h # Type resolution and registration │ └── context.h # Read/Write context diff --git a/cpp/fory/meta/field.h b/cpp/fory/meta/field.h index 471f92e04d..78b445dccd 100644 --- a/cpp/fory/meta/field.h +++ b/cpp/fory/meta/field.h @@ -880,7 +880,7 @@ struct GetFieldTagEntry ReadType(Buffer &buffer) { return duration(); case TypeId::TIMESTAMP: return timestamp(); - case TypeId::LOCAL_DATE: + case TypeId::DATE: return date32(); case TypeId::DECIMAL: { uint8_t precision = buffer.ReadUint8(error); diff --git a/cpp/fory/row/schema.h b/cpp/fory/row/schema.h index 05e72a8d0d..2be3d33a2c 100644 --- a/cpp/fory/row/schema.h +++ b/cpp/fory/row/schema.h @@ -200,7 +200,7 @@ class TimestampType : public FixedWidthType { /// Date stored as 32-bit integer (days since epoch). class LocalDateType : public FixedWidthType { public: - LocalDateType() : FixedWidthType(TypeId::LOCAL_DATE, 32) {} + LocalDateType() : FixedWidthType(TypeId::DATE, 32) {} std::string name() const override { return "date32"; } }; diff --git a/cpp/fory/serialization/skip.cc b/cpp/fory/serialization/skip.cc index 0af4913ec3..d58b170909 100644 --- a/cpp/fory/serialization/skip.cc +++ b/cpp/fory/serialization/skip.cc @@ -520,10 +520,8 @@ void skip_field_value(ReadContext &ctx, const FieldType &field_type, skip_map(ctx, field_type); return; - case TypeId::DURATION: - case TypeId::TIMESTAMP: { - // Duration/Timestamp are stored as fixed 8-byte - // nanosecond counts. + case TypeId::DURATION: { + // Duration is stored as fixed 8-byte nanosecond count. constexpr uint32_t kBytes = static_cast(sizeof(int64_t)); if (ctx.buffer().reader_index() + kBytes > ctx.buffer().size()) { ctx.set_error(Error::buffer_out_of_bound(ctx.buffer().reader_index(), @@ -533,9 +531,21 @@ void skip_field_value(ReadContext &ctx, const FieldType &field_type, ctx.buffer().IncreaseReaderIndex(kBytes); return; } + case TypeId::TIMESTAMP: { + // Timestamp is stored as int64 seconds + uint32 nanoseconds. + constexpr uint32_t kBytes = + static_cast(sizeof(int64_t) + sizeof(uint32_t)); + if (ctx.buffer().reader_index() + kBytes > ctx.buffer().size()) { + ctx.set_error(Error::buffer_out_of_bound(ctx.buffer().reader_index(), + kBytes, ctx.buffer().size())); + return; + } + ctx.buffer().IncreaseReaderIndex(kBytes); + return; + } - case TypeId::LOCAL_DATE: { - // LocalDate is stored as fixed 4-byte day count. + case TypeId::DATE: { + // Date is stored as fixed 4-byte day count. constexpr uint32_t kBytes = static_cast(sizeof(int32_t)); if (ctx.buffer().reader_index() + kBytes > ctx.buffer().size()) { ctx.set_error(Error::buffer_out_of_bound(ctx.buffer().reader_index(), diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index 9276542c05..31fe1c0d39 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -1044,7 +1044,7 @@ template struct CompileTimeFieldHelpers { } /// Check if a type ID is an internal (built-in, final) type for group 2. - /// Internal types are STRING, DURATION, TIMESTAMP, LOCAL_DATE, DECIMAL, + /// Internal types are STRING, DURATION, TIMESTAMP, DATE, DECIMAL, /// BINARY, ARRAY, and primitive arrays. Java xlang DescriptorGrouper excludes /// enums from finals (line 897 in XtypeResolver). Excludes: ENUM (13-14), /// STRUCT (15-18), EXT (19-20), LIST (21), SET (22), MAP (23) diff --git a/cpp/fory/serialization/temporal_serializers.h b/cpp/fory/serialization/temporal_serializers.h index b279d77c79..34874be069 100644 --- a/cpp/fory/serialization/temporal_serializers.h +++ b/cpp/fory/serialization/temporal_serializers.h @@ -32,22 +32,22 @@ namespace serialization { /// Duration: absolute length of time as nanoseconds using Duration = std::chrono::nanoseconds; -/// Timestamp: point in time as microseconds since Unix epoch (Jan 1, 1970 UTC) +/// Timestamp: point in time as nanoseconds since Unix epoch (Jan 1, 1970 UTC) using Timestamp = std::chrono::time_point; + std::chrono::nanoseconds>; -/// LocalDate: naive date without timezone as days since Unix epoch -struct LocalDate { +/// Date: naive date without timezone as days since Unix epoch +struct Date { int32_t days_since_epoch; // Days since Jan 1, 1970 UTC - LocalDate() : days_since_epoch(0) {} - explicit LocalDate(int32_t days) : days_since_epoch(days) {} + Date() : days_since_epoch(0) {} + explicit Date(int32_t days) : days_since_epoch(days) {} - bool operator==(const LocalDate &other) const { + bool operator==(const Date &other) const { return days_since_epoch == other.days_since_epoch; } - bool operator!=(const LocalDate &other) const { return !(*this == other); } + bool operator!=(const Date &other) const { return !(*this == other); } }; // ============================================================================ @@ -131,7 +131,7 @@ template <> struct Serializer { // ============================================================================ /// Serializer for Timestamp -/// Per xlang spec: serialized as int64 microsecond count since Unix epoch +/// Per xlang spec: serialized as int64 seconds + uint32 nanoseconds since Unix epoch template <> struct Serializer { static constexpr TypeId type_id = TypeId::TIMESTAMP; @@ -161,8 +161,17 @@ template <> struct Serializer { } static inline void write_data(const Timestamp ×tamp, WriteContext &ctx) { - int64_t micros = timestamp.time_since_epoch().count(); - ctx.write_bytes(µs, sizeof(int64_t)); + auto nanos = timestamp.time_since_epoch(); + auto seconds = std::chrono::duration_cast(nanos); + auto remainder = nanos - seconds; + if (remainder.count() < 0) { + seconds -= std::chrono::seconds(1); + remainder += std::chrono::seconds(1); + } + int64_t seconds_count = seconds.count(); + uint32_t nanos_count = static_cast(remainder.count()); + ctx.write_bytes(&seconds_count, sizeof(int64_t)); + ctx.write_bytes(&nanos_count, sizeof(uint32_t)); } static inline void write_data_generic(const Timestamp ×tamp, @@ -174,26 +183,32 @@ template <> struct Serializer { bool read_type) { bool has_value = read_null_only_flag(ctx, ref_mode); if (ctx.has_error() || !has_value) { - return Timestamp(std::chrono::microseconds(0)); + return Timestamp(std::chrono::nanoseconds(0)); } if (read_type) { uint32_t type_id_read = ctx.read_varuint32(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return Timestamp(std::chrono::microseconds(0)); + return Timestamp(std::chrono::nanoseconds(0)); } if (type_id_read != static_cast(type_id)) { ctx.set_error( Error::type_mismatch(type_id_read, static_cast(type_id))); - return Timestamp(std::chrono::microseconds(0)); + return Timestamp(std::chrono::nanoseconds(0)); } } return read_data(ctx); } static inline Timestamp read_data(ReadContext &ctx) { - int64_t micros; - ctx.read_bytes(µs, sizeof(int64_t), ctx.error()); - return Timestamp(std::chrono::microseconds(micros)); + int64_t seconds; + uint32_t nanos; + ctx.read_bytes(&seconds, sizeof(int64_t), ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return Timestamp(std::chrono::nanoseconds(0)); + } + ctx.read_bytes(&nanos, sizeof(uint32_t), ctx.error()); + return Timestamp(std::chrono::seconds(seconds) + + std::chrono::nanoseconds(nanos)); } static inline Timestamp read_with_type_info(ReadContext &ctx, @@ -204,13 +219,13 @@ template <> struct Serializer { }; // ============================================================================ -// LocalDate Serializer +// Date Serializer // ============================================================================ -/// Serializer for LocalDate +/// Serializer for Date /// Per xlang spec: serialized as int32 day count since Unix epoch -template <> struct Serializer { - static constexpr TypeId type_id = TypeId::LOCAL_DATE; +template <> struct Serializer { + static constexpr TypeId type_id = TypeId::DATE; static inline void write_type_info(WriteContext &ctx) { ctx.write_varuint32(static_cast(type_id)); @@ -227,7 +242,7 @@ template <> struct Serializer { } } - static inline void write(const LocalDate &date, WriteContext &ctx, + static inline void write(const Date &date, WriteContext &ctx, RefMode ref_mode, bool write_type, bool has_generics = false) { write_not_null_ref_flag(ctx, ref_mode); @@ -237,44 +252,42 @@ template <> struct Serializer { write_data(date, ctx); } - static inline void write_data(const LocalDate &date, WriteContext &ctx) { + static inline void write_data(const Date &date, WriteContext &ctx) { ctx.write_bytes(&date.days_since_epoch, sizeof(int32_t)); } - static inline void write_data_generic(const LocalDate &date, + static inline void write_data_generic(const Date &date, WriteContext &ctx, bool has_generics) { write_data(date, ctx); } - static inline LocalDate read(ReadContext &ctx, RefMode ref_mode, - bool read_type) { + static inline Date read(ReadContext &ctx, RefMode ref_mode, bool read_type) { bool has_value = read_null_only_flag(ctx, ref_mode); if (ctx.has_error() || !has_value) { - return LocalDate(); + return Date(); } if (read_type) { uint32_t type_id_read = ctx.read_varuint32(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return LocalDate(); + return Date(); } if (type_id_read != static_cast(type_id)) { ctx.set_error( Error::type_mismatch(type_id_read, static_cast(type_id))); - return LocalDate(); + return Date(); } } return read_data(ctx); } - static inline LocalDate read_data(ReadContext &ctx) { - LocalDate date; + static inline Date read_data(ReadContext &ctx) { + Date date; ctx.read_bytes(&date.days_since_epoch, sizeof(int32_t), ctx.error()); return date; } - static inline LocalDate read_with_type_info(ReadContext &ctx, - RefMode ref_mode, - const TypeInfo &type_info) { + static inline Date read_with_type_info(ReadContext &ctx, RefMode ref_mode, + const TypeInfo &type_info) { return read(ctx, ref_mode, false); } }; diff --git a/cpp/fory/serialization/type_resolver.cc b/cpp/fory/serialization/type_resolver.cc index e5873eb4d6..606f170a88 100644 --- a/cpp/fory/serialization/type_resolver.cc +++ b/cpp/fory/serialization/type_resolver.cc @@ -719,7 +719,7 @@ bool name_sorter(const FieldInfo &a, const FieldInfo &b) { } // Check if a type ID is a "final" type for field group 2 in field ordering. -// Final types are STRING, DURATION, TIMESTAMP, LOCAL_DATE, DECIMAL, BINARY, +// Final types are STRING, DURATION, TIMESTAMP, DATE, DECIMAL, BINARY, // ARRAY, and primitive arrays. // These are types with fixed serializers that don't need type info written. // Excludes: ENUM (13-14), STRUCT (15-18), EXT (19-20), LIST (21), SET (22), MAP @@ -1371,7 +1371,7 @@ void TypeResolver::register_builtin_types() { register_type_id_only(TypeId::NONE); register_type_id_only(TypeId::DURATION); register_type_id_only(TypeId::TIMESTAMP); - register_type_id_only(TypeId::LOCAL_DATE); + register_type_id_only(TypeId::DATE); register_type_id_only(TypeId::DECIMAL); register_type_id_only(TypeId::ARRAY); } diff --git a/cpp/fory/serialization/xlang_test_main.cc b/cpp/fory/serialization/xlang_test_main.cc index 0bd8da85b5..d4fe9a12d3 100644 --- a/cpp/fory/serialization/xlang_test_main.cc +++ b/cpp/fory/serialization/xlang_test_main.cc @@ -45,7 +45,7 @@ using ::fory::Error; using ::fory::Result; using ::fory::serialization::Fory; using ::fory::serialization::ForyBuilder; -using ::fory::serialization::LocalDate; +using ::fory::serialization::Date; using ::fory::serialization::Serializer; using ::fory::serialization::Timestamp; @@ -1275,7 +1275,7 @@ void RunTestCrossLanguageSerializer(const std::string &data_file) { std::set str_set = {"hello", "world"}; std::map str_map = {{"hello", "world"}, {"foo", "bar"}}; - LocalDate day(18954); // 2021-11-23 + Date day(18954); // 2021-11-23 Timestamp instant(std::chrono::seconds(100)); std::vector copy = bytes; @@ -1332,8 +1332,8 @@ void RunTestCrossLanguageSerializer(const std::string &data_file) { if (ReadNext(fory, buffer) != "str") { Fail("String mismatch"); } - if (ReadNext(fory, buffer) != day) { - Fail("LocalDate mismatch"); + if (ReadNext(fory, buffer) != day) { + Fail("Date mismatch"); } if (ReadNext(fory, buffer) != instant) { Fail("Timestamp mismatch"); diff --git a/cpp/fory/type/type.h b/cpp/fory/type/type.h index 4195442dec..61edd861bb 100644 --- a/cpp/fory/type/type.h +++ b/cpp/fory/type/type.h @@ -101,7 +101,7 @@ enum class TypeId : int32_t { TIMESTAMP = 36, // a naive date without timezone. The count is days relative to an // epoch at UTC midnight on Jan 1, 1970. - LOCAL_DATE = 37, + DATE = 37, // exact decimal value represented as an integer value in two's // complement. DECIMAL = 38, diff --git a/dart/example/example.g.dart b/dart/example/example.g.dart index 7374096ced..41db39aa6d 100644 --- a/dart/example/example.g.dart +++ b/dart/example/example.g.dart @@ -49,7 +49,7 @@ final $Person = ClassSpec( 'dateOfBirth', TypeSpec( LocalDate, - ObjType.LOCAL_DATE, + ObjType.DATE, false, true, null, diff --git a/dart/packages/fory/example/example.g.dart b/dart/packages/fory/example/example.g.dart index 7374096ced..41db39aa6d 100644 --- a/dart/packages/fory/example/example.g.dart +++ b/dart/packages/fory/example/example.g.dart @@ -49,7 +49,7 @@ final $Person = ClassSpec( 'dateOfBirth', TypeSpec( LocalDate, - ObjType.LOCAL_DATE, + ObjType.DATE, false, true, null, diff --git a/dart/packages/fory/lib/src/const/dart_type.dart b/dart/packages/fory/lib/src/const/dart_type.dart index c52189979c..e2fdf36fd5 100644 --- a/dart/packages/fory/lib/src/const/dart_type.dart +++ b/dart/packages/fory/lib/src/const/dart_type.dart @@ -52,7 +52,7 @@ enum DartTypeEnum{ DOUBLE(double,true, 'double', 'dart', 'core', ObjType.FLOAT64, true, 'dart:core@double'), STRING(String,true, 'String', 'dart', 'core', ObjType.STRING, true, 'dart:core@String'), - LOCALDATE(LocalDate, true, 'LocalDate', 'package', 'fory/src/datatype/local_date.dart', ObjType.LOCAL_DATE, true, 'dart:core@LocalDate'), + LOCALDATE(LocalDate, true, 'LocalDate', 'package', 'fory/src/datatype/local_date.dart', ObjType.DATE, true, 'dart:core@LocalDate'), TIMESTAMP(TimeStamp, false, 'TimeStamp', 'package', 'fory/src/datatype/timestamp.dart', ObjType.TIMESTAMP, true, 'dart:core@DateTime'), BOOLLIST(BoolList, true, 'BoolList', 'package', 'collection/src/boollist.dart', ObjType.BOOL_ARRAY, true, 'dart:typed_data@BoolList'), diff --git a/dart/packages/fory/lib/src/const/obj_type.dart b/dart/packages/fory/lib/src/const/obj_type.dart index a7fb48df68..42e7f4d9b5 100644 --- a/dart/packages/fory/lib/src/const/obj_type.dart +++ b/dart/packages/fory/lib/src/const/obj_type.dart @@ -126,7 +126,7 @@ enum ObjType { // TODO: here time /// A naive date without timezone. The count is days relative to an epoch at UTC midnight on Jan 1, /// 1970. - LOCAL_DATE(26, true), // 26 + DATE(26, true), // 26 /// Exact decimal value represented as an integer value in two's complement. DECIMAL(27, true), // 27 @@ -213,7 +213,7 @@ enum ObjType { bool isTimeType() { return this == TIMESTAMP - || this == LOCAL_DATE + || this == DATE || this == DURATION; } } diff --git a/dart/packages/fory/lib/src/serializer/time/date_serializer.dart b/dart/packages/fory/lib/src/serializer/time/date_serializer.dart index ba75e4c4ce..ea1a48bc6d 100644 --- a/dart/packages/fory/lib/src/serializer/time/date_serializer.dart +++ b/dart/packages/fory/lib/src/serializer/time/date_serializer.dart @@ -51,7 +51,7 @@ final class DateSerializer extends Serializer { static const SerializerCache cache = _DateSerializerCache(); - DateSerializer._(bool writeRef) : super(ObjType.LOCAL_DATE, writeRef); + DateSerializer._(bool writeRef) : super(ObjType.DATE, writeRef); @override LocalDate read(ByteReader br, int refId, DeserializerPack pack) { diff --git a/dart/packages/fory/lib/src/serializer/time/timestamp_serializer.dart b/dart/packages/fory/lib/src/serializer/time/timestamp_serializer.dart index 6e5c4eef3a..c80b90161d 100644 --- a/dart/packages/fory/lib/src/serializer/time/timestamp_serializer.dart +++ b/dart/packages/fory/lib/src/serializer/time/timestamp_serializer.dart @@ -53,13 +53,23 @@ final class TimestampSerializer extends Serializer { @override TimeStamp read(ByteReader br, int refId, DeserializerPack pack) { - int microseconds = br.readInt64(); + final int seconds = br.readInt64(); + final int nanos = br.readUint32(); + final int microseconds = seconds * 1000000 + (nanos ~/ 1000); // attention: UTC return TimeStamp(microseconds); } @override void write(ByteWriter bw, TimeStamp v, SerializerPack pack) { - bw.writeInt64(v.microsecondsSinceEpoch); + int seconds = v.microsecondsSinceEpoch ~/ 1000000; + int microsRem = v.microsecondsSinceEpoch % 1000000; + if (microsRem < 0) { + seconds -= 1; + microsRem += 1000000; + } + final int nanos = microsRem * 1000; + bw.writeInt64(seconds); + bw.writeUint32(nanos); } -} \ No newline at end of file +} diff --git a/docs/compiler/type-system.md b/docs/compiler/type-system.md index 141c8d6604..8a4ea01753 100644 --- a/docs/compiler/type-system.md +++ b/docs/compiler/type-system.md @@ -159,13 +159,13 @@ Calendar date without time: date birth_date = 1; ``` -| Language | Type | Notes | -| -------- | -------------------------------- | ----------------------- | -| Java | `java.time.LocalDate` | | -| Python | `datetime.date` | | -| Go | `time.Time` | Time portion ignored | -| Rust | `chrono::NaiveDate` | Requires `chrono` crate | -| C++ | `fory::serialization::LocalDate` | | +| Language | Type | Notes | +| -------- | --------------------------- | ----------------------- | +| Java | `java.time.LocalDate` | | +| Python | `datetime.date` | | +| Go | `time.Time` | Time portion ignored | +| Rust | `chrono::NaiveDate` | Requires `chrono` crate | +| C++ | `fory::serialization::Date` | | #### Timestamp diff --git a/docs/guide/cpp/cross-language.md b/docs/guide/cpp/cross-language.md index ce192de8bb..e443b2ff50 100644 --- a/docs/guide/cpp/cross-language.md +++ b/docs/guide/cpp/cross-language.md @@ -167,7 +167,7 @@ print(f"Timestamp: {msg.timestamp}") | ----------- | ----------- | --------------- | --------------- | | `Timestamp` | `Instant` | `datetime` | `time.Time` | | `Duration` | `Duration` | `timedelta` | `time.Duration` | -| `LocalDate` | `LocalDate` | `datetime.date` | `time.Time` | +| `Date` | `LocalDate` | `datetime.date` | `time.Time` | ## Field Order Requirements diff --git a/docs/guide/cpp/supported-types.md b/docs/guide/cpp/supported-types.md index 889540c343..3963852f2c 100644 --- a/docs/guide/cpp/supported-types.md +++ b/docs/guide/cpp/supported-types.md @@ -211,15 +211,15 @@ auto bytes = fory.serialize(now).value(); auto decoded = fory.deserialize(bytes).value(); ``` -### LocalDate +### Date Days since Unix epoch: ```cpp -LocalDate date{18628}; // Days since 1970-01-01 +Date date{18628}; // Days since 1970-01-01 auto bytes = fory.serialize(date).value(); -auto decoded = fory.deserialize(bytes).value(); +auto decoded = fory.deserialize(bytes).value(); ``` ## User-Defined Structs diff --git a/docs/guide/cpp/type-registration.md b/docs/guide/cpp/type-registration.md index e256184954..21fa4985d0 100644 --- a/docs/guide/cpp/type-registration.md +++ b/docs/guide/cpp/type-registration.md @@ -184,7 +184,7 @@ Built-in types have pre-assigned type IDs and don't need registration: | 15 | SET | | 16 | TIMESTAMP | | 17 | DURATION | -| 18 | LOCAL_DATE | +| 18 | DATE | | 19 | DECIMAL | | 20 | BINARY | | 21 | ARRAY | diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md index e02ee979c2..362d2784c4 100644 --- a/docs/specification/xlang_serialization_spec.md +++ b/docs/specification/xlang_serialization_spec.md @@ -69,9 +69,9 @@ This specification defines the Fory xlang binary format. The format is dynamic r - set: an unordered set of unique elements. - map: a map of key-value pairs. Mutable types such as `list/map/set/array` are not allowed as key of map. - duration: an absolute length of time, independent of any calendar/timezone, as a count of nanoseconds. -- timestamp: a point in time, independent of any calendar/timezone, as a count of nanoseconds. The count is relative - to an epoch at UTC midnight on January 1, 1970. -- local_date: a naive date without timezone. The count is days relative to an epoch at UTC midnight on Jan 1, 1970. +- timestamp: a point in time, independent of any calendar/timezone, encoded as seconds (int64) and nanoseconds + (uint32) since the epoch at UTC midnight on January 1, 1970. +- date: a naive date without timezone. The count is days relative to an epoch at UTC midnight on Jan 1, 1970. - decimal: exact decimal value represented as an integer value in two's complement. - binary: an variable-length array of bytes. - array: only allow 1d numeric components. Other arrays will be taken as List. The implementation should support the @@ -197,8 +197,8 @@ Named types (`NAMED_*`) do not embed a user ID; their names are carried in metad | 33 | NAMED_UNION | Union with embedded union type name/TypeDef | | 34 | NONE | Empty/unit type (no data) | | 35 | DURATION | Time duration (seconds + nanoseconds) | -| 36 | TIMESTAMP | Point in time (nanoseconds since epoch) | -| 37 | LOCAL_DATE | Date without timezone (days since epoch) | +| 36 | TIMESTAMP | Point in time (seconds + nanoseconds since epoch) | +| 37 | DATE | Date without timezone (days since epoch) | | 38 | DECIMAL | Arbitrary precision decimal | | 39 | BINARY | Raw binary data | | 40 | ARRAY | Generic array type | @@ -1220,6 +1220,21 @@ Enums are serialized as an unsigned var int. If the order of enum values change, the value users expect. In such cases, users must register enum serializer by make it write enum value as an enumerated string with unique hash disabled. +### timestamp + +Timestamp represents a point in time independent of any calendar/timezone. It is encoded as: + +- `seconds` (int64): seconds since Unix epoch (1970-01-01T00:00:00Z) +- `nanos` (uint32): nanosecond adjustment within the second + +On write, implementations must normalize negative timestamps so that `nanos` is always in `[0, 1_000_000_000)`. +This is a fixed-size 12-byte payload (8 bytes seconds + 4 bytes nanos). + +### date + +Date represents a date without timezone. It is encoded as an `int32` count of days since the Unix epoch +(1970-01-01). This is a fixed-size 4-byte payload. + ### decimal Not supported for now. @@ -1532,8 +1547,8 @@ This section provides a step-by-step guide for implementing Fory xlang serializa 6. **Temporal Types** - [ ] Duration (seconds + nanoseconds) - - [ ] Timestamp (nanoseconds since epoch) - - [ ] LocalDate (days since epoch) + - [ ] Timestamp (seconds + nanoseconds since epoch) + - [ ] Date (days since epoch) 7. **Reference Tracking** - [ ] Implement write-side object tracking (object → ref_id map) diff --git a/docs/specification/xlang_type_mapping.md b/docs/specification/xlang_type_mapping.md index 8a1b1c18b4..9390e56d6d 100644 --- a/docs/specification/xlang_type_mapping.md +++ b/docs/specification/xlang_type_mapping.md @@ -63,7 +63,7 @@ Note: | none | 32 | null | None | null | `std::monostate` | nil | `()` | | duration | 33 | Duration | timedelta | Number | duration | Duration | Duration | | timestamp | 34 | Instant | datetime | Number | std::chrono::nanoseconds | Time | DateTime | -| local_date | 35 | Date | datetime | Number | std::chrono::nanoseconds | Time | DateTime | +| date | 35 | Date | datetime | Number | fory::serialization::Date | Time | DateTime | | decimal | 36 | BigDecimal | Decimal | bigint | / | / | / | | binary | 37 | byte[] | bytes | / | `uint8_t[n]/vector` | `[n]uint8/[]T` | `Vec` | | array | 38 | array | np.ndarray | / | / | array/slice | Vec | diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 15aedb8db2..2327588a8c 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -456,8 +456,9 @@ func generateSliceElementRead(buf *bytes.Buffer, elemType types.Type, elemAccess typeStr := named.String() switch typeStr { case "time.Time": - fmt.Fprintf(buf, "\t\t\t\tusec := buf.ReadInt64()\n") - fmt.Fprintf(buf, "\t\t\t\t%s = fory.CreateTimeFromUnixMicro(usec)\n", elemAccess) + fmt.Fprintf(buf, "\t\t\t\tseconds := buf.ReadInt64()\n") + fmt.Fprintf(buf, "\t\t\t\tnanos := buf.ReadUint32()\n") + fmt.Fprintf(buf, "\t\t\t\t%s = fory.CreateTimeFromUnixSecondsAndNanos(seconds, nanos)\n", elemAccess) return nil case "github.com/apache/fory/go/fory.Date": fmt.Fprintf(buf, "\t\t\t\tdays := buf.ReadInt32()\n") diff --git a/go/fory/codegen/encoder.go b/go/fory/codegen/encoder.go index b2178d3d3d..537d2569bd 100644 --- a/go/fory/codegen/encoder.go +++ b/go/fory/codegen/encoder.go @@ -372,7 +372,7 @@ func generateElementTypeIDWrite(buf *bytes.Buffer, elemType types.Type) error { fmt.Fprintf(buf, "\t\tbuf.WriteVaruint32(%d) // TIMESTAMP\n", fory.TIMESTAMP) return nil case "github.com/apache/fory/go/fory.Date": - fmt.Fprintf(buf, "\t\tbuf.WriteVaruint32(%d) // LOCAL_DATE\n", fory.LOCAL_DATE) + fmt.Fprintf(buf, "\t\tbuf.WriteVaruint32(%d) // DATE\n", fory.DATE) return nil } // Check if it's a struct diff --git a/go/fory/codegen/utils.go b/go/fory/codegen/utils.go index d18ac944f7..82fc4e2be0 100644 --- a/go/fory/codegen/utils.go +++ b/go/fory/codegen/utils.go @@ -237,7 +237,7 @@ func getTypeID(t types.Type) string { case "time.Time": return "TIMESTAMP" case "github.com/apache/fory/go/fory.Date": - return "LOCAL_DATE" + return "DATE" } // Struct types if _, ok := named.Underlying().(*types.Struct); ok { @@ -355,8 +355,8 @@ func getTypeIDValue(typeID string) int { return int(fory.MAP) // 22 case "TIMESTAMP": return int(fory.TIMESTAMP) // 25 - case "LOCAL_DATE": - return int(fory.LOCAL_DATE) // 26 + case "DATE": + return int(fory.DATE) // 26 case "NAMED_STRUCT": return int(fory.NAMED_STRUCT) // 17 // Primitive array types diff --git a/go/fory/field_info.go b/go/fory/field_info.go index bd841910ee..48de0d16a3 100644 --- a/go/fory/field_info.go +++ b/go/fory/field_info.go @@ -826,7 +826,7 @@ func typeIdFromKind(type_ reflect.Type) TypeId { return typeIdFromKind(info.valueType) } if type_ == dateType { - return LOCAL_DATE + return DATE } if type_ == timestampType { return TIMESTAMP diff --git a/go/fory/skip.go b/go/fory/skip.go index 06b5a38ece..89df58e712 100644 --- a/go/fory/skip.go +++ b/go/fory/skip.go @@ -606,10 +606,11 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo _ = ctx.buffer.ReadBinary(int(length), err) // Date/Time types - case LOCAL_DATE: + case DATE: _ = ctx.buffer.ReadVaruint32Small7(err) case TIMESTAMP: - _ = ctx.buffer.ReadVarint64(err) + _ = ctx.buffer.ReadInt64(err) + _ = ctx.buffer.ReadUint32(err) // Container types case LIST, SET: diff --git a/go/fory/time.go b/go/fory/time.go index 570093bb9b..ed76f4bd7e 100644 --- a/go/fory/time.go +++ b/go/fory/time.go @@ -36,7 +36,7 @@ func (s dateSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool ctx.buffer.WriteInt8(NotNullValueFlag) } if writeType { - ctx.buffer.WriteVaruint32Small7(uint32(LOCAL_DATE)) + ctx.buffer.WriteVaruint32Small7(uint32(DATE)) } s.WriteData(ctx, value) } @@ -78,7 +78,9 @@ func (s dateSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, type type timeSerializer struct{} func (s timeSerializer) WriteData(ctx *WriteContext, value reflect.Value) { - ctx.buffer.WriteInt64(GetUnixMicro(value.Interface().(time.Time))) + seconds, nanos := GetUnixSecondsAndNanos(value.Interface().(time.Time)) + ctx.buffer.WriteInt64(seconds) + ctx.buffer.WriteUint32(nanos) } func (s timeSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { @@ -93,7 +95,9 @@ func (s timeSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool func (s timeSerializer) ReadData(ctx *ReadContext, value reflect.Value) { err := ctx.Err() - value.Set(reflect.ValueOf(CreateTimeFromUnixMicro(ctx.buffer.ReadInt64(err)))) + seconds := ctx.buffer.ReadInt64(err) + nanos := ctx.buffer.ReadUint32(err) + value.Set(reflect.ValueOf(CreateTimeFromUnixSecondsAndNanos(seconds, nanos))) } func (s timeSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index 91b9264220..e8e5e2753c 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -361,7 +361,7 @@ func (r *TypeResolver) initialize() { {intType, VARINT64, intSerializer{}}, // int maps to int64 for xlang {float32Type, FLOAT32, float32Serializer{}}, {float64Type, FLOAT64, float64Serializer{}}, - {dateType, LOCAL_DATE, dateSerializer{}}, + {dateType, DATE, dateSerializer{}}, {timestampType, TIMESTAMP, timeSerializer{}}, {genericSetType, SET, setSerializer{}}, } diff --git a/go/fory/types.go b/go/fory/types.go index fb00798d4c..715f8813d3 100644 --- a/go/fory/types.go +++ b/go/fory/types.go @@ -94,10 +94,10 @@ const ( NONE = 34 // DURATION Measure of elapsed time in either seconds milliseconds microseconds DURATION = 35 - // TIMESTAMP Exact timestamp encoded with int64 since UNIX epoch + // TIMESTAMP Exact timestamp encoded with seconds(int64) + nanos(uint32) since UNIX epoch TIMESTAMP = 36 - // LOCAL_DATE a naive date without timezone - LOCAL_DATE = 37 + // DATE a naive date without timezone + DATE = 37 // DECIMAL Precision- and scale-based decimal type DECIMAL = 38 // BINARY Variable-length bytes (no guarantee of UTF8-ness) @@ -185,7 +185,7 @@ func NeedWriteRef(typeID TypeId) bool { switch typeID { case BOOL, INT8, INT16, INT32, INT64, VARINT32, VARINT64, TAGGED_INT64, FLOAT32, FLOAT64, FLOAT16, - STRING, TIMESTAMP, LOCAL_DATE, DURATION, NONE: + STRING, TIMESTAMP, DATE, DURATION, NONE: return false default: return true diff --git a/go/fory/util.go b/go/fory/util.go index b0131af035..2d3bc95e5c 100644 --- a/go/fory/util.go +++ b/go/fory/util.go @@ -118,17 +118,15 @@ const ( MaxUint64 = 1<<64 - 1 ) -// GetUnixMicro returns t as a Unix time, the number of microseconds elapsed since -// January 1, 1970 UTC. The result is undefined if the Unix time in -// microseconds cannot be represented by an int64 (a date before year -290307 or -// after year 294246). The result does not depend on the location associated +// GetUnixSecondsAndNanos returns t as Unix seconds and nanoseconds since +// January 1, 1970 UTC. The result does not depend on the location associated // with t. -func GetUnixMicro(t time.Time) int64 { - return int64(t.Unix())*1e6 + int64(t.Nanosecond())/1e3 +func GetUnixSecondsAndNanos(t time.Time) (int64, uint32) { + return t.Unix(), uint32(t.Nanosecond()) } -// CreateTimeFromUnixMicro returns the local Time corresponding to the given Unix time, -// usec microseconds since January 1, 1970 UTC. -func CreateTimeFromUnixMicro(usec int64) time.Time { - return time.Unix(usec/1e6, (usec%1e6)*1e3) +// CreateTimeFromUnixSecondsAndNanos returns the local Time corresponding to +// the given Unix seconds and nanoseconds since January 1, 1970 UTC. +func CreateTimeFromUnixSecondsAndNanos(seconds int64, nanos uint32) time.Time { + return time.Unix(seconds, int64(nanos)) } diff --git a/go/fory/util_test.go b/go/fory/util_test.go index 6dbad69b76..250342644e 100644 --- a/go/fory/util_test.go +++ b/go/fory/util_test.go @@ -36,10 +36,9 @@ func TestSnake(t *testing.T) { func TestTime(t *testing.T) { t1 := time.Now() - ts := GetUnixMicro(t1) - t2 := CreateTimeFromUnixMicro(ts) + seconds, nanos := GetUnixSecondsAndNanos(t1) + t2 := CreateTimeFromUnixSecondsAndNanos(seconds, nanos) require.Equal(t, t1.Second(), t2.Second()) - // Micro doesn't preserve Nanosecond precision. - require.Equal(t, t1.Nanosecond()/1000, t2.Nanosecond()/1000) - require.WithinDuration(t, t1, t2, 1000) + require.Equal(t, t1.Nanosecond(), t2.Nanosecond()) + require.Equal(t, t1.Unix(), t2.Unix()) } diff --git a/integration_tests/idl_tests/cpp/main.cc b/integration_tests/idl_tests/cpp/main.cc index 25a544e8b8..72ef5383d5 100644 --- a/integration_tests/idl_tests/cpp/main.cc +++ b/integration_tests/idl_tests/cpp/main.cc @@ -230,7 +230,7 @@ fory::Result RunRoundTrip() { all_types.set_string_value("optional"); all_types.set_bytes_value({static_cast(1), static_cast(2), static_cast(3)}); - all_types.set_date_value(fory::serialization::LocalDate(19724)); + all_types.set_date_value(fory::serialization::Date(19724)); all_types.set_timestamp_value( fory::serialization::Timestamp(std::chrono::seconds(1704164645))); all_types.set_int32_list({1, 2, 3}); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java index 145e2e14d8..3e76943613 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java @@ -833,7 +833,7 @@ private void registerDefaultTypes() { registerType(Types.TIMESTAMP, Timestamp.class, new TimeSerializers.TimestampSerializer(fory)); registerType( Types.TIMESTAMP, LocalDateTime.class, new TimeSerializers.LocalDateTimeSerializer(fory)); - registerType(Types.LOCAL_DATE, LocalDate.class, new TimeSerializers.LocalDateSerializer(fory)); + registerType(Types.DATE, LocalDate.class, new TimeSerializers.LocalDateSerializer(fory)); // Decimal types registerType(Types.DECIMAL, BigDecimal.class, new Serializers.BigDecimalSerializer(fory)); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/TimeSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/TimeSerializers.java index de0b6314ec..110017fb26 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/TimeSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/TimeSerializers.java @@ -174,12 +174,16 @@ public TimestampSerializer(Fory fory, boolean needToWriteRef) { @Override public void xwrite(MemoryBuffer buffer, Timestamp value) { - buffer.writeInt64(DateTimeUtils.fromJavaTimestamp(value)); + Instant instant = value.toInstant(); + buffer.writeInt64(instant.getEpochSecond()); + buffer.writeInt32(instant.getNano()); } @Override public Timestamp xread(MemoryBuffer buffer) { - return DateTimeUtils.toJavaTimestamp(buffer.readInt64()); + long seconds = buffer.readInt64(); + int nanos = buffer.readInt32(); + return Timestamp.from(Instant.ofEpochSecond(seconds, nanos)); } @Override @@ -257,13 +261,15 @@ public InstantSerializer(Fory fory, boolean needToWriteRef) { @Override public void xwrite(MemoryBuffer buffer, Instant value) { - // FIXME JDK17 may have higher precision than millisecond - buffer.writeInt64(DateTimeUtils.instantToMicros(value)); + buffer.writeInt64(value.getEpochSecond()); + buffer.writeInt32(value.getNano()); } @Override public Instant xread(MemoryBuffer buffer) { - return DateTimeUtils.microsToInstant(buffer.readInt64()); + long seconds = buffer.readInt64(); + int nanos = buffer.readInt32(); + return Instant.ofEpochSecond(seconds, nanos); } @Override diff --git a/java/fory-core/src/main/java/org/apache/fory/type/Types.java b/java/fory-core/src/main/java/org/apache/fory/type/Types.java index cf65a8d724..08b3f6dd6d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/Types.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/Types.java @@ -157,7 +157,7 @@ public class Types { * A naive date without timezone. The count is days relative to an epoch at UTC midnight on Jan 1, * 1970. */ - public static final int LOCAL_DATE = 37; + public static final int DATE = 37; /** Exact decimal value represented as an integer value in two's complement. */ public static final int DECIMAL = 38; diff --git a/java/fory-format/src/main/java/org/apache/fory/format/type/DataTypes.java b/java/fory-format/src/main/java/org/apache/fory/format/type/DataTypes.java index f3007bbf6d..8c40f27ae8 100644 --- a/java/fory-format/src/main/java/org/apache/fory/format/type/DataTypes.java +++ b/java/fory-format/src/main/java/org/apache/fory/format/type/DataTypes.java @@ -71,7 +71,7 @@ public class DataTypes { public static final int TYPE_MAP = Types.MAP; public static final int TYPE_DURATION = Types.DURATION; public static final int TYPE_TIMESTAMP = Types.TIMESTAMP; - public static final int TYPE_LOCAL_DATE = Types.LOCAL_DATE; + public static final int TYPE_LOCAL_DATE = Types.DATE; public static final int TYPE_DECIMAL = Types.DECIMAL; public static final int TYPE_BINARY = Types.BINARY; diff --git a/java/fory-format/src/main/java/org/apache/fory/format/type/SchemaEncoder.java b/java/fory-format/src/main/java/org/apache/fory/format/type/SchemaEncoder.java index a88b24811e..12e78c4015 100644 --- a/java/fory-format/src/main/java/org/apache/fory/format/type/SchemaEncoder.java +++ b/java/fory-format/src/main/java/org/apache/fory/format/type/SchemaEncoder.java @@ -227,7 +227,7 @@ private static DataType readType(MemoryBuffer buffer) { return DataTypes.duration(); case Types.TIMESTAMP: return DataTypes.timestamp(); - case Types.LOCAL_DATE: + case Types.DATE: return DataTypes.date32(); case Types.DECIMAL: int precision = buffer.readByte() & 0xFF; diff --git a/javascript/packages/fory/lib/gen/datetime.ts b/javascript/packages/fory/lib/gen/datetime.ts index 7a0f8e69fe..9c96538347 100644 --- a/javascript/packages/fory/lib/gen/datetime.ts +++ b/javascript/packages/fory/lib/gen/datetime.ts @@ -34,17 +34,41 @@ class TimestampSerializerGenerator extends BaseSerializerGenerator { writeStmt(accessor: string): string { if (/^-?[0-9]+$/.test(accessor)) { - return this.builder.writer.int64(`BigInt(${accessor})`); + const msVar = this.scope.uniqueName("ts_ms"); + const secondsVar = this.scope.uniqueName("ts_sec"); + const nanosVar = this.scope.uniqueName("ts_nanos"); + return ` + { + const ${msVar} = ${accessor}; + const ${secondsVar} = Math.floor(${msVar} / 1000); + const ${nanosVar} = (${msVar} - ${secondsVar} * 1000) * 1000000; + ${this.builder.writer.int64(`BigInt(${secondsVar})`)} + ${this.builder.writer.uint32(`${nanosVar}`)} + } + `; } - return this.builder.writer.int64(`BigInt(${accessor}.getTime())`); + const msVar = this.scope.uniqueName("ts_ms"); + const secondsVar = this.scope.uniqueName("ts_sec"); + const nanosVar = this.scope.uniqueName("ts_nanos"); + return ` + { + const ${msVar} = ${accessor}.getTime(); + const ${secondsVar} = Math.floor(${msVar} / 1000); + const ${nanosVar} = (${msVar} - ${secondsVar} * 1000) * 1000000; + ${this.builder.writer.int64(`BigInt(${secondsVar})`)} + ${this.builder.writer.uint32(`${nanosVar}`)} + } + `; } readStmt(accessor: (expr: string) => string): string { - return accessor(`new Date(Number(${this.builder.reader.int64()}))`); + const seconds = this.builder.reader.int64(); + const nanos = this.builder.reader.uint32(); + return accessor(`new Date(Number(${seconds}) * 1000 + Math.floor(${nanos} / 1000000))`); } getFixedSize(): number { - return 11; + return 12; } needToWriteRef(): boolean { diff --git a/javascript/packages/fory/lib/type.ts b/javascript/packages/fory/lib/type.ts index 4da8e6727a..0cc9655800 100644 --- a/javascript/packages/fory/lib/type.ts +++ b/javascript/packages/fory/lib/type.ts @@ -95,7 +95,7 @@ export const TypeId = { // a point in time, independent of any calendar/timezone, as a count of nanoseconds. TIMESTAMP: 36, // a naive date without timezone. - LOCAL_DATE: 37, + DATE: 37, // exact decimal value represented as an integer value in two's complement. DECIMAL: 38, // a variable-length array of bytes. diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py index f45aac12e6..505be3e6e2 100644 --- a/python/pyfory/_serializer.py +++ b/python/pyfory/_serializer.py @@ -291,16 +291,22 @@ def _get_timestamp(self, value: datetime.datetime): is_dst = time.daylight and time.localtime().tm_isdst > 0 seconds_offset = time.altzone if is_dst else time.timezone value = value.replace(tzinfo=datetime.timezone.utc) - return int((value.timestamp() + seconds_offset) * 1000000) + micros = int((value.timestamp() + seconds_offset) * 1_000_000) + seconds, micros_rem = divmod(micros, 1_000_000) + nanos = micros_rem * 1000 + return seconds, nanos def write(self, buffer, value: datetime.datetime): if not isinstance(value, datetime.datetime): raise TypeError("{} should be {} instead of {}".format(value, datetime, type(value))) - # TimestampType represent micro seconds - buffer.write_int64(self._get_timestamp(value)) + seconds, nanos = self._get_timestamp(value) + buffer.write_int64(seconds) + buffer.write_uint32(nanos) def read(self, buffer): - ts = buffer.read_int64() / 1000000 + seconds = buffer.read_int64() + nanos = buffer.read_uint32() + ts = seconds + nanos / 1_000_000_000 # TODO support timezone return datetime.datetime.fromtimestamp(ts) diff --git a/python/pyfory/format/encoder.pxi b/python/pyfory/format/encoder.pxi index 68be180b8f..56b929930f 100644 --- a/python/pyfory/format/encoder.pxi +++ b/python/pyfory/format/encoder.pxi @@ -455,7 +455,7 @@ cdef create_converter(Field field_, CWriter* writer): return create_atomic_encoder(FloatWriter, writer) elif type_id == CTypeId.FLOAT64: return create_atomic_encoder(DoubleWriter, writer) - elif type_id == CTypeId.LOCAL_DATE: + elif type_id == CTypeId.DATE: return create_atomic_encoder(DateWriter, writer) elif type_id == CTypeId.TIMESTAMP: return create_atomic_encoder(TimestampWriter, writer) diff --git a/python/pyfory/format/infer.py b/python/pyfory/format/infer.py index 007a9d7a68..fbb56a1d38 100644 --- a/python/pyfory/format/infer.py +++ b/python/pyfory/format/infer.py @@ -350,7 +350,7 @@ def convert_type(fory_type: DataType): return pa.string() elif type_id == TypeId.BINARY: return pa.binary() - elif type_id == TypeId.LOCAL_DATE: + elif type_id == TypeId.DATE: return pa.date32() elif type_id == TypeId.TIMESTAMP: return pa.timestamp("us") diff --git a/python/pyfory/format/row.pxi b/python/pyfory/format/row.pxi index ec19a65a8c..312ecda480 100644 --- a/python/pyfory/format/row.pxi +++ b/python/pyfory/format/row.pxi @@ -410,7 +410,7 @@ def get_reader(data_type, type_): return type_.get_float elif type_id == CTypeId.FLOAT64: return type_.get_double - elif type_id == CTypeId.LOCAL_DATE: + elif type_id == CTypeId.DATE: return type_.get_date elif type_id == CTypeId.TIMESTAMP: return type_.get_datetime diff --git a/python/pyfory/format/schema.pxi b/python/pyfory/format/schema.pxi index ac1c25e1de..74480761dc 100644 --- a/python/pyfory/format/schema.pxi +++ b/python/pyfory/format/schema.pxi @@ -69,7 +69,7 @@ class TypeId: NONE = 34 DURATION = 35 TIMESTAMP = 36 - LOCAL_DATE = 37 + DATE = 37 DECIMAL = 38 BINARY = 39 diff --git a/python/pyfory/format/schema.py b/python/pyfory/format/schema.py index a0a7cea10b..83ed44a309 100644 --- a/python/pyfory/format/schema.py +++ b/python/pyfory/format/schema.py @@ -69,7 +69,7 @@ def arrow_type_to_fory_type_id(arrow_type): # Date/time types if pa_types.is_date32(arrow_type): - return 26 # LOCAL_DATE + return 26 # DATE if pa_types.is_timestamp(arrow_type): return 25 # TIMESTAMP if pa_types.is_duration(arrow_type): @@ -121,7 +121,7 @@ def fory_type_id_to_arrow_type(type_id, precision=None, scale=None, list_type=No 12: pa.utf8(), # STRING 24: pa.duration("ns"), # DURATION 25: pa.timestamp("us"), # TIMESTAMP - 26: pa.date32(), # LOCAL_DATE + 26: pa.date32(), # DATE 28: pa.binary(), # BINARY } diff --git a/python/pyfory/format/tests/test_infer.py b/python/pyfory/format/tests/test_infer.py index c982001469..ea3a517c00 100644 --- a/python/pyfory/format/tests/test_infer.py +++ b/python/pyfory/format/tests/test_infer.py @@ -85,7 +85,7 @@ def test_infer_class_schema(): def test_type_id(): assert pyfory.format.infer.get_type_id(str) == TypeId.STRING - assert pyfory.format.infer.get_type_id(datetime.date) == TypeId.LOCAL_DATE + assert pyfory.format.infer.get_type_id(datetime.date) == TypeId.DATE assert pyfory.format.infer.get_type_id(datetime.datetime) == TypeId.TIMESTAMP diff --git a/python/pyfory/includes/libformat.pxd b/python/pyfory/includes/libformat.pxd index 8ec42118f6..0c1129dcdf 100755 --- a/python/pyfory/includes/libformat.pxd +++ b/python/pyfory/includes/libformat.pxd @@ -77,7 +77,7 @@ cdef extern from "fory/type/type.h" namespace "fory" nogil: NONE = 34 DURATION = 35 TIMESTAMP = 36 - LOCAL_DATE = 37 + DATE = 37 DECIMAL = 38 BINARY = 39 ARRAY = 40 diff --git a/python/pyfory/includes/libserialization.pxd b/python/pyfory/includes/libserialization.pxd index 62cb1485df..8a4b860476 100644 --- a/python/pyfory/includes/libserialization.pxd +++ b/python/pyfory/includes/libserialization.pxd @@ -60,7 +60,7 @@ cdef extern from "fory/type/type.h" namespace "fory" nogil: NONE = 34 DURATION = 35 TIMESTAMP = 36 - LOCAL_DATE = 37 + DATE = 37 DECIMAL = 38 BINARY = 39 ARRAY = 40 diff --git a/python/pyfory/primitive.pxi b/python/pyfory/primitive.pxi index ed25317779..7f932962cf 100644 --- a/python/pyfory/primitive.pxi +++ b/python/pyfory/primitive.pxi @@ -252,17 +252,34 @@ cdef class TimestampSerializer(XlangCompatibleSerializer): is_dst = time.daylight and time.localtime().tm_isdst > 0 seconds_offset = time.altzone if is_dst else time.timezone value = value.replace(tzinfo=datetime.timezone.utc) - return int((value.timestamp() + seconds_offset) * 1000000) + cdef long long micros = ((value.timestamp() + seconds_offset) * 1000000) + cdef long long seconds + cdef long long micros_rem + if micros >= 0: + seconds = micros // 1000000 + micros_rem = micros % 1000000 + else: + seconds = -((-micros) // 1000000) + micros_rem = micros - seconds * 1000000 + if micros_rem < 0: + seconds -= 1 + micros_rem += 1000000 + return seconds, (micros_rem * 1000) cpdef inline write(self, Buffer buffer, value): if type(value) is not datetime.datetime: raise TypeError( "{} should be {} instead of {}".format(value, datetime, type(value)) ) - # TimestampType represent micro seconds - buffer.write_int64(self._get_timestamp(value)) + cdef long long seconds + cdef unsigned int nanos + seconds, nanos = self._get_timestamp(value) + buffer.write_int64(seconds) + buffer.write_uint32(nanos) cpdef inline read(self, Buffer buffer): - ts = buffer.read_int64() / 1000000 + cdef long long seconds = buffer.read_int64() + cdef unsigned int nanos = buffer.read_uint32() + ts = seconds + nanos / 1000000000 # TODO support timezone return datetime.datetime.fromtimestamp(ts) diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py index 43547a2268..ad037d32b0 100644 --- a/python/pyfory/registry.py +++ b/python/pyfory/registry.py @@ -317,7 +317,7 @@ def _initialize_common(self): register(str, type_id=TypeId.STRING, serializer=StringSerializer) # TODO(chaokunyang) DURATION DECIMAL register(datetime.datetime, type_id=TypeId.TIMESTAMP, serializer=TimestampSerializer) - register(datetime.date, type_id=TypeId.LOCAL_DATE, serializer=DateSerializer) + register(datetime.date, type_id=TypeId.DATE, serializer=DateSerializer) register(bytes, type_id=TypeId.BINARY, serializer=BytesSerializer) for itemsize, ftype, typeid in PyArraySerializer.typecode_dict.values(): register( diff --git a/python/pyfory/tests/test_serializer.py b/python/pyfory/tests/test_serializer.py index d5899e43bd..c49b45730f 100644 --- a/python/pyfory/tests/test_serializer.py +++ b/python/pyfory/tests/test_serializer.py @@ -136,7 +136,7 @@ def test_basic_serializer(language): typeinfo = fory.type_resolver.get_typeinfo(datetime.date) assert isinstance(typeinfo.serializer, (DateSerializer, serialization.DateSerializer)) if language == Language.XLANG: - assert typeinfo.type_id == TypeId.LOCAL_DATE + assert typeinfo.type_id == TypeId.DATE assert ser_de(fory, True) is True assert ser_de(fory, False) is False assert ser_de(fory, -1) == -1 diff --git a/python/pyfory/types.py b/python/pyfory/types.py index 9aeb5ca8f3..efcad1d66f 100644 --- a/python/pyfory/types.py +++ b/python/pyfory/types.py @@ -112,7 +112,7 @@ class TypeId: # to an epoch at UTC midnight on January 1, 1970. TIMESTAMP = 36 # a naive date without timezone. The count is days relative to an epoch at UTC midnight on Jan 1, 1970. - LOCAL_DATE = 37 + DATE = 37 # exact decimal value represented as an integer value in two's complement. DECIMAL = 38 # a variable-length array of bytes. diff --git a/rust/fory-core/src/resolver/type_resolver.rs b/rust/fory-core/src/resolver/type_resolver.rs index 170d69bb6a..b47723c008 100644 --- a/rust/fory-core/src/resolver/type_resolver.rs +++ b/rust/fory-core/src/resolver/type_resolver.rs @@ -604,7 +604,7 @@ impl TypeResolver { self.register_internal_serializer::(TypeId::U128)?; self.register_internal_serializer::(TypeId::STRING)?; self.register_internal_serializer::(TypeId::TIMESTAMP)?; - self.register_internal_serializer::(TypeId::LOCAL_DATE)?; + self.register_internal_serializer::(TypeId::DATE)?; self.register_internal_serializer::>(TypeId::BOOL_ARRAY)?; self.register_internal_serializer::>(TypeId::INT8_ARRAY)?; diff --git a/rust/fory-core/src/serializer/datetime.rs b/rust/fory-core/src/serializer/datetime.rs index a0c0f6ae88..bbeaaff624 100644 --- a/rust/fory-core/src/serializer/datetime.rs +++ b/rust/fory-core/src/serializer/datetime.rs @@ -32,25 +32,25 @@ impl Serializer for NaiveDateTime { #[inline(always)] fn fory_write_data(&self, context: &mut WriteContext) -> Result<(), Error> { let dt = self.and_utc(); - let micros = dt.timestamp() * 1_000_000 + dt.timestamp_subsec_micros() as i64; - context.writer.write_i64(micros); + let seconds = dt.timestamp(); + let nanos = dt.timestamp_subsec_nanos(); + context.writer.write_i64(seconds); + context.writer.write_u32(nanos); Ok(()) } #[inline(always)] fn fory_read_data(context: &mut ReadContext) -> Result { - let micros = context.reader.read_i64()?; - use chrono::TimeDelta; - let duration = TimeDelta::microseconds(micros); + let seconds = context.reader.read_i64()?; + let nanos = context.reader.read_u32()?; #[allow(deprecated)] - let epoch_datetime = NaiveDateTime::from_timestamp(0, 0); - let result = epoch_datetime + duration; + let result = NaiveDateTime::from_timestamp(seconds, nanos); Ok(result) } #[inline(always)] fn fory_reserved_space() -> usize { - mem::size_of::() + mem::size_of::() + mem::size_of::() } #[inline(always)] @@ -109,17 +109,17 @@ impl Serializer for NaiveDate { #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { - Ok(TypeId::LOCAL_DATE as u32) + Ok(TypeId::DATE as u32) } #[inline(always)] fn fory_type_id_dyn(&self, _: &TypeResolver) -> Result { - Ok(TypeId::LOCAL_DATE as u32) + Ok(TypeId::DATE as u32) } #[inline(always)] fn fory_static_type_id() -> TypeId { - TypeId::LOCAL_DATE + TypeId::DATE } #[inline(always)] @@ -129,7 +129,7 @@ impl Serializer for NaiveDate { #[inline(always)] fn fory_write_type_info(context: &mut WriteContext) -> Result<(), Error> { - context.writer.write_varuint32(TypeId::LOCAL_DATE as u32); + context.writer.write_varuint32(TypeId::DATE as u32); Ok(()) } diff --git a/rust/fory-core/src/serializer/skip.rs b/rust/fory-core/src/serializer/skip.rs index 158d883016..bb3f902176 100644 --- a/rust/fory-core/src/serializer/skip.rs +++ b/rust/fory-core/src/serializer/skip.rs @@ -649,8 +649,8 @@ fn skip_value( ::fory_read_data(context)?; } - // ============ LOCAL_DATE (TypeId = 37) ============ - types::LOCAL_DATE => { + // ============ DATE (TypeId = 37) ============ + types::DATE => { ::fory_read_data(context)?; } diff --git a/rust/fory-core/src/types.rs b/rust/fory-core/src/types.rs index 1ca0037035..391eef992f 100644 --- a/rust/fory-core/src/types.rs +++ b/rust/fory-core/src/types.rs @@ -140,7 +140,7 @@ pub enum TypeId { NONE = 34, DURATION = 35, TIMESTAMP = 36, - LOCAL_DATE = 37, + DATE = 37, DECIMAL = 38, BINARY = 39, ARRAY = 40, @@ -207,7 +207,7 @@ pub const SET: u32 = TypeId::SET as u32; pub const MAP: u32 = TypeId::MAP as u32; pub const DURATION: u32 = TypeId::DURATION as u32; pub const TIMESTAMP: u32 = TypeId::TIMESTAMP as u32; -pub const LOCAL_DATE: u32 = TypeId::LOCAL_DATE as u32; +pub const DATE: u32 = TypeId::DATE as u32; pub const DECIMAL: u32 = TypeId::DECIMAL as u32; pub const BINARY: u32 = TypeId::BINARY as u32; pub const ARRAY: u32 = TypeId::ARRAY as u32; @@ -280,7 +280,7 @@ pub static BASIC_TYPES: [TypeId; 33] = [ TypeId::FLOAT32, TypeId::FLOAT64, TypeId::STRING, - TypeId::LOCAL_DATE, + TypeId::DATE, TypeId::TIMESTAMP, TypeId::BOOL_ARRAY, TypeId::BINARY, @@ -591,7 +591,7 @@ pub fn format_type_id(type_id: u32) -> String { 34 => "NONE", 35 => "DURATION", 36 => "TIMESTAMP", - 37 => "LOCAL_DATE", + 37 => "DATE", 38 => "DECIMAL", 39 => "BINARY", 40 => "ARRAY", diff --git a/rust/fory-derive/src/object/util.rs b/rust/fory-derive/src/object/util.rs index d8182828fa..711a4e0113 100644 --- a/rust/fory-derive/src/object/util.rs +++ b/rust/fory-derive/src/object/util.rs @@ -956,7 +956,7 @@ pub(crate) fn get_type_id_by_name(ty: &str) -> u32 { // Check internal types match ty { "String" => return TypeId::STRING as u32, - "NaiveDate" => return TypeId::LOCAL_DATE as u32, + "NaiveDate" => return TypeId::DATE as u32, "NaiveDateTime" => return TypeId::TIMESTAMP as u32, "Duration" => return TypeId::DURATION as u32, "Decimal" => return TypeId::DECIMAL as u32, @@ -1083,7 +1083,7 @@ fn is_compress(type_id: u32) -> bool { fn is_internal_type_id(type_id: u32) -> bool { [ TypeId::STRING as u32, - TypeId::LOCAL_DATE as u32, + TypeId::DATE as u32, TypeId::TIMESTAMP as u32, TypeId::DURATION as u32, TypeId::DECIMAL as u32, From 466b1b37e86ceb0aaa27d960813739740a6132a8 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 15:36:13 +0800 Subject: [PATCH 05/21] fix benchmark --- benchmarks/go_benchmark/benchmark_test.go | 28 +++++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/benchmarks/go_benchmark/benchmark_test.go b/benchmarks/go_benchmark/benchmark_test.go index e6c12ada22..2366d0a112 100644 --- a/benchmarks/go_benchmark/benchmark_test.go +++ b/benchmarks/go_benchmark/benchmark_test.go @@ -37,13 +37,27 @@ func newFory() *fory.Fory { fory.WithTrackRef(false), ) // Register types with IDs matching C++ benchmark - f.Register(NumericStruct{}, 1) - f.Register(Sample{}, 2) - f.Register(Media{}, 3) - f.Register(Image{}, 4) - f.Register(MediaContent{}, 5) - f.RegisterEnum(Player(0), 6) - f.RegisterEnum(Size(0), 7) + if err := f.RegisterStruct(NumericStruct{}, 1); err != nil { + panic(err) + } + if err := f.RegisterStruct(Sample{}, 2); err != nil { + panic(err) + } + if err := f.RegisterStruct(Media{}, 3); err != nil { + panic(err) + } + if err := f.RegisterStruct(Image{}, 4); err != nil { + panic(err) + } + if err := f.RegisterStruct(MediaContent{}, 5); err != nil { + panic(err) + } + if err := f.RegisterEnum(Player(0), 6); err != nil { + panic(err) + } + if err := f.RegisterEnum(Size(0), 7); err != nil { + panic(err) + } return f } From 346ddd9841046ebf9fc799c58a648300b1a3e33c Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 16:42:56 +0800 Subject: [PATCH 06/21] refactor go struct serialization --- go/fory/field_info.go | 75 ++- go/fory/primitive.go | 82 ++-- go/fory/struct.go | 967 ++++++++++++++++++--------------------- go/fory/struct_test.go | 165 +++++++ go/fory/type_resolver.go | 2 + go/fory/types.go | 132 +----- 6 files changed, 719 insertions(+), 704 deletions(-) diff --git a/go/fory/field_info.go b/go/fory/field_info.go index 48de0d16a3..b12fdc620c 100644 --- a/go/fory/field_info.go +++ b/go/fory/field_info.go @@ -24,13 +24,23 @@ import ( "strings" ) +// FieldKind describes how a field is stored in Go memory. +type FieldKind uint8 + +const ( + FieldKindValue FieldKind = iota + FieldKindPointer + FieldKindOptional +) + // PrimitiveFieldInfo contains only the fields needed for hot primitive serialization loops. // This minimal struct improves cache efficiency during iteration. -// Size: 16 bytes (vs full FieldInfo) type PrimitiveFieldInfo struct { Offset uintptr // Field offset for unsafe access DispatchId DispatchId // Type dispatch ID WriteOffset uint8 // Offset within fixed-fields buffer (0-255, sufficient for fixed primitives) + Kind FieldKind + Meta *FieldMeta } // FieldMeta contains cold/rarely-accessed field metadata. @@ -43,8 +53,7 @@ type FieldMeta struct { FieldIndex int // -1 if field doesn't exist in current struct (for compatible mode) FieldDef FieldDef // original FieldDef from remote TypeDef (for compatible mode skip) - // Optional fields (fory/optional.Optional[T]) - IsOptional bool + // Optional fields (fory/optional.Optional[T]) - only valid when FieldKindOptional OptionalInfo optionalInfo // Pre-computed sizes (for fixed primitives) @@ -71,7 +80,7 @@ type FieldInfo struct { DispatchId DispatchId // Type dispatch ID WriteOffset int // Offset within fixed-fields buffer region (sum of preceding field sizes) RefMode RefMode // ref mode for serializer.Write/Read - IsPtr bool // True if field.Type.Kind() == reflect.Ptr + Kind FieldKind Serializer Serializer // Serializer for this field // Cold fields - accessed less frequently @@ -151,15 +160,11 @@ func GroupFields(fields []FieldInfo) FieldGroup { // Categorize fields for i := range fields { field := &fields[i] - if field.Meta.IsOptional { - g.RemainingFields = append(g.RemainingFields, *field) - continue - } - if isFixedSizePrimitive(field.DispatchId, field.Meta.Nullable) { + if isFixedSizePrimitive(field.DispatchId) { // Non-nullable fixed-size primitives only field.Meta.FixedSize = getFixedSizeByDispatchId(field.DispatchId) g.FixedFields = append(g.FixedFields, *field) - } else if isVarintPrimitive(field.DispatchId, field.Meta.Nullable) { + } else if isVarintPrimitive(field.DispatchId) { // Non-nullable varint primitives only g.VarintFields = append(g.VarintFields, *field) } else { @@ -188,6 +193,8 @@ func GroupFields(fields []FieldInfo) FieldGroup { Offset: g.FixedFields[i].Offset, DispatchId: g.FixedFields[i].DispatchId, WriteOffset: uint8(g.FixedSize), + Kind: g.FixedFields[i].Kind, + Meta: g.FixedFields[i].Meta, } g.FixedSize += g.FixedFields[i].Meta.FixedSize } @@ -214,6 +221,8 @@ func GroupFields(fields []FieldInfo) FieldGroup { g.PrimitiveVarintFields[i] = PrimitiveFieldInfo{ Offset: g.VarintFields[i].Offset, DispatchId: g.VarintFields[i].DispatchId, + Kind: g.VarintFields[i].Kind, + Meta: g.VarintFields[i].Meta, // WriteOffset not used for varint fields (variable length) } } @@ -357,27 +366,19 @@ func getUnderlyingTypeSize(dispatchId DispatchId) int { switch dispatchId { // 64-bit types case PrimitiveInt64DispatchId, PrimitiveUint64DispatchId, PrimitiveFloat64DispatchId, - NotnullInt64PtrDispatchId, NotnullUint64PtrDispatchId, NotnullFloat64PtrDispatchId, PrimitiveVarint64DispatchId, PrimitiveVarUint64DispatchId, - NotnullVarint64PtrDispatchId, NotnullVarUint64PtrDispatchId, PrimitiveTaggedInt64DispatchId, PrimitiveTaggedUint64DispatchId, - NotnullTaggedInt64PtrDispatchId, NotnullTaggedUint64PtrDispatchId, - PrimitiveIntDispatchId, PrimitiveUintDispatchId, - NotnullIntPtrDispatchId, NotnullUintPtrDispatchId: + PrimitiveIntDispatchId, PrimitiveUintDispatchId: return 8 // 32-bit types case PrimitiveInt32DispatchId, PrimitiveUint32DispatchId, PrimitiveFloat32DispatchId, - NotnullInt32PtrDispatchId, NotnullUint32PtrDispatchId, NotnullFloat32PtrDispatchId, - PrimitiveVarint32DispatchId, PrimitiveVarUint32DispatchId, - NotnullVarint32PtrDispatchId, NotnullVarUint32PtrDispatchId: + PrimitiveVarint32DispatchId, PrimitiveVarUint32DispatchId: return 4 // 16-bit types - case PrimitiveInt16DispatchId, PrimitiveUint16DispatchId, - NotnullInt16PtrDispatchId, NotnullUint16PtrDispatchId: + case PrimitiveInt16DispatchId, PrimitiveUint16DispatchId: return 2 // 8-bit types - case PrimitiveBoolDispatchId, PrimitiveInt8DispatchId, PrimitiveUint8DispatchId, - NotnullBoolPtrDispatchId, NotnullInt8PtrDispatchId, NotnullUint8PtrDispatchId: + case PrimitiveBoolDispatchId, PrimitiveInt8DispatchId, PrimitiveUint8DispatchId: return 1 // Nullable types case NullableInt64DispatchId, NullableUint64DispatchId, NullableFloat64DispatchId, @@ -771,6 +772,12 @@ func typesCompatible(actual, expected reflect.Type) bool { if actual == expected { return true } + if (actual.Kind() == reflect.Int && expected.Kind() == reflect.Int64) || + (actual.Kind() == reflect.Int64 && expected.Kind() == reflect.Int) || + (actual.Kind() == reflect.Uint && expected.Kind() == reflect.Uint64) || + (actual.Kind() == reflect.Uint64 && expected.Kind() == reflect.Uint) { + return true + } // any can accept any value if actual.Kind() == reflect.Interface && actual.NumMethod() == 0 { return true @@ -784,6 +791,24 @@ func typesCompatible(actual, expected reflect.Type) bool { if expected.Kind() == reflect.Ptr && expected.Elem() == actual { return true } + if actual.Kind() == reflect.Ptr && expected.Kind() != reflect.Ptr { + elem := actual.Elem() + if (elem.Kind() == reflect.Int && expected.Kind() == reflect.Int64) || + (elem.Kind() == reflect.Int64 && expected.Kind() == reflect.Int) || + (elem.Kind() == reflect.Uint && expected.Kind() == reflect.Uint64) || + (elem.Kind() == reflect.Uint64 && expected.Kind() == reflect.Uint) { + return true + } + } + if expected.Kind() == reflect.Ptr && actual.Kind() != reflect.Ptr { + elem := expected.Elem() + if (elem.Kind() == reflect.Int && actual.Kind() == reflect.Int64) || + (elem.Kind() == reflect.Int64 && actual.Kind() == reflect.Int) || + (elem.Kind() == reflect.Uint && actual.Kind() == reflect.Uint64) || + (elem.Kind() == reflect.Uint64 && actual.Kind() == reflect.Uint) { + return true + } + } if actual.Kind() == expected.Kind() { switch actual.Kind() { case reflect.Slice, reflect.Array: @@ -812,6 +837,12 @@ func elementTypesCompatible(actual, expected reflect.Type) bool { if actual == expected || actual.AssignableTo(expected) || expected.AssignableTo(actual) { return true } + if (actual.Kind() == reflect.Int && expected.Kind() == reflect.Int64) || + (actual.Kind() == reflect.Int64 && expected.Kind() == reflect.Int) || + (actual.Kind() == reflect.Uint && expected.Kind() == reflect.Uint64) || + (actual.Kind() == reflect.Uint64 && expected.Kind() == reflect.Uint) { + return true + } if actual.Kind() == reflect.Ptr { return elementTypesCompatible(actual, expected.Elem()) } diff --git a/go/fory/primitive.go b/go/fory/primitive.go index a2341ab435..62ab243ea2 100644 --- a/go/fory/primitive.go +++ b/go/fory/primitive.go @@ -17,10 +17,7 @@ package fory -import ( - "reflect" - "unsafe" -) +import "reflect" // ============================================================================ // Primitive Serializers - implement unified Serializer interface @@ -303,6 +300,50 @@ func (s uint64Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, ty s.Read(ctx, refMode, false, false, value) } +// uintSerializer handles uint type with variable-length encoding (VAR_UINT64) +type uintSerializer struct{} + +var globalUintSerializer = uintSerializer{} + +func (s uintSerializer) WriteData(ctx *WriteContext, value reflect.Value) { + ctx.buffer.WriteVaruint64(value.Uint()) +} + +func (s uintSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + if refMode != RefModeNone { + ctx.buffer.WriteInt8(NotNullValueFlag) + } + if writeType { + ctx.buffer.WriteVaruint32Small7(uint32(VAR_UINT64)) + } + s.WriteData(ctx, value) +} + +func (s uintSerializer) ReadData(ctx *ReadContext, value reflect.Value) { + err := ctx.Err() + value.SetUint(ctx.buffer.ReadVaruint64(err)) +} + +func (s uintSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + err := ctx.Err() + if refMode != RefModeNone { + if ctx.buffer.ReadInt8(err) == NullFlag { + return + } + } + if readType { + _ = ctx.buffer.ReadVaruint32Small7(err) + } + if ctx.HasError() { + return + } + s.ReadData(ctx, value) +} + +func (s uintSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} + // int16Serializer handles int16 type type int16Serializer struct{} @@ -564,36 +605,3 @@ func (s float64Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool func (s float64Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { s.Read(ctx, refMode, false, false, value) } - -// ============================================================================ -// Notnull Pointer Helper Functions for Varint Types -// These are used by struct serializer for the rare case of *T with nullable=false -// ============================================================================ - -// writeNotnullVarintPtrUnsafe writes a notnull pointer varint type at the given offset. -// Used by struct serializer for rare notnull pointer types. -// Returns the number of bytes written. -// -//go:inline -func writeNotnullVarintPtrUnsafe(buf *ByteBuffer, offset int, fieldPtr unsafe.Pointer, dispatchId DispatchId) int { - switch dispatchId { - case NotnullVarint32PtrDispatchId: - return buf.UnsafePutVarInt32(offset, **(**int32)(fieldPtr)) - case NotnullVarint64PtrDispatchId: - return buf.UnsafePutVarInt64(offset, **(**int64)(fieldPtr)) - case NotnullIntPtrDispatchId: - return buf.UnsafePutVarInt64(offset, int64(**(**int)(fieldPtr))) - case NotnullVarUint32PtrDispatchId: - return buf.UnsafePutVaruint32(offset, **(**uint32)(fieldPtr)) - case NotnullVarUint64PtrDispatchId: - return buf.UnsafePutVaruint64(offset, **(**uint64)(fieldPtr)) - case NotnullUintPtrDispatchId: - return buf.UnsafePutVaruint64(offset, uint64(**(**uint)(fieldPtr))) - case NotnullTaggedInt64PtrDispatchId: - return buf.UnsafePutTaggedInt64(offset, **(**int64)(fieldPtr)) - case NotnullTaggedUint64PtrDispatchId: - return buf.UnsafePutTaggedUint64(offset, **(**uint64)(fieldPtr)) - default: - return 0 - } -} diff --git a/go/fory/struct.go b/go/fory/struct.go index ac945a0a54..8beea7f32c 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -160,7 +160,12 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { } baseType = optionalInfo.valueType } - + fieldKind := FieldKindValue + if isOptional { + fieldKind = FieldKindOptional + } else if fieldType.Kind() == reflect.Ptr { + fieldKind = FieldKindPointer + } var fieldSerializer Serializer // For any fields, don't get a serializer - use WriteValue/ReadValue instead // which will handle polymorphic types dynamically @@ -261,7 +266,7 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { fieldType.Kind() == reflect.Map || fieldType.Kind() == reflect.Interface } - if foryTag.NullableSet && !isOptional { + if foryTag.NullableSet { // Override nullable flag if explicitly set in fory tag nullableFlag = foryTag.Nullable } @@ -290,18 +295,13 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { writeType := typeResolver.Compatible() && isStructField(baseType) // Pre-compute DispatchId, with special handling for enum fields and pointer-to-numeric - var dispatchId DispatchId - if fieldType.Kind() == reflect.Ptr && isNumericKind(fieldType.Elem().Kind()) { - if nullableFlag { - dispatchId = getDispatchIdFromTypeId(fieldTypeId, true) - } else { - dispatchId = getNotnullPtrDispatchId(fieldType.Elem().Kind(), foryTag.Encoding) - } - } else { - dispatchId = getDispatchIdFromTypeId(fieldTypeId, nullableFlag) - if dispatchId == UnknownDispatchId { - dispatchId = GetDispatchId(fieldType) + dispatchId := getDispatchIdFromTypeId(fieldTypeId, nullableFlag) + if dispatchId == UnknownDispatchId { + dispatchType := baseType + if dispatchType.Kind() == reflect.Ptr { + dispatchType = dispatchType.Elem() } + dispatchId = GetDispatchId(dispatchType) } if fieldSerializer != nil { if _, ok := fieldSerializer.(*enumSerializer); ok { @@ -321,7 +321,7 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { Offset: field.Offset, DispatchId: dispatchId, RefMode: refMode, - IsPtr: fieldType.Kind() == reflect.Ptr, + Kind: fieldKind, Serializer: fieldSerializer, Meta: &FieldMeta{ Name: SnakeCase(field.Name), @@ -331,7 +331,6 @@ func (s *structSerializer) initFields(typeResolver *TypeResolver) error { FieldIndex: i, WriteType: writeType, HasGenerics: isCollectionType(fieldTypeId), // Container fields have declared element types - IsOptional: isOptional, OptionalInfo: optionalInfo, TagID: foryTag.ID, HasForyTag: foryTag.HasTag, @@ -425,7 +424,7 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err Offset: 0, DispatchId: dispatchId, RefMode: refMode, - IsPtr: remoteType != nil && remoteType.Kind() == reflect.Ptr, + Kind: FieldKindValue, Serializer: fieldSerializer, Meta: &FieldMeta{ Name: def.name, @@ -688,6 +687,17 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err } baseType = optionalInfo.valueType } + fieldKind := FieldKindValue + if isOptional { + fieldKind = FieldKindOptional + } else if fieldType.Kind() == reflect.Ptr { + fieldKind = FieldKindPointer + } + if fieldKind == FieldKindOptional { + // Use the Optional serializer for local Optional[T] fields. + // The serializer resolved from remote type IDs is for the element type. + fieldSerializer, _ = typeResolver.getSerializerByType(fieldType, true) + } // Get TypeId from FieldType's TypeId method fieldTypeId := def.fieldType.TypeId() @@ -714,26 +724,26 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err localIsPrimitive := isPrimitiveDispatchKind(baseKind) || (localIsPtr && isPrimitiveDispatchKind(fieldType.Elem().Kind())) if localIsPrimitive { - if localIsPtr { - if def.nullable { - // Local is *T, remote is nullable - use nullable DispatchId - dispatchId = getDispatchIdFromTypeId(fieldTypeId, true) - } else { - // Local is *T, remote is NOT nullable - use notnull pointer DispatchId - encoding := getEncodingFromTypeId(fieldTypeId) - dispatchId = getNotnullPtrDispatchId(fieldType.Elem().Kind(), encoding) - } + if def.nullable { + // Remote is nullable - use nullable DispatchId + dispatchId = getDispatchIdFromTypeId(fieldTypeId, true) } else { - if def.nullable { - // Local is T (non-pointer), remote is nullable - use nullable DispatchId - dispatchId = getDispatchIdFromTypeId(fieldTypeId, true) - } else { - // Local is T, remote is NOT nullable - use primitive DispatchId - dispatchId = GetDispatchId(baseType) + // Remote is NOT nullable - use primitive DispatchId + dispatchId = getDispatchIdFromTypeId(fieldTypeId, false) + if dispatchId == UnknownDispatchId { + dispatchType := baseType + if dispatchType.Kind() == reflect.Ptr { + dispatchType = dispatchType.Elem() + } + dispatchId = GetDispatchId(dispatchType) } } } else { - dispatchId = GetDispatchId(baseType) + dispatchType := baseType + if dispatchType.Kind() == reflect.Ptr { + dispatchType = dispatchType.Elem() + } + dispatchId = GetDispatchId(dispatchType) } if fieldSerializer != nil { if _, ok := fieldSerializer.(*enumSerializer); ok { @@ -755,7 +765,7 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err Offset: offset, DispatchId: dispatchId, RefMode: refMode, - IsPtr: fieldType != nil && fieldType.Kind() == reflect.Ptr, + Kind: fieldKind, Serializer: fieldSerializer, Meta: &FieldMeta{ Name: fieldName, @@ -766,7 +776,6 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err FieldDef: def, // Save original FieldDef for skipping WriteType: writeType, HasGenerics: isCollectionType(fieldTypeId), // Container fields have declared element types - IsOptional: isOptional, OptionalInfo: optionalInfo, TagID: def.tagID, HasForyTag: def.tagID >= 0, @@ -802,8 +811,8 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err // Local nullable is determined by whether the Go field is a pointer type if i < len(s.fieldDefs) && field.Meta.FieldIndex >= 0 { remoteNullable := s.fieldDefs[i].nullable - // Check if local Go field is nullable (pointer or Option) - localNullable := field.IsPtr || field.Meta.IsOptional + // Check if local Go field is nullable based on computed field metadata + localNullable := field.Meta.Nullable if remoteNullable != localNullable { s.typeDefDiffers = true break @@ -850,7 +859,7 @@ func (s *structSerializer) computeHash() int32 { typeId = UNKNOWN } fieldTypeForHash := field.Meta.Type - if field.Meta.IsOptional { + if field.Kind == FieldKindOptional { fieldTypeForHash = field.Meta.OptionalInfo.valueType } // For fixed-size arrays with primitive elements, use primitive array type IDs @@ -898,17 +907,17 @@ func (s *structSerializer) computeHash() int32 { // - Default: false for ALL fields (xlang default - aligned with all languages) // - Primitives are always non-nullable // - Can be overridden by explicit fory tag - nullable := field.Meta.IsOptional // Optional fields are nullable by default - if field.Meta.TagNullableSet && !field.Meta.IsOptional { + nullable := field.Kind == FieldKindOptional // Optional fields are nullable by default + if field.Meta.TagNullableSet { // Use explicit tag value if set nullable = field.Meta.TagNullable } // Primitives are never nullable, regardless of tag fieldTypeForNullable := field.Meta.Type - if field.Meta.IsOptional { + if field.Kind == FieldKindOptional { fieldTypeForNullable = field.Meta.OptionalInfo.valueType } - if !field.Meta.IsOptional && isNonNullablePrimitiveKind(fieldTypeForNullable.Kind()) && !isEnumField { + if field.Kind != FieldKindOptional && isNonNullablePrimitiveKind(fieldTypeForNullable.Kind()) && !isEnumField { nullable = false } @@ -1022,122 +1031,111 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { for _, field := range s.fieldGroup.PrimitiveFixedFields { fieldPtr := unsafe.Add(ptr, field.Offset) bufOffset := baseOffset + int(field.WriteOffset) + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional && field.Meta != nil { + optInfo = field.Meta.OptionalInfo + } switch field.DispatchId { case PrimitiveBoolDispatchId: - if *(*bool)(fieldPtr) { + v, ok := loadFieldValue[bool](field.Kind, fieldPtr, optInfo) + if ok && v { data[bufOffset] = 1 } else { data[bufOffset] = 0 } - case NotnullBoolPtrDispatchId: - if **(**bool)(fieldPtr) { - data[bufOffset] = 1 + case PrimitiveInt8DispatchId: + v, ok := loadFieldValue[int8](field.Kind, fieldPtr, optInfo) + if ok { + data[bufOffset] = byte(v) } else { data[bufOffset] = 0 } - case PrimitiveInt8DispatchId: - data[bufOffset] = *(*byte)(fieldPtr) - case NotnullInt8PtrDispatchId: - data[bufOffset] = byte(**(**int8)(fieldPtr)) case PrimitiveUint8DispatchId: - data[bufOffset] = *(*uint8)(fieldPtr) - case NotnullUint8PtrDispatchId: - data[bufOffset] = **(**uint8)(fieldPtr) - case PrimitiveInt16DispatchId: - if isLittleEndian { - *(*int16)(unsafe.Pointer(&data[bufOffset])) = *(*int16)(fieldPtr) + v, ok := loadFieldValue[uint8](field.Kind, fieldPtr, optInfo) + if ok { + data[bufOffset] = v } else { - binary.LittleEndian.PutUint16(data[bufOffset:], uint16(*(*int16)(fieldPtr))) + data[bufOffset] = 0 + } + case PrimitiveInt16DispatchId: + v, ok := loadFieldValue[int16](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 } - case NotnullInt16PtrDispatchId: if isLittleEndian { - *(*int16)(unsafe.Pointer(&data[bufOffset])) = **(**int16)(fieldPtr) + *(*int16)(unsafe.Pointer(&data[bufOffset])) = v } else { - binary.LittleEndian.PutUint16(data[bufOffset:], uint16(**(**int16)(fieldPtr))) + binary.LittleEndian.PutUint16(data[bufOffset:], uint16(v)) } case PrimitiveUint16DispatchId: - if isLittleEndian { - *(*uint16)(unsafe.Pointer(&data[bufOffset])) = *(*uint16)(fieldPtr) - } else { - binary.LittleEndian.PutUint16(data[bufOffset:], *(*uint16)(fieldPtr)) + v, ok := loadFieldValue[uint16](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 } - case NotnullUint16PtrDispatchId: if isLittleEndian { - *(*uint16)(unsafe.Pointer(&data[bufOffset])) = **(**uint16)(fieldPtr) + *(*uint16)(unsafe.Pointer(&data[bufOffset])) = v } else { - binary.LittleEndian.PutUint16(data[bufOffset:], **(**uint16)(fieldPtr)) + binary.LittleEndian.PutUint16(data[bufOffset:], v) } case PrimitiveInt32DispatchId: - if isLittleEndian { - *(*int32)(unsafe.Pointer(&data[bufOffset])) = *(*int32)(fieldPtr) - } else { - binary.LittleEndian.PutUint32(data[bufOffset:], uint32(*(*int32)(fieldPtr))) + v, ok := loadFieldValue[int32](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 } - case NotnullInt32PtrDispatchId: if isLittleEndian { - *(*int32)(unsafe.Pointer(&data[bufOffset])) = **(**int32)(fieldPtr) + *(*int32)(unsafe.Pointer(&data[bufOffset])) = v } else { - binary.LittleEndian.PutUint32(data[bufOffset:], uint32(**(**int32)(fieldPtr))) + binary.LittleEndian.PutUint32(data[bufOffset:], uint32(v)) } case PrimitiveUint32DispatchId: - if isLittleEndian { - *(*uint32)(unsafe.Pointer(&data[bufOffset])) = *(*uint32)(fieldPtr) - } else { - binary.LittleEndian.PutUint32(data[bufOffset:], *(*uint32)(fieldPtr)) + v, ok := loadFieldValue[uint32](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 } - case NotnullUint32PtrDispatchId: if isLittleEndian { - *(*uint32)(unsafe.Pointer(&data[bufOffset])) = **(**uint32)(fieldPtr) + *(*uint32)(unsafe.Pointer(&data[bufOffset])) = v } else { - binary.LittleEndian.PutUint32(data[bufOffset:], **(**uint32)(fieldPtr)) + binary.LittleEndian.PutUint32(data[bufOffset:], v) } case PrimitiveInt64DispatchId: - if isLittleEndian { - *(*int64)(unsafe.Pointer(&data[bufOffset])) = *(*int64)(fieldPtr) - } else { - binary.LittleEndian.PutUint64(data[bufOffset:], uint64(*(*int64)(fieldPtr))) + v, ok := loadFieldValue[int64](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 } - case NotnullInt64PtrDispatchId: if isLittleEndian { - *(*int64)(unsafe.Pointer(&data[bufOffset])) = **(**int64)(fieldPtr) + *(*int64)(unsafe.Pointer(&data[bufOffset])) = v } else { - binary.LittleEndian.PutUint64(data[bufOffset:], uint64(**(**int64)(fieldPtr))) + binary.LittleEndian.PutUint64(data[bufOffset:], uint64(v)) } case PrimitiveUint64DispatchId: - if isLittleEndian { - *(*uint64)(unsafe.Pointer(&data[bufOffset])) = *(*uint64)(fieldPtr) - } else { - binary.LittleEndian.PutUint64(data[bufOffset:], *(*uint64)(fieldPtr)) + v, ok := loadFieldValue[uint64](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 } - case NotnullUint64PtrDispatchId: if isLittleEndian { - *(*uint64)(unsafe.Pointer(&data[bufOffset])) = **(**uint64)(fieldPtr) + *(*uint64)(unsafe.Pointer(&data[bufOffset])) = v } else { - binary.LittleEndian.PutUint64(data[bufOffset:], **(**uint64)(fieldPtr)) + binary.LittleEndian.PutUint64(data[bufOffset:], v) } case PrimitiveFloat32DispatchId: - if isLittleEndian { - *(*float32)(unsafe.Pointer(&data[bufOffset])) = *(*float32)(fieldPtr) - } else { - binary.LittleEndian.PutUint32(data[bufOffset:], math.Float32bits(*(*float32)(fieldPtr))) + v, ok := loadFieldValue[float32](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 } - case NotnullFloat32PtrDispatchId: if isLittleEndian { - *(*float32)(unsafe.Pointer(&data[bufOffset])) = **(**float32)(fieldPtr) + *(*float32)(unsafe.Pointer(&data[bufOffset])) = v } else { - binary.LittleEndian.PutUint32(data[bufOffset:], math.Float32bits(**(**float32)(fieldPtr))) + binary.LittleEndian.PutUint32(data[bufOffset:], math.Float32bits(v)) } case PrimitiveFloat64DispatchId: - if isLittleEndian { - *(*float64)(unsafe.Pointer(&data[bufOffset])) = *(*float64)(fieldPtr) - } else { - binary.LittleEndian.PutUint64(data[bufOffset:], math.Float64bits(*(*float64)(fieldPtr))) + v, ok := loadFieldValue[float64](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 } - case NotnullFloat64PtrDispatchId: if isLittleEndian { - *(*float64)(unsafe.Pointer(&data[bufOffset])) = **(**float64)(fieldPtr) + *(*float64)(unsafe.Pointer(&data[bufOffset])) = v } else { - binary.LittleEndian.PutUint64(data[bufOffset:], math.Float64bits(**(**float64)(fieldPtr))) + binary.LittleEndian.PutUint64(data[bufOffset:], math.Float64bits(v)) } } } @@ -1147,53 +1145,74 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { // Fallback to reflect-based access for unaddressable values for _, field := range s.fieldGroup.FixedFields { fieldValue := value.Field(field.Meta.FieldIndex) + val, ok := loadReflectFieldValue(&field, fieldValue) switch field.DispatchId { - // Primitive types (non-pointer) case PrimitiveBoolDispatchId: - buf.WriteBool(fieldValue.Bool()) + if ok { + buf.WriteBool(val.Bool()) + } else { + buf.WriteBool(false) + } case PrimitiveInt8DispatchId: - buf.WriteByte_(byte(fieldValue.Int())) + if ok { + buf.WriteByte_(byte(val.Int())) + } else { + buf.WriteByte_(0) + } case PrimitiveUint8DispatchId: - buf.WriteByte_(byte(fieldValue.Uint())) + if ok { + buf.WriteByte_(byte(val.Uint())) + } else { + buf.WriteByte_(0) + } case PrimitiveInt16DispatchId: - buf.WriteInt16(int16(fieldValue.Int())) + if ok { + buf.WriteInt16(int16(val.Int())) + } else { + buf.WriteInt16(0) + } case PrimitiveUint16DispatchId: - buf.WriteInt16(int16(fieldValue.Uint())) + if ok { + buf.WriteInt16(int16(val.Uint())) + } else { + buf.WriteInt16(0) + } case PrimitiveInt32DispatchId: - buf.WriteInt32(int32(fieldValue.Int())) + if ok { + buf.WriteInt32(int32(val.Int())) + } else { + buf.WriteInt32(0) + } case PrimitiveUint32DispatchId: - buf.WriteInt32(int32(fieldValue.Uint())) + if ok { + buf.WriteInt32(int32(val.Uint())) + } else { + buf.WriteInt32(0) + } case PrimitiveInt64DispatchId: - buf.WriteInt64(fieldValue.Int()) + if ok { + buf.WriteInt64(val.Int()) + } else { + buf.WriteInt64(0) + } case PrimitiveUint64DispatchId: - buf.WriteInt64(int64(fieldValue.Uint())) + if ok { + buf.WriteInt64(int64(val.Uint())) + } else { + buf.WriteInt64(0) + } case PrimitiveFloat32DispatchId: - buf.WriteFloat32(float32(fieldValue.Float())) + if ok { + buf.WriteFloat32(float32(val.Float())) + } else { + buf.WriteFloat32(0) + } case PrimitiveFloat64DispatchId: - buf.WriteFloat64(fieldValue.Float()) - // Notnull pointer types - dereference and write - case NotnullBoolPtrDispatchId: - buf.WriteBool(fieldValue.Elem().Bool()) - case NotnullInt8PtrDispatchId: - buf.WriteByte_(byte(fieldValue.Elem().Int())) - case NotnullUint8PtrDispatchId: - buf.WriteByte_(byte(fieldValue.Elem().Uint())) - case NotnullInt16PtrDispatchId: - buf.WriteInt16(int16(fieldValue.Elem().Int())) - case NotnullUint16PtrDispatchId: - buf.WriteInt16(int16(fieldValue.Elem().Uint())) - case NotnullInt32PtrDispatchId: - buf.WriteInt32(int32(fieldValue.Elem().Int())) - case NotnullUint32PtrDispatchId: - buf.WriteInt32(int32(fieldValue.Elem().Uint())) - case NotnullInt64PtrDispatchId: - buf.WriteInt64(fieldValue.Elem().Int()) - case NotnullUint64PtrDispatchId: - buf.WriteInt64(int64(fieldValue.Elem().Uint())) - case NotnullFloat32PtrDispatchId: - buf.WriteFloat32(float32(fieldValue.Elem().Float())) - case NotnullFloat64PtrDispatchId: - buf.WriteFloat64(fieldValue.Elem().Float()) + if ok { + buf.WriteFloat64(val.Float()) + } else { + buf.WriteFloat64(0) + } } } } @@ -1208,26 +1227,59 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { for _, field := range s.fieldGroup.PrimitiveVarintFields { fieldPtr := unsafe.Add(ptr, field.Offset) + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional && field.Meta != nil { + optInfo = field.Meta.OptionalInfo + } switch field.DispatchId { case PrimitiveVarint32DispatchId: - offset += buf.UnsafePutVarInt32(offset, *(*int32)(fieldPtr)) + v, ok := loadFieldValue[int32](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 + } + offset += buf.UnsafePutVarInt32(offset, v) case PrimitiveVarint64DispatchId: - offset += buf.UnsafePutVarInt64(offset, *(*int64)(fieldPtr)) + v, ok := loadFieldValue[int64](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 + } + offset += buf.UnsafePutVarInt64(offset, v) case PrimitiveIntDispatchId: - offset += buf.UnsafePutVarInt64(offset, int64(*(*int)(fieldPtr))) + v, ok := loadFieldValue[int](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 + } + offset += buf.UnsafePutVarInt64(offset, int64(v)) case PrimitiveVarUint32DispatchId: - offset += buf.UnsafePutVaruint32(offset, *(*uint32)(fieldPtr)) + v, ok := loadFieldValue[uint32](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 + } + offset += buf.UnsafePutVaruint32(offset, v) case PrimitiveVarUint64DispatchId: - offset += buf.UnsafePutVaruint64(offset, *(*uint64)(fieldPtr)) + v, ok := loadFieldValue[uint64](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 + } + offset += buf.UnsafePutVaruint64(offset, v) case PrimitiveUintDispatchId: - offset += buf.UnsafePutVaruint64(offset, uint64(*(*uint)(fieldPtr))) + v, ok := loadFieldValue[uint](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 + } + offset += buf.UnsafePutVaruint64(offset, uint64(v)) case PrimitiveTaggedInt64DispatchId: - offset += buf.UnsafePutTaggedInt64(offset, *(*int64)(fieldPtr)) + v, ok := loadFieldValue[int64](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 + } + offset += buf.UnsafePutTaggedInt64(offset, v) case PrimitiveTaggedUint64DispatchId: - offset += buf.UnsafePutTaggedUint64(offset, *(*uint64)(fieldPtr)) - default: - // Notnull pointer types (rare case - pointers with nullable=false tag) - offset += writeNotnullVarintPtrUnsafe(buf, offset, fieldPtr, field.DispatchId) + v, ok := loadFieldValue[uint64](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 + } + offset += buf.UnsafePutTaggedUint64(offset, v) } } // Update writer index ONCE after all varint fields @@ -1236,41 +1288,56 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { // Slow path for non-addressable values: use reflection for _, field := range s.fieldGroup.VarintFields { fieldValue := value.Field(field.Meta.FieldIndex) + val, ok := loadReflectFieldValue(&field, fieldValue) switch field.DispatchId { - // Primitive types (non-pointer) case PrimitiveVarint32DispatchId: - buf.WriteVarint32(int32(fieldValue.Int())) + if ok { + buf.WriteVarint32(int32(val.Int())) + } else { + buf.WriteVarint32(0) + } case PrimitiveVarint64DispatchId: - buf.WriteVarint64(fieldValue.Int()) + if ok { + buf.WriteVarint64(val.Int()) + } else { + buf.WriteVarint64(0) + } case PrimitiveIntDispatchId: - buf.WriteVarint64(fieldValue.Int()) + if ok { + buf.WriteVarint64(val.Int()) + } else { + buf.WriteVarint64(0) + } case PrimitiveVarUint32DispatchId: - buf.WriteVaruint32(uint32(fieldValue.Uint())) + if ok { + buf.WriteVaruint32(uint32(val.Uint())) + } else { + buf.WriteVaruint32(0) + } case PrimitiveVarUint64DispatchId: - buf.WriteVaruint64(fieldValue.Uint()) + if ok { + buf.WriteVaruint64(val.Uint()) + } else { + buf.WriteVaruint64(0) + } case PrimitiveUintDispatchId: - buf.WriteVaruint64(fieldValue.Uint()) + if ok { + buf.WriteVaruint64(val.Uint()) + } else { + buf.WriteVaruint64(0) + } case PrimitiveTaggedInt64DispatchId: - buf.WriteTaggedInt64(fieldValue.Int()) + if ok { + buf.WriteTaggedInt64(val.Int()) + } else { + buf.WriteTaggedInt64(0) + } case PrimitiveTaggedUint64DispatchId: - buf.WriteTaggedUint64(fieldValue.Uint()) - // Notnull pointer types - dereference and write - case NotnullVarint32PtrDispatchId: - buf.WriteVarint32(int32(fieldValue.Elem().Int())) - case NotnullVarint64PtrDispatchId: - buf.WriteVarint64(fieldValue.Elem().Int()) - case NotnullIntPtrDispatchId: - buf.WriteVarint64(fieldValue.Elem().Int()) - case NotnullVarUint32PtrDispatchId: - buf.WriteVaruint32(uint32(fieldValue.Elem().Uint())) - case NotnullVarUint64PtrDispatchId: - buf.WriteVaruint64(fieldValue.Elem().Uint()) - case NotnullUintPtrDispatchId: - buf.WriteVaruint64(fieldValue.Elem().Uint()) - case NotnullTaggedInt64PtrDispatchId: - buf.WriteTaggedInt64(fieldValue.Elem().Int()) - case NotnullTaggedUint64PtrDispatchId: - buf.WriteTaggedUint64(fieldValue.Elem().Uint()) + if ok { + buf.WriteTaggedUint64(val.Uint()) + } else { + buf.WriteTaggedUint64(0) + } } } } @@ -1288,7 +1355,7 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { // writeRemainingField writes a non-primitive field (string, slice, map, struct, enum) func (s *structSerializer) writeRemainingField(ctx *WriteContext, ptr unsafe.Pointer, field *FieldInfo, value reflect.Value) { buf := ctx.Buffer() - if field.Meta.IsOptional { + if field.Kind == FieldKindOptional { if ptr != nil { if writeOptionFast(ctx, field, unsafe.Add(ptr, field.Offset)) { return @@ -1309,7 +1376,7 @@ func (s *structSerializer) writeRemainingField(ctx *WriteContext, ptr unsafe.Poi switch field.DispatchId { case StringDispatchId: // Check isPtr first for better branch prediction - if !field.IsPtr { + if field.Kind != FieldKindPointer { // Non-pointer string: always non-null, no ref tracking needed in fast path if field.RefMode == RefModeNone { ctx.WriteString(*(*string)(fieldPtr)) @@ -1793,6 +1860,94 @@ func (s *structSerializer) writeRemainingField(ctx *WriteContext, ptr unsafe.Poi } } +func loadFieldValue[T any](kind FieldKind, fieldPtr unsafe.Pointer, opt optionalInfo) (T, bool) { + var zero T + switch kind { + case FieldKindPointer: + ptr := *(**T)(fieldPtr) + if ptr == nil { + return zero, false + } + return *ptr, true + case FieldKindOptional: + if !*(*bool)(unsafe.Add(fieldPtr, opt.hasOffset)) { + return zero, false + } + return *(*T)(unsafe.Add(fieldPtr, opt.valueOffset)), true + default: + return *(*T)(fieldPtr), true + } +} + +func storeFieldValue[T any](kind FieldKind, fieldPtr unsafe.Pointer, opt optionalInfo, value T) { + switch kind { + case FieldKindPointer: + ptr := *(**T)(fieldPtr) + if ptr == nil { + ptr = new(T) + *(**T)(fieldPtr) = ptr + } + *ptr = value + case FieldKindOptional: + *(*bool)(unsafe.Add(fieldPtr, opt.hasOffset)) = true + *(*T)(unsafe.Add(fieldPtr, opt.valueOffset)) = value + default: + *(*T)(fieldPtr) = value + } +} + +func clearFieldValue(kind FieldKind, fieldPtr unsafe.Pointer, opt optionalInfo) { + switch kind { + case FieldKindPointer: + *(*unsafe.Pointer)(fieldPtr) = nil + case FieldKindOptional: + *(*bool)(unsafe.Add(fieldPtr, opt.hasOffset)) = false + default: + } +} + +func loadReflectFieldValue(field *FieldInfo, fieldValue reflect.Value) (reflect.Value, bool) { + switch field.Kind { + case FieldKindPointer: + if fieldValue.IsNil() { + return reflect.Value{}, false + } + return fieldValue.Elem(), true + case FieldKindOptional: + if !fieldValue.FieldByName("Has").Bool() { + return reflect.Value{}, false + } + return fieldValue.FieldByName("Value"), true + default: + return fieldValue, true + } +} + +func storeReflectFieldValue(field *FieldInfo, fieldValue reflect.Value, value reflect.Value) { + switch field.Kind { + case FieldKindPointer: + ptr := reflect.New(value.Type()) + ptr.Elem().Set(value) + fieldValue.Set(ptr) + case FieldKindOptional: + fieldValue.FieldByName("Has").SetBool(true) + fieldValue.FieldByName("Value").Set(value) + default: + fieldValue.Set(value) + } +} + +func clearReflectFieldValue(field *FieldInfo, fieldValue reflect.Value) { + switch field.Kind { + case FieldKindPointer: + fieldValue.Set(reflect.Zero(fieldValue.Type())) + case FieldKindOptional: + fieldValue.FieldByName("Has").SetBool(false) + default: + fieldValue.Set(reflect.Zero(fieldValue.Type())) + } +} + func writeOptionFast(ctx *WriteContext, field *FieldInfo, optPtr unsafe.Pointer) bool { buf := ctx.Buffer() has := *(*bool)(unsafe.Add(optPtr, field.Meta.OptionalInfo.hasOffset)) @@ -2331,138 +2486,81 @@ func (s *structSerializer) ReadData(ctx *ReadContext, value reflect.Value) { for _, field := range s.fieldGroup.PrimitiveFixedFields { fieldPtr := unsafe.Add(ptr, field.Offset) bufOffset := baseOffset + int(field.WriteOffset) + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional && field.Meta != nil { + optInfo = field.Meta.OptionalInfo + } switch field.DispatchId { case PrimitiveBoolDispatchId: - *(*bool)(fieldPtr) = data[bufOffset] != 0 + storeFieldValue(field.Kind, fieldPtr, optInfo, data[bufOffset] != 0) case PrimitiveInt8DispatchId: - *(*int8)(fieldPtr) = int8(data[bufOffset]) + storeFieldValue(field.Kind, fieldPtr, optInfo, int8(data[bufOffset])) case PrimitiveUint8DispatchId: - *(*uint8)(fieldPtr) = data[bufOffset] + storeFieldValue(field.Kind, fieldPtr, optInfo, data[bufOffset]) case PrimitiveInt16DispatchId: + var v int16 if isLittleEndian { - *(*int16)(fieldPtr) = *(*int16)(unsafe.Pointer(&data[bufOffset])) + v = *(*int16)(unsafe.Pointer(&data[bufOffset])) } else { - *(*int16)(fieldPtr) = int16(binary.LittleEndian.Uint16(data[bufOffset:])) + v = int16(binary.LittleEndian.Uint16(data[bufOffset:])) } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) case PrimitiveUint16DispatchId: + var v uint16 if isLittleEndian { - *(*uint16)(fieldPtr) = *(*uint16)(unsafe.Pointer(&data[bufOffset])) + v = *(*uint16)(unsafe.Pointer(&data[bufOffset])) } else { - *(*uint16)(fieldPtr) = binary.LittleEndian.Uint16(data[bufOffset:]) + v = binary.LittleEndian.Uint16(data[bufOffset:]) } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) case PrimitiveInt32DispatchId: + var v int32 if isLittleEndian { - *(*int32)(fieldPtr) = *(*int32)(unsafe.Pointer(&data[bufOffset])) + v = *(*int32)(unsafe.Pointer(&data[bufOffset])) } else { - *(*int32)(fieldPtr) = int32(binary.LittleEndian.Uint32(data[bufOffset:])) + v = int32(binary.LittleEndian.Uint32(data[bufOffset:])) } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) case PrimitiveUint32DispatchId: + var v uint32 if isLittleEndian { - *(*uint32)(fieldPtr) = *(*uint32)(unsafe.Pointer(&data[bufOffset])) + v = *(*uint32)(unsafe.Pointer(&data[bufOffset])) } else { - *(*uint32)(fieldPtr) = binary.LittleEndian.Uint32(data[bufOffset:]) + v = binary.LittleEndian.Uint32(data[bufOffset:]) } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) case PrimitiveInt64DispatchId: + var v int64 if isLittleEndian { - *(*int64)(fieldPtr) = *(*int64)(unsafe.Pointer(&data[bufOffset])) + v = *(*int64)(unsafe.Pointer(&data[bufOffset])) } else { - *(*int64)(fieldPtr) = int64(binary.LittleEndian.Uint64(data[bufOffset:])) + v = int64(binary.LittleEndian.Uint64(data[bufOffset:])) } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) case PrimitiveUint64DispatchId: + var v uint64 if isLittleEndian { - *(*uint64)(fieldPtr) = *(*uint64)(unsafe.Pointer(&data[bufOffset])) + v = *(*uint64)(unsafe.Pointer(&data[bufOffset])) } else { - *(*uint64)(fieldPtr) = binary.LittleEndian.Uint64(data[bufOffset:]) + v = binary.LittleEndian.Uint64(data[bufOffset:]) } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) case PrimitiveFloat32DispatchId: + var v float32 if isLittleEndian { - *(*float32)(fieldPtr) = *(*float32)(unsafe.Pointer(&data[bufOffset])) + v = *(*float32)(unsafe.Pointer(&data[bufOffset])) } else { - *(*float32)(fieldPtr) = math.Float32frombits(binary.LittleEndian.Uint32(data[bufOffset:])) + v = math.Float32frombits(binary.LittleEndian.Uint32(data[bufOffset:])) } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) case PrimitiveFloat64DispatchId: + var v float64 if isLittleEndian { - *(*float64)(fieldPtr) = *(*float64)(unsafe.Pointer(&data[bufOffset])) - } else { - *(*float64)(fieldPtr) = math.Float64frombits(binary.LittleEndian.Uint64(data[bufOffset:])) - } - // Notnull pointer types - allocate and set pointer - case NotnullBoolPtrDispatchId: - v := new(bool) - *v = data[bufOffset] != 0 - *(**bool)(fieldPtr) = v - case NotnullInt8PtrDispatchId: - v := new(int8) - *v = int8(data[bufOffset]) - *(**int8)(fieldPtr) = v - case NotnullUint8PtrDispatchId: - v := new(uint8) - *v = data[bufOffset] - *(**uint8)(fieldPtr) = v - case NotnullInt16PtrDispatchId: - v := new(int16) - if isLittleEndian { - *v = *(*int16)(unsafe.Pointer(&data[bufOffset])) - } else { - *v = int16(binary.LittleEndian.Uint16(data[bufOffset:])) - } - *(**int16)(fieldPtr) = v - case NotnullUint16PtrDispatchId: - v := new(uint16) - if isLittleEndian { - *v = *(*uint16)(unsafe.Pointer(&data[bufOffset])) - } else { - *v = binary.LittleEndian.Uint16(data[bufOffset:]) - } - *(**uint16)(fieldPtr) = v - case NotnullInt32PtrDispatchId: - v := new(int32) - if isLittleEndian { - *v = *(*int32)(unsafe.Pointer(&data[bufOffset])) + v = *(*float64)(unsafe.Pointer(&data[bufOffset])) } else { - *v = int32(binary.LittleEndian.Uint32(data[bufOffset:])) + v = math.Float64frombits(binary.LittleEndian.Uint64(data[bufOffset:])) } - *(**int32)(fieldPtr) = v - case NotnullUint32PtrDispatchId: - v := new(uint32) - if isLittleEndian { - *v = *(*uint32)(unsafe.Pointer(&data[bufOffset])) - } else { - *v = binary.LittleEndian.Uint32(data[bufOffset:]) - } - *(**uint32)(fieldPtr) = v - case NotnullInt64PtrDispatchId: - v := new(int64) - if isLittleEndian { - *v = *(*int64)(unsafe.Pointer(&data[bufOffset])) - } else { - *v = int64(binary.LittleEndian.Uint64(data[bufOffset:])) - } - *(**int64)(fieldPtr) = v - case NotnullUint64PtrDispatchId: - v := new(uint64) - if isLittleEndian { - *v = *(*uint64)(unsafe.Pointer(&data[bufOffset])) - } else { - *v = binary.LittleEndian.Uint64(data[bufOffset:]) - } - *(**uint64)(fieldPtr) = v - case NotnullFloat32PtrDispatchId: - v := new(float32) - if isLittleEndian { - *v = *(*float32)(unsafe.Pointer(&data[bufOffset])) - } else { - *v = math.Float32frombits(binary.LittleEndian.Uint32(data[bufOffset:])) - } - *(**float32)(fieldPtr) = v - case NotnullFloat64PtrDispatchId: - v := new(float64) - if isLittleEndian { - *v = *(*float64)(unsafe.Pointer(&data[bufOffset])) - } else { - *v = math.Float64frombits(binary.LittleEndian.Uint64(data[bufOffset:])) - } - *(**float64)(fieldPtr) = v + storeFieldValue(field.Kind, fieldPtr, optInfo, v) } } // Update reader index ONCE after all fixed fields @@ -2475,58 +2573,29 @@ func (s *structSerializer) ReadData(ctx *ReadContext, value reflect.Value) { err := ctx.Err() for _, field := range s.fieldGroup.PrimitiveVarintFields { fieldPtr := unsafe.Add(ptr, field.Offset) + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional && field.Meta != nil { + optInfo = field.Meta.OptionalInfo + } switch field.DispatchId { case PrimitiveVarint32DispatchId: - *(*int32)(fieldPtr) = buf.ReadVarint32(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint32(err)) case PrimitiveVarint64DispatchId: - *(*int64)(fieldPtr) = buf.ReadVarint64(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint64(err)) case PrimitiveIntDispatchId: - *(*int)(fieldPtr) = int(buf.ReadVarint64(err)) + storeFieldValue(field.Kind, fieldPtr, optInfo, int(buf.ReadVarint64(err))) case PrimitiveVarUint32DispatchId: - *(*uint32)(fieldPtr) = buf.ReadVaruint32(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVaruint32(err)) case PrimitiveVarUint64DispatchId: - *(*uint64)(fieldPtr) = buf.ReadVaruint64(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVaruint64(err)) case PrimitiveUintDispatchId: - *(*uint)(fieldPtr) = uint(buf.ReadVaruint64(err)) + storeFieldValue(field.Kind, fieldPtr, optInfo, uint(buf.ReadVaruint64(err))) case PrimitiveTaggedInt64DispatchId: // Tagged INT64: use buffer's tagged decoding (4 bytes for small, 9 for large) - *(*int64)(fieldPtr) = buf.ReadTaggedInt64(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedInt64(err)) case PrimitiveTaggedUint64DispatchId: // Tagged UINT64: use buffer's tagged decoding (4 bytes for small, 9 for large) - *(*uint64)(fieldPtr) = buf.ReadTaggedUint64(err) - // Notnull pointer types - allocate and set pointer - case NotnullVarint32PtrDispatchId: - v := new(int32) - *v = buf.ReadVarint32(err) - *(**int32)(fieldPtr) = v - case NotnullVarint64PtrDispatchId: - v := new(int64) - *v = buf.ReadVarint64(err) - *(**int64)(fieldPtr) = v - case NotnullIntPtrDispatchId: - v := new(int) - *v = int(buf.ReadVarint64(err)) - *(**int)(fieldPtr) = v - case NotnullVarUint32PtrDispatchId: - v := new(uint32) - *v = buf.ReadVaruint32(err) - *(**uint32)(fieldPtr) = v - case NotnullVarUint64PtrDispatchId: - v := new(uint64) - *v = buf.ReadVaruint64(err) - *(**uint64)(fieldPtr) = v - case NotnullUintPtrDispatchId: - v := new(uint) - *v = uint(buf.ReadVaruint64(err)) - *(**uint)(fieldPtr) = v - case NotnullTaggedInt64PtrDispatchId: - v := new(int64) - *v = buf.ReadTaggedInt64(err) - *(**int64)(fieldPtr) = v - case NotnullTaggedUint64PtrDispatchId: - v := new(uint64) - *v = buf.ReadTaggedUint64(err) - *(**uint64)(fieldPtr) = v + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedUint64(err)) } } } @@ -2542,7 +2611,7 @@ func (s *structSerializer) ReadData(ctx *ReadContext, value reflect.Value) { func (s *structSerializer) readRemainingField(ctx *ReadContext, ptr unsafe.Pointer, field *FieldInfo, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() - if field.Meta.IsOptional { + if field.Kind == FieldKindOptional { if ptr != nil { if readOptionFast(ctx, field, unsafe.Add(ptr, field.Offset)) { return @@ -2563,7 +2632,7 @@ func (s *structSerializer) readRemainingField(ctx *ReadContext, ptr unsafe.Point switch field.DispatchId { case StringDispatchId: // Check isPtr first for better branch prediction - if !field.IsPtr { + if field.Kind != FieldKindPointer { // Non-pointer string: no ref tracking needed in fast path if field.RefMode == RefModeNone { *(*string)(fieldPtr) = ctx.ReadString() @@ -3198,7 +3267,7 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val s.skipField(ctx, field) return } - if field.Meta.IsOptional { + if field.Kind == FieldKindOptional { fieldValue := value.Field(field.Meta.FieldIndex) if field.Serializer != nil { field.Serializer.Read(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, fieldValue) @@ -3209,135 +3278,63 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val } // Fast path for fixed-size primitive types (no ref flag from remote schema) - if isFixedSizePrimitive(field.DispatchId, field.Meta.Nullable) { + if isFixedSizePrimitive(field.DispatchId) { fieldPtr := unsafe.Add(ptr, field.Offset) + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional { + optInfo = field.Meta.OptionalInfo + } switch field.DispatchId { - // PrimitiveXxxDispatchId: local field is non-pointer type case PrimitiveBoolDispatchId: - *(*bool)(fieldPtr) = buf.ReadBool(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadBool(err)) case PrimitiveInt8DispatchId: - *(*int8)(fieldPtr) = buf.ReadInt8(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt8(err)) case PrimitiveUint8DispatchId: - *(*uint8)(fieldPtr) = uint8(buf.ReadInt8(err)) + storeFieldValue(field.Kind, fieldPtr, optInfo, uint8(buf.ReadInt8(err))) case PrimitiveInt16DispatchId: - *(*int16)(fieldPtr) = buf.ReadInt16(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt16(err)) case PrimitiveUint16DispatchId: - *(*uint16)(fieldPtr) = buf.ReadUint16(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint16(err)) case PrimitiveInt32DispatchId: - *(*int32)(fieldPtr) = buf.ReadInt32(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt32(err)) case PrimitiveUint32DispatchId: - *(*uint32)(fieldPtr) = buf.ReadUint32(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint32(err)) case PrimitiveInt64DispatchId: - *(*int64)(fieldPtr) = buf.ReadInt64(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt64(err)) case PrimitiveUint64DispatchId: - *(*uint64)(fieldPtr) = buf.ReadUint64(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint64(err)) case PrimitiveFloat32DispatchId: - *(*float32)(fieldPtr) = buf.ReadFloat32(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadFloat32(err)) case PrimitiveFloat64DispatchId: - *(*float64)(fieldPtr) = buf.ReadFloat64(err) - // NotnullXxxPtrDispatchId: local field is *T with nullable=false - case NotnullBoolPtrDispatchId: - v := new(bool) - *v = buf.ReadBool(err) - *(**bool)(fieldPtr) = v - case NotnullInt8PtrDispatchId: - v := new(int8) - *v = buf.ReadInt8(err) - *(**int8)(fieldPtr) = v - case NotnullUint8PtrDispatchId: - v := new(uint8) - *v = uint8(buf.ReadInt8(err)) - *(**uint8)(fieldPtr) = v - case NotnullInt16PtrDispatchId: - v := new(int16) - *v = buf.ReadInt16(err) - *(**int16)(fieldPtr) = v - case NotnullUint16PtrDispatchId: - v := new(uint16) - *v = buf.ReadUint16(err) - *(**uint16)(fieldPtr) = v - case NotnullInt32PtrDispatchId: - v := new(int32) - *v = buf.ReadInt32(err) - *(**int32)(fieldPtr) = v - case NotnullUint32PtrDispatchId: - v := new(uint32) - *v = buf.ReadUint32(err) - *(**uint32)(fieldPtr) = v - case NotnullInt64PtrDispatchId: - v := new(int64) - *v = buf.ReadInt64(err) - *(**int64)(fieldPtr) = v - case NotnullUint64PtrDispatchId: - v := new(uint64) - *v = buf.ReadUint64(err) - *(**uint64)(fieldPtr) = v - case NotnullFloat32PtrDispatchId: - v := new(float32) - *v = buf.ReadFloat32(err) - *(**float32)(fieldPtr) = v - case NotnullFloat64PtrDispatchId: - v := new(float64) - *v = buf.ReadFloat64(err) - *(**float64)(fieldPtr) = v + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadFloat64(err)) } return } // Fast path for varint primitive types (no ref flag from remote schema) - if isVarintPrimitive(field.DispatchId, field.Meta.Nullable) && !fieldHasNonPrimitiveSerializer(field) { + if isVarintPrimitive(field.DispatchId) && !fieldHasNonPrimitiveSerializer(field) { fieldPtr := unsafe.Add(ptr, field.Offset) + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional { + optInfo = field.Meta.OptionalInfo + } switch field.DispatchId { - // PrimitiveXxxDispatchId: local field is non-pointer type case PrimitiveVarint32DispatchId: - *(*int32)(fieldPtr) = buf.ReadVarint32(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint32(err)) case PrimitiveVarint64DispatchId: - *(*int64)(fieldPtr) = buf.ReadVarint64(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint64(err)) case PrimitiveVarUint32DispatchId: - *(*uint32)(fieldPtr) = buf.ReadVaruint32(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVaruint32(err)) case PrimitiveVarUint64DispatchId: - *(*uint64)(fieldPtr) = buf.ReadVaruint64(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVaruint64(err)) case PrimitiveTaggedInt64DispatchId: - *(*int64)(fieldPtr) = buf.ReadTaggedInt64(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedInt64(err)) case PrimitiveTaggedUint64DispatchId: - *(*uint64)(fieldPtr) = buf.ReadTaggedUint64(err) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedUint64(err)) case PrimitiveIntDispatchId: - *(*int)(fieldPtr) = int(buf.ReadVarint64(err)) + storeFieldValue(field.Kind, fieldPtr, optInfo, int(buf.ReadVarint64(err))) case PrimitiveUintDispatchId: - *(*uint)(fieldPtr) = uint(buf.ReadVaruint64(err)) - // NotnullXxxPtrDispatchId: local field is *T with nullable=false - case NotnullVarint32PtrDispatchId: - v := new(int32) - *v = buf.ReadVarint32(err) - *(**int32)(fieldPtr) = v - case NotnullVarint64PtrDispatchId: - v := new(int64) - *v = buf.ReadVarint64(err) - *(**int64)(fieldPtr) = v - case NotnullVarUint32PtrDispatchId: - v := new(uint32) - *v = buf.ReadVaruint32(err) - *(**uint32)(fieldPtr) = v - case NotnullVarUint64PtrDispatchId: - v := new(uint64) - *v = buf.ReadVaruint64(err) - *(**uint64)(fieldPtr) = v - case NotnullTaggedInt64PtrDispatchId: - v := new(int64) - *v = buf.ReadTaggedInt64(err) - *(**int64)(fieldPtr) = v - case NotnullTaggedUint64PtrDispatchId: - v := new(uint64) - *v = buf.ReadTaggedUint64(err) - *(**uint64)(fieldPtr) = v - case NotnullIntPtrDispatchId: - v := new(int) - *v = int(buf.ReadVarint64(err)) - *(**int)(fieldPtr) = v - case NotnullUintPtrDispatchId: - v := new(uint) - *v = uint(buf.ReadVaruint64(err)) - *(**uint)(fieldPtr) = v + storeFieldValue(field.Kind, fieldPtr, optInfo, uint(buf.ReadVaruint64(err))) } return } @@ -3350,89 +3347,44 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val if isNullableFixedSizePrimitive(field.DispatchId) { refFlag := buf.ReadInt8(err) if refFlag == NullFlag { - // Leave pointer as nil (or zero for non-pointer local types) + clearReflectFieldValue(field, fieldValue) return } // Read fixed-size value based on dispatch ID - // Handle both pointer and non-pointer local field types (schema evolution) switch field.DispatchId { case NullableBoolDispatchId: v := buf.ReadBool(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetBool(v) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableInt8DispatchId: v := buf.ReadInt8(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetInt(int64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableUint8DispatchId: v := uint8(buf.ReadInt8(err)) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetUint(uint64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableInt16DispatchId: v := buf.ReadInt16(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetInt(int64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableUint16DispatchId: v := buf.ReadUint16(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetUint(uint64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableInt32DispatchId: v := buf.ReadInt32(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetInt(int64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableUint32DispatchId: v := buf.ReadUint32(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetUint(uint64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableInt64DispatchId: v := buf.ReadInt64(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetInt(v) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableUint64DispatchId: v := buf.ReadUint64(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetUint(v) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableFloat32DispatchId: v := buf.ReadFloat32(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetFloat(float64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableFloat64DispatchId: v := buf.ReadFloat64(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetFloat(v) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) } return } @@ -3441,68 +3393,35 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val if isNullableVarintPrimitive(field.DispatchId) { refFlag := buf.ReadInt8(err) if refFlag == NullFlag { - // Leave pointer as nil (or zero for non-pointer local types) + clearReflectFieldValue(field, fieldValue) return } // Read varint value based on dispatch ID - // Handle both pointer and non-pointer local field types (schema evolution) switch field.DispatchId { case NullableVarint32DispatchId: v := buf.ReadVarint32(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetInt(int64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableVarint64DispatchId: v := buf.ReadVarint64(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetInt(v) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableVarUint32DispatchId: v := buf.ReadVaruint32(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetUint(uint64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableVarUint64DispatchId: v := buf.ReadVaruint64(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetUint(v) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableTaggedInt64DispatchId: v := buf.ReadTaggedInt64(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetInt(v) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableTaggedUint64DispatchId: v := buf.ReadTaggedUint64(err) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetUint(v) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableIntDispatchId: v := int(buf.ReadVarint64(err)) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetInt(int64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) case NullableUintDispatchId: v := uint(buf.ReadVaruint64(err)) - if field.IsPtr { - fieldValue.Set(reflect.ValueOf(&v)) - } else { - fieldValue.SetUint(uint64(v)) - } + storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) } return } @@ -3574,7 +3493,7 @@ func (s *structSerializer) skipField(ctx *ReadContext, field *FieldInfo) { // This is important for compatible mode where remote TypeDef's nullable flag controls the wire format. func writeEnumField(ctx *WriteContext, field *FieldInfo, fieldValue reflect.Value) { buf := ctx.Buffer() - isPointer := field.IsPtr + isPointer := field.Kind == FieldKindPointer // Write null flag based on RefMode only (not based on whether local type is pointer) if field.RefMode != RefModeNone { @@ -3612,7 +3531,7 @@ func writeEnumField(ctx *WriteContext, field *FieldInfo, fieldValue reflect.Valu // Uses context error state for deferred error checking. func readEnumField(ctx *ReadContext, field *FieldInfo, fieldValue reflect.Value) { buf := ctx.Buffer() - isPointer := field.IsPtr + isPointer := field.Kind == FieldKindPointer // Read null flag based on RefMode only (not based on whether local type is pointer) if field.RefMode != RefModeNone { diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go index c6c0a48388..82d680b84f 100644 --- a/go/fory/struct_test.go +++ b/go/fory/struct_test.go @@ -25,6 +25,10 @@ import ( "github.com/stretchr/testify/require" ) +func ptr[T any](v T) *T { + return &v +} + func TestUnsignedTypeSerialization(t *testing.T) { type TestStruct struct { U32Var uint32 `fory:"compress=true"` @@ -133,6 +137,167 @@ func TestOptionFieldUnsupportedTypes(t *testing.T) { require.Error(t, f.RegisterStruct(OptionMap{}, 1103)) } +func TestNumericPointerOptionalInterop(t *testing.T) { + type NumericPtrStruct struct { + I8 *int8 + I16 *int16 + I32 *int32 + I64 *int64 + I *int + U8 *uint8 + U16 *uint16 + U32 *uint32 + U64 *uint64 + U *uint + F32 *float32 + F64 *float64 + } + type NumericOptStruct struct { + I8 optional.Optional[int8] + I16 optional.Optional[int16] + I32 optional.Optional[int32] + I64 optional.Optional[int64] + I optional.Optional[int] + U8 optional.Optional[uint8] + U16 optional.Optional[uint16] + U32 optional.Optional[uint32] + U64 optional.Optional[uint64] + U optional.Optional[uint] + F32 optional.Optional[float32] + F64 optional.Optional[float64] + } + + ptrValues := NumericPtrStruct{ + I8: ptr(int8(-8)), + I16: ptr(int16(-16)), + I32: ptr(int32(-32)), + I64: ptr(int64(-64)), + I: ptr(int(-7)), + U8: ptr(uint8(8)), + U16: ptr(uint16(16)), + U32: ptr(uint32(32)), + U64: ptr(uint64(64)), + U: ptr(uint(7)), + F32: ptr(float32(3.25)), + F64: ptr(float64(-6.5)), + } + optValues := NumericOptStruct{ + I8: optional.Some[int8](-8), + I16: optional.Some[int16](-16), + I32: optional.Some[int32](-32), + I64: optional.Some[int64](-64), + I: optional.Some[int](-7), + U8: optional.Some[uint8](8), + U16: optional.Some[uint16](16), + U32: optional.Some[uint32](32), + U64: optional.Some[uint64](64), + U: optional.Some[uint](7), + F32: optional.Some[float32](3.25), + F64: optional.Some[float64](-6.5), + } + + t.Run("PointerToOptionalNull", func(t *testing.T) { + writer := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, writer.RegisterNamedStruct(NumericPtrStruct{}, "NumericInterop")) + data, err := writer.Marshal(NumericPtrStruct{}) + require.NoError(t, err) + + reader := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, reader.RegisterNamedStruct(NumericOptStruct{}, "NumericInterop")) + var out NumericOptStruct + require.NoError(t, reader.Unmarshal(data, &out)) + + require.False(t, out.I8.Has) + require.False(t, out.I16.Has) + require.False(t, out.I32.Has) + require.False(t, out.I64.Has) + require.False(t, out.I.Has) + require.False(t, out.U8.Has) + require.False(t, out.U16.Has) + require.False(t, out.U32.Has) + require.False(t, out.U64.Has) + require.False(t, out.U.Has) + require.False(t, out.F32.Has) + require.False(t, out.F64.Has) + }) + + t.Run("PointerToOptionalValue", func(t *testing.T) { + writer := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, writer.RegisterNamedStruct(NumericPtrStruct{}, "NumericInterop")) + data, err := writer.Marshal(ptrValues) + require.NoError(t, err) + + reader := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, reader.RegisterNamedStruct(NumericOptStruct{}, "NumericInterop")) + var out NumericOptStruct + require.NoError(t, reader.Unmarshal(data, &out)) + + require.Equal(t, optValues, out) + }) + + t.Run("OptionalToPointerNull", func(t *testing.T) { + writer := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, writer.RegisterNamedStruct(NumericOptStruct{}, "NumericInterop")) + data, err := writer.Marshal(NumericOptStruct{}) + require.NoError(t, err) + + reader := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, reader.RegisterNamedStruct(NumericPtrStruct{}, "NumericInterop")) + var out NumericPtrStruct + require.NoError(t, reader.Unmarshal(data, &out)) + + require.Nil(t, out.I8) + require.Nil(t, out.I16) + require.Nil(t, out.I32) + require.Nil(t, out.I64) + require.Nil(t, out.I) + require.Nil(t, out.U8) + require.Nil(t, out.U16) + require.Nil(t, out.U32) + require.Nil(t, out.U64) + require.Nil(t, out.U) + require.Nil(t, out.F32) + require.Nil(t, out.F64) + }) + + t.Run("OptionalToPointerValue", func(t *testing.T) { + writer := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, writer.RegisterNamedStruct(NumericOptStruct{}, "NumericInterop")) + data, err := writer.Marshal(optValues) + require.NoError(t, err) + + reader := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, reader.RegisterNamedStruct(NumericPtrStruct{}, "NumericInterop")) + var out NumericPtrStruct + require.NoError(t, reader.Unmarshal(data, &out)) + + require.NotNil(t, out.I8) + require.Equal(t, *ptrValues.I8, *out.I8) + require.NotNil(t, out.I16) + require.Equal(t, *ptrValues.I16, *out.I16) + require.NotNil(t, out.I32) + require.Equal(t, *ptrValues.I32, *out.I32) + require.NotNil(t, out.I64) + require.Equal(t, *ptrValues.I64, *out.I64) + require.NotNil(t, out.I) + require.Equal(t, *ptrValues.I, *out.I) + require.NotNil(t, out.U8) + require.Equal(t, *ptrValues.U8, *out.U8) + require.NotNil(t, out.U16) + require.Equal(t, *ptrValues.U16, *out.U16) + require.NotNil(t, out.U32) + require.Equal(t, *ptrValues.U32, *out.U32) + require.NotNil(t, out.U64) + require.Equal(t, *ptrValues.U64, *out.U64) + require.NotNil(t, out.U) + require.Equal(t, *ptrValues.U, *out.U) + require.NotNil(t, out.F32) + require.Equal(t, *ptrValues.F32, *out.F32) + require.NotNil(t, out.F64) + require.Equal(t, *ptrValues.F64, *out.F64) + }) +} + // Test struct for compatible mode tests (must be named struct at package level) type SetFieldsStruct struct { SetField Set[string] diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index e8e5e2753c..ab7e86cc8d 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -85,6 +85,7 @@ var ( uint16Type = reflect.TypeOf((*uint16)(nil)).Elem() uint32Type = reflect.TypeOf((*uint32)(nil)).Elem() uint64Type = reflect.TypeOf((*uint64)(nil)).Elem() + uintType = reflect.TypeOf((*uint)(nil)).Elem() int8Type = reflect.TypeOf((*int8)(nil)).Elem() int16Type = reflect.TypeOf((*int16)(nil)).Elem() int32Type = reflect.TypeOf((*int32)(nil)).Elem() @@ -354,6 +355,7 @@ func (r *TypeResolver) initialize() { {uint16Type, UINT16, uint16Serializer{}}, {uint32Type, VAR_UINT32, uint32Serializer{}}, {uint64Type, VAR_UINT64, uint64Serializer{}}, + {uintType, VAR_UINT64, uintSerializer{}}, {int8Type, INT8, int8Serializer{}}, {int16Type, INT16, int16Serializer{}}, {int32Type, VARINT32, int32Serializer{}}, diff --git a/go/fory/types.go b/go/fory/types.go index 715f8813d3..6555222eac 100644 --- a/go/fory/types.go +++ b/go/fory/types.go @@ -336,28 +336,6 @@ const ( NullableIntDispatchId // Go-specific: *int NullableUintDispatchId // Go-specific: *uint - // ========== NOTNULL POINTER DISPATCH IDs ========== - // Pointer types with nullable=false - write without null flag - NotnullBoolPtrDispatchId - NotnullInt8PtrDispatchId - NotnullInt16PtrDispatchId - NotnullInt32PtrDispatchId - NotnullVarint32PtrDispatchId - NotnullInt64PtrDispatchId - NotnullVarint64PtrDispatchId - NotnullTaggedInt64PtrDispatchId - NotnullFloat32PtrDispatchId - NotnullFloat64PtrDispatchId - NotnullUint8PtrDispatchId - NotnullUint16PtrDispatchId - NotnullUint32PtrDispatchId - NotnullVarUint32PtrDispatchId - NotnullUint64PtrDispatchId - NotnullVarUint64PtrDispatchId - NotnullTaggedUint64PtrDispatchId - NotnullIntPtrDispatchId - NotnullUintPtrDispatchId - // String dispatch ID StringDispatchId @@ -493,21 +471,15 @@ func IsPrimitiveTypeId(typeId TypeId) bool { } } -// isFixedSizePrimitive returns true for fixed-size primitives and notnull pointer types. +// isFixedSizePrimitive returns true for fixed-size primitives. // Includes INT32/UINT32/INT64/UINT64 (fixed encoding), NOT VARINT32/VAR_UINT32 etc. -func isFixedSizePrimitive(dispatchId DispatchId, referencable bool) bool { +func isFixedSizePrimitive(dispatchId DispatchId) bool { switch dispatchId { case PrimitiveBoolDispatchId, PrimitiveInt8DispatchId, PrimitiveUint8DispatchId, PrimitiveInt16DispatchId, PrimitiveUint16DispatchId, PrimitiveInt32DispatchId, PrimitiveUint32DispatchId, PrimitiveInt64DispatchId, PrimitiveUint64DispatchId, PrimitiveFloat32DispatchId, PrimitiveFloat64DispatchId: - return !referencable - case NotnullBoolPtrDispatchId, NotnullInt8PtrDispatchId, NotnullUint8PtrDispatchId, - NotnullInt16PtrDispatchId, NotnullUint16PtrDispatchId, - NotnullInt32PtrDispatchId, NotnullUint32PtrDispatchId, - NotnullInt64PtrDispatchId, NotnullUint64PtrDispatchId, - NotnullFloat32PtrDispatchId, NotnullFloat64PtrDispatchId: return true default: return false @@ -543,19 +515,14 @@ func isNullableVarintPrimitive(dispatchId DispatchId) bool { } } -// isVarintPrimitive returns true for varint primitives and notnull pointer types. +// isVarintPrimitive returns true for varint primitives. // Includes VARINT32/VAR_UINT32/VARINT64/VAR_UINT64 (variable encoding), NOT INT32/UINT32 etc. -func isVarintPrimitive(dispatchId DispatchId, referencable bool) bool { +func isVarintPrimitive(dispatchId DispatchId) bool { switch dispatchId { case PrimitiveVarint32DispatchId, PrimitiveVarint64DispatchId, PrimitiveVarUint32DispatchId, PrimitiveVarUint64DispatchId, PrimitiveTaggedInt64DispatchId, PrimitiveTaggedUint64DispatchId, PrimitiveIntDispatchId, PrimitiveUintDispatchId: - return !referencable - case NotnullVarint32PtrDispatchId, NotnullVarint64PtrDispatchId, - NotnullVarUint32PtrDispatchId, NotnullVarUint64PtrDispatchId, - NotnullTaggedInt64PtrDispatchId, NotnullTaggedUint64PtrDispatchId, - NotnullIntPtrDispatchId, NotnullUintPtrDispatchId: return true default: return false @@ -575,24 +542,6 @@ func isPrimitiveDispatchId(dispatchId DispatchId) bool { } } -// isNotnullPtrDispatchId returns true if the dispatchId represents a notnull pointer type -func isNotnullPtrDispatchId(dispatchId DispatchId) bool { - switch dispatchId { - case NotnullBoolPtrDispatchId, NotnullInt8PtrDispatchId, NotnullUint8PtrDispatchId, - NotnullInt16PtrDispatchId, NotnullUint16PtrDispatchId, - NotnullInt32PtrDispatchId, NotnullUint32PtrDispatchId, - NotnullInt64PtrDispatchId, NotnullUint64PtrDispatchId, - NotnullFloat32PtrDispatchId, NotnullFloat64PtrDispatchId, - NotnullVarint32PtrDispatchId, NotnullVarint64PtrDispatchId, - NotnullVarUint32PtrDispatchId, NotnullVarUint64PtrDispatchId, - NotnullTaggedInt64PtrDispatchId, NotnullTaggedUint64PtrDispatchId, - NotnullIntPtrDispatchId, NotnullUintPtrDispatchId: - return true - default: - return false - } -} - // isNumericKind returns true for numeric types (Go enums are typically int-based) func isNumericKind(kind reflect.Kind) bool { switch kind { @@ -718,17 +667,13 @@ func IsNullablePrimitiveDispatchId(id DispatchId) bool { // getFixedSizeByDispatchId returns byte size for fixed primitives (0 if not fixed) func getFixedSizeByDispatchId(dispatchId DispatchId) int { switch dispatchId { - case PrimitiveBoolDispatchId, PrimitiveInt8DispatchId, PrimitiveUint8DispatchId, - NotnullBoolPtrDispatchId, NotnullInt8PtrDispatchId, NotnullUint8PtrDispatchId: + case PrimitiveBoolDispatchId, PrimitiveInt8DispatchId, PrimitiveUint8DispatchId: return 1 - case PrimitiveInt16DispatchId, PrimitiveUint16DispatchId, - NotnullInt16PtrDispatchId, NotnullUint16PtrDispatchId: + case PrimitiveInt16DispatchId, PrimitiveUint16DispatchId: return 2 - case PrimitiveInt32DispatchId, PrimitiveUint32DispatchId, PrimitiveFloat32DispatchId, - NotnullInt32PtrDispatchId, NotnullUint32PtrDispatchId, NotnullFloat32PtrDispatchId: + case PrimitiveInt32DispatchId, PrimitiveUint32DispatchId, PrimitiveFloat32DispatchId: return 4 - case PrimitiveInt64DispatchId, PrimitiveUint64DispatchId, PrimitiveFloat64DispatchId, - NotnullInt64PtrDispatchId, NotnullUint64PtrDispatchId, NotnullFloat64PtrDispatchId: + case PrimitiveInt64DispatchId, PrimitiveUint64DispatchId, PrimitiveFloat64DispatchId: return 8 default: return 0 @@ -738,14 +683,11 @@ func getFixedSizeByDispatchId(dispatchId DispatchId) int { // getVarintMaxSizeByDispatchId returns max byte size for varint primitives (0 if not varint) func getVarintMaxSizeByDispatchId(dispatchId DispatchId) int { switch dispatchId { - case PrimitiveVarint32DispatchId, PrimitiveVarUint32DispatchId, - NotnullVarint32PtrDispatchId, NotnullVarUint32PtrDispatchId: + case PrimitiveVarint32DispatchId, PrimitiveVarUint32DispatchId: return 5 - case PrimitiveVarint64DispatchId, PrimitiveVarUint64DispatchId, PrimitiveIntDispatchId, PrimitiveUintDispatchId, - NotnullVarint64PtrDispatchId, NotnullVarUint64PtrDispatchId, NotnullIntPtrDispatchId, NotnullUintPtrDispatchId: + case PrimitiveVarint64DispatchId, PrimitiveVarUint64DispatchId, PrimitiveIntDispatchId, PrimitiveUintDispatchId: return 10 - case PrimitiveTaggedInt64DispatchId, PrimitiveTaggedUint64DispatchId, - NotnullTaggedInt64PtrDispatchId, NotnullTaggedUint64PtrDispatchId: + case PrimitiveTaggedInt64DispatchId, PrimitiveTaggedUint64DispatchId: return 9 default: return 0 @@ -767,58 +709,6 @@ func getEncodingFromTypeId(typeId TypeId) string { } } -// getNotnullPtrDispatchId returns the NotnullXxxPtrDispatchId for a pointer-to-numeric type. -// elemKind is the kind of the element type (e.g., reflect.Uint8 for *uint8). -// encoding specifies the encoding type (fixed, varint, tagged) for int32/int64/uint32/uint64. -func getNotnullPtrDispatchId(elemKind reflect.Kind, encoding string) DispatchId { - switch elemKind { - case reflect.Bool: - return NotnullBoolPtrDispatchId - case reflect.Int8: - return NotnullInt8PtrDispatchId - case reflect.Int16: - return NotnullInt16PtrDispatchId - case reflect.Int32: - if encoding == "fixed" { - return NotnullInt32PtrDispatchId - } - return NotnullVarint32PtrDispatchId - case reflect.Int64: - if encoding == "fixed" { - return NotnullInt64PtrDispatchId - } else if encoding == "tagged" { - return NotnullTaggedInt64PtrDispatchId - } - return NotnullVarint64PtrDispatchId - case reflect.Int: - return NotnullIntPtrDispatchId - case reflect.Uint8: - return NotnullUint8PtrDispatchId - case reflect.Uint16: - return NotnullUint16PtrDispatchId - case reflect.Uint32: - if encoding == "fixed" { - return NotnullUint32PtrDispatchId - } - return NotnullVarUint32PtrDispatchId - case reflect.Uint64: - if encoding == "fixed" { - return NotnullUint64PtrDispatchId - } else if encoding == "tagged" { - return NotnullTaggedUint64PtrDispatchId - } - return NotnullVarUint64PtrDispatchId - case reflect.Uint: - return NotnullUintPtrDispatchId - case reflect.Float32: - return NotnullFloat32PtrDispatchId - case reflect.Float64: - return NotnullFloat64PtrDispatchId - default: - return UnknownDispatchId - } -} - // isPrimitiveFixedDispatchId returns true if the dispatch ID is for a non-nullable fixed-size primitive. // Note: int32/int64/uint32/uint64 are NOT included here because they default to varint encoding. // Only types that are always fixed-size are included (bool, int8/uint8, int16/uint16, float32/float64). From bd09956e5aaabbef0b4fbab50c8779236c909a8f Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 17:22:52 +0800 Subject: [PATCH 07/21] refactor optional --- go/fory/codegen/decoder.go | 30 ++++--- go/fory/codegen/encoder.go | 15 ++-- go/fory/codegen/generator.go | 14 ++++ go/fory/codegen/utils.go | 19 ++--- go/fory/optional/optional.go | 148 +++++++-------------------------- go/fory/optional_serializer.go | 87 ++++++++++++++----- go/fory/struct_test.go | 40 ++++----- 7 files changed, 169 insertions(+), 184 deletions(-) diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 2327588a8c..ccd208b3f8 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -249,6 +249,7 @@ func generateOptionReadTyped(buf *bytes.Buffer, field *FieldInfo, fieldAccess st } fmt.Fprintf(buf, "\t{\n") if isReferencableType(elemType) { + fmt.Fprintf(buf, "\t\tvar optValue %s\n", elemType.String()) fmt.Fprintf(buf, "\t\tif ctx.TrackRef() {\n") fmt.Fprintf(buf, "\t\t\trefID, refErr := ctx.RefResolver().TryPreserveRefId(buf)\n") fmt.Fprintf(buf, "\t\t\tif refErr != nil {\n") @@ -256,39 +257,40 @@ func generateOptionReadTyped(buf *bytes.Buffer, field *FieldInfo, fieldAccess st fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif refID < int32(fory.NotNullValueFlag) {\n") fmt.Fprintf(buf, "\t\t\t\tif refID == int32(fory.NullFlag) {\n") - fmt.Fprintf(buf, "\t\t\t\t\t%s.Has = false\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t\t\t%s = optional.None[%s]()\n", fieldAccess, elemType.String()) fmt.Fprintf(buf, "\t\t\t\t\treturn nil\n") fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tobj := ctx.RefResolver().GetReadObject(refID)\n") fmt.Fprintf(buf, "\t\t\t\tif obj.IsValid() {\n") - fmt.Fprintf(buf, "\t\t\t\t\ttarget := reflect.ValueOf(&%s.Value).Elem()\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t\t\ttarget := reflect.ValueOf(&optValue).Elem()\n") fmt.Fprintf(buf, "\t\t\t\t\tif obj.Type().AssignableTo(target.Type()) {\n") fmt.Fprintf(buf, "\t\t\t\t\t\ttarget.Set(obj)\n") - fmt.Fprintf(buf, "\t\t\t\t\t\t%s.Has = true\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t\t\t\t%s = optional.Some(optValue)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t\t\t\treturn nil\n") fmt.Fprintf(buf, "\t\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\t%s.Has = false\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t\t%s = optional.None[%s]()\n", fieldAccess, elemType.String()) fmt.Fprintf(buf, "\t\t\t\treturn nil\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t%s.Has = true\n", fieldAccess) - if err := generateOptionValueRead(buf, elemType, fmt.Sprintf("%s.Value", fieldAccess)); err != nil { + if err := generateOptionValueRead(buf, elemType, "optValue"); err != nil { return err } fmt.Fprintf(buf, "\t\t\tif refID >= 0 {\n") - fmt.Fprintf(buf, "\t\t\t\tctx.RefResolver().SetReadObject(refID, reflect.ValueOf(%s.Value))\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t\tctx.RefResolver().SetReadObject(refID, reflect.ValueOf(optValue))\n") fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t%s = optional.Some(optValue)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\treturn nil\n") fmt.Fprintf(buf, "\t\t}\n") } fmt.Fprintf(buf, "\t\tflag := buf.ReadInt8(err)\n") fmt.Fprintf(buf, "\t\tif flag == fory.NullFlag {\n") - fmt.Fprintf(buf, "\t\t\t%s.Has = false\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\t%s = optional.None[%s]()\n", fieldAccess, elemType.String()) fmt.Fprintf(buf, "\t\t} else {\n") - fmt.Fprintf(buf, "\t\t\t%s.Has = true\n", fieldAccess) - if err := generateOptionValueRead(buf, elemType, fmt.Sprintf("%s.Value", fieldAccess)); err != nil { + fmt.Fprintf(buf, "\t\t\tvar optValue %s\n", elemType.String()) + if err := generateOptionValueRead(buf, elemType, "optValue"); err != nil { return err } + fmt.Fprintf(buf, "\t\t\t%s = optional.Some(optValue)\n", fieldAccess) fmt.Fprintf(buf, "\t\t}\n") fmt.Fprintf(buf, "\t}\n") return nil @@ -319,7 +321,9 @@ func generateOptionValueRead(buf *bytes.Buffer, elemType types.Type, valueExpr s fmt.Fprintf(buf, "\t\t\t%s = buf.ReadInt16(err)\n", valueExpr) case types.Int32: fmt.Fprintf(buf, "\t\t\t%s = buf.ReadVarint32(err)\n", valueExpr) - case types.Int, types.Int64: + case types.Int: + fmt.Fprintf(buf, "\t\t\t%s = int(buf.ReadVarint64(err))\n", valueExpr) + case types.Int64: fmt.Fprintf(buf, "\t\t\t%s = buf.ReadVarint64(err)\n", valueExpr) case types.Uint8: fmt.Fprintf(buf, "\t\t\t%s = buf.ReadByte(err)\n", valueExpr) @@ -327,7 +331,9 @@ func generateOptionValueRead(buf *bytes.Buffer, elemType types.Type, valueExpr s fmt.Fprintf(buf, "\t\t\t%s = uint16(buf.ReadInt16(err))\n", valueExpr) case types.Uint32: fmt.Fprintf(buf, "\t\t\t%s = uint32(buf.ReadInt32(err))\n", valueExpr) - case types.Uint, types.Uint64: + case types.Uint: + fmt.Fprintf(buf, "\t\t\t%s = uint(buf.ReadInt64(err))\n", valueExpr) + case types.Uint64: fmt.Fprintf(buf, "\t\t\t%s = uint64(buf.ReadInt64(err))\n", valueExpr) case types.Float32: fmt.Fprintf(buf, "\t\t\t%s = buf.ReadFloat32(err)\n", valueExpr) diff --git a/go/fory/codegen/encoder.go b/go/fory/codegen/encoder.go index 537d2569bd..efb2d8ecae 100644 --- a/go/fory/codegen/encoder.go +++ b/go/fory/codegen/encoder.go @@ -233,12 +233,13 @@ func generateOptionWriteTyped(buf *bytes.Buffer, field *FieldInfo, fieldAccess s fmt.Fprintf(buf, "\tctx.WriteValue(reflect.ValueOf(%s), fory.RefModeTracking, true)\n", fieldAccess) return nil } - fmt.Fprintf(buf, "\tif !%s.Has {\n", fieldAccess) + fmt.Fprintf(buf, "\tif !%s.IsSome() {\n", fieldAccess) fmt.Fprintf(buf, "\t\tbuf.WriteInt8(fory.NullFlag)\n") fmt.Fprintf(buf, "\t} else {\n") + fmt.Fprintf(buf, "\t\toptValue := %s.Unwrap()\n", fieldAccess) if isReferencableType(elemType) { fmt.Fprintf(buf, "\t\tif ctx.TrackRef() {\n") - fmt.Fprintf(buf, "\t\t\trefWritten, err := ctx.RefResolver().WriteRefOrNull(buf, reflect.ValueOf(%s.Value))\n", fieldAccess) + fmt.Fprintf(buf, "\t\t\trefWritten, err := ctx.RefResolver().WriteRefOrNull(buf, reflect.ValueOf(optValue))\n") fmt.Fprintf(buf, "\t\t\tif err != nil {\n") fmt.Fprintf(buf, "\t\t\t\treturn err\n") fmt.Fprintf(buf, "\t\t\t}\n") @@ -251,7 +252,7 @@ func generateOptionWriteTyped(buf *bytes.Buffer, field *FieldInfo, fieldAccess s } else { fmt.Fprintf(buf, "\t\tbuf.WriteInt8(fory.NotNullValueFlag)\n") } - if err := generateOptionValueWrite(buf, elemType, fmt.Sprintf("%s.Value", fieldAccess)); err != nil { + if err := generateOptionValueWrite(buf, elemType, "optValue"); err != nil { return err } fmt.Fprintf(buf, "\t}\n") @@ -283,7 +284,9 @@ func generateOptionValueWrite(buf *bytes.Buffer, elemType types.Type, valueExpr fmt.Fprintf(buf, "\t\tbuf.WriteInt16(%s)\n", valueExpr) case types.Int32: fmt.Fprintf(buf, "\t\tbuf.WriteVarint32(%s)\n", valueExpr) - case types.Int, types.Int64: + case types.Int: + fmt.Fprintf(buf, "\t\tbuf.WriteVarint64(int64(%s))\n", valueExpr) + case types.Int64: fmt.Fprintf(buf, "\t\tbuf.WriteVarint64(%s)\n", valueExpr) case types.Uint8: fmt.Fprintf(buf, "\t\tbuf.WriteByte_(%s)\n", valueExpr) @@ -291,7 +294,9 @@ func generateOptionValueWrite(buf *bytes.Buffer, elemType types.Type, valueExpr fmt.Fprintf(buf, "\t\tbuf.WriteInt16(int16(%s))\n", valueExpr) case types.Uint32: fmt.Fprintf(buf, "\t\tbuf.WriteInt32(int32(%s))\n", valueExpr) - case types.Uint, types.Uint64: + case types.Uint: + fmt.Fprintf(buf, "\t\tbuf.WriteInt64(int64(%s))\n", valueExpr) + case types.Uint64: fmt.Fprintf(buf, "\t\tbuf.WriteInt64(int64(%s))\n", valueExpr) case types.Float32: fmt.Fprintf(buf, "\t\tbuf.WriteFloat32(%s)\n", valueExpr) diff --git a/go/fory/codegen/generator.go b/go/fory/codegen/generator.go index 65d0cf22fd..a6077d9211 100644 --- a/go/fory/codegen/generator.go +++ b/go/fory/codegen/generator.go @@ -283,6 +283,7 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil // Determine which imports are needed needsTime := false needsReflect := false + needsOptional := false for _, s := range structs { for _, field := range s.Fields { @@ -290,6 +291,9 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil if typeStr == "time.Time" || typeStr == "github.com/apache/fory/go/fory.Date" { needsTime = true } + if field.IsOptional { + needsOptional = true + } // We need reflect for the interface compatibility methods needsReflect = true } @@ -305,6 +309,9 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil fmt.Fprintf(&buf, "\t\"time\"\n") } fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory\"\n") + if needsOptional { + fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory/optional\"\n") + } fmt.Fprintf(&buf, ")\n\n") // Generate init function to register serializer factories @@ -536,6 +543,7 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { // Determine which imports are needed needsTime := false needsReflect := false + needsOptional := false for _, s := range structs { for _, field := range s.Fields { @@ -543,6 +551,9 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { if typeStr == "time.Time" || typeStr == "github.com/apache/fory/go/fory.Date" { needsTime = true } + if field.IsOptional { + needsOptional = true + } // We need reflect for the interface compatibility methods needsReflect = true } @@ -558,6 +569,9 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { fmt.Fprintf(&buf, "\t\"time\"\n") } fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory\"\n") + if needsOptional { + fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory/optional\"\n") + } fmt.Fprintf(&buf, ")\n\n") // Generate init function to register serializer factories diff --git a/go/fory/codegen/utils.go b/go/fory/codegen/utils.go index 82fc4e2be0..7062ee1bc3 100644 --- a/go/fory/codegen/utils.go +++ b/go/fory/codegen/utils.go @@ -520,17 +520,16 @@ func analyzeField(field *types.Var, index int) (*FieldInfo, error) { optionalElem, isOptional := getOptionalElementType(fieldType) if isOptional && optionalElem != nil { - base := optionalElem - for { - if ptr, ok := base.(*types.Pointer); ok { - base = ptr.Elem() - continue + if ptr, ok := optionalElem.(*types.Pointer); ok { + switch ptr.Elem().Underlying().(type) { + case *types.Slice, *types.Map: + return nil, fmt.Errorf("field %s: optional.Optional is not allowed for slice/map", goName) + } + } else { + switch optionalElem.Underlying().(type) { + case *types.Struct, *types.Slice, *types.Map: + return nil, fmt.Errorf("field %s: optional.Optional is not allowed for struct/slice/map", goName) } - break - } - switch base.Underlying().(type) { - case *types.Struct, *types.Slice, *types.Map: - return nil, fmt.Errorf("field %s: optional.Optional is not supported for struct/slice/map", goName) } } diff --git a/go/fory/optional/optional.go b/go/fory/optional/optional.go index d66bcba072..857fa7d3db 100644 --- a/go/fory/optional/optional.go +++ b/go/fory/optional/optional.go @@ -17,15 +17,17 @@ package optional -// Optional represents an optional value without pointer indirection. +// Optional represents an immutable optional value without pointer indirection. +// Optional is intended for scalar values. Do not wrap structs; prefer *Struct +// (or Optional[*Struct] if you need explicit optional semantics). type Optional[T any] struct { - Value T - Has bool + value T + has bool } // Some returns an Optional containing a value. func Some[T any](v T) Optional[T] { - return Optional[T]{Value: v, Has: true} + return Optional[T]{value: v, has: true} } // None returns an empty Optional. @@ -41,49 +43,40 @@ func FromPtr[T any](v *T) Optional[T] { return Some(*v) } -// Ptr returns a pointer to the contained value or nil. -func (o Optional[T]) Ptr() *T { - if !o.Has { - return nil - } - v := o.Value - return &v -} - // IsSome reports whether the optional contains a value. -func (o Optional[T]) IsSome() bool { return o.Has } +func (o Optional[T]) IsSome() bool { return o.has } // IsNone reports whether the optional is empty. -func (o Optional[T]) IsNone() bool { return !o.Has } +func (o Optional[T]) IsNone() bool { return !o.has } // Expect returns the contained value or panics with the provided message. func (o Optional[T]) Expect(message string) T { - if o.Has { - return o.Value + if o.has { + return o.value } panic(message) } // Unwrap returns the contained value or panics. func (o Optional[T]) Unwrap() T { - if o.Has { - return o.Value + if o.has { + return o.value } panic("optional: unwrap on None") } // UnwrapOr returns the contained value or a default. func (o Optional[T]) UnwrapOr(defaultValue T) T { - if o.Has { - return o.Value + if o.has { + return o.value } return defaultValue } // UnwrapOrDefault returns the contained value or the zero value. func (o Optional[T]) UnwrapOrDefault() T { - if o.Has { - return o.Value + if o.has { + return o.value } var zero T return zero @@ -91,55 +84,24 @@ func (o Optional[T]) UnwrapOrDefault() T { // UnwrapOrElse returns the contained value or computes a default. func (o Optional[T]) UnwrapOrElse(defaultFn func() T) T { - if o.Has { - return o.Value - } - return defaultFn() -} - -// Map maps an Optional[T] to Optional[U] by applying a function. -func Map[T, U any](o Optional[T], f func(T) U) Optional[U] { - if o.Has { - return Some(f(o.Value)) - } - return None[U]() -} - -// MapOr applies a function to the contained value or returns a default. -func MapOr[T, U any](o Optional[T], defaultValue U, f func(T) U) U { - if o.Has { - return f(o.Value) - } - return defaultValue -} - -// MapOrElse applies a function to the contained value or computes a default. -func MapOrElse[T, U any](o Optional[T], defaultFn func() U, f func(T) U) U { - if o.Has { - return f(o.Value) + if o.has { + return o.value } return defaultFn() } -// And returns None if either option is None, otherwise returns the second option. -func And[T, U any](o Optional[T], other Optional[U]) Optional[U] { - if o.Has { - return other - } - return None[U]() -} - -// AndThen returns None if this option is None, otherwise calls f and returns its result. -func AndThen[T, U any](o Optional[T], f func(T) Optional[U]) Optional[U] { - if o.Has { - return f(o.Value) +// OkOr returns the contained value or the provided error. +func (o Optional[T]) OkOr(err error) (T, error) { + if o.has { + return o.value, nil } - return None[U]() + var zero T + return zero, err } // Or returns the option if it is Some, otherwise returns other. func (o Optional[T]) Or(other Optional[T]) Optional[T] { - if o.Has { + if o.has { return o } return other @@ -147,71 +109,25 @@ func (o Optional[T]) Or(other Optional[T]) Optional[T] { // OrElse returns the option if it is Some, otherwise returns the result of f. func (o Optional[T]) OrElse(f func() Optional[T]) Optional[T] { - if o.Has { + if o.has { return o } return f() } +// ValueOrZero returns the contained value or the zero value. +func (o Optional[T]) ValueOrZero() T { + return o.UnwrapOrDefault() +} + // Filter returns None if the predicate returns false. func (o Optional[T]) Filter(predicate func(T) bool) Optional[T] { - if o.Has && predicate(o.Value) { + if o.has && predicate(o.value) { return o } return None[T]() } -// Result represents a simplified Result type for OkOr helpers. -type Result[T any] struct { - Value T - Err error -} - -// OkOr transforms the option into a Result, using err if None. -func (o Optional[T]) OkOr(err error) Result[T] { - if o.Has { - return Result[T]{Value: o.Value} - } - return Result[T]{Err: err} -} - -// OkOrElse transforms the option into a Result, using a function to produce the error. -func (o Optional[T]) OkOrElse(errFn func() error) Result[T] { - if o.Has { - return Result[T]{Value: o.Value} - } - return Result[T]{Err: errFn()} -} - -// Take takes the value out, leaving None in its place. -func (o *Optional[T]) Take() Optional[T] { - if o == nil || !o.Has { - return None[T]() - } - v := o.Value - o.Has = false - var zero T - o.Value = zero - return Some(v) -} - -// Set sets the option to Some(value). -func (o *Optional[T]) Set(v T) { - if o == nil { - return - } - o.Value = v - o.Has = true -} - -// Flatten transforms Optional[Optional[T]] into Optional[T]. -func Flatten[T any](o Optional[Optional[T]]) Optional[T] { - if !o.Has { - return None[T]() - } - return o.Value -} - // Int8 wraps an int8 value in Optional. func Int8(v int8) Optional[int8] { return Some(v) } diff --git a/go/fory/optional_serializer.go b/go/fory/optional_serializer.go index b96a31a412..55aca4bdbb 100644 --- a/go/fory/optional_serializer.go +++ b/go/fory/optional_serializer.go @@ -21,6 +21,7 @@ import ( "fmt" "reflect" "strings" + "unsafe" ) const optionalPkgPath = "github.com/apache/fory/go/fory/optional" @@ -49,11 +50,11 @@ func getOptionalInfo(type_ reflect.Type) (optionalInfo, bool) { if name != "Optional" && !strings.HasPrefix(name, "Optional[") { return optionalInfo{}, false } - valueField, ok := type_.FieldByName("Value") + valueField, ok := type_.FieldByName("value") if !ok { return optionalInfo{}, false } - hasField, ok := type_.FieldByName("Has") + hasField, ok := type_.FieldByName("has") if !ok || hasField.Type.Kind() != reflect.Bool { return optionalInfo{}, false } @@ -68,16 +69,18 @@ func validateOptionalValueType(valueType reflect.Type) error { if valueType == nil { return fmt.Errorf("optional value type is nil") } - base := valueType - for base.Kind() == reflect.Ptr { - base = base.Elem() - } - switch base.Kind() { - case reflect.Struct, reflect.Slice, reflect.Map: - return fmt.Errorf("optional.Optional[%s] is not supported for struct/slice/map", valueType.String()) - default: - return nil + switch valueType.Kind() { + case reflect.Struct: + return fmt.Errorf("optional.Optional[%s] is not supported for struct values", valueType.String()) + case reflect.Slice, reflect.Map: + return fmt.Errorf("optional.Optional[%s] is not supported for slice/map values", valueType.String()) + case reflect.Ptr: + elem := valueType.Elem() + if elem.Kind() == reflect.Slice || elem.Kind() == reflect.Map { + return fmt.Errorf("optional.Optional[%s] is not supported for slice/map values", valueType.String()) + } } + return nil } func isOptionalType(type_ reflect.Type) bool { @@ -100,32 +103,41 @@ func optionalHasValue(value reflect.Value, info optionalInfo) bool { } value = value.Elem() } - return value.FieldByName("Has").Bool() + if value.CanAddr() { + hasPtr := (*bool)(unsafe.Add(unsafe.Pointer(value.UnsafeAddr()), info.hasOffset)) + return *hasPtr + } + field := value.FieldByName("has") + if !field.IsValid() { + return false + } + return field.Bool() } // optionalSerializer handles Optional[T] values by writing null flags and delegating to the element serializer. type optionalSerializer struct { optionalType reflect.Type valueType reflect.Type - valueIndex int - hasIndex int + valueOffset uintptr + hasOffset uintptr valueSerializer Serializer } func newOptionalSerializer(optionalType reflect.Type, info optionalInfo, valueSerializer Serializer) *optionalSerializer { - valueField, _ := optionalType.FieldByName("Value") - hasField, _ := optionalType.FieldByName("Has") return &optionalSerializer{ optionalType: optionalType, valueType: info.valueType, - valueIndex: valueField.Index[0], - hasIndex: hasField.Index[0], + valueOffset: info.valueOffset, + hasOffset: info.hasOffset, valueSerializer: valueSerializer, } } func (s *optionalSerializer) unwrap(value reflect.Value) reflect.Value { if value.Kind() == reflect.Ptr { + if value.IsNil() { + return reflect.Value{} + } return value.Elem() } return value @@ -133,17 +145,50 @@ func (s *optionalSerializer) unwrap(value reflect.Value) reflect.Value { func (s *optionalSerializer) has(value reflect.Value) bool { value = s.unwrap(value) - return value.Field(s.hasIndex).Bool() + if !value.IsValid() { + return false + } + if value.CanAddr() { + hasPtr := (*bool)(unsafe.Add(unsafe.Pointer(value.UnsafeAddr()), s.hasOffset)) + return *hasPtr + } + field := value.FieldByName("has") + if !field.IsValid() { + return false + } + return field.Bool() } func (s *optionalSerializer) valueField(value reflect.Value) reflect.Value { value = s.unwrap(value) - return value.Field(s.valueIndex) + if !value.IsValid() { + return reflect.New(s.valueType).Elem() + } + if value.CanAddr() { + ptr := unsafe.Add(unsafe.Pointer(value.UnsafeAddr()), s.valueOffset) + return reflect.NewAt(s.valueType, ptr).Elem() + } + field := value.FieldByName("value") + if field.IsValid() { + return field + } + return reflect.New(s.valueType).Elem() } func (s *optionalSerializer) setHas(value reflect.Value, has bool) { value = s.unwrap(value) - value.Field(s.hasIndex).SetBool(has) + if !value.IsValid() { + return + } + if value.CanAddr() { + hasPtr := (*bool)(unsafe.Add(unsafe.Pointer(value.UnsafeAddr()), s.hasOffset)) + *hasPtr = has + return + } + field := value.FieldByName("has") + if field.IsValid() && field.CanSet() { + field.SetBool(has) + } } func (s *optionalSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go index 82d680b84f..d2b504c0aa 100644 --- a/go/fory/struct_test.go +++ b/go/fory/struct_test.go @@ -107,14 +107,14 @@ func TestOptionFieldSerialization(t *testing.T) { require.NoError(t, err) out := result.(*OptionStruct) - require.True(t, out.OptInt.Has) - require.Equal(t, int32(123), out.OptInt.Value) - require.True(t, out.OptZero.Has) - require.Equal(t, int32(0), out.OptZero.Value) - require.True(t, out.OptString.Has) - require.Equal(t, "hello", out.OptString.Value) - require.True(t, out.OptBool.Has) - require.Equal(t, true, out.OptBool.Value) + require.True(t, out.OptInt.IsSome()) + require.Equal(t, int32(123), out.OptInt.Unwrap()) + require.True(t, out.OptZero.IsSome()) + require.Equal(t, int32(0), out.OptZero.Unwrap()) + require.True(t, out.OptString.IsSome()) + require.Equal(t, "hello", out.OptString.Unwrap()) + require.True(t, out.OptBool.IsSome()) + require.Equal(t, true, out.OptBool.Unwrap()) } func TestOptionFieldUnsupportedTypes(t *testing.T) { @@ -207,18 +207,18 @@ func TestNumericPointerOptionalInterop(t *testing.T) { var out NumericOptStruct require.NoError(t, reader.Unmarshal(data, &out)) - require.False(t, out.I8.Has) - require.False(t, out.I16.Has) - require.False(t, out.I32.Has) - require.False(t, out.I64.Has) - require.False(t, out.I.Has) - require.False(t, out.U8.Has) - require.False(t, out.U16.Has) - require.False(t, out.U32.Has) - require.False(t, out.U64.Has) - require.False(t, out.U.Has) - require.False(t, out.F32.Has) - require.False(t, out.F64.Has) + require.False(t, out.I8.IsSome()) + require.False(t, out.I16.IsSome()) + require.False(t, out.I32.IsSome()) + require.False(t, out.I64.IsSome()) + require.False(t, out.I.IsSome()) + require.False(t, out.U8.IsSome()) + require.False(t, out.U16.IsSome()) + require.False(t, out.U32.IsSome()) + require.False(t, out.U64.IsSome()) + require.False(t, out.U.IsSome()) + require.False(t, out.F32.IsSome()) + require.False(t, out.F64.IsSome()) }) t.Run("PointerToOptionalValue", func(t *testing.T) { From a689b651a483d66cda3f85199b8405051bfe16c5 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 17:27:56 +0800 Subject: [PATCH 08/21] fix refl --- go/fory/optional/optional.go | 11 +++++++++++ go/fory/refl/refl.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 go/fory/refl/refl.go diff --git a/go/fory/optional/optional.go b/go/fory/optional/optional.go index 857fa7d3db..06f9e7a543 100644 --- a/go/fory/optional/optional.go +++ b/go/fory/optional/optional.go @@ -17,6 +17,12 @@ package optional +import ( + "unsafe" + + "github.com/apache/fory/go/fory/refl" +) + // Optional represents an immutable optional value without pointer indirection. // Optional is intended for scalar values. Do not wrap structs; prefer *Struct // (or Optional[*Struct] if you need explicit optional semantics). @@ -163,3 +169,8 @@ func String(v string) Optional[string] { return Some(v) } // Bool wraps a bool value in Optional. func Bool(v bool) Optional[bool] { return Some(v) } + +// ForyReflect exposes the address of the Optional for unsafe fast paths. +func (o *Optional[T]) ForyReflect() refl.ForyReflectValue { + return refl.NewForyReflectValue(unsafe.Pointer(o)) +} diff --git a/go/fory/refl/refl.go b/go/fory/refl/refl.go new file mode 100644 index 0000000000..595256a5af --- /dev/null +++ b/go/fory/refl/refl.go @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package refl + +import "unsafe" + +// ForyReflectValue wraps the address returned by ForyReflect. +type ForyReflectValue struct { + Ptr unsafe.Pointer +} + +// NewForyReflectValue constructs a ForyReflectValue from a pointer. +func NewForyReflectValue(ptr unsafe.Pointer) ForyReflectValue { + return ForyReflectValue{Ptr: ptr} +} + +// ForyAddressable exposes an address for unsafe fast paths. +type ForyAddressable interface { + ForyReflect() ForyReflectValue +} From 491e8435864bc8874099e99f90933d63b9bfa554 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 17:34:06 +0800 Subject: [PATCH 09/21] delete fory reflect in go --- go/fory/optional/optional.go | 11 ----------- go/fory/refl/refl.go | 35 ----------------------------------- 2 files changed, 46 deletions(-) delete mode 100644 go/fory/refl/refl.go diff --git a/go/fory/optional/optional.go b/go/fory/optional/optional.go index 06f9e7a543..857fa7d3db 100644 --- a/go/fory/optional/optional.go +++ b/go/fory/optional/optional.go @@ -17,12 +17,6 @@ package optional -import ( - "unsafe" - - "github.com/apache/fory/go/fory/refl" -) - // Optional represents an immutable optional value without pointer indirection. // Optional is intended for scalar values. Do not wrap structs; prefer *Struct // (or Optional[*Struct] if you need explicit optional semantics). @@ -169,8 +163,3 @@ func String(v string) Optional[string] { return Some(v) } // Bool wraps a bool value in Optional. func Bool(v bool) Optional[bool] { return Some(v) } - -// ForyReflect exposes the address of the Optional for unsafe fast paths. -func (o *Optional[T]) ForyReflect() refl.ForyReflectValue { - return refl.NewForyReflectValue(unsafe.Pointer(o)) -} diff --git a/go/fory/refl/refl.go b/go/fory/refl/refl.go deleted file mode 100644 index 595256a5af..0000000000 --- a/go/fory/refl/refl.go +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package refl - -import "unsafe" - -// ForyReflectValue wraps the address returned by ForyReflect. -type ForyReflectValue struct { - Ptr unsafe.Pointer -} - -// NewForyReflectValue constructs a ForyReflectValue from a pointer. -func NewForyReflectValue(ptr unsafe.Pointer) ForyReflectValue { - return ForyReflectValue{Ptr: ptr} -} - -// ForyAddressable exposes an address for unsafe fast paths. -type ForyAddressable interface { - ForyReflect() ForyReflectValue -} From 06ff6fd44985661c0ec684ac15433cfb0f8bcb46 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 19:03:10 +0800 Subject: [PATCH 10/21] refactor struct serializer --- go/fory/buffer.go | 21 +- go/fory/struct.go | 493 +++++++++++----------------------------------- 2 files changed, 126 insertions(+), 388 deletions(-) diff --git a/go/fory/buffer.go b/go/fory/buffer.go index 74e4498d75..2aa9cd4ba4 100644 --- a/go/fory/buffer.go +++ b/go/fory/buffer.go @@ -969,14 +969,14 @@ func (b *ByteBuffer) WriteVarint64(value int64) { b.WriteVaruint64(u) } -// WriteVaruint64 writes to unsigned varint (up to 9 bytes) +// WriteVaruint64 writes to unsigned varint (up to 10 bytes) func (b *ByteBuffer) WriteVaruint64(value uint64) { - b.grow(9) + b.grow(10) offset := b.writerIndex - data := b.data[offset : offset+9] + data := b.data[offset : offset+10] i := 0 - for i < 8 && value >= 0x80 { + for value >= 0x80 { data[i] = byte(value&0x7F) | 0x80 value >>= 7 i++ @@ -1206,7 +1206,7 @@ func (b *ByteBuffer) ReadTaggedUint64(err *Error) uint64 { // //go:inline func (b *ByteBuffer) ReadVaruint64(err *Error) uint64 { - if b.remaining() >= 9 { + if b.remaining() >= 10 { return b.readVaruint64Fast() } return b.readVaruint64Slow(err) @@ -1249,9 +1249,16 @@ func (b *ByteBuffer) readVaruint64Fast() uint64 { result |= (bulk >> 7) & 0xFE000000000000 readLength = 8 if (bulk & 0x8000000000000000) != 0 { - // Need 9th byte - result |= uint64(b.data[b.readerIndex+8]) << 56 + // Need 9th byte (and possibly 10th if continuation bit is set) + b9 := b.data[b.readerIndex+8] + result |= uint64(b9&0x7F) << 56 readLength = 9 + if (b9 & 0x80) != 0 { + // 10th byte carries the remaining bits + b10 := b.data[b.readerIndex+9] + result |= uint64(b10) << 63 + readLength = 10 + } } } } diff --git a/go/fory/struct.go b/go/fory/struct.go index 8beea7f32c..d6f539d27b 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -60,6 +60,9 @@ type structSerializer struct { // Initialization state initialized bool + + // Cached addressable value for non-addressable writes. + tempValue *reflect.Value } // newStructSerializerFromTypeDef creates a new structSerializer with the given parameters. @@ -118,6 +121,10 @@ func (s *structSerializer) initialize(typeResolver *TypeResolver) error { } // Compute struct hash s.structHash = s.computeHash() + if s.tempValue == nil { + tmp := reflect.New(s.type_).Elem() + s.tempValue = &tmp + } s.initialized = true return nil } @@ -1011,19 +1018,36 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { buf.WriteInt32(s.structHash) } - // Check if value is addressable for unsafe access - canUseUnsafe := value.CanAddr() - var ptr unsafe.Pointer - if canUseUnsafe { - ptr = unsafe.Pointer(value.UnsafeAddr()) + // Ensure value is addressable for unsafe access + if !value.CanAddr() { + reuseCache := s.tempValue != nil + if ctx.RefResolver().refTracking && len(ctx.RefResolver().writtenObjects) > 0 { + reuseCache = false + } + if reuseCache { + tempValue := s.tempValue + s.tempValue = nil + defer func() { + tempValue.SetZero() + s.tempValue = tempValue + }() + addrValue := *tempValue + addrValue.Set(value) + value = addrValue + } else { + tmp := reflect.New(value.Type()).Elem() + tmp.Set(value) + value = tmp + } } + ptr := unsafe.Pointer(value.UnsafeAddr()) // ========================================================================== // Phase 1: Fixed-size primitives (bool, int8, int16, float32, float64) // - Reserve once, inline unsafe writes with endian handling, update index once // - field.WriteOffset computed at init time // ========================================================================== - if canUseUnsafe && s.fieldGroup.FixedSize > 0 { + if s.fieldGroup.FixedSize > 0 { buf.Reserve(s.fieldGroup.FixedSize) baseOffset := buf.WriterIndex() data := buf.GetData() @@ -1141,87 +1165,13 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { } // Update writer index ONCE after all fixed fields buf.SetWriterIndex(baseOffset + s.fieldGroup.FixedSize) - } else if len(s.fieldGroup.FixedFields) > 0 { - // Fallback to reflect-based access for unaddressable values - for _, field := range s.fieldGroup.FixedFields { - fieldValue := value.Field(field.Meta.FieldIndex) - val, ok := loadReflectFieldValue(&field, fieldValue) - switch field.DispatchId { - case PrimitiveBoolDispatchId: - if ok { - buf.WriteBool(val.Bool()) - } else { - buf.WriteBool(false) - } - case PrimitiveInt8DispatchId: - if ok { - buf.WriteByte_(byte(val.Int())) - } else { - buf.WriteByte_(0) - } - case PrimitiveUint8DispatchId: - if ok { - buf.WriteByte_(byte(val.Uint())) - } else { - buf.WriteByte_(0) - } - case PrimitiveInt16DispatchId: - if ok { - buf.WriteInt16(int16(val.Int())) - } else { - buf.WriteInt16(0) - } - case PrimitiveUint16DispatchId: - if ok { - buf.WriteInt16(int16(val.Uint())) - } else { - buf.WriteInt16(0) - } - case PrimitiveInt32DispatchId: - if ok { - buf.WriteInt32(int32(val.Int())) - } else { - buf.WriteInt32(0) - } - case PrimitiveUint32DispatchId: - if ok { - buf.WriteInt32(int32(val.Uint())) - } else { - buf.WriteInt32(0) - } - case PrimitiveInt64DispatchId: - if ok { - buf.WriteInt64(val.Int()) - } else { - buf.WriteInt64(0) - } - case PrimitiveUint64DispatchId: - if ok { - buf.WriteInt64(int64(val.Uint())) - } else { - buf.WriteInt64(0) - } - case PrimitiveFloat32DispatchId: - if ok { - buf.WriteFloat32(float32(val.Float())) - } else { - buf.WriteFloat32(0) - } - case PrimitiveFloat64DispatchId: - if ok { - buf.WriteFloat64(val.Float()) - } else { - buf.WriteFloat64(0) - } - } - } } // ========================================================================== // Phase 2: Varint primitives (int32, int64, int, uint32, uint64, uint, tagged int64/uint64) // - Reserve max size once, track offset locally, update writerIndex once at end // ========================================================================== - if canUseUnsafe && s.fieldGroup.MaxVarintSize > 0 { + if s.fieldGroup.MaxVarintSize > 0 { buf.Reserve(s.fieldGroup.MaxVarintSize) offset := buf.WriterIndex() @@ -1284,62 +1234,6 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { } // Update writer index ONCE after all varint fields buf.SetWriterIndex(offset) - } else if len(s.fieldGroup.VarintFields) > 0 { - // Slow path for non-addressable values: use reflection - for _, field := range s.fieldGroup.VarintFields { - fieldValue := value.Field(field.Meta.FieldIndex) - val, ok := loadReflectFieldValue(&field, fieldValue) - switch field.DispatchId { - case PrimitiveVarint32DispatchId: - if ok { - buf.WriteVarint32(int32(val.Int())) - } else { - buf.WriteVarint32(0) - } - case PrimitiveVarint64DispatchId: - if ok { - buf.WriteVarint64(val.Int()) - } else { - buf.WriteVarint64(0) - } - case PrimitiveIntDispatchId: - if ok { - buf.WriteVarint64(val.Int()) - } else { - buf.WriteVarint64(0) - } - case PrimitiveVarUint32DispatchId: - if ok { - buf.WriteVaruint32(uint32(val.Uint())) - } else { - buf.WriteVaruint32(0) - } - case PrimitiveVarUint64DispatchId: - if ok { - buf.WriteVaruint64(val.Uint()) - } else { - buf.WriteVaruint64(0) - } - case PrimitiveUintDispatchId: - if ok { - buf.WriteVaruint64(val.Uint()) - } else { - buf.WriteVaruint64(0) - } - case PrimitiveTaggedInt64DispatchId: - if ok { - buf.WriteTaggedInt64(val.Int()) - } else { - buf.WriteTaggedInt64(0) - } - case PrimitiveTaggedUint64DispatchId: - if ok { - buf.WriteTaggedUint64(val.Uint()) - } else { - buf.WriteTaggedUint64(0) - } - } - } } // ========================================================================== @@ -1709,150 +1603,13 @@ func (s *structSerializer) writeRemainingField(ctx *WriteContext, ptr unsafe.Poi } } - // Slow path: use reflection for non-addressable values - fieldValue := value.Field(field.Meta.FieldIndex) - - // Handle nullable types via reflection when ptr is nil (non-addressable) - switch field.DispatchId { - case NullableTaggedInt64DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteTaggedInt64(fieldValue.Elem().Int()) - return - case NullableTaggedUint64DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteTaggedUint64(fieldValue.Elem().Uint()) - return - case NullableBoolDispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteBool(fieldValue.Elem().Bool()) - return - case NullableInt8DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteInt8(int8(fieldValue.Elem().Int())) - return - case NullableUint8DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteUint8(uint8(fieldValue.Elem().Uint())) - return - case NullableInt16DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteInt16(int16(fieldValue.Elem().Int())) - return - case NullableUint16DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteUint16(uint16(fieldValue.Elem().Uint())) - return - case NullableInt32DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteInt32(int32(fieldValue.Elem().Int())) - return - case NullableUint32DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteUint32(uint32(fieldValue.Elem().Uint())) - return - case NullableInt64DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteInt64(fieldValue.Elem().Int()) - return - case NullableUint64DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteUint64(fieldValue.Elem().Uint()) - return - case NullableFloat32DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteFloat32(float32(fieldValue.Elem().Float())) - return - case NullableFloat64DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteFloat64(fieldValue.Elem().Float()) - return - case NullableVarint32DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteVarint32(int32(fieldValue.Elem().Int())) - return - case NullableVarUint32DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteVaruint32(uint32(fieldValue.Elem().Uint())) - return - case NullableVarint64DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteVarint64(fieldValue.Elem().Int()) - return - case NullableVarUint64DispatchId: - if fieldValue.IsNil() { - buf.WriteInt8(NullFlag) - return - } - buf.WriteInt8(NotNullValueFlag) - buf.WriteVaruint64(fieldValue.Elem().Uint()) + if ptr == nil { + ctx.SetError(SerializationError("cannot write struct field without addressable value")) return } // Fall back to serializer for other types + fieldValue := value.Field(field.Meta.FieldIndex) if field.Serializer != nil { field.Serializer.Write(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, fieldValue) } else { @@ -1906,48 +1663,6 @@ func clearFieldValue(kind FieldKind, fieldPtr unsafe.Pointer, opt optionalInfo) } } -func loadReflectFieldValue(field *FieldInfo, fieldValue reflect.Value) (reflect.Value, bool) { - switch field.Kind { - case FieldKindPointer: - if fieldValue.IsNil() { - return reflect.Value{}, false - } - return fieldValue.Elem(), true - case FieldKindOptional: - if !fieldValue.FieldByName("Has").Bool() { - return reflect.Value{}, false - } - return fieldValue.FieldByName("Value"), true - default: - return fieldValue, true - } -} - -func storeReflectFieldValue(field *FieldInfo, fieldValue reflect.Value, value reflect.Value) { - switch field.Kind { - case FieldKindPointer: - ptr := reflect.New(value.Type()) - ptr.Elem().Set(value) - fieldValue.Set(ptr) - case FieldKindOptional: - fieldValue.FieldByName("Has").SetBool(true) - fieldValue.FieldByName("Value").Set(value) - default: - fieldValue.Set(value) - } -} - -func clearReflectFieldValue(field *FieldInfo, fieldValue reflect.Value) { - switch field.Kind { - case FieldKindPointer: - fieldValue.Set(reflect.Zero(fieldValue.Type())) - case FieldKindOptional: - fieldValue.FieldByName("Has").SetBool(false) - default: - fieldValue.Set(reflect.Zero(fieldValue.Type())) - } -} - func writeOptionFast(ctx *WriteContext, field *FieldInfo, optPtr unsafe.Pointer) bool { buf := ctx.Buffer() has := *(*bool)(unsafe.Add(optPtr, field.Meta.OptionalInfo.hasOffset)) @@ -2666,8 +2381,7 @@ func (s *structSerializer) readRemainingField(ctx *ReadContext, ptr unsafe.Point return case EnumDispatchId: // Enums don't track refs - always use fast path - fieldValue := value.Field(field.Meta.FieldIndex) - readEnumField(ctx, field, fieldValue) + readEnumFieldUnsafe(ctx, field, fieldPtr) return case StringSliceDispatchId: if field.RefMode == RefModeTracking { @@ -3339,52 +3053,44 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val return } - // Get field value for nullable primitives and non-primitives - fieldValue := value.Field(field.Meta.FieldIndex) + fieldPtr := unsafe.Add(ptr, field.Offset) + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional { + optInfo = field.Meta.OptionalInfo + } // Handle nullable fixed-size primitives (read ref flag + fixed bytes) // These have Nullable=true but use fixed encoding, not varint if isNullableFixedSizePrimitive(field.DispatchId) { refFlag := buf.ReadInt8(err) if refFlag == NullFlag { - clearReflectFieldValue(field, fieldValue) + clearFieldValue(field.Kind, fieldPtr, optInfo) return } // Read fixed-size value based on dispatch ID switch field.DispatchId { case NullableBoolDispatchId: - v := buf.ReadBool(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadBool(err)) case NullableInt8DispatchId: - v := buf.ReadInt8(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt8(err)) case NullableUint8DispatchId: - v := uint8(buf.ReadInt8(err)) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, uint8(buf.ReadInt8(err))) case NullableInt16DispatchId: - v := buf.ReadInt16(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt16(err)) case NullableUint16DispatchId: - v := buf.ReadUint16(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint16(err)) case NullableInt32DispatchId: - v := buf.ReadInt32(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt32(err)) case NullableUint32DispatchId: - v := buf.ReadUint32(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint32(err)) case NullableInt64DispatchId: - v := buf.ReadInt64(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt64(err)) case NullableUint64DispatchId: - v := buf.ReadUint64(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint64(err)) case NullableFloat32DispatchId: - v := buf.ReadFloat32(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadFloat32(err)) case NullableFloat64DispatchId: - v := buf.ReadFloat64(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadFloat64(err)) } return } @@ -3393,44 +3099,37 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val if isNullableVarintPrimitive(field.DispatchId) { refFlag := buf.ReadInt8(err) if refFlag == NullFlag { - clearReflectFieldValue(field, fieldValue) + clearFieldValue(field.Kind, fieldPtr, optInfo) return } // Read varint value based on dispatch ID switch field.DispatchId { case NullableVarint32DispatchId: - v := buf.ReadVarint32(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint32(err)) case NullableVarint64DispatchId: - v := buf.ReadVarint64(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint64(err)) case NullableVarUint32DispatchId: - v := buf.ReadVaruint32(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVaruint32(err)) case NullableVarUint64DispatchId: - v := buf.ReadVaruint64(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVaruint64(err)) case NullableTaggedInt64DispatchId: - v := buf.ReadTaggedInt64(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedInt64(err)) case NullableTaggedUint64DispatchId: - v := buf.ReadTaggedUint64(err) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedUint64(err)) case NullableIntDispatchId: - v := int(buf.ReadVarint64(err)) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, int(buf.ReadVarint64(err))) case NullableUintDispatchId: - v := uint(buf.ReadVaruint64(err)) - storeReflectFieldValue(field, fieldValue, reflect.ValueOf(v)) + storeFieldValue(field.Kind, fieldPtr, optInfo, uint(buf.ReadVaruint64(err))) } return } if isEnumField(field) { - readEnumField(ctx, field, fieldValue) + readEnumFieldUnsafe(ctx, field, fieldPtr) return } // Slow path for non-primitives (all need ref flag per xlang spec) + fieldValue := value.Field(field.Meta.FieldIndex) if field.Serializer != nil { // Use pre-computed RefMode and WriteType from field initialization field.Serializer.Read(ctx, field.RefMode, field.Meta.WriteType, field.Meta.HasGenerics, fieldValue) @@ -3525,11 +3224,40 @@ func writeEnumField(ctx *WriteContext, field *FieldInfo, fieldValue reflect.Valu } } -// readEnumField reads an enum field respecting the field's RefMode. +func setEnumValue(ctx *ReadContext, ptr unsafe.Pointer, kind reflect.Kind, ordinal uint32) bool { + switch kind { + case reflect.Int: + *(*int)(ptr) = int(ordinal) + case reflect.Int8: + *(*int8)(ptr) = int8(ordinal) + case reflect.Int16: + *(*int16)(ptr) = int16(ordinal) + case reflect.Int32: + *(*int32)(ptr) = int32(ordinal) + case reflect.Int64: + *(*int64)(ptr) = int64(ordinal) + case reflect.Uint: + *(*uint)(ptr) = uint(ordinal) + case reflect.Uint8: + *(*uint8)(ptr) = uint8(ordinal) + case reflect.Uint16: + *(*uint16)(ptr) = uint16(ordinal) + case reflect.Uint32: + *(*uint32)(ptr) = ordinal + case reflect.Uint64: + *(*uint64)(ptr) = uint64(ordinal) + default: + ctx.SetError(DeserializationErrorf("enum serializer: unsupported kind %v", kind)) + return false + } + return true +} + +// readEnumFieldUnsafe reads an enum field respecting the field's RefMode. // RefMode determines whether null flag is read, regardless of whether the local type is a pointer. // This is important for compatible mode where remote TypeDef's nullable flag controls the wire format. // Uses context error state for deferred error checking. -func readEnumField(ctx *ReadContext, field *FieldInfo, fieldValue reflect.Value) { +func readEnumFieldUnsafe(ctx *ReadContext, field *FieldInfo, fieldPtr unsafe.Pointer) { buf := ctx.Buffer() isPointer := field.Kind == FieldKindPointer @@ -3537,29 +3265,32 @@ func readEnumField(ctx *ReadContext, field *FieldInfo, fieldValue reflect.Value) if field.RefMode != RefModeNone { nullFlag := buf.ReadInt8(ctx.Err()) if nullFlag == NullFlag { - // For pointer enum fields, leave as nil; for non-pointer, set to zero - if !isPointer { - fieldValue.SetInt(0) + if isPointer { + *(*unsafe.Pointer)(fieldPtr) = nil + } else { + setEnumValue(ctx, fieldPtr, field.Meta.Type.Kind(), 0) } return } } - // For pointer enum fields, allocate a new value - targetValue := fieldValue - if isPointer { - newVal := reflect.New(field.Meta.Type.Elem()) - fieldValue.Set(newVal) - targetValue = newVal.Elem() + ordinal := buf.ReadVaruint32Small7(ctx.Err()) + if ctx.HasError() { + return } - // For pointer enum fields, the serializer is ptrToValueSerializer wrapping enumSerializer. - // We need to call the inner enumSerializer directly with the dereferenced value. - if ptrSer, ok := field.Serializer.(*ptrToValueSerializer); ok { - ptrSer.valueSerializer.ReadData(ctx, targetValue) - } else { - field.Serializer.ReadData(ctx, targetValue) + if isPointer { + elemType := field.Meta.Type.Elem() + newVal := reflect.New(elemType) + elemPtr := unsafe.Pointer(newVal.Pointer()) + if !setEnumValue(ctx, elemPtr, elemType.Kind(), ordinal) { + return + } + *(*unsafe.Pointer)(fieldPtr) = elemPtr + return } + + setEnumValue(ctx, fieldPtr, field.Meta.Type.Kind(), ordinal) } // skipStructSerializer is a serializer that skips unknown struct data From 5d28d41f559348d1c583c464a2c6e7f615bd94e1 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 19:30:21 +0800 Subject: [PATCH 11/21] remove unrolloed int64 write --- go/fory/buffer.go | 145 --------------- go/fory/buffer_bench_test.go | 340 +++++++++++++++++++++++++++++++++++ go/fory/string.go | 4 +- 3 files changed, 342 insertions(+), 147 deletions(-) create mode 100644 go/fory/buffer_bench_test.go diff --git a/go/fory/buffer.go b/go/fory/buffer.go index 2aa9cd4ba4..4e571e979d 100644 --- a/go/fory/buffer.go +++ b/go/fory/buffer.go @@ -593,151 +593,6 @@ func (b *ByteBuffer) UnsafeWriteBool(v bool) { b.writerIndex++ } -// UnsafeWriteVarint64 writes a varint64 without grow check. -// Caller must have called Reserve(10) beforehand. -// -//go:inline -func (b *ByteBuffer) UnsafeWriteVarint64(value int64) { - u := uint64((value << 1) ^ (value >> 63)) - b.UnsafeWriteVaruint64(u) -} - -// UnsafeWriteVaruint64 writes a varuint64 without grow check. -// Caller must have called Reserve(16) beforehand (for bulk writes). -func (b *ByteBuffer) UnsafeWriteVaruint64(value uint64) { - if value>>7 == 0 { - b.data[b.writerIndex] = byte(value) - b.writerIndex++ - return - } - if value>>14 == 0 { - encoded := uint16((value&0x7F)|0x80) | uint16(value>>7)<<8 - if isLittleEndian { - *(*uint16)(unsafe.Pointer(&b.data[b.writerIndex])) = encoded - } else { - b.data[b.writerIndex] = byte(encoded) - b.data[b.writerIndex+1] = byte(encoded >> 8) - } - b.writerIndex += 2 - return - } - if value>>21 == 0 { - encoded := uint32((value&0x7F)|0x80) | - uint32(((value>>7)&0x7F)|0x80)<<8 | - uint32(value>>14)<<16 - if isLittleEndian { - *(*uint32)(unsafe.Pointer(&b.data[b.writerIndex])) = encoded - } else { - b.data[b.writerIndex] = byte(encoded) - b.data[b.writerIndex+1] = byte(encoded >> 8) - b.data[b.writerIndex+2] = byte(encoded >> 16) - } - b.writerIndex += 3 - return - } - if value>>28 == 0 { - encoded := uint32((value&0x7F)|0x80) | - uint32(((value>>7)&0x7F)|0x80)<<8 | - uint32(((value>>14)&0x7F)|0x80)<<16 | - uint32(value>>21)<<24 - if isLittleEndian { - *(*uint32)(unsafe.Pointer(&b.data[b.writerIndex])) = encoded - } else { - binary.LittleEndian.PutUint32(b.data[b.writerIndex:], encoded) - } - b.writerIndex += 4 - return - } - if value>>35 == 0 { - encoded := uint64((value&0x7F)|0x80) | - uint64(((value>>7)&0x7F)|0x80)<<8 | - uint64(((value>>14)&0x7F)|0x80)<<16 | - uint64(((value>>21)&0x7F)|0x80)<<24 | - uint64(value>>28)<<32 - if isLittleEndian { - *(*uint64)(unsafe.Pointer(&b.data[b.writerIndex])) = encoded - } else { - binary.LittleEndian.PutUint64(b.data[b.writerIndex:], encoded) - } - b.writerIndex += 5 - return - } - if value>>42 == 0 { - encoded := uint64((value&0x7F)|0x80) | - uint64(((value>>7)&0x7F)|0x80)<<8 | - uint64(((value>>14)&0x7F)|0x80)<<16 | - uint64(((value>>21)&0x7F)|0x80)<<24 | - uint64(((value>>28)&0x7F)|0x80)<<32 | - uint64(value>>35)<<40 - if isLittleEndian { - *(*uint64)(unsafe.Pointer(&b.data[b.writerIndex])) = encoded - } else { - binary.LittleEndian.PutUint64(b.data[b.writerIndex:], encoded) - } - b.writerIndex += 6 - return - } - if value>>49 == 0 { - encoded := uint64((value&0x7F)|0x80) | - uint64(((value>>7)&0x7F)|0x80)<<8 | - uint64(((value>>14)&0x7F)|0x80)<<16 | - uint64(((value>>21)&0x7F)|0x80)<<24 | - uint64(((value>>28)&0x7F)|0x80)<<32 | - uint64(((value>>35)&0x7F)|0x80)<<40 | - uint64(value>>42)<<48 - if isLittleEndian { - *(*uint64)(unsafe.Pointer(&b.data[b.writerIndex])) = encoded - } else { - binary.LittleEndian.PutUint64(b.data[b.writerIndex:], encoded) - } - b.writerIndex += 7 - return - } - if value>>56 == 0 { - encoded := uint64((value&0x7F)|0x80) | - uint64(((value>>7)&0x7F)|0x80)<<8 | - uint64(((value>>14)&0x7F)|0x80)<<16 | - uint64(((value>>21)&0x7F)|0x80)<<24 | - uint64(((value>>28)&0x7F)|0x80)<<32 | - uint64(((value>>35)&0x7F)|0x80)<<40 | - uint64(((value>>42)&0x7F)|0x80)<<48 | - uint64(value>>49)<<56 - if isLittleEndian { - *(*uint64)(unsafe.Pointer(&b.data[b.writerIndex])) = encoded - } else { - binary.LittleEndian.PutUint64(b.data[b.writerIndex:], encoded) - } - b.writerIndex += 8 - return - } - // 9 or 10 bytes needed - write 8 bytes bulk then remaining - encoded := uint64((value&0x7F)|0x80) | - uint64(((value>>7)&0x7F)|0x80)<<8 | - uint64(((value>>14)&0x7F)|0x80)<<16 | - uint64(((value>>21)&0x7F)|0x80)<<24 | - uint64(((value>>28)&0x7F)|0x80)<<32 | - uint64(((value>>35)&0x7F)|0x80)<<40 | - uint64(((value>>42)&0x7F)|0x80)<<48 | - uint64(((value>>49)&0x7F)|0x80)<<56 - if isLittleEndian { - *(*uint64)(unsafe.Pointer(&b.data[b.writerIndex])) = encoded - } else { - binary.LittleEndian.PutUint64(b.data[b.writerIndex:], encoded) - } - b.writerIndex += 8 - - // Remaining 1-2 bytes - remaining := value >> 56 - if remaining>>7 == 0 { - b.data[b.writerIndex] = byte(remaining) - b.writerIndex++ - } else { - b.data[b.writerIndex] = byte(remaining) | 0x80 - b.data[b.writerIndex+1] = byte(remaining >> 7) - b.writerIndex += 2 - } -} - // UnsafePutVarInt32 writes a zigzag-encoded varint32 at the given offset without advancing writerIndex. // Caller must have called Reserve() to ensure capacity. // Returns the number of bytes written (1-5). diff --git a/go/fory/buffer_bench_test.go b/go/fory/buffer_bench_test.go new file mode 100644 index 0000000000..aa15163349 --- /dev/null +++ b/go/fory/buffer_bench_test.go @@ -0,0 +1,340 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import "testing" + +var benchVaruint64Values = []uint64{ + 0, + 1, + 127, + 128, + 16384, + 1 << 20, + 1 << 40, + 1<<63 - 1, + ^uint64(0), +} + +var benchVaruint64SmallValues = []uint64{ + 0, + 1, + 2, + 3, + 7, + 15, + 31, + 63, + 127, +} + +var benchVaruint64MidValues = []uint64{ + 128, + 129, + 16383, + 16384, + 1<<20 - 1, + 1 << 20, + 1<<27 - 1, + 1 << 27, + 1<<34 - 1, + 1 << 34, +} + +var benchVaruint64LargeValues = []uint64{ + 1<<40 - 1, + 1 << 40, + 1<<55 - 1, + 1 << 55, + 1<<63 - 1, + ^uint64(0), +} + +var benchVaruint32SmallValues = []uint32{ + 0, + 1, + 2, + 3, + 7, + 15, + 31, + 63, + 127, +} + +var benchVaruint32MidValues = []uint32{ + 128, + 129, + 16383, + 16384, + 1<<20 - 1, + 1 << 20, + 1<<27 - 1, + 1 << 27, +} + +var benchVaruint32LargeValues = []uint32{ + 1<<29 - 1, + 1 << 29, + 1<<31 - 1, + ^uint32(0), +} + +var benchVaruint36SmallValues = []uint64{ + 0, + 1, + 127, + 128, + 16383, + 16384, + 1<<20 - 1, + 1 << 20, +} + +var benchVaruint36MidValues = []uint64{ + 1<<27 - 1, + 1 << 27, + 1<<34 - 1, + 1 << 34, +} + +var benchVaruint36LargeValues = []uint64{ + 1<<35 - 1, + 1<<35 + 123, + 1<<36 - 1, +} + +func writeVaruint32Loop(buf *ByteBuffer, value uint32) int8 { + buf.grow(5) + offset := buf.writerIndex + data := buf.data[offset : offset+5] + i := 0 + for value >= 0x80 { + data[i] = byte(value&0x7F) | 0x80 + value >>= 7 + i++ + } + data[i] = byte(value) + i++ + buf.writerIndex += i + return int8(i) +} + +func writeVaruint32Unrolled(buf *ByteBuffer, value uint32) int8 { + buf.grow(5) + return buf.UnsafeWriteVaruint32(value) +} + +func writeVaruint36SmallLoop(buf *ByteBuffer, value uint64) { + buf.grow(5) + offset := buf.writerIndex + data := buf.data[offset : offset+5] + i := 0 + for i < 4 && value >= 0x80 { + data[i] = byte(value&0x7F) | 0x80 + value >>= 7 + i++ + } + if i < 4 { + data[i] = byte(value) + buf.writerIndex += i + 1 + return + } + data[4] = byte(value) + buf.writerIndex += 5 +} + +func writeVaruint36SmallUnrolled(buf *ByteBuffer, value uint64) { + buf.WriteVaruint36Small(value) +} + +func BenchmarkWriteVaruint64Loop(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint64Values + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + buf.WriteVaruint64(values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint64LoopSmall(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint64SmallValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + buf.WriteVaruint64(values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint64LoopMid(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint64MidValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + buf.WriteVaruint64(values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint64LoopLarge(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint64LargeValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + buf.WriteVaruint64(values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint32LoopSmall(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint32SmallValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint32Loop(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint32UnrolledSmall(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint32SmallValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint32Unrolled(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint32LoopMid(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint32MidValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint32Loop(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint32UnrolledMid(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint32MidValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint32Unrolled(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint32LoopLarge(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint32LargeValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint32Loop(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint32UnrolledLarge(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint32LargeValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint32Unrolled(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint36SmallLoopSmall(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint36SmallValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint36SmallLoop(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint36SmallUnrolledSmall(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint36SmallValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint36SmallUnrolled(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint36SmallLoopMid(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint36MidValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint36SmallLoop(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint36SmallUnrolledMid(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint36MidValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint36SmallUnrolled(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint36SmallLoopLarge(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint36LargeValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint36SmallLoop(buf, values[i%len(values)]) + } +} + +func BenchmarkWriteVaruint36SmallUnrolledLarge(b *testing.B) { + buf := NewByteBuffer(make([]byte, 0, 1024)) + values := benchVaruint36LargeValues + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.writerIndex = 0 + writeVaruint36SmallUnrolled(buf, values[i%len(values)]) + } +} diff --git a/go/fory/string.go b/go/fory/string.go index 48b942cffa..d334cf9494 100644 --- a/go/fory/string.go +++ b/go/fory/string.go @@ -39,8 +39,8 @@ func writeString(buf *ByteBuffer, value string) { // Reserve space for header (max 5 bytes) + data in one call buf.Reserve(5 + dataLen) - // Write header inline without grow check - buf.UnsafeWriteVaruint64(header) + // Write header inline + buf.WriteVaruint36Small(header) // Write data inline without grow check if dataLen > 0 { copy(buf.data[buf.writerIndex:], data) From a368d479eeab98bea56dca55afaed562a7322912 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 21:12:48 +0800 Subject: [PATCH 12/21] fix test error --- AGENTS.md | 5 ++ cpp/fory/meta/preprocessor.h | 6 +- cpp/fory/serialization/context.h | 10 +++ cpp/fory/serialization/temporal_serializers.h | 17 +++-- cpp/fory/serialization/xlang_test_main.cc | 2 +- cpp/fory/util/buffer.h | 8 +++ docs/guide/cpp/type-registration.md | 2 +- go/fory/buffer.go | 64 ++++++++----------- go/fory/optional/optional.go | 8 --- go/fory/struct_test.go | 24 +++---- integration_tests/idl_tests/cpp/main.cc | 2 - .../idl_tests/go/idl_roundtrip_test.go | 5 -- .../idl_tests/idl/complex_pb.proto | 1 - .../idl_tests/idl/optional_types.fdl | 1 - .../fory/idl_tests/IdlRoundTripTest.java | 2 - .../python/src/idl_tests/roundtrip.py | 3 - .../idl_tests/rust/tests/idl_roundtrip.rs | 2 - 17 files changed, 76 insertions(+), 86 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 415d213ce7..f8f59c3050 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -35,6 +35,7 @@ While working on Fory, please remember: - Fory java needs JDK `17+` installed. - Modules target different bytecode levels (fory-core Java 8, fory-format Java 11); avoid using newer APIs in those modules. - Use '.\*' form of import is not allowed. +- If you run temporary tests using `java -cp`, you must run `mvn -T16 install -DskipTests` to get latest jars for fory java library. ```bash # Clean the build @@ -302,6 +303,10 @@ sbt scalafmt - All commands must be executed within the `integration_tests` directory. - For java related integration tests, please install the java libraries first by `cd ../java && mvn -T16 install -DskipTests`. If no code changes after installed fory java, you can skip the installation step. - For mac, graalvm is installed at `/Library/Java/JavaVirtualMachines/graalvm-xxx` by default. +- For `integration_tests/idl_tests`: + - you must install fory java first before runing tests under this dir if any code changes under `java` dir. + - you must install fory python first before runing tests under this dir if any code changes under `python` dir. +- You are never allowed to manual edit generated code by fory compiler for `IDL` files, you must invoke fory compiler to regenerate code. ```bash it_dir=$(pwd) diff --git a/cpp/fory/meta/preprocessor.h b/cpp/fory/meta/preprocessor.h index 30339321ce..5ebe684ece 100644 --- a/cpp/fory/meta/preprocessor.h +++ b/cpp/fory/meta/preprocessor.h @@ -42,9 +42,9 @@ N #define FORY_PP_NARG_REV() \ - 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, \ - 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, \ - 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, \ + 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, \ + 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, \ + 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, \ 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 #define FORY_PP_HAS_COMMA(...) \ diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 69ad1256bb..46999e2b12 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -180,6 +180,16 @@ class WriteContext { buffer().WriteUint16(value); } + /// Write uint32_t value to buffer. + FORY_ALWAYS_INLINE void write_uint32(uint32_t value) { + buffer().WriteUint32(value); + } + + /// Write int64_t value to buffer. + FORY_ALWAYS_INLINE void write_int64(int64_t value) { + buffer().WriteInt64(value); + } + /// Write uint32_t value as varint to buffer. FORY_ALWAYS_INLINE void write_varuint32(uint32_t value) { buffer().WriteVarUint32(value); diff --git a/cpp/fory/serialization/temporal_serializers.h b/cpp/fory/serialization/temporal_serializers.h index 34874be069..8e0b6b49bd 100644 --- a/cpp/fory/serialization/temporal_serializers.h +++ b/cpp/fory/serialization/temporal_serializers.h @@ -131,7 +131,8 @@ template <> struct Serializer { // ============================================================================ /// Serializer for Timestamp -/// Per xlang spec: serialized as int64 seconds + uint32 nanoseconds since Unix epoch +/// Per xlang spec: serialized as int64 seconds + uint32 nanoseconds since Unix +/// epoch template <> struct Serializer { static constexpr TypeId type_id = TypeId::TIMESTAMP; @@ -170,8 +171,8 @@ template <> struct Serializer { } int64_t seconds_count = seconds.count(); uint32_t nanos_count = static_cast(remainder.count()); - ctx.write_bytes(&seconds_count, sizeof(int64_t)); - ctx.write_bytes(&nanos_count, sizeof(uint32_t)); + ctx.write_int64(seconds_count); + ctx.write_uint32(nanos_count); } static inline void write_data_generic(const Timestamp ×tamp, @@ -200,13 +201,11 @@ template <> struct Serializer { } static inline Timestamp read_data(ReadContext &ctx) { - int64_t seconds; - uint32_t nanos; - ctx.read_bytes(&seconds, sizeof(int64_t), ctx.error()); + int64_t seconds = ctx.read_int64(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { return Timestamp(std::chrono::nanoseconds(0)); } - ctx.read_bytes(&nanos, sizeof(uint32_t), ctx.error()); + uint32_t nanos = ctx.read_uint32(ctx.error()); return Timestamp(std::chrono::seconds(seconds) + std::chrono::nanoseconds(nanos)); } @@ -256,8 +255,8 @@ template <> struct Serializer { ctx.write_bytes(&date.days_since_epoch, sizeof(int32_t)); } - static inline void write_data_generic(const Date &date, - WriteContext &ctx, bool has_generics) { + static inline void write_data_generic(const Date &date, WriteContext &ctx, + bool has_generics) { write_data(date, ctx); } diff --git a/cpp/fory/serialization/xlang_test_main.cc b/cpp/fory/serialization/xlang_test_main.cc index d4fe9a12d3..3a05997d83 100644 --- a/cpp/fory/serialization/xlang_test_main.cc +++ b/cpp/fory/serialization/xlang_test_main.cc @@ -43,9 +43,9 @@ using ::fory::Buffer; using ::fory::Error; using ::fory::Result; +using ::fory::serialization::Date; using ::fory::serialization::Fory; using ::fory::serialization::ForyBuilder; -using ::fory::serialization::Date; using ::fory::serialization::Serializer; using ::fory::serialization::Timestamp; diff --git a/cpp/fory/util/buffer.h b/cpp/fory/util/buffer.h index 1c3581715e..496e79fb00 100644 --- a/cpp/fory/util/buffer.h +++ b/cpp/fory/util/buffer.h @@ -549,6 +549,14 @@ class Buffer { IncreaseWriterIndex(4); } + /// Write uint32_t value as fixed 4 bytes to buffer at current writer index. + /// Automatically grows buffer and advances writer index. + FORY_ALWAYS_INLINE void WriteUint32(uint32_t value) { + Grow(4); + UnsafePut(writer_index_, value); + IncreaseWriterIndex(4); + } + /// Write int64_t value as fixed 8 bytes to buffer at current writer index. /// Automatically grows buffer and advances writer index. FORY_ALWAYS_INLINE void WriteInt64(int64_t value) { diff --git a/docs/guide/cpp/type-registration.md b/docs/guide/cpp/type-registration.md index 21fa4985d0..99a5ca599f 100644 --- a/docs/guide/cpp/type-registration.md +++ b/docs/guide/cpp/type-registration.md @@ -184,7 +184,7 @@ Built-in types have pre-assigned type IDs and don't need registration: | 15 | SET | | 16 | TIMESTAMP | | 17 | DURATION | -| 18 | DATE | +| 18 | DATE | | 19 | DECIMAL | | 20 | BINARY | | 21 | ARRAY | diff --git a/go/fory/buffer.go b/go/fory/buffer.go index 4e571e979d..1e90c40597 100644 --- a/go/fory/buffer.go +++ b/go/fory/buffer.go @@ -661,7 +661,7 @@ func (b *ByteBuffer) UnsafePutVaruint32(offset int, value uint32) int { // UnsafePutVarInt64 writes a zigzag-encoded varint64 at the given offset without advancing writerIndex. // Caller must have called Reserve() to ensure capacity. -// Returns the number of bytes written (1-10). +// Returns the number of bytes written (1-9). // //go:inline func (b *ByteBuffer) UnsafePutVarInt64(offset int, value int64) int { @@ -671,7 +671,7 @@ func (b *ByteBuffer) UnsafePutVarInt64(offset int, value int64) int { // UnsafePutVaruint64 writes an unsigned varuint64 at the given offset without advancing writerIndex. // Caller must have called Reserve(16) to ensure capacity (for bulk writes). -// Returns the number of bytes written (1-10). +// Returns the number of bytes written (1-9). func (b *ByteBuffer) UnsafePutVaruint64(offset int, value uint64) int { if value>>7 == 0 { b.data[offset] = byte(value) @@ -770,7 +770,7 @@ func (b *ByteBuffer) UnsafePutVaruint64(offset int, value uint64) int { } return 8 } - // 9 or 10 bytes needed + // 9 bytes needed encoded := uint64((value&0x7F)|0x80) | uint64(((value>>7)&0x7F)|0x80)<<8 | uint64(((value>>14)&0x7F)|0x80)<<16 | @@ -785,14 +785,8 @@ func (b *ByteBuffer) UnsafePutVaruint64(offset int, value uint64) int { binary.LittleEndian.PutUint64(b.data[offset:], encoded) } - remaining := value >> 56 - if remaining>>7 == 0 { - b.data[offset+8] = byte(remaining) - return 9 - } - b.data[offset+8] = byte(remaining) | 0x80 - b.data[offset+9] = byte(remaining >> 7) - return 10 + b.data[offset+8] = byte(value >> 56) + return 9 } //go:inline @@ -824,21 +818,23 @@ func (b *ByteBuffer) WriteVarint64(value int64) { b.WriteVaruint64(u) } -// WriteVaruint64 writes to unsigned varint (up to 10 bytes) +// WriteVaruint64 writes to unsigned varint (up to 9 bytes) func (b *ByteBuffer) WriteVaruint64(value uint64) { - b.grow(10) + b.grow(9) offset := b.writerIndex - data := b.data[offset : offset+10] + data := b.data[offset : offset+9] - i := 0 - for value >= 0x80 { + for i := 0; i < 8; i++ { + if value < 0x80 { + data[i] = byte(value) + b.writerIndex += i + 1 + return + } data[i] = byte(value&0x7F) | 0x80 value >>= 7 - i++ } - data[i] = byte(value) - i++ - b.writerIndex += i + data[8] = byte(value) + b.writerIndex += 9 } // WriteVaruint36Small writes a varint optimized for small values (up to 36 bits) @@ -1061,7 +1057,7 @@ func (b *ByteBuffer) ReadTaggedUint64(err *Error) uint64 { // //go:inline func (b *ByteBuffer) ReadVaruint64(err *Error) uint64 { - if b.remaining() >= 10 { + if b.remaining() >= 9 { return b.readVaruint64Fast() } return b.readVaruint64Slow(err) @@ -1104,16 +1100,10 @@ func (b *ByteBuffer) readVaruint64Fast() uint64 { result |= (bulk >> 7) & 0xFE000000000000 readLength = 8 if (bulk & 0x8000000000000000) != 0 { - // Need 9th byte (and possibly 10th if continuation bit is set) + // Need 9th byte (full 8 bits) b9 := b.data[b.readerIndex+8] - result |= uint64(b9&0x7F) << 56 + result |= uint64(b9) << 56 readLength = 9 - if (b9 & 0x80) != 0 { - // 10th byte carries the remaining bits - b10 := b.data[b.readerIndex+9] - result |= uint64(b10) << 63 - readLength = 10 - } } } } @@ -1130,7 +1120,7 @@ func (b *ByteBuffer) readVaruint64Fast() uint64 { func (b *ByteBuffer) readVaruint64Slow(err *Error) uint64 { var result uint64 var shift uint - for { + for i := 0; i < 8; i++ { if b.readerIndex >= len(b.data) { *err = BufferOutOfBoundError(b.readerIndex, 1, len(b.data)) return 0 @@ -1139,15 +1129,17 @@ func (b *ByteBuffer) readVaruint64Slow(err *Error) uint64 { b.readerIndex++ result |= (uint64(byteVal) & 0x7F) << shift if byteVal < 0x80 { - break + return result } shift += 7 - if shift >= 64 { - *err = DeserializationError("varuint64 overflow") - return 0 - } } - return result + if b.readerIndex >= len(b.data) { + *err = BufferOutOfBoundError(b.readerIndex, 1, len(b.data)) + return 0 + } + byteVal := b.data[b.readerIndex] + b.readerIndex++ + return result | (uint64(byteVal) << 56) } // Auxiliary function diff --git a/go/fory/optional/optional.go b/go/fory/optional/optional.go index 857fa7d3db..43cfbab2c5 100644 --- a/go/fory/optional/optional.go +++ b/go/fory/optional/optional.go @@ -35,14 +35,6 @@ func None[T any]() Optional[T] { return Optional[T]{} } -// FromPtr converts a pointer to an Optional. -func FromPtr[T any](v *T) Optional[T] { - if v == nil { - return None[T]() - } - return Some(*v) -} - // IsSome reports whether the optional contains a value. func (o Optional[T]) IsSome() bool { return o.has } diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go index d2b504c0aa..7475a26cfe 100644 --- a/go/fory/struct_test.go +++ b/go/fory/struct_test.go @@ -207,18 +207,18 @@ func TestNumericPointerOptionalInterop(t *testing.T) { var out NumericOptStruct require.NoError(t, reader.Unmarshal(data, &out)) - require.False(t, out.I8.IsSome()) - require.False(t, out.I16.IsSome()) - require.False(t, out.I32.IsSome()) - require.False(t, out.I64.IsSome()) - require.False(t, out.I.IsSome()) - require.False(t, out.U8.IsSome()) - require.False(t, out.U16.IsSome()) - require.False(t, out.U32.IsSome()) - require.False(t, out.U64.IsSome()) - require.False(t, out.U.IsSome()) - require.False(t, out.F32.IsSome()) - require.False(t, out.F64.IsSome()) + require.False(t, out.I8.IsSome()) + require.False(t, out.I16.IsSome()) + require.False(t, out.I32.IsSome()) + require.False(t, out.I64.IsSome()) + require.False(t, out.I.IsSome()) + require.False(t, out.U8.IsSome()) + require.False(t, out.U16.IsSome()) + require.False(t, out.U32.IsSome()) + require.False(t, out.U64.IsSome()) + require.False(t, out.U.IsSome()) + require.False(t, out.F32.IsSome()) + require.False(t, out.F64.IsSome()) }) t.Run("PointerToOptionalValue", func(t *testing.T) { diff --git a/integration_tests/idl_tests/cpp/main.cc b/integration_tests/idl_tests/cpp/main.cc index 72ef5383d5..f0e65955cd 100644 --- a/integration_tests/idl_tests/cpp/main.cc +++ b/integration_tests/idl_tests/cpp/main.cc @@ -126,7 +126,6 @@ fory::Result RunRoundTrip() { types.set_uint64_value(9876543210ULL); types.set_var_uint64_value(12345678901ULL); types.set_tagged_uint64_value(2222222222ULL); - types.set_float16_value(1.5F); types.set_float32_value(2.5F); types.set_float64_value(3.5); types.set_contact(addressbook::PrimitiveTypes::Contact::email( @@ -224,7 +223,6 @@ fory::Result RunRoundTrip() { all_types.set_fixed_uint64_value(9876543210ULL); all_types.set_var_uint64_value(12345678901ULL); all_types.set_tagged_uint64_value(2222222222ULL); - all_types.set_float16_value(1.5F); all_types.set_float32_value(2.5F); all_types.set_float64_value(3.5); all_types.set_string_value("optional"); diff --git a/integration_tests/idl_tests/go/idl_roundtrip_test.go b/integration_tests/idl_tests/go/idl_roundtrip_test.go index 738a1d24ad..1c97cf406d 100644 --- a/integration_tests/idl_tests/go/idl_roundtrip_test.go +++ b/integration_tests/idl_tests/go/idl_roundtrip_test.go @@ -164,7 +164,6 @@ func buildPrimitiveTypes() PrimitiveTypes { Uint64Value: 9876543210, VarUint64Value: 12345678901, TaggedUint64Value: 2222222222, - Float16Value: 1.5, Float32Value: 2.5, Float64Value: 3.5, Contact: &contact, @@ -368,7 +367,6 @@ func buildOptionalHolder() optionaltypes.OptionalHolder { FixedUint64Value: optional.Some(uint64(9876543210)), VarUint64Value: optional.Some(uint64(12345678901)), TaggedUint64Value: optional.Some(uint64(2222222222)), - Float16Value: optional.Some(float32(1.5)), Float32Value: optional.Some(float32(2.5)), Float64Value: optional.Some(3.5), StringValue: optional.Some("optional"), @@ -535,9 +533,6 @@ func assertOptionalTypesEqual(t *testing.T, expected, actual *optionaltypes.AllO if expected.TaggedUint64Value != actual.TaggedUint64Value { t.Fatalf("tagged_uint64_value mismatch: %#v != %#v", expected.TaggedUint64Value, actual.TaggedUint64Value) } - if expected.Float16Value != actual.Float16Value { - t.Fatalf("float16_value mismatch: %#v != %#v", expected.Float16Value, actual.Float16Value) - } if expected.Float32Value != actual.Float32Value { t.Fatalf("float32_value mismatch: %#v != %#v", expected.Float32Value, actual.Float32Value) } diff --git a/integration_tests/idl_tests/idl/complex_pb.proto b/integration_tests/idl_tests/idl/complex_pb.proto index c3b51a4c4a..7c84ed5169 100644 --- a/integration_tests/idl_tests/idl/complex_pb.proto +++ b/integration_tests/idl_tests/idl/complex_pb.proto @@ -39,7 +39,6 @@ message PrimitiveTypes { fixed64 uint64_value = 13; uint64 var_uint64_value = 14; uint64 tagged_uint64_value = 15 [ (fory).type = "tagged_uint64" ]; - float16 float16_value = 16; float float32_value = 17; double float64_value = 18; diff --git a/integration_tests/idl_tests/idl/optional_types.fdl b/integration_tests/idl_tests/idl/optional_types.fdl index a59baec9b3..31d164e5c5 100644 --- a/integration_tests/idl_tests/idl/optional_types.fdl +++ b/integration_tests/idl_tests/idl/optional_types.fdl @@ -37,7 +37,6 @@ message AllOptionalTypes [id=120] { optional fixed_uint64 fixed_uint64_value = 17; optional var_uint64 var_uint64_value = 18; optional tagged_uint64 tagged_uint64_value = 19; - optional float16 float16_value = 20; optional float32 float32_value = 21; optional float64 float64_value = 22; optional string string_value = 23; diff --git a/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java b/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java index 9695559c27..b075dda121 100644 --- a/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java +++ b/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java @@ -356,7 +356,6 @@ private PrimitiveTypes buildPrimitiveTypes() { types.setUint64Value(9876543210L); types.setVarUint64Value(12345678901L); types.setTaggedUint64Value(2222222222L); - types.setFloat16Value(1.5f); types.setFloat32Value(2.5f); types.setFloat64Value(3.5d); PrimitiveTypes.Contact contact = PrimitiveTypes.Contact.ofEmail("alice@example.com"); @@ -403,7 +402,6 @@ private OptionalHolder buildOptionalHolder() { allTypes.setFixedUint64Value(9876543210L); allTypes.setVarUint64Value(12345678901L); allTypes.setTaggedUint64Value(2222222222L); - allTypes.setFloat16Value(1.5f); allTypes.setFloat32Value(2.5f); allTypes.setFloat64Value(3.5); allTypes.setStringValue("optional"); diff --git a/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py b/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py index 9b00f8419a..54c3874b5f 100644 --- a/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py +++ b/integration_tests/idl_tests/python/src/idl_tests/roundtrip.py @@ -96,7 +96,6 @@ def build_primitive_types() -> "addressbook.PrimitiveTypes": uint64_value=9876543210, var_uint64_value=12345678901, tagged_uint64_value=2222222222, - float16_value=1.5, float32_value=2.5, float64_value=3.5, contact=contact, @@ -124,7 +123,6 @@ def build_optional_holder() -> "optional_types.OptionalHolder": fixed_uint64_value=9876543210, var_uint64_value=12345678901, tagged_uint64_value=2222222222, - float16_value=1.5, float32_value=2.5, float64_value=3.5, string_value="optional", @@ -292,7 +290,6 @@ def assert_optional_types_equal( assert decoded.fixed_uint64_value == expected.fixed_uint64_value assert decoded.var_uint64_value == expected.var_uint64_value assert decoded.tagged_uint64_value == expected.tagged_uint64_value - assert decoded.float16_value == expected.float16_value assert decoded.float32_value == expected.float32_value assert decoded.float64_value == expected.float64_value assert decoded.string_value == expected.string_value diff --git a/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs b/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs index e62178965b..3b7a2163ef 100644 --- a/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs +++ b/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs @@ -86,7 +86,6 @@ fn build_primitive_types() -> addressbook::PrimitiveTypes { uint64_value: 9876543210, var_uint64_value: 12345678901, tagged_uint64_value: 2222222222, - float16_value: 1.5, float32_value: 2.5, float64_value: 3.5, contact: Some(contact), @@ -162,7 +161,6 @@ fn build_optional_holder() -> OptionalHolder { fixed_uint64_value: Some(9876543210), var_uint64_value: Some(12345678901), tagged_uint64_value: Some(2222222222), - float16_value: Some(1.5), float32_value: Some(2.5), float64_value: Some(3.5), string_value: Some("optional".to_string()), From b4c44eeea73e49bc0f47e5ab33490546a70c2e51 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 22:00:32 +0800 Subject: [PATCH 13/21] fix(go): align nullable compatibility checks --- cpp/fory/meta/field.h | 410 +++++++++++++++++++----------------------- go/fory/struct.go | 51 ++++-- 2 files changed, 219 insertions(+), 242 deletions(-) diff --git a/cpp/fory/meta/field.h b/cpp/fory/meta/field.h index 78b445dccd..616facbff3 100644 --- a/cpp/fory/meta/field.h +++ b/cpp/fory/meta/field.h @@ -1106,51 +1106,46 @@ struct GetFieldTagEntry struct field index fieldTagIDToOffset := make(map[int]uintptr) // tag ID -> field offset fieldTagIDToType := make(map[int]reflect.Type) // tag ID -> field type @@ -473,6 +501,7 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err fieldNameToIndex[name] = i fieldNameToOffset[name] = field.Offset fieldNameToType[name] = field.Type + localNullableByIndex[i] = computeLocalNullable(typeResolver, field, foryTag) // Also index by tag ID if present if foryTag.ID >= 0 { @@ -818,8 +847,8 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err // Local nullable is determined by whether the Go field is a pointer type if i < len(s.fieldDefs) && field.Meta.FieldIndex >= 0 { remoteNullable := s.fieldDefs[i].nullable - // Check if local Go field is nullable based on computed field metadata - localNullable := field.Meta.Nullable + // Check if local Go field is nullable based on local field definitions + localNullable := localNullableByIndex[field.Meta.FieldIndex] if remoteNullable != localNullable { s.typeDefDiffers = true break @@ -3138,22 +3167,8 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val } } - for i := range s.fieldGroup.FixedFields { - field := &s.fieldGroup.FixedFields[i] - readField(field) - if ctx.HasError() { - return - } - } - for i := range s.fieldGroup.VarintFields { - field := &s.fieldGroup.VarintFields[i] - readField(field) - if ctx.HasError() { - return - } - } - for i := range s.fieldGroup.RemainingFields { - field := &s.fieldGroup.RemainingFields[i] + for i := range s.fields { + field := &s.fields[i] readField(field) if ctx.HasError() { return From 3fc1a6042f1dc6e9e319211a9e546f26f9c09855 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 22:31:11 +0800 Subject: [PATCH 14/21] fix python tests --- AGENTS.md | 2 +- python/pyfory/serializer.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index f8f59c3050..8d73f2c7fc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -305,7 +305,7 @@ sbt scalafmt - For mac, graalvm is installed at `/Library/Java/JavaVirtualMachines/graalvm-xxx` by default. - For `integration_tests/idl_tests`: - you must install fory java first before runing tests under this dir if any code changes under `java` dir. - - you must install fory python first before runing tests under this dir if any code changes under `python` dir. + - you must install fory python first before runing tests under this dir if any code cython code changes under `python` dir. - You are never allowed to manual edit generated code by fory compiler for `IDL` files, you must invoke fory compiler to regenerate code. ```bash diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index c5e3457491..b3d3ef53de 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -560,6 +560,12 @@ def read(self, buffer): return fory_buf return fory_buf.to_pybytes() + def xwrite(self, buffer, value): + buffer.write_bytes_and_size(value) + + def xread(self, buffer): + return buffer.read_bytes_and_size() + class BytesBufferObject(BufferObject): __slots__ = ("binary",) From 1c575a420505b578aca7b2c0a2435942d05b9eaf Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 23:18:06 +0800 Subject: [PATCH 15/21] fix array group --- go/fory/field_info.go | 42 ++++++++++++++++++------------- go/fory/struct.go | 6 +++++ go/fory/tests/structs_fory_gen.go | 2 +- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/go/fory/field_info.go b/go/fory/field_info.go index b12fdc620c..9c1fb1ce21 100644 --- a/go/fory/field_info.go +++ b/go/fory/field_info.go @@ -659,10 +659,12 @@ func sortFields( } typeTriples = append(typeTriples, triple{typeIds[i], ser, name, nullables[i], tagID}) } - // Java orders: primitives, boxed, finals, others, collections, maps + // Ordering: + // 1) primitives (nullable=false), 2) primitives (nullable=true), + // 3) built-in non-container, 4) list/set, 5) map, 6) user-defined/unknown // primitives = non-nullable primitive types (int, long, etc.) // boxed = nullable boxed types (Integer, Long, etc. which are pointers in Go) - var primitives, boxed, collection, otherInternalTypeFields []triple + var primitives, boxed, listSet, maps, otherInternalTypeFields []triple for _, t := range typeTriples { switch { @@ -674,11 +676,14 @@ func sortFields( primitives = append(primitives, t) } case isPrimitiveArrayType(t.typeID): - // Primitive arrays: sorted by name only (category 2 in reflection) - collection = append(collection, t) - case isListType(t.typeID), isSetType(t.typeID), isMapType(t.typeID): - // LIST, SET, MAP: sorted by typeId, name (category 1 in reflection) + // Primitive arrays: built-in non-container types (sorted by typeId then name) otherInternalTypeFields = append(otherInternalTypeFields, t) + case isListType(t.typeID), isSetType(t.typeID): + // LIST, SET: collection group + listSet = append(listSet, t) + case isMapType(t.typeID): + // MAP: map group + maps = append(maps, t) case isUserDefinedType(t.typeID): userDefined = append(userDefined, t) case t.typeID == UNKNOWN: @@ -734,21 +739,24 @@ func sortFields( }) } sortByTypeIDThenName(otherInternalTypeFields) + sortByTypeIDThenName(listSet) + sortByTypeIDThenName(maps) // Merge all category 2 fields (primitive arrays, userDefined, others) and sort by name - // This matches GroupFields' getFieldCategory which sorts all category 2 fields together - category2 := make([]triple, 0, len(collection)+len(userDefined)+len(others)) - category2 = append(category2, collection...) // primitive arrays - category2 = append(category2, userDefined...) // structs, enums - category2 = append(category2, others...) // unknown types - sortTuple(category2) - - // Order: primitives, boxed, internal types (STRING/BINARY/LIST/SET/MAP), category 2 (by name) - // This aligns with GroupFields' getFieldCategory sorting + // This matches GroupFields' getFieldCategory which sorts all category 4 fields together + otherGroup := make([]triple, 0, len(userDefined)+len(others)) + otherGroup = append(otherGroup, userDefined...) // structs, enums, ext + otherGroup = append(otherGroup, others...) // unknown types + sortTuple(otherGroup) + + // Order: primitives, boxed, built-in non-container, list/set, map, other (by name) + // This aligns with GroupFields' getFieldCategory sorting and spec ordering. all := make([]triple, 0, len(fieldNames)) all = append(all, primitives...) all = append(all, boxed...) - all = append(all, otherInternalTypeFields...) // STRING, BINARY, LIST, SET, MAP (category 1) - all = append(all, category2...) // all category 2 fields sorted by name + all = append(all, otherInternalTypeFields...) // STRING, BINARY, primitive arrays, time, unions, etc. + all = append(all, listSet...) + all = append(all, maps...) + all = append(all, otherGroup...) outSer := make([]Serializer, len(all)) outNam := make([]string, len(all)) diff --git a/go/fory/struct.go b/go/fory/struct.go index fa10b9e2ac..bb034fad89 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -3180,6 +3180,12 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val // Uses context error state for deferred error checking. func (s *structSerializer) skipField(ctx *ReadContext, field *FieldInfo) { if field.Meta.FieldDef.name != "" { + if DebugOutputEnabled() { + fmt.Printf("[Go][fory-debug] skipField name=%s typeId=%d fieldType=%s\n", + field.Meta.FieldDef.name, + field.Meta.FieldDef.fieldType.TypeId(), + fieldTypeToString(field.Meta.FieldDef.fieldType)) + } fieldDefIsStructType := isStructFieldType(field.Meta.FieldDef.fieldType) // Use FieldDef's trackingRef and nullable to determine if ref flag was written by Java // Java writes ref flag based on its FieldDef, not Go's field type diff --git a/go/fory/tests/structs_fory_gen.go b/go/fory/tests/structs_fory_gen.go index 707ac18ff5..c04d0702bc 100644 --- a/go/fory/tests/structs_fory_gen.go +++ b/go/fory/tests/structs_fory_gen.go @@ -1,6 +1,6 @@ // Code generated by forygen. DO NOT EDIT. // source: structs.go -// generated at: 2026-01-22T08:57:50+08:00 +// generated at: 2026-01-26T23:14:39+08:00 package fory From cef791d52a2c86bc562390f849aa0baff048d4ea Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 26 Jan 2026 23:46:08 +0800 Subject: [PATCH 16/21] skip out-of-band buffer --- .../test/java/org/apache/fory/xlang/PyCrossLanguageTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/PyCrossLanguageTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/PyCrossLanguageTest.java index ac92223892..43f980e2b3 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/PyCrossLanguageTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/PyCrossLanguageTest.java @@ -704,7 +704,7 @@ private byte[] roundBytes(String testName, byte[] bytes) throws IOException { return Files.readAllBytes(dataFile); } - @Test + @Test(enabled = false) public void testOutOfBandBuffer() throws Exception { Fory fory = Fory.builder() From bb98db15ce5427f210cdc1358f93b8b8fb8a7d5d Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 27 Jan 2026 00:15:30 +0800 Subject: [PATCH 17/21] fix lint --- README.md | 4 ++-- python/pyfory/struct.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 572b061fec..d752dfeed9 100644 --- a/README.md +++ b/README.md @@ -553,8 +553,8 @@ For more details on row format, see [Row Format Specification](docs/specificatio ### User Guides -| Guide | Description | Source | Website | -| -------------------------------- | ------------------------------------------ | ----------------------------------------------- | ------------------------------------------------------------------------ | +| Guide | Description | Source | Website | +| -------------------------------- | ------------------------------------------ | ----------------------------------------------- | ------------------------------------------------------------------- | | **Java Serialization** | Comprehensive guide for Java serialization | [java](docs/guide/java) | [📖 View](https://fory.apache.org/docs/guide/java/) | | **Python** | Python-specific features and usage | [python](docs/guide/python) | [📖 View](https://fory.apache.org/docs/guide/python/) | | **Rust** | Rust implementation and patterns | [rust](docs/guide/rust) | [📖 View](https://fory.apache.org/docs/guide/rust/) | diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py index 7d3d1a4f19..0dfcb6177d 100644 --- a/python/pyfory/struct.py +++ b/python/pyfory/struct.py @@ -1330,9 +1330,6 @@ def compute_struct_fingerprint(type_resolver, field_names, serializers, nullable nullable_flag = "1" if nullable_map.get(field_name, False) else "0" else: type_id = type_resolver.get_typeinfo(serializer.type_).type_id & 0xFF - # For xlang, user-defined types use UNKNOWN in fingerprint to match other languages. - if not type_resolver.fory.is_py and type_id >= TypeId.BOUND: - type_id = TypeId.UNKNOWN if type_id in {TypeId.TYPED_UNION, TypeId.NAMED_UNION}: type_id = TypeId.UNION is_nullable = nullable_map.get(field_name, False) From 26a17513b35619c9d9781fb94bbaef6a8d7de295 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 27 Jan 2026 00:28:53 +0800 Subject: [PATCH 18/21] fix oob serialization in python --- .../fory/xlang/PyCrossLanguageTest.java | 2 +- python/pyfory/_fory.py | 39 +++++++++++++------ python/pyfory/serialization.pyx | 34 ++++++++++++---- python/pyfory/serializer.py | 11 +++--- 4 files changed, 61 insertions(+), 25 deletions(-) diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/PyCrossLanguageTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/PyCrossLanguageTest.java index 43f980e2b3..ac92223892 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/PyCrossLanguageTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/PyCrossLanguageTest.java @@ -704,7 +704,7 @@ private byte[] roundBytes(String testName, byte[] bytes) throws IOException { return Files.readAllBytes(dataFile); } - @Test(enabled = false) + @Test public void testOutOfBandBuffer() throws Exception { Fory fory = Fory.builder() diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 0e029e7090..bc68c437de 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -154,12 +154,13 @@ class Fory: "serialization_context", "strict", "buffer", - "_buffer_callback", + "buffer_callback", "_buffers", "metastring_resolver", "_unsupported_callback", "_unsupported_objects", "_peer_language", + "is_peer_out_of_band_enabled", "max_depth", "depth", "field_nullable", @@ -249,11 +250,12 @@ def __init__( self.type_resolver.initialize() self.buffer = Buffer.allocate(32) - self._buffer_callback = None + self.buffer_callback = None self._buffers = None self._unsupported_callback = None self._unsupported_objects = None self._peer_language = None + self.is_peer_out_of_band_enabled = False self.max_depth = max_depth self.depth = 0 @@ -467,7 +469,7 @@ def _serialize( ) -> Union[Buffer, bytes]: assert self.depth == 0, "Nested serialization should use write_ref/write_no_ref/xwrite_ref/xwrite_no_ref." self.depth += 1 - self._buffer_callback = buffer_callback + self.buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback if buffer is None: self.buffer.writer_index = 0 @@ -489,7 +491,7 @@ def _serialize( else: # set reader as native. clear_bit(buffer, mask_index, 1) - if self._buffer_callback is not None: + if self.buffer_callback is not None: set_bit(buffer, mask_index, 2) else: clear_bit(buffer, mask_index, 2) @@ -618,8 +620,8 @@ def _deserialize( self._peer_language = Language(buffer.read_int8()) else: self._peer_language = Language.PYTHON - is_out_of_band_serialization_enabled = get_bit(buffer, reader_index, 2) - if is_out_of_band_serialization_enabled: + self.is_peer_out_of_band_enabled = get_bit(buffer, reader_index, 2) + if self.is_peer_out_of_band_enabled: assert buffers is not None, "buffers shouldn't be null when the serialized stream is produced with buffer_callback not null." self._buffers = iter(buffers) else: @@ -692,7 +694,17 @@ def _xread_no_ref_internal(self, buffer, serializer): return o def write_buffer_object(self, buffer, buffer_object: BufferObject): - if self._buffer_callback is None or self._buffer_callback(buffer_object): + if self.buffer_callback is None: + size = buffer_object.total_bytes() + # writer length. + buffer.write_varuint32(size) + writer_index = buffer.writer_index + buffer.ensure(writer_index + size) + buf = buffer.slice(buffer.writer_index, size) + buffer_object.write_to(buf) + buffer.writer_index += size + return + if self.buffer_callback(buffer_object): buffer.write_bool(True) size = buffer_object.total_bytes() # writer length. @@ -706,15 +718,19 @@ def write_buffer_object(self, buffer, buffer_object: BufferObject): buffer.write_bool(False) def read_buffer_object(self, buffer) -> Buffer: - in_band = buffer.read_bool() - if in_band: + if not self.is_peer_out_of_band_enabled: size = buffer.read_varuint32() buf = buffer.slice(buffer.reader_index, size) buffer.reader_index += size return buf - else: + in_band = buffer.read_bool() + if not in_band: assert self._buffers is not None return next(self._buffers) + size = buffer.read_varuint32() + buf = buffer.slice(buffer.reader_index, size) + buffer.reader_index += size + return buf def handle_unsupported_write(self, buffer, obj): if self._unsupported_callback is None or self._unsupported_callback(obj): @@ -747,7 +763,7 @@ def reset_write(self): self.type_resolver.reset_write() self.serialization_context.reset_write() self.metastring_resolver.reset_write() - self._buffer_callback = None + self.buffer_callback = None self._unsupported_callback = None def reset_read(self): @@ -764,6 +780,7 @@ def reset_read(self): self.metastring_resolver.reset_write() self._buffers = None self._unsupported_objects = None + self.is_peer_out_of_band_enabled = False def reset(self): """ diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 55752d666f..8c6cf43378 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -943,6 +943,7 @@ cdef class Fory: cdef object _unsupported_callback cdef object _unsupported_objects # iterator cdef object _peer_language + cdef c_bool is_peer_out_of_band_enabled cdef int32_t max_depth cdef int32_t depth @@ -1030,6 +1031,7 @@ cdef class Fory: self._unsupported_callback = None self._unsupported_objects = None self._peer_language = None + self.is_peer_out_of_band_enabled = False self.depth = 0 self.max_depth = max_depth @@ -1376,9 +1378,8 @@ cdef class Fory: self._peer_language = Language(buffer.read_int8()) else: self._peer_language = Language.PYTHON - cdef c_bool is_out_of_band_serialization_enabled = \ - get_bit(buffer, reader_index, 2) - if is_out_of_band_serialization_enabled: + self.is_peer_out_of_band_enabled = get_bit(buffer, reader_index, 2) + if self.is_peer_out_of_band_enabled: assert buffers is not None, ( "buffers shouldn't be null when the serialized stream is " "produced with buffer_callback not null." @@ -1495,7 +1496,17 @@ cdef class Fory: cdef int32_t size cdef int32_t writer_index cdef Buffer buf - if self.buffer_callback is None or self.buffer_callback(buffer_object): + if self.buffer_callback is None: + size = buffer_object.total_bytes() + # writer length. + buffer.write_varuint32(size) + writer_index = buffer.writer_index + buffer.ensure(writer_index + size) + buf = buffer.slice(buffer.writer_index, size) + buffer_object.write_to(buf) + buffer.writer_index += size + return + if self.buffer_callback(buffer_object): buffer.write_bool(True) size = buffer_object.total_bytes() # writer length. @@ -1509,12 +1520,20 @@ cdef class Fory: buffer.write_bool(False) cpdef inline object read_buffer_object(self, Buffer buffer): - cdef c_bool in_band = buffer.read_bool() + cdef c_bool in_band + cdef int32_t size + cdef Buffer buf + if not self.is_peer_out_of_band_enabled: + size = buffer.read_varuint32() + buf = buffer.slice(buffer.reader_index, size) + buffer.reader_index += size + return buf + in_band = buffer.read_bool() if not in_band: assert self._buffers is not None return next(self._buffers) - cdef int32_t size = buffer.read_varuint32() - cdef Buffer buf = buffer.slice(buffer.reader_index, size) + size = buffer.read_varuint32() + buf = buffer.slice(buffer.reader_index, size) buffer.reader_index += size return buf @@ -1576,6 +1595,7 @@ cdef class Fory: self.serialization_context.reset_read() self._buffers = None self._unsupported_objects = None + self.is_peer_out_of_band_enabled = False cpdef inline reset(self): """ diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index b3d3ef53de..4aca4f07cf 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -550,9 +550,14 @@ def read(self, buffer): class BytesSerializer(XlangCompatibleSerializer): def write(self, buffer, value): + if self.fory.buffer_callback is None: + buffer.write_bytes_and_size(value) + return self.fory.write_buffer_object(buffer, BytesBufferObject(value)) def read(self, buffer): + if not self.fory.is_peer_out_of_band_enabled: + return buffer.read_bytes_and_size() fory_buf = self.fory.read_buffer_object(buffer) if isinstance(fory_buf, memoryview): return bytes(fory_buf) @@ -560,12 +565,6 @@ def read(self, buffer): return fory_buf return fory_buf.to_pybytes() - def xwrite(self, buffer, value): - buffer.write_bytes_and_size(value) - - def xread(self, buffer): - return buffer.read_bytes_and_size() - class BytesBufferObject(BufferObject): __slots__ = ("binary",) From a3dbaa58d86232173bb0165025b0684e819bc672 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 27 Jan 2026 00:54:13 +0800 Subject: [PATCH 19/21] fix ci --- .github/workflows/ci.yml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2a17f40994..f056214cfe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,6 +17,10 @@ name: Fory CI +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + on: push: branches: @@ -110,7 +114,7 @@ jobs: java21_windows: name: Windows Java 21 CI - runs-on: windows-2022 + runs-on: windows-2025 env: MY_VAR: "PATH" strategy: @@ -252,7 +256,7 @@ jobs: strategy: matrix: node-version: [18, 20, 24] - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest, windows-2025] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v5 @@ -341,7 +345,7 @@ jobs: name: C++ CI strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-2022] + os: [ubuntu-latest, macos-latest, windows-2025] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v5 @@ -510,7 +514,7 @@ jobs: strategy: matrix: python-version: [3.8, 3.12, 3.13.3] - os: [ubuntu-latest, ubuntu-24.04-arm, macos-latest, windows-2022] + os: [ubuntu-latest, ubuntu-24.04-arm, macos-latest, windows-2025] steps: - uses: actions/checkout@v5 - name: Set up Python ${{ matrix.python-version }} From 5c0200ebaa201a5df4c4daab6ab82f4a6b32b7be Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 27 Jan 2026 01:04:45 +0800 Subject: [PATCH 20/21] fix(python): expose peer buffer flag for bytes read --- python/pyfory/primitive.pxi | 2 +- python/pyfory/serialization.pyx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyfory/primitive.pxi b/python/pyfory/primitive.pxi index 7f932962cf..9c2eff67ed 100644 --- a/python/pyfory/primitive.pxi +++ b/python/pyfory/primitive.pxi @@ -280,6 +280,6 @@ cdef class TimestampSerializer(XlangCompatibleSerializer): cpdef inline read(self, Buffer buffer): cdef long long seconds = buffer.read_int64() cdef unsigned int nanos = buffer.read_uint32() - ts = seconds + nanos / 1000000000 + ts = seconds + (nanos) / 1000000000.0 # TODO support timezone return datetime.datetime.fromtimestamp(ts) diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 8c6cf43378..49aa54e59d 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -943,7 +943,7 @@ cdef class Fory: cdef object _unsupported_callback cdef object _unsupported_objects # iterator cdef object _peer_language - cdef c_bool is_peer_out_of_band_enabled + cdef public bint is_peer_out_of_band_enabled cdef int32_t max_depth cdef int32_t depth From e5045e6d2cbc12c51e425ac889e8ef9f5cf3fa93 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 27 Jan 2026 01:17:19 +0800 Subject: [PATCH 21/21] ci: revert windows-2025 runners --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f056214cfe..f126ded247 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -114,7 +114,7 @@ jobs: java21_windows: name: Windows Java 21 CI - runs-on: windows-2025 + runs-on: windows-2022 env: MY_VAR: "PATH" strategy: @@ -256,7 +256,7 @@ jobs: strategy: matrix: node-version: [18, 20, 24] - os: [ubuntu-latest, macos-latest, windows-2025] + os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v5 @@ -345,7 +345,7 @@ jobs: name: C++ CI strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-2025] + os: [ubuntu-latest, macos-latest, windows-2022] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v5 @@ -514,7 +514,7 @@ jobs: strategy: matrix: python-version: [3.8, 3.12, 3.13.3] - os: [ubuntu-latest, ubuntu-24.04-arm, macos-latest, windows-2025] + os: [ubuntu-latest, ubuntu-24.04-arm, macos-latest, windows-2022] steps: - uses: actions/checkout@v5 - name: Set up Python ${{ matrix.python-version }}