From d084ea59063e73d1c1fa53bbf071f814af375309 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 15 May 2026 02:15:08 +0800 Subject: [PATCH 1/9] add scala docs --- docs/compiler/generated-code.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/compiler/generated-code.md b/docs/compiler/generated-code.md index e3a8f9b7eb..3f88028da8 100644 --- a/docs/compiler/generated-code.md +++ b/docs/compiler/generated-code.md @@ -1042,6 +1042,8 @@ void main() { } ``` +## Scala + ## Cross-Language Notes ### Type ID Behavior From fef775b6e100d540b953b46844b23404ee6140d1 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 15 May 2026 07:26:21 +0800 Subject: [PATCH 2/9] feat(scala): add xlang schema idl support --- .github/workflows/ci.yml | 35 + compiler/README.md | 10 +- compiler/fory_compiler/cli.py | 11 +- compiler/fory_compiler/generators/__init__.py | 3 + compiler/fory_compiler/generators/scala.py | 820 +++++++++++ compiler/fory_compiler/ir/construction.py | 132 ++ .../tests/test_scala_generator.py | 116 ++ docs/compiler/compiler-guide.md | 25 + docs/compiler/generated-code.md | 113 ++ docs/compiler/index.md | 15 +- docs/compiler/schema-idl.md | 15 +- docs/guide/scala/fory-creation.md | 20 + docs/guide/scala/index.md | 4 +- docs/guide/scala/schema-idl.md | 156 +++ docs/guide/xlang/field-nullability.md | 1 + docs/guide/xlang/field-reference-tracking.md | 32 + .../specification/xlang_serialization_spec.md | 9 + docs/specification/xlang_type_mapping.md | 37 + integration_tests/idl_tests/generate_idl.py | 1 + .../fory/idl_tests/IdlRoundTripTest.java | 13 +- .../idl_tests/run_scala_tests.sh | 31 + integration_tests/idl_tests/scala/build.sbt | 31 + .../idl_tests/scala/project/build.properties | 1 + .../idl_tests/ScalaIdlRoundTripPeer.scala | 88 ++ .../idl_tests/ScalaIdlRoundTripTest.scala | 180 +++ .../org/apache/fory/annotation/ForyCase.java | 45 + .../org/apache/fory/annotation/ForyField.java | 2 +- .../org/apache/fory/annotation/ForyUnion.java | 32 + .../java/org/apache/fory/annotation/Ref.java | 7 +- .../java/org/apache/fory/meta/FieldTypes.java | 56 +- .../apache/fory/resolver/ClassResolver.java | 64 + .../StaticGeneratedSerializerRegistry.java | 38 + .../apache/fory/resolver/TypeResolver.java | 69 +- .../apache/fory/resolver/XtypeResolver.java | 115 +- .../StaticGeneratedStructSerializer.java | 128 +- ...taticGeneratedStructSerializerFactory.java | 49 + .../fory/serializer/UnionSerializer.java | 140 +- .../apache/fory/type/DescriptorGrouper.java | 9 +- .../java/org/apache/fory/type/TypeUtils.java | 9 +- .../org/apache/fory/xlang/ScalaXlangTest.java | 211 +++ .../org/apache/fory/scala/ForyScalaEnum.java | 25 + .../serializer/scala/ScalaDispatcher.java | 18 + .../serializer/scala/ScalaEnumSerializer.java | 124 ++ .../serializer/scala/ScalaSerializers.java | 13 + .../apache/fory/scala/ForySerializer.scala | 136 ++ .../scala/internal/ForySerializerMacros.scala | 1209 +++++++++++++++++ .../scala/XlangCollectionSerializer.scala | 196 +++ .../scala/ForySerializerDerivationTest.scala | 151 ++ .../serializer/scala/ScalaXlangPeer.scala | 99 ++ .../scala/ScalaXlangSerializerTest.scala | 65 + 50 files changed, 4828 insertions(+), 81 deletions(-) create mode 100644 compiler/fory_compiler/generators/scala.py create mode 100644 compiler/fory_compiler/ir/construction.py create mode 100644 compiler/fory_compiler/tests/test_scala_generator.py create mode 100644 docs/guide/scala/schema-idl.md create mode 100755 integration_tests/idl_tests/run_scala_tests.sh create mode 100644 integration_tests/idl_tests/scala/build.sbt create mode 100644 integration_tests/idl_tests/scala/project/build.properties create mode 100644 integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripPeer.scala create mode 100644 integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala create mode 100644 java/fory-core/src/main/java/org/apache/fory/annotation/ForyCase.java create mode 100644 java/fory-core/src/main/java/org/apache/fory/annotation/ForyUnion.java create mode 100644 java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializerFactory.java create mode 100644 java/fory-core/src/test/java/org/apache/fory/xlang/ScalaXlangTest.java create mode 100644 scala/src/main/java/org/apache/fory/scala/ForyScalaEnum.java create mode 100644 scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java create mode 100644 scala/src/main/scala-3/org/apache/fory/scala/ForySerializer.scala create mode 100644 scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala create mode 100644 scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala create mode 100644 scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala create mode 100644 scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala create mode 100644 scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5492574552..325336d811 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -597,6 +597,41 @@ jobs: run: | cd scala && sbt +test && cd - + scala_xlang: + name: Scala Xlang Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Set up JDK 21 + uses: actions/setup-java@v4 + with: + java-version: 21 + distribution: "temurin" + cache: sbt + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: 3.11 + - name: Cache Maven local repository + uses: actions/cache@v4 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + - uses: sbt/setup-sbt@1cad58d595b729a71ca2254cdf5b43dd6f42d4bb # v1.1.18 + - name: Run Scala Xlang Test + env: + FORY_SCALA_JAVA_CI: "1" + ENABLE_FORY_DEBUG_OUTPUT: "1" + run: | + cd java + mvn -T16 --no-transfer-progress clean install -DskipTests -Dmaven.javadoc.skip=true -Dmaven.source.skip=true + cd fory-core + mvn -T16 --no-transfer-progress test -Dtest=org.apache.fory.xlang.ScalaXlangTest + - name: Run Scala IDL Tests + run: ./integration_tests/idl_tests/run_scala_tests.sh + integration_tests: name: Integration Tests runs-on: ubuntu-latest diff --git a/compiler/README.md b/compiler/README.md index afc533a9da..4d78589996 100644 --- a/compiler/README.md +++ b/compiler/README.md @@ -4,7 +4,7 @@ The FDL compiler generates cross-language serialization code from schema definit ## Features -- **Multi-language code generation**: Java, Python, Go, Rust, C++, C#, JavaScript, and Swift +- **Multi-language code generation**: Java, Python, Go, Rust, C++, C#, JavaScript, Swift, Dart, and Scala - **Rich type system**: Primitives, enums, messages, lists, dense arrays, maps - **Cross-language serialization**: Generated code works seamlessly with Apache Fory - **Type ID and namespace support**: Both numeric IDs and name-based type registration @@ -64,16 +64,16 @@ message Cat [id=103] { foryc schema.fdl --output ./generated # Generate for specific languages -foryc schema.fdl --lang java,python,csharp,javascript --output ./generated +foryc schema.fdl --lang java,python,csharp,javascript,scala --output ./generated # Override package name foryc schema.fdl --package myapp.models --output ./generated # Language-specific output directories (protoc-style) -foryc schema.fdl --java_out=./src/main/java --python_out=./python/src --csharp_out=./csharp/src/Generated --javascript_out=./javascript +foryc schema.fdl --java_out=./src/main/java --python_out=./python/src --csharp_out=./csharp/src/Generated --javascript_out=./javascript --scala_out=./scala/src/main/scala # Combine with other options -foryc schema.fdl --java_out=./gen --go_out=./gen/go --csharp_out=./gen/csharp --javascript_out=./gen/js -I ./proto +foryc schema.fdl --java_out=./gen --go_out=./gen/go --csharp_out=./gen/csharp --javascript_out=./gen/js --scala_out=./gen/scala -I ./proto ``` ### 3. Use Generated Code @@ -457,7 +457,7 @@ Arguments: FILES FDL files to compile Options: - --lang TEXT Target languages (java,python,cpp,rust,go,csharp,javascript or "all") + --lang TEXT Target languages (java,python,cpp,rust,go,csharp,javascript,swift,dart,scala or "all") Default: all --output, -o PATH Output directory Default: ./generated diff --git a/compiler/fory_compiler/cli.py b/compiler/fory_compiler/cli.py index c0620b8547..9ca8a9b4d2 100644 --- a/compiler/fory_compiler/cli.py +++ b/compiler/fory_compiler/cli.py @@ -264,7 +264,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: "--lang", type=str, default="all", - help="Comma-separated list of target languages (java,python,cpp,rust,go,csharp,javascript,swift,dart). Default: all", + help="Comma-separated list of target languages (java,python,cpp,rust,go,csharp,javascript,swift,dart,scala). Default: all", ) parser.add_argument( @@ -367,6 +367,14 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: help="Generate Dart code in DST_DIR", ) + parser.add_argument( + "--scala_out", + type=Path, + default=None, + metavar="DST_DIR", + help="Generate Scala 3 code in DST_DIR", + ) + parser.add_argument( "--go_nested_type_style", type=str, @@ -672,6 +680,7 @@ def cmd_compile(args: argparse.Namespace) -> int: "javascript": args.javascript_out, "swift": args.swift_out, "dart": args.dart_out, + "scala": args.scala_out, } # Determine which languages to generate diff --git a/compiler/fory_compiler/generators/__init__.py b/compiler/fory_compiler/generators/__init__.py index 4311fcab66..ad955bd8e4 100644 --- a/compiler/fory_compiler/generators/__init__.py +++ b/compiler/fory_compiler/generators/__init__.py @@ -27,6 +27,7 @@ from fory_compiler.generators.javascript import JavaScriptGenerator from fory_compiler.generators.swift import SwiftGenerator from fory_compiler.generators.dart import DartGenerator +from fory_compiler.generators.scala import ScalaGenerator GENERATORS = { "java": JavaGenerator, @@ -38,6 +39,7 @@ "javascript": JavaScriptGenerator, "swift": SwiftGenerator, "dart": DartGenerator, + "scala": ScalaGenerator, } __all__ = [ @@ -51,5 +53,6 @@ "JavaScriptGenerator", "SwiftGenerator", "DartGenerator", + "ScalaGenerator", "GENERATORS", ] diff --git a/compiler/fory_compiler/generators/scala.py b/compiler/fory_compiler/generators/scala.py new file mode 100644 index 0000000000..6211014c02 --- /dev/null +++ b/compiler/fory_compiler/generators/scala.py @@ -0,0 +1,820 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Scala 3 schema IDL generator.""" + +from __future__ import annotations + +from pathlib import Path +from typing import List, Optional, Set + +from fory_compiler.generators.base import BaseGenerator, GeneratedFile +from fory_compiler.frontend.utils import parse_idl_file +from fory_compiler.ir.ast import ( + ArrayType, + Enum, + Field, + FieldType, + ListType, + MapType, + Message, + NamedType, + Schema, + PrimitiveType, + Union, +) +from fory_compiler.ir.construction import analyze_message_construction_shapes +from fory_compiler.ir.types import PrimitiveKind + + +class ScalaGenerator(BaseGenerator): + """Generates Scala 3 models with Fory macro-derived serializers.""" + + language_name = "scala" + file_extension = ".scala" + + PRIMITIVE_MAP = { + PrimitiveKind.BOOL: "Boolean", + PrimitiveKind.INT8: "Byte", + PrimitiveKind.INT16: "Short", + PrimitiveKind.INT32: "Int", + PrimitiveKind.INT64: "Long", + PrimitiveKind.UINT8: "Int", + PrimitiveKind.UINT16: "Int", + PrimitiveKind.UINT32: "Long", + PrimitiveKind.UINT64: "Long", + PrimitiveKind.FLOAT16: "Float16", + PrimitiveKind.BFLOAT16: "BFloat16", + PrimitiveKind.FLOAT32: "Float", + PrimitiveKind.FLOAT64: "Double", + PrimitiveKind.STRING: "String", + PrimitiveKind.BYTES: "Array[Byte]", + PrimitiveKind.DATE: "LocalDate", + PrimitiveKind.TIMESTAMP: "Instant", + PrimitiveKind.DURATION: "Duration", + PrimitiveKind.DECIMAL: "BigDecimal", + PrimitiveKind.ANY: "AnyRef", + } + + ARRAY_ELEMENT_MAP = { + PrimitiveKind.BOOL: "Boolean", + PrimitiveKind.INT8: "Byte", + PrimitiveKind.INT16: "Short", + PrimitiveKind.INT32: "Int", + PrimitiveKind.INT64: "Long", + PrimitiveKind.UINT8: "Byte", + PrimitiveKind.UINT16: "Short", + PrimitiveKind.UINT32: "Int", + PrimitiveKind.UINT64: "Long", + PrimitiveKind.FLOAT16: "Short", + PrimitiveKind.BFLOAT16: "Short", + PrimitiveKind.FLOAT32: "Float", + PrimitiveKind.FLOAT64: "Double", + } + + RESERVED = { + "abstract", + "case", + "catch", + "class", + "def", + "do", + "else", + "enum", + "export", + "extends", + "false", + "final", + "finally", + "for", + "given", + "if", + "implicit", + "import", + "lazy", + "match", + "new", + "null", + "object", + "override", + "package", + "private", + "protected", + "return", + "sealed", + "super", + "then", + "this", + "throw", + "trait", + "true", + "try", + "type", + "val", + "var", + "while", + "with", + "yield", + } + + def __init__(self, schema, options): + super().__init__(schema, options) + self._construction_shapes = analyze_message_construction_shapes(schema) + + def get_scala_package(self) -> Optional[str]: + return self.options.package_override or self.schema.package + + def get_scala_package_path(self) -> str: + package = self.get_scala_package() + return package.replace(".", "/") if package else "" + + def get_registration_class_name(self) -> str: + package = self.get_scala_package() + if package: + return self.to_pascal_case(package.split(".")[-1]) + "ForyRegistration" + return "ForyRegistration" + + def is_imported_type(self, type_def: object) -> bool: + if not self.schema.source_file: + return False + location = getattr(type_def, "location", None) + if location is None or not location.file: + return False + try: + return ( + Path(location.file).resolve() != Path(self.schema.source_file).resolve() + ) + except Exception: + return location.file != self.schema.source_file + + def _normalize_import_path(self, path_str: str) -> str: + if not path_str: + return path_str + try: + return str(Path(path_str).resolve()) + except Exception: + return path_str + + def _load_schema(self, file_path: str) -> Optional[Schema]: + if not file_path: + return None + if not hasattr(self, "_schema_cache"): + self._schema_cache = {} + cache = self._schema_cache + path = Path(file_path).resolve() + if path in cache: + return cache[path] + try: + schema = parse_idl_file(path) + except Exception: + return None + cache[path] = schema + return schema + + def _scala_package_for_schema(self, schema: Schema) -> Optional[str]: + return schema.package + + def _registration_class_name_for_schema(self, schema: Schema) -> str: + package = self._scala_package_for_schema(schema) + if package: + return self.to_pascal_case(package.split(".")[-1]) + "ForyRegistration" + return "ForyRegistration" + + def _scala_package_for_type(self, type_def: object) -> Optional[str]: + location = getattr(type_def, "location", None) + file_path = getattr(location, "file", None) if location else None + schema = self._load_schema(file_path) + if schema is None: + return None + return self._scala_package_for_schema(schema) + + def _collect_imported_registrations(self) -> List[tuple[str, str]]: + packages: dict[str, str] = {} + for type_def in self.schema.enums + self.schema.unions + self.schema.messages: + if not self.is_imported_type(type_def): + continue + package = self._scala_package_for_type(type_def) + if not package or package in packages: + continue + schema = self._load_schema( + getattr(getattr(type_def, "location", None), "file", None) + ) + if schema is None: + continue + packages[package] = self._registration_class_name_for_schema(schema) + + ordered: List[tuple[str, str]] = [] + used: Set[str] = set() + if self.schema.source_file: + base_dir = Path(self.schema.source_file).resolve().parent + for imp in self.schema.imports: + candidate = self._normalize_import_path(str((base_dir / imp.path).resolve())) + schema = self._load_schema(candidate) + if schema is None: + continue + package = self._scala_package_for_schema(schema) + if not package or package in used: + continue + ordered.append((package, self._registration_class_name_for_schema(schema))) + used.add(package) + for package, registration in sorted(packages.items()): + if package not in used: + ordered.append((package, registration)) + return ordered + + def generate(self) -> List[GeneratedFile]: + files: List[GeneratedFile] = [] + for enum in self.schema.enums: + if self.is_imported_type(enum): + continue + files.append(self.generate_enum_file(enum)) + for union in self.schema.unions: + if self.is_imported_type(union): + continue + files.append(self.generate_union_file(union)) + for message in self.schema.messages: + if self.is_imported_type(message): + continue + files.append(self.generate_message_file(message)) + files.append(self.generate_registration_file()) + return files + + def generate_enum_file(self, enum: Enum) -> GeneratedFile: + lines = self.source_header( + {"org.apache.fory.annotation.ForyEnumId", "org.apache.fory.scala.ForyScalaEnum"} + ) + comment = self.format_type_id_comment(enum, "//") + if comment: + lines.append(comment) + lines.extend(self.generate_enum(enum)) + return self.source_file(enum.name, lines) + + def generate_union_file(self, union: Union) -> GeneratedFile: + imports = { + "org.apache.fory.annotation.{ForyCase, ForyUnion}", + "org.apache.fory.scala.ForySerializer", + } + self.collect_union_imports(union, imports) + lines = self.source_header(imports) + comment = self.format_type_id_comment(union, "//") + if comment: + lines.append(comment) + lines.extend(self.generate_union(union, parent_stack=[])) + return self.source_file(union.name, lines) + + def generate_message_file(self, message: Message) -> GeneratedFile: + imports = { + "org.apache.fory.annotation.{ForyField, ForyStruct}", + "org.apache.fory.scala.ForySerializer", + } + self.collect_message_imports(message, imports) + lines = self.source_header(imports) + comment = self.format_type_id_comment(message, "//") + if comment: + lines.append(comment) + lines.extend(self.generate_message(message, parent_stack=[])) + return self.source_file(message.name, lines) + + def source_header(self, imports: Set[str]) -> List[str]: + lines = [self.get_license_header(), ""] + package = self.get_scala_package() + if package: + lines.append(f"package {package}") + lines.append("") + for item in sorted(imports): + lines.append(f"import {item}") + if imports: + lines.append("") + return lines + + def source_file(self, type_name: str, lines: List[str]) -> GeneratedFile: + path = self.get_scala_package_path() + file_path = f"{path}/{type_name}.scala" if path else f"{type_name}.scala" + return GeneratedFile(path=file_path, content="\n".join(lines) + "\n") + + def generate_enum(self, enum: Enum, indent: int = 0) -> List[str]: + ind = self.indent_str * indent + lines = [f"{ind}enum {enum.name}(val foryId: Int) extends ForyScalaEnum {{"] + for value in enum.values: + case_name = self.safe_identifier( + self.to_pascal_case(self.strip_enum_prefix(enum.name, value.name)) + ) + lines.append(f"{ind} case {case_name} extends {enum.name}({value.value})") + lines.append("") + lines.append(f"{ind} @ForyEnumId") + lines.append(f"{ind} def getForyId: Int = foryId") + lines.append(f"{ind}}}") + lines.append("") + return lines + + def generate_union( + self, + union: Union, + indent: int = 0, + parent_stack: Optional[List[Message]] = None, + ) -> List[str]: + ind = self.indent_str * indent + lines = [f"{ind}@ForyUnion", f"{ind}enum {union.name} derives ForySerializer {{"] + lines.append(f"{ind} @ForyCase(id = 0)") + lines.append(f"{ind} case UnknownCase(caseId: Int, value: Any)") + lines.append("") + for field in union.fields: + lines.append(f"{ind} @ForyCase(id = {field.number})") + case_name = self.to_pascal_case(field.name) + field_type = self.generate_type( + field.field_type, + nullable=False, + element_optional=field.element_optional, + element_ref=field.element_ref, + top_level_ref=field.ref, + parent_stack=parent_stack, + ) + lines.append(f"{ind} case {case_name}Case(value: {field_type})") + lines.append("") + lines.append(f"{ind}}}") + lines.append("") + return lines + + def generate_message( + self, + message: Message, + indent: int = 0, + parent_stack: Optional[List[Message]] = None, + ) -> List[str]: + if self._construction_shapes.get(message.name, None) and self._construction_shapes[ + message.name + ].cycle_owned: + return self.generate_normal_class(message, indent, parent_stack) + return self.generate_case_class(message, indent, parent_stack) + + def generate_case_class( + self, + message: Message, + indent: int = 0, + parent_stack: Optional[List[Message]] = None, + ) -> List[str]: + ind = self.indent_str * indent + current_stack = self.current_stack(parent_stack, message) + lines = [f"{ind}@ForyStruct", f"{ind}final case class {message.name}("] + for index, field in enumerate(message.fields): + suffix = "," if index < len(message.fields) - 1 else "" + lines.append( + f"{ind} {self.generate_parameter(field, current_stack)}{suffix}" + ) + lines.append(f"{ind}) derives ForySerializer") + lines.append("") + lines.extend(self.generate_nested_types(message, indent, current_stack)) + return lines + + def generate_normal_class( + self, + message: Message, + indent: int = 0, + parent_stack: Optional[List[Message]] = None, + ) -> List[str]: + ind = self.indent_str * indent + current_stack = self.current_stack(parent_stack, message) + lines = [f"{ind}@ForyStruct", f"{ind}final class {message.name}() derives ForySerializer {{"] + for field in message.fields: + field_type = self.generate_type( + field.field_type, + nullable=field.optional, + element_optional=field.element_optional, + element_ref=field.element_ref, + top_level_ref=field.ref, + parent_stack=current_stack, + ) + if field.ref and self.is_ref_target_type(field.field_type): + lines.append(f"{ind} @Ref") + lines.append(f"{ind} @ForyField(id = {field.number})") + lines.append( + f"{ind} var {self.safe_identifier(self.to_camel_case(field.name))}: {field_type} = {self.default_value(field)}" + ) + lines.append("") + lines.append(f"{ind}}}") + lines.append("") + lines.extend(self.generate_nested_types(message, indent, current_stack)) + return lines + + def generate_nested_types( + self, message: Message, indent: int, parent_stack: List[Message] + ) -> List[str]: + if not ( + message.nested_enums or message.nested_unions or message.nested_messages + ): + return [] + ind = self.indent_str * indent + lines = [f"{ind}object {message.name} {{"] + for enum in message.nested_enums: + lines.extend(self.generate_enum(enum, indent + 1)) + for union in message.nested_unions: + lines.extend(self.generate_union(union, indent + 1, parent_stack)) + for nested in message.nested_messages: + lines.extend(self.generate_message(nested, indent + 1, parent_stack)) + lines.append(f"{ind}}}") + lines.append("") + return lines + + def generate_parameter(self, field: Field, parent_stack: List[Message]) -> str: + field_name = self.safe_identifier(self.to_camel_case(field.name)) + field_type = self.generate_type( + field.field_type, + nullable=field.optional, + element_optional=field.element_optional, + element_ref=field.element_ref, + top_level_ref=field.ref, + parent_stack=parent_stack, + ) + ref_annotation = ( + "@Ref " if field.ref and self.is_ref_target_type(field.field_type) else "" + ) + return f"{ref_annotation}@ForyField(id = {field.number}) {field_name}: {field_type}" + + def generate_type( + self, + field_type: FieldType, + nullable: bool = False, + element_optional: bool = False, + element_ref: bool = False, + top_level_ref: bool = False, + parent_stack: Optional[List[Message]] = None, + ) -> str: + base = self._generate_non_optional_type( + field_type, element_optional, element_ref, parent_stack + ) + if top_level_ref and self.is_ref_target_type(field_type): + base = self.apply_type_annotation(base, "Ref") + return f"Option[{base}]" if nullable else base + + def _generate_non_optional_type( + self, + field_type: FieldType, + element_optional: bool = False, + element_ref: bool = False, + parent_stack: Optional[List[Message]] = None, + ) -> str: + if isinstance(field_type, PrimitiveType): + scala_type = self.PRIMITIVE_MAP[field_type.kind] + return self.apply_primitive_annotation(scala_type, field_type) + if isinstance(field_type, NamedType): + return self.resolve_scala_type_name(field_type.name, parent_stack) + if isinstance(field_type, ListType): + element = self.generate_type( + field_type.element_type, + nullable=element_optional or field_type.element_optional, + element_optional=False, + element_ref=False, + top_level_ref=element_ref or field_type.element_ref, + parent_stack=parent_stack, + ) + return f"List[{element}]" + if isinstance(field_type, ArrayType): + element = self.generate_array_element_type(field_type.element_type) + return f"Array[{element}]" + if isinstance(field_type, MapType): + key = self.generate_type(field_type.key_type, parent_stack=parent_stack) + value = self.generate_type( + field_type.value_type, + nullable=field_type.value_optional, + top_level_ref=field_type.value_ref, + parent_stack=parent_stack, + ) + return f"Map[{key}, {value}]" + return "AnyRef" + + def generate_array_element_type(self, field_type: FieldType) -> str: + if not isinstance(field_type, PrimitiveType): + return self._generate_non_optional_type(field_type) + scala_type = self.ARRAY_ELEMENT_MAP[field_type.kind] + annotation = self.get_array_element_annotation(field_type) + if annotation is not None: + return self.apply_type_annotation(scala_type, annotation) + return scala_type + + def current_stack( + self, parent_stack: Optional[List[Message]], message: Message + ) -> List[Message]: + return [*(parent_stack or []), message] + + def resolve_scala_type_name( + self, name: str, parent_stack: Optional[List[Message]] + ) -> str: + if "." in name: + root_name = name.split(".", 1)[0] + named_type = self.schema.get_type(root_name) + if named_type is not None and self.is_imported_type(named_type): + package = self._scala_package_for_type(named_type) + if package and package != self.get_scala_package(): + return f"{package}.{name}" + return name + named_type = self.schema.get_type(name) + if named_type is not None and self.is_imported_type(named_type): + package = self._scala_package_for_type(named_type) + if package and package != self.get_scala_package(): + return f"{package}.{name}" + if parent_stack: + for index in range(len(parent_stack) - 1, -1, -1): + owner = parent_stack[index] + if owner.get_nested_type(name) is not None: + return ".".join([message.name for message in parent_stack[: index + 1]] + [name]) + return name + + def apply_type_annotation(self, scala_type: str, annotation: str) -> str: + return f"{scala_type} @{annotation}" + + def apply_primitive_annotation( + self, scala_type: str, field_type: PrimitiveType + ) -> str: + annotation = self.get_integer_annotation(field_type) + if annotation is not None: + return self.apply_type_annotation(scala_type, annotation) + return scala_type + + def get_integer_annotation(self, field_type: PrimitiveType) -> Optional[str]: + kind = field_type.kind + if kind == PrimitiveKind.INT32 and field_type.encoding_modifier == "fixed": + return "Int32Type(encoding = Int32Encoding.FIXED)" + if kind == PrimitiveKind.INT32 and field_type.encoding_modifier == "varint": + return "Int32Type(encoding = Int32Encoding.VARINT)" + if kind == PrimitiveKind.INT64 and field_type.encoding_modifier == "fixed": + return "Int64Type(encoding = Int64Encoding.FIXED)" + if kind == PrimitiveKind.INT64 and field_type.encoding_modifier == "varint": + return "Int64Type(encoding = Int64Encoding.VARINT)" + if kind == PrimitiveKind.INT64 and field_type.encoding_modifier == "tagged": + return "Int64Type(encoding = Int64Encoding.TAGGED)" + if kind == PrimitiveKind.UINT8: + return "UInt8Type" + if kind == PrimitiveKind.UINT16: + return "UInt16Type" + if kind == PrimitiveKind.UINT32 and field_type.encoding_modifier == "fixed": + return "UInt32Type(encoding = Int32Encoding.FIXED)" + if kind == PrimitiveKind.UINT32 and field_type.encoding_modifier == "varint": + return "UInt32Type(encoding = Int32Encoding.VARINT)" + if kind == PrimitiveKind.UINT32: + return "UInt32Type" + if kind == PrimitiveKind.UINT64 and field_type.encoding_modifier == "fixed": + return "UInt64Type(encoding = Int64Encoding.FIXED)" + if kind == PrimitiveKind.UINT64 and field_type.encoding_modifier == "varint": + return "UInt64Type(encoding = Int64Encoding.VARINT)" + if kind == PrimitiveKind.UINT64 and field_type.encoding_modifier == "tagged": + return "UInt64Type(encoding = Int64Encoding.TAGGED)" + if kind == PrimitiveKind.UINT64: + return "UInt64Type" + return None + + def get_array_element_annotation(self, field_type: PrimitiveType) -> Optional[str]: + kind = field_type.kind + if kind == PrimitiveKind.INT8: + return "Int8Type" + if kind == PrimitiveKind.UINT8: + return "UInt8Type" + if kind == PrimitiveKind.UINT16: + return "UInt16Type" + if kind == PrimitiveKind.UINT32: + return "UInt32Type" + if kind == PrimitiveKind.UINT64: + return "UInt64Type" + if kind == PrimitiveKind.FLOAT16: + return "Float16Type" + if kind == PrimitiveKind.BFLOAT16: + return "BFloat16Type" + return None + + def default_value(self, field: Field) -> str: + if field.optional: + return "None" + return self.default_value_for_type(field.field_type) + + def default_value_for_type(self, field_type: FieldType) -> str: + if isinstance(field_type, PrimitiveType): + defaults = { + PrimitiveKind.BOOL: "false", + PrimitiveKind.INT8: "0.toByte", + PrimitiveKind.INT16: "0.toShort", + PrimitiveKind.INT32: "0", + PrimitiveKind.INT64: "0L", + PrimitiveKind.UINT8: "0", + PrimitiveKind.UINT16: "0", + PrimitiveKind.UINT32: "0L", + PrimitiveKind.UINT64: "0L", + PrimitiveKind.FLOAT32: "0.0f", + PrimitiveKind.FLOAT64: "0.0d", + PrimitiveKind.STRING: '""', + PrimitiveKind.BYTES: "Array.emptyByteArray", + PrimitiveKind.DECIMAL: "BigDecimal.ZERO", + } + return defaults.get(field_type.kind, "null") + if isinstance(field_type, ListType): + return "List.empty" + if isinstance(field_type, ArrayType): + return "Array.empty" + if isinstance(field_type, MapType): + return "Map.empty" + return "null" + + def collect_message_imports(self, message: Message, imports: Set[str]) -> None: + for field in message.fields: + self.collect_type_imports(field.field_type, imports) + if field.ref or self.field_type_has_ref(field.field_type): + imports.add("org.apache.fory.annotation.Ref") + for enum in message.nested_enums: + imports.add("org.apache.fory.annotation.ForyEnumId") + imports.add("org.apache.fory.scala.ForyScalaEnum") + for union in message.nested_unions: + imports.add("org.apache.fory.annotation.{ForyCase, ForyUnion}") + for nested in message.nested_messages: + self.collect_message_imports(nested, imports) + + def collect_union_imports(self, union: Union, imports: Set[str]) -> None: + for field in union.fields: + self.collect_type_imports(field.field_type, imports) + if field.ref or self.field_type_has_ref(field.field_type): + imports.add("org.apache.fory.annotation.Ref") + + def collect_type_imports(self, field_type: FieldType, imports: Set[str]) -> None: + if isinstance(field_type, PrimitiveType): + if field_type.kind == PrimitiveKind.DATE: + imports.add("java.time.LocalDate") + elif field_type.kind == PrimitiveKind.TIMESTAMP: + imports.add("java.time.Instant") + elif field_type.kind == PrimitiveKind.DURATION: + imports.add("java.time.Duration") + elif field_type.kind == PrimitiveKind.DECIMAL: + imports.add("java.math.BigDecimal") + elif field_type.kind == PrimitiveKind.FLOAT16: + imports.add("org.apache.fory.`type`.Float16") + elif field_type.kind == PrimitiveKind.BFLOAT16: + imports.add("org.apache.fory.`type`.BFloat16") + self.collect_integer_imports(field_type, imports) + return + if isinstance(field_type, ListType): + self.collect_type_imports(field_type.element_type, imports) + return + if isinstance(field_type, ArrayType): + self.collect_type_imports(field_type.element_type, imports) + if isinstance(field_type.element_type, PrimitiveType): + self.collect_array_element_imports(field_type.element_type, imports) + return + if isinstance(field_type, MapType): + self.collect_type_imports(field_type.key_type, imports) + self.collect_type_imports(field_type.value_type, imports) + + def collect_integer_imports( + self, field_type: FieldType, imports: Set[str] + ) -> None: + if not isinstance(field_type, PrimitiveType): + return + kind = field_type.kind + if kind == PrimitiveKind.INT32: + if field_type.encoding_modifier in ("fixed", "varint"): + imports.add("org.apache.fory.annotation.Int32Type") + imports.add("org.apache.fory.config.Int32Encoding") + elif kind == PrimitiveKind.INT64: + if field_type.encoding_modifier in ("fixed", "varint", "tagged"): + imports.add("org.apache.fory.annotation.Int64Type") + imports.add("org.apache.fory.config.Int64Encoding") + elif kind == PrimitiveKind.UINT8: + imports.add("org.apache.fory.annotation.UInt8Type") + elif kind == PrimitiveKind.UINT16: + imports.add("org.apache.fory.annotation.UInt16Type") + elif kind == PrimitiveKind.UINT32: + imports.add("org.apache.fory.annotation.UInt32Type") + if field_type.encoding_modifier in ("fixed", "varint"): + imports.add("org.apache.fory.config.Int32Encoding") + elif kind == PrimitiveKind.UINT64: + imports.add("org.apache.fory.annotation.UInt64Type") + if field_type.encoding_modifier in ("fixed", "varint", "tagged"): + imports.add("org.apache.fory.config.Int64Encoding") + + def collect_array_element_imports( + self, field_type: PrimitiveType, imports: Set[str] + ) -> None: + kind = field_type.kind + if kind == PrimitiveKind.INT8: + imports.add("org.apache.fory.annotation.Int8Type") + elif kind == PrimitiveKind.UINT8: + imports.add("org.apache.fory.annotation.UInt8Type") + elif kind == PrimitiveKind.UINT16: + imports.add("org.apache.fory.annotation.UInt16Type") + elif kind == PrimitiveKind.UINT32: + imports.add("org.apache.fory.annotation.UInt32Type") + elif kind == PrimitiveKind.UINT64: + imports.add("org.apache.fory.annotation.UInt64Type") + elif kind == PrimitiveKind.FLOAT16: + imports.add("org.apache.fory.annotation.Float16Type") + elif kind == PrimitiveKind.BFLOAT16: + imports.add("org.apache.fory.annotation.BFloat16Type") + + def field_type_has_ref(self, field_type: FieldType) -> bool: + if isinstance(field_type, ListType): + return field_type.element_ref or self.field_type_has_ref(field_type.element_type) + if isinstance(field_type, MapType): + return field_type.value_ref or self.field_type_has_ref(field_type.value_type) + return False + + def is_ref_target_type(self, field_type: FieldType) -> bool: + if not isinstance(field_type, NamedType): + return False + return self.schema.get_type(field_type.name) is not None + + def generate_registration_file(self) -> GeneratedFile: + imports = { + "org.apache.fory.{Fory, ThreadSafeFory}", + "org.apache.fory.scala.ForySerializer", + "org.apache.fory.serializer.scala.ScalaSerializers", + } + lines = self.source_header(imports) + class_name = self.get_registration_class_name() + lines.append(f"object {class_name} {{") + lines.append(" private lazy val fory: ThreadSafeFory = createFory()") + lines.append("") + lines.append(" def getFory: ThreadSafeFory = fory") + lines.append("") + lines.append(" private def createFory(): ThreadSafeFory = {") + lines.append( + " val runtime = Fory.builder().withXlang(true).withCompatible(true).withRefTracking(true).withScalaOptimizationEnabled(true).buildThreadSafeFory()" + ) + imported_registrations = self._collect_imported_registrations() + if imported_registrations: + lines.append(" runtime.registerCallback((fory: Fory) => {") + for package, registration in imported_registrations: + lines.append(f" {package}.{registration}.register(fory)") + lines.append(" register(fory)") + lines.append(" })") + else: + lines.append(" runtime.registerCallback((fory: Fory) => register(fory))") + lines.append(" runtime") + lines.append(" }") + lines.append("") + lines.append(" def register(fory: Fory): Unit = {") + lines.append(" ScalaSerializers.registerSerializers(fory)") + for enum in self.schema.enums: + if self.is_imported_type(enum): + continue + self.generate_type_registration(lines, enum) + for message in self.schema.messages: + if self.is_imported_type(message): + continue + self.generate_type_registration(lines, message) + self.generate_nested_registration(lines, message.name, message) + for union in self.schema.unions: + if self.is_imported_type(union): + continue + self.generate_type_registration(lines, union) + lines.append(" }") + lines.append("}") + return self.source_file(class_name, lines) + + def generate_nested_registration( + self, lines: List[str], owner_path: str, message: Message + ) -> None: + for enum in message.nested_enums: + self.generate_type_registration(lines, enum, owner_path) + for nested in message.nested_messages: + nested_path = f"{owner_path}.{nested.name}" + self.generate_type_registration(lines, nested, owner_path) + self.generate_nested_registration(lines, nested_path, nested) + for union in message.nested_unions: + self.generate_type_registration(lines, union, owner_path) + + def generate_type_registration( + self, lines: List[str], type_def, owner_path: Optional[str] = None + ) -> None: + class_ref = f"{owner_path}.{type_def.name}" if owner_path else type_def.name + if isinstance(type_def, Enum): + if self.should_register_by_id(type_def): + lines.append( + f" ScalaSerializers.registerEnum(fory, classOf[{class_ref}], {type_def.type_id}L)" + ) + else: + namespace = self.schema.package or "default" + lines.append( + f' ScalaSerializers.registerEnum(fory, classOf[{class_ref}], "{namespace}", "{type_def.name}")' + ) + return + if self.should_register_by_id(type_def): + lines.append( + f" ForySerializer.register(fory, classOf[{class_ref}], {type_def.type_id}L)" + ) + else: + namespace = self.schema.package or "default" + lines.append( + f' ForySerializer.register(fory, classOf[{class_ref}], "{namespace}", "{type_def.name}")' + ) + + def safe_identifier(self, name: str) -> str: + return f"`{name}`" if name in self.RESERVED else name diff --git a/compiler/fory_compiler/ir/construction.py b/compiler/fory_compiler/ir/construction.py new file mode 100644 index 0000000000..5612245d76 --- /dev/null +++ b/compiler/fory_compiler/ir/construction.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Construction shape analysis shared by JVM-family generators.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Iterable, List, Set + +from fory_compiler.ir.ast import ( + ArrayType, + FieldType, + ListType, + MapType, + Message, + NamedType, + PrimitiveType, + Schema, +) + + +@dataclass(frozen=True) +class MessageConstructionShape: + """Generated message construction shape.""" + + cycle_owned: bool + + +def analyze_message_construction_shapes( + schema: Schema, +) -> Dict[str, MessageConstructionShape]: + """Return construction shapes for all messages in ``schema``. + + A message becomes cycle-owned only when message dependencies form a real + construction cycle. A top-level ``ref`` marker, nested ``ref`` marker, or + ``any`` field does not force this shape by itself. + """ + + messages = {message.name: message for message in schema.messages} + graph = {name: set(_message_dependencies(message, messages)) for name, message in messages.items()} + cycle_owned = _cycle_nodes(graph) + return { + name: MessageConstructionShape(cycle_owned=name in cycle_owned) + for name in messages + } + + +def _message_dependencies( + message: Message, messages: Dict[str, Message] +) -> Iterable[str]: + for field in message.fields: + yield from _field_type_dependencies(field.field_type, messages) + + +def _field_type_dependencies( + field_type: FieldType, messages: Dict[str, Message] +) -> Iterable[str]: + if isinstance(field_type, PrimitiveType): + return + if isinstance(field_type, NamedType): + root_name = field_type.name.split(".", 1)[0] + if root_name in messages: + yield root_name + return + if isinstance(field_type, ListType): + yield from _field_type_dependencies(field_type.element_type, messages) + return + if isinstance(field_type, ArrayType): + yield from _field_type_dependencies(field_type.element_type, messages) + return + if isinstance(field_type, MapType): + yield from _field_type_dependencies(field_type.key_type, messages) + yield from _field_type_dependencies(field_type.value_type, messages) + + +def _cycle_nodes(graph: Dict[str, Set[str]]) -> Set[str]: + index = 0 + stack: List[str] = [] + on_stack: Set[str] = set() + indexes: Dict[str, int] = {} + lowlinks: Dict[str, int] = {} + result: Set[str] = set() + + def strong_connect(node: str) -> None: + nonlocal index + indexes[node] = index + lowlinks[node] = index + index += 1 + stack.append(node) + on_stack.add(node) + + for target in graph[node]: + if target not in graph: + continue + if target not in indexes: + strong_connect(target) + lowlinks[node] = min(lowlinks[node], lowlinks[target]) + elif target in on_stack: + lowlinks[node] = min(lowlinks[node], indexes[target]) + + if lowlinks[node] == indexes[node]: + component = [] + while True: + current = stack.pop() + on_stack.remove(current) + component.append(current) + if current == node: + break + if len(component) > 1: + result.update(component) + elif component[0] in graph[component[0]]: + result.add(component[0]) + + for node in graph: + if node not in indexes: + strong_connect(node) + return result diff --git a/compiler/fory_compiler/tests/test_scala_generator.py b/compiler/fory_compiler/tests/test_scala_generator.py new file mode 100644 index 0000000000..d072aef89c --- /dev/null +++ b/compiler/fory_compiler/tests/test_scala_generator.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pathlib import Path + +from fory_compiler.cli import resolve_imports +from fory_compiler.frontend.fdl.lexer import Lexer +from fory_compiler.frontend.fdl.parser import Parser +from fory_compiler.generators.base import GeneratorOptions +from fory_compiler.generators.scala import ScalaGenerator +from fory_compiler.ir.validator import SchemaValidator + + +def generate_scala(source: str): + schema = Parser(Lexer(source).tokenize()).parse() + validator = SchemaValidator(schema) + assert validator.validate(), validator.errors + generator = ScalaGenerator(schema, GeneratorOptions(output_dir=Path("/tmp"))) + return {item.path: item.content for item in generator.generate()} + + +def test_scala_generator_emits_case_classes_options_enums_and_unions(): + files = generate_scala( + """ + package demo; + + enum Status [id=101] { + STATUS_UNKNOWN = 0; + STATUS_OK = 7; + } + + message User [id=102] { + string name = 1; + optional int32 age = 2; + list tags = 3; + } + + union SearchTarget [id=103] { + User user = 1; + string note = 2; + } + """ + ) + + user = files["demo/User.scala"] + assert "final case class User(" in user + assert "@ForyField(id = 1) name: String" in user + assert "@ForyField(id = 2) age: Option[Int]" in user + assert "@ForyField(id = 3) tags: List[String]" in user + assert "derives ForySerializer" in user + + status = files["demo/Status.scala"] + assert "enum Status(val foryId: Int)" in status + assert "case Unknown extends Status(0)" in status + assert "case Ok extends Status(7)" in status + assert "@ForyEnumId" in status + + union = files["demo/SearchTarget.scala"] + assert "@ForyUnion" in union + assert "enum SearchTarget derives ForySerializer" in union + assert "@ForyCase(id = 0)" in union + assert "case UnknownCase(caseId: Int, value: Any)" in union + assert "@ForyCase(id = 1)" in union + assert "case UserCase(value: User)" in union + assert "@ForyCase(id = 2)" in union + assert "case NoteCase(value: String)" in union + + +def test_scala_generator_uses_mutable_normal_class_for_construction_cycles(): + files = generate_scala( + """ + package graph; + + message Node [id=110] { + string id = 1; + ref Node parent = 2; + } + """ + ) + + node = files["graph/Node.scala"] + assert "final class Node() derives ForySerializer" in node + assert "var id: String = \"\"" in node + assert "var parent: Option[Node @Ref] = None" in node + + +def test_scala_generator_keeps_imported_types_in_owner_package(): + repo_root = Path(__file__).resolve().parents[3] + idl_dir = repo_root / "integration_tests" / "idl_tests" / "idl" + schema = resolve_imports(idl_dir / "root.idl", [idl_dir]) + generator = ScalaGenerator(schema, GeneratorOptions(output_dir=Path("/tmp"))) + files = {item.path: item.content for item in generator.generate()} + + assert "root/MultiHolder.scala" in files + assert "root/PrimitiveTypes.scala" not in files + assert "addressbook.AddressBook" in files["root/MultiHolder.scala"] + assert "tree.TreeNode" in files["root/MultiHolder.scala"] + + registration = files["root/RootForyRegistration.scala"] + assert "addressbook.AddressbookForyRegistration.register(fory)" in registration + assert "tree.TreeForyRegistration.register(fory)" in registration + assert "classOf[PrimitiveTypes]" not in registration diff --git a/docs/compiler/compiler-guide.md b/docs/compiler/compiler-guide.md index 28ed04f1f1..a1c35e0268 100644 --- a/docs/compiler/compiler-guide.md +++ b/docs/compiler/compiler-guide.md @@ -67,6 +67,7 @@ Compile options: | `--javascript_out=DST_DIR` | Generate JavaScript code in DST_DIR | (none) | | `--swift_out=DST_DIR` | Generate Swift code in DST_DIR | (none) | | `--dart_out=DST_DIR` | Generate Dart code in DST_DIR | (none) | +| `--scala_out=DST_DIR` | Generate Scala 3 code in DST_DIR | (none) | | `--go_nested_type_style` | Go nested type naming: `camelcase` or `underscore` | `underscore` | | `--swift_namespace_style` | Swift namespace style: `enum` or `flatten` | `enum` | | `--emit-fdl` | Emit translated FDL (for non-FDL inputs) | `false` | @@ -174,6 +175,9 @@ foryc schema.fdl --java_out=./java/gen --python_out=./python/src --go_out=./go/g # Combine with import paths foryc schema.fdl --java_out=./gen/java -I proto/ -I common/ + +# Generate Scala 3 code to a specific directory +foryc schema.fdl --scala_out=./src/main/scala ``` When using `--{lang}_out` options: @@ -250,6 +254,7 @@ Compiling src/main.fdl... | JavaScript | `javascript` | `.ts` | Interfaces with registration function | | Swift | `swift` | `.swift` | Fory Swift model macros | | Dart | `dart` | `.dart` | `@ForyStruct` classes with annotations | +| Scala | `scala` | `.scala` | Scala 3 models with macro derivation | ## Output Structure @@ -379,6 +384,26 @@ generated/ - Registration helper class included in the part file - Typed arrays used for non-optional, non-ref primitive lists (e.g., `Int32List`) +### Scala + +``` +generated/ +└── scala/ + └── example/ + ├── User.scala + ├── Status.scala + ├── Animal.scala + └── ExampleForyRegistration.scala +``` + +- One Scala 3 source file per generated type +- Package structure matches the Fory IDL package +- Messages derive `org.apache.fory.scala.ForySerializer` +- `optional T` fields use `Option[T]` +- Enums use Scala 3 `enum` +- Unions use Scala 3 ADT `enum` with `@ForyUnion`, `@ForyCase`, and an `UnknownCase` +- Registration helper object included + ### C# IDL Matrix Verification Run the end-to-end C# IDL matrix (FDL/IDL/Proto/FBS generation plus roundtrip tests): diff --git a/docs/compiler/generated-code.md b/docs/compiler/generated-code.md index 3f88028da8..5056d7822f 100644 --- a/docs/compiler/generated-code.md +++ b/docs/compiler/generated-code.md @@ -1044,6 +1044,119 @@ void main() { ## Scala +The Scala target emits Scala 3 source only. The `fory-scala` runtime artifact +still supports Scala 2.13 and Scala 3, but generated IDL source and macro +derivation require Scala 3. + +### Output Layout + +For `package addressbook`, Scala output is generated under: + +- `/addressbook/` +- Type files: `AddressBook.scala`, `Person.scala`, `Dog.scala`, `Cat.scala`, `Animal.scala` +- Registration helper: `AddressbookForyRegistration.scala` + +### Type Generation + +Messages outside compiler-detected construction cycles generate case classes: + +```scala +import org.apache.fory.annotation.{ForyField, ForyStruct} +import org.apache.fory.scala.ForySerializer + +@ForyStruct +final case class Person( + @ForyField(id = 1) name: String, + @ForyField(id = 3) email: Option[String], + @ForyField(id = 7) phones: List[Person.PhoneNumber], + @ForyField(id = 8) pet: Animal +) derives ForySerializer +``` + +Messages in circular construction cycles generate normal classes with mutable +serialized fields so reads can register the object before reading back-references: + +```scala +import org.apache.fory.annotation.{ForyField, ForyStruct, Ref} +import org.apache.fory.scala.ForySerializer + +@ForyStruct +final class Node() derives ForySerializer { + @ForyField(id = 1) + var id: String = "" + + @Ref + @ForyField(id = 2) + var parent: Option[Node @Ref] = None +} +``` + +Enums generate Scala 3 enums with stable Fory IDs: + +```scala +import org.apache.fory.annotation.ForyEnumId +import org.apache.fory.scala.ForyScalaEnum + +enum PhoneType(val foryId: Int) extends ForyScalaEnum { + case Mobile extends PhoneType(0) + case Home extends PhoneType(1) + case Work extends PhoneType(2) + + @ForyEnumId + def getForyId: Int = foryId +} +``` + +Unions generate Scala 3 ADT enums. Case ID `0` is reserved for the unknown-case +carrier; schema-defined cases start at `1`. + +```scala +import org.apache.fory.annotation.{ForyCase, ForyUnion} +import org.apache.fory.scala.ForySerializer + +@ForyUnion +enum Animal derives ForySerializer { + @ForyCase(id = 0) + case UnknownCase(caseId: Int, value: Any) + + @ForyCase(id = 1) + case DogCase(value: Dog) + + @ForyCase(id = 2) + case CatCase(value: Cat) +} +``` + +`optional T` fields generate `Option[T]`. Reference tracking uses `@Ref`; +`@ForyField(ref = true)` is not emitted by the Scala generator. + +### Registration + +Generated registration helpers register Scala serializers, enums, structs, and +unions for `Fory` and `ThreadSafeFory`: + +```scala +object AddressbookForyRegistration { + def register(fory: Fory): Unit = { + ScalaSerializers.registerSerializers(fory) + ScalaSerializers.registerEnum(fory, classOf[Person.PhoneType], 101L) + ForySerializer.register(fory, classOf[Person.PhoneNumber], 102L) + ForySerializer.register(fory, classOf[Person], 100L) + ForySerializer.register(fory, classOf[Animal], 106L) + } +} +``` + +Run the end-to-end Scala IDL matrix with: + +```bash +cd integration_tests/idl_tests +./run_scala_tests.sh +``` + +The runner regenerates Scala fixtures, runs Scala 3 IDL tests, and then runs the +Java peer matrix with `IDL_PEER_LANG=scala`. + ## Cross-Language Notes ### Type ID Behavior diff --git a/docs/compiler/index.md b/docs/compiler/index.md index 286de0b5cc..e0317ed22b 100644 --- a/docs/compiler/index.md +++ b/docs/compiler/index.md @@ -21,7 +21,7 @@ license: | Fory IDL is a schema definition language for Apache Fory that enables type-safe cross-language serialization. Define your data structures once and generate -native data structure code for Java, Python, Go, Rust, C++, C#, Swift, JavaScript, and Dart. +native data structure code for Java, Python, Go, Rust, C++, C#, Swift, JavaScript, Dart, and Scala. ## Example Schema @@ -104,6 +104,7 @@ Generated code uses native language constructs: - JavaScript: Interfaces with registration function - Swift: Fory model macros with field/case metadata and registration helpers - Dart: `@ForyStruct` classes with `@ForyField` annotations and registration helpers +- Scala: Scala 3 `case class`, normal class, enum, and ADT enum models with macro-derived serializers ## Quick Start @@ -141,7 +142,7 @@ message Person { foryc example.fdl --output ./generated # Generate for specific languages -foryc example.fdl --lang java,python,csharp,javascript,swift,dart --output ./generated +foryc example.fdl --lang java,python,csharp,javascript,swift,dart,scala --output ./generated ``` ### 4. Use Generated Code @@ -197,11 +198,11 @@ message Example { Fory IDL types map to native types in each language: -| Fory IDL Type | Java | Python | Go | Rust | C++ | C# | JavaScript | Swift | Dart | -| ------------- | --------- | -------------- | -------- | -------- | ------------- | -------- | ---------- | -------- | -------- | -| `int32` | `int` | `pyfory.Int32` | `int32` | `i32` | `int32_t` | `int` | `number` | `Int32` | `int` | -| `string` | `String` | `str` | `string` | `String` | `std::string` | `string` | `string` | `String` | `String` | -| `bool` | `boolean` | `bool` | `bool` | `bool` | `bool` | `bool` | `boolean` | `Bool` | `bool` | +| Fory IDL Type | Java | Python | Go | Rust | C++ | C# | JavaScript | Swift | Dart | Scala | +| ------------- | --------- | -------------- | -------- | -------- | ------------- | -------- | ---------- | -------- | -------- | --------- | +| `int32` | `int` | `pyfory.Int32` | `int32` | `i32` | `int32_t` | `int` | `number` | `Int32` | `int` | `Int` | +| `string` | `String` | `str` | `string` | `String` | `std::string` | `string` | `string` | `String` | `String` | `String` | +| `bool` | `boolean` | `bool` | `bool` | `bool` | `bool` | `bool` | `boolean` | `Bool` | `bool` | `Boolean` | See [Type System](schema-idl.md#type-system) for complete mappings. diff --git a/docs/compiler/schema-idl.md b/docs/compiler/schema-idl.md index b7252ed6a8..e0a9dc6bfe 100644 --- a/docs/compiler/schema-idl.md +++ b/docs/compiler/schema-idl.md @@ -818,6 +818,7 @@ message Person [id=100] { ### Rules - Case IDs must be unique within the union +- Case IDs must be positive. Case ID `0` is reserved for generated unknown-case carriers in languages that expose one. - Cases cannot be `optional` or `ref` - Union cases do not support field options - Case types can be primitives, enums, messages, or other named types @@ -886,6 +887,7 @@ message User { | C++ | `std::string name` | `std::optional name` | | JavaScript | `name: string` | `name?: string \| null` | | Dart | `String name` | `String? email` | +| Scala | `name: String` | `email: Option[String]` | **Default Values:** @@ -923,6 +925,7 @@ message Node { | C++ | `Node parent` | `std::shared_ptr parent` | | JavaScript | `parent: Node` | `parent: Node` (no ref distinction) | | Dart | `Node parent` | `Node parent` with `@ForyField(ref: true)` | +| Scala | `parent: Node` | `parent: Node @Ref` | Rust uses `Arc` by default; use `ref(thread_safe=false)` or `ref(weak=true)` to customize pointer types. For protobuf option syntax, see @@ -970,12 +973,12 @@ apply to elements. `repeated` is accepted as an alias for `list`. **List modifier mapping:** -| Fory IDL | Java | Python | Go | Rust | C++ | Dart | -| ----------------------- | --------------------------------------- | --------------------------------------- | ----------------------- | --------------------- | ----------------------------------------- | ------------------------------------------------------------- | -| `optional list` | `@Nullable List` | `Optional[List[str]]` | `[]string` + `nullable` | `Option>` | `std::optional>` | `List?` | -| `list` | `List` (nullable elements) | `List[Optional[str]]` | `[]*string` | `Vec>` | `std::vector>` | `List` | -| `ref list` | `List` + `@ForyField(ref = true)` | `List[User]` + `pyfory.field(ref=True)` | `[]User` + `ref` | `Arc>` | `std::shared_ptr>` | `List` + `@ForyField(ref: true)` | -| `list` | `List` | `List[User]` | `[]*User` + `ref=false` | `Vec>` | `std::vector>` | `List` + `@ListField(element: DeclaredType(ref: true))` | +| Fory IDL | Java | Python | Go | Rust | C++ | Dart | Scala | +| ----------------------- | --------------------------------------- | --------------------------------------- | ----------------------- | --------------------- | ----------------------------------------- | ------------------------------------------------------------- | ---------------------- | +| `optional list` | `@Nullable List` | `Optional[List[str]]` | `[]string` + `nullable` | `Option>` | `std::optional>` | `List?` | `Option[List[String]]` | +| `list` | `List` (nullable elements) | `List[Optional[str]]` | `[]*string` | `Vec>` | `std::vector>` | `List` | `List[Option[String]]` | +| `ref list` | `List` + `@ForyField(ref = true)` | `List[User]` + `pyfory.field(ref=True)` | `[]User` + `ref` | `Arc>` | `std::shared_ptr>` | `List` + `@ForyField(ref: true)` | `List[User] @Ref` | +| `list` | `List` | `List[User]` | `[]*User` + `ref=false` | `Vec>` | `std::vector>` | `List` + `@ListField(element: DeclaredType(ref: true))` | `List[User @Ref]` | Use `ref(thread_safe=false)` in Fory IDL (or `[(fory).thread_safe_pointer = false]` in protobuf) to generate `Rc` instead of `Arc` in Rust. diff --git a/docs/guide/scala/fory-creation.md b/docs/guide/scala/fory-creation.md index b79e618563..a3f0a2ba6e 100644 --- a/docs/guide/scala/fory-creation.md +++ b/docs/guide/scala/fory-creation.md @@ -133,3 +133,23 @@ val fory = Fory.builder() ScalaSerializers.registerSerializers(fory) ``` + +## Cross-Language Mode + +For Scala xlang or schema IDL generated code, enable xlang and register the +Scala serializers before registering generated model types: + +```scala +val fory = Fory.builder() + .withXlang(true) + .withCompatible(true) + .withScalaOptimizationEnabled(true) + .build() + +ScalaSerializers.registerSerializers(fory) +ExampleForyRegistration.register(fory) +``` + +In xlang mode, Scala collections use canonical `list`, `set`, and `map` +payloads instead of Scala factory payloads. Generated optional fields use +`Option[T]`. diff --git a/docs/guide/scala/index.md b/docs/guide/scala/index.md index 60d5da14c5..382b43e1c1 100644 --- a/docs/guide/scala/index.md +++ b/docs/guide/scala/index.md @@ -29,7 +29,8 @@ Apache Fory™ Scala provides optimized serializers for Scala types, built on to - `Option` types - Scala 2 and 3 enumerations -Both Scala 2 and Scala 3 are supported. +The runtime artifact supports Scala 2.13 and Scala 3. Schema IDL generated +Scala source and macro-derived xlang serializers require Scala 3. ## Features @@ -98,3 +99,4 @@ Fory Scala is built on top of Fory Java. Most configuration options, features, a - [Fory Creation](fory-creation.md) - Scala-specific Fory setup requirements - [Type Serialization](type-serialization.md) - Serializing Scala types - [Default Values](default-values.md) - Scala class default values support +- [Schema IDL And Xlang](schema-idl.md) - Scala 3 generated models and macro-derived xlang serializers diff --git a/docs/guide/scala/schema-idl.md b/docs/guide/scala/schema-idl.md new file mode 100644 index 0000000000..1b002434dd --- /dev/null +++ b/docs/guide/scala/schema-idl.md @@ -0,0 +1,156 @@ +--- +title: Schema IDL And Xlang +sidebar_position: 4 +id: schema_idl +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +The Fory schema IDL Scala target generates Scala 3 source for xlang payloads. +The runtime artifact remains cross-built for Scala 2.13 and Scala 3; only the +schema IDL output and quoted macro derivation require Scala 3. + +## Setup + +Generated Scala code uses the public macro API in `org.apache.fory.scala` and +the shared JVM annotations in `org.apache.fory.annotation`. Macro internals live +under `org.apache.fory.scala.internal`. + +```scala +import org.apache.fory.Fory +import org.apache.fory.scala.ForySerializer +import org.apache.fory.serializer.scala.ScalaSerializers + +val fory = Fory.builder() + .withXlang(true) + .withCompatible(true) + .withScalaOptimizationEnabled(true) + .build() + +ScalaSerializers.registerSerializers(fory) +ExampleForyRegistration.register(fory) +``` + +For `ThreadSafeFory`, generated registration helpers install a callback so each +runtime instance gets the same serializers. + +## Generated Messages + +Acyclic messages generate case classes: + +```scala +import org.apache.fory.annotation.{ForyField, ForyStruct} +import org.apache.fory.scala.ForySerializer + +@ForyStruct +final case class Person( + @ForyField(id = 1) name: String, + @ForyField(id = 2) email: Option[String] +) derives ForySerializer +``` + +Schema `optional T` fields are stored as `Option[T]`. + +Messages in compiler-detected construction cycles generate normal classes with +mutable serialized fields so the deserializer can allocate and register the +object before reading fields that can point back to it. A top-level `ref Foo`, +nested `list`, or `any` field does not by itself force this shape. + +Reference tracking is expressed with the shared `@Ref` annotation, including +type-use positions: + +```scala +@ForyStruct +final class Node() derives ForySerializer { + @ForyField(id = 1) + var children: List[Node @Ref] = List.empty + + @Ref + @ForyField(id = 2) + var parent: Option[Node @Ref] = None +} +``` + +`@ForyField(ref = true)` is not the Scala macro or IDL API. + +## Generated Enums + +IDL enums generate Scala 3 enums only. No Java enum sidecar is emitted. + +```scala +import org.apache.fory.annotation.ForyEnumId +import org.apache.fory.scala.ForyScalaEnum + +enum Status(val foryId: Int) extends ForyScalaEnum { + case Unknown extends Status(0) + case Ok extends Status(1) + + @ForyEnumId + def getForyId: Int = foryId +} +``` + +Generated registration uses `ScalaSerializers.registerEnum(...)` so stable Fory +enum IDs are used in xlang mode. + +## Generated Unions + +IDL unions generate Scala 3 ADT enums with macro-derived serializers: + +```scala +import org.apache.fory.annotation.{ForyCase, ForyUnion, UInt32Type} +import org.apache.fory.config.Int32Encoding +import org.apache.fory.scala.ForySerializer + +@ForyUnion +enum SearchTarget derives ForySerializer { + @ForyCase(id = 0) + case UnknownCase(caseId: Int, value: Any) + + @ForyCase(id = 1) + case UserCase(value: User) + + @ForyCase(id = 2) + case FixedIdCase(value: Long @UInt32Type(encoding = Int32Encoding.FIXED)) +} +``` + +Schema-defined union cases must use positive IDs. Case ID `0` is reserved for +the Scala unknown-case carrier, whose payload stores the original positive case +ID and the deserialized value. When a reader sees a newer positive case ID, it +returns `UnknownCase(originalId, value)` instead of failing solely because the +case ID is not known locally. + +The macro writes the existing xlang union envelope directly. It does not +allocate temporary Java `Union` carriers. + +## Manual Scala 3 Derivation + +Manual Scala 3 models can derive the same serializer typeclass: + +```scala +@ForyStruct +final class Record(@ForyField(id = 1) val id: Int) derives ForySerializer { + @ForyField(id = 2) + var name: String = "" +} +``` + +The macro generates direct constructor calls for constructor-owned fields and +direct assignments for mutable post-construction fields. It builds descriptor +metadata from Scala compile-time types, including nested generics, `Option`, +arrays, scalar encoding annotations, nullability, and `@Ref` metadata. Java +reflection is not the source of truth for generated Scala metadata. diff --git a/docs/guide/xlang/field-nullability.md b/docs/guide/xlang/field-nullability.md index bea7f3f092..5d891fc7d4 100644 --- a/docs/guide/xlang/field-nullability.md +++ b/docs/guide/xlang/field-nullability.md @@ -36,6 +36,7 @@ The following types are nullable by default: - Go pointer types (`*int32`, `*string`, etc.) - Rust `Option` - Python `Optional[T]` +- Scala `Option[T]` | Field Type | Default Nullable | Null Flag Written | | ------------------------------------------ | ---------------- | ----------------- | diff --git a/docs/guide/xlang/field-reference-tracking.md b/docs/guide/xlang/field-reference-tracking.md index 9d32c0e88b..1a5260c167 100644 --- a/docs/guide/xlang/field-reference-tracking.md +++ b/docs/guide/xlang/field-reference-tracking.md @@ -69,6 +69,17 @@ let fory = Fory::builder() .track_ref(true).build(); ``` +### Scala + +```scala +val fory = Fory.builder() + .withXlang(true) + .withCompatible(true) + .withRefTracking(true) + .withScalaOptimizationEnabled(true) + .build() +``` + ## Wire Format When reference tracking is enabled, nullable fields write a **ref flag byte** before the value: @@ -121,6 +132,7 @@ By default, **most fields do not track references** even when global `refTrackin | Go | No | None (use `fory:"ref"` to enable) | | C++ | Yes | `std::shared_ptr`, `fory::serialization::SharedWeak` | | Rust | No | `Rc`, `Arc`, `Weak` | +| Scala | No | None (use `@Ref` to enable) | ### Customizing Per-Field Ref Tracking @@ -183,6 +195,26 @@ struct Document { } ``` +#### Scala: @Ref Annotation + +Scala schema IDL and Scala 3 macro derivation use the shared JVM `@Ref` +annotation instead of `@ForyField(ref = true)`: + +```scala +import org.apache.fory.annotation.{ForyField, ForyStruct, Ref} +import org.apache.fory.scala.ForySerializer + +@ForyStruct +final class Node() derives ForySerializer { + @ForyField(id = 1) + var children: List[Node @Ref] = List.empty + + @Ref + @ForyField(id = 2) + var parent: Option[Node @Ref] = None +} +``` + #### Go: Struct Tags ```go diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md index 6343392c7e..e6b901b415 100644 --- a/docs/specification/xlang_serialization_spec.md +++ b/docs/specification/xlang_serialization_spec.md @@ -1578,6 +1578,8 @@ Rules: - Each union alternative MUST have a stable tag number (`= 1`, `= 2`, ...). - Tag numbers MUST be unique within the union and MUST NOT be reused. +- Tag number `0` is reserved for language bindings that expose an unknown-case + carrier. Schema-defined alternatives MUST NOT use tag `0`. #### Type IDs and type meta @@ -1611,6 +1613,13 @@ A union payload is: This is required even for primitives so unknown alternatives can be skipped safely. +If a reader sees a positive `case_id` that is not present in its local union +schema, it SHOULD preserve the unknown case when the target language has a +language-neutral carrier for it. Such a carrier MUST store the original positive +case ID and the deserialized `case_value`. Writers MUST NOT serialize `0` as a +schema-defined case ID; `0` is only the local unknown-case slot used by bindings +that need one. + #### Wire layouts **UNION (schema known from context)** diff --git a/docs/specification/xlang_type_mapping.md b/docs/specification/xlang_type_mapping.md index 654fa14cfc..2a5fdcebe4 100644 --- a/docs/specification/xlang_type_mapping.md +++ b/docs/specification/xlang_type_mapping.md @@ -158,6 +158,43 @@ Notes: payload that declares nullable or ref-tracked elements must raise a compatible-read error when the local matched field is `array`. +### Scala IDL Mapping + +The Scala schema IDL target emits Scala 3 source only. The `fory-scala` runtime +artifact remains cross-built for Scala 2.13 and Scala 3. + +| Fory schema kind | Scala generated carrier | +| ------------------------------------- | -------------------------------------------------------------------------- | +| `optional T` | `Option[T]` | +| `bool` | `Boolean` | +| `int8`, `int16`, `int32`, `int64` | `Byte`, `Short`, `Int`, `Long` | +| `uint8`, `uint16`, `uint32`, `uint64` | `Int`, `Int`, `Long`, `Long` plus unsigned Fory type metadata | +| `float16`, `bfloat16` | JVM half-float and bfloat16 carriers | +| `float32`, `float64` | `Float`, `Double` | +| `string` | `String` | +| `binary` | `Array[Byte]` | +| `list`, `set`, `map` | `List[T]`, `Set[T]`, `Map[K, V]` | +| `array` | `Array[Boolean]` | +| `array`, `array` | `Array[Byte]` with signed/unsigned descriptor metadata | +| `array`, `array` | `Array[Short]` with signed/unsigned descriptor metadata | +| `array`, `array` | `Array[Int]` with signed/unsigned descriptor metadata | +| `array`, `array` | `Array[Long]` with signed/unsigned descriptor metadata | +| `array`, `array` | `Array[Short]` with reduced-precision descriptor metadata | +| `array`, `array` | `Array[Float]`, `Array[Double]` | +| `date`, `timestamp`, `duration` | `java.time.LocalDate`, `java.time.Instant`, `java.time.Duration` | +| `decimal` | `java.math.BigDecimal` | +| `message` | Scala 3 `case class` by default; normal class only for construction cycles | +| `enum` | Scala 3 `enum` with stable Fory enum IDs | +| `union` | Scala 3 ADT `enum derives ForySerializer` | +| `any` | `AnyRef` | + +Generated Scala descriptor metadata is produced by Scala 3 macro derivation +from Scala compile-time types, including nested generics, `Option`, arrays, +scalar encoding annotations, nullability, and `@Ref`. Java reflection is not the +source of truth for generated Scala TypeDef metadata. Scala `@Ref` metadata is +represented by the shared `org.apache.fory.annotation.Ref` annotation; generated +Scala does not use `@ForyField(ref = true)`. + ## Type info Due to differences between type systems of languages, those types can't be mapped one-to-one between languages. diff --git a/integration_tests/idl_tests/generate_idl.py b/integration_tests/idl_tests/generate_idl.py index fba81bfb18..b53442e997 100755 --- a/integration_tests/idl_tests/generate_idl.py +++ b/integration_tests/idl_tests/generate_idl.py @@ -53,6 +53,7 @@ "swift": REPO_ROOT / "integration_tests/idl_tests/swift/idl_package/Sources/IdlGenerated/generated", "dart": REPO_ROOT / "integration_tests/idl_tests/dart/lib/generated", + "scala": REPO_ROOT / "integration_tests/idl_tests/scala/src/main/scala/generated", } GO_OUTPUT_OVERRIDES = { diff --git a/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java b/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java index cd1f2ea0f0..44d320e50e 100644 --- a/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java +++ b/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java @@ -707,7 +707,8 @@ private List resolvePeers() { .filter(value -> !value.isEmpty()) .collect(Collectors.toList()); if (peers.contains("all")) { - return Arrays.asList("python", "go", "rust", "cpp", "swift", "javascript", "csharp", "dart"); + return Arrays.asList( + "python", "go", "rust", "cpp", "swift", "javascript", "csharp", "dart", "scala"); } return peers; } @@ -791,6 +792,16 @@ private PeerCommand buildPeerCommand( "dart", "test", "--name", "interop file roundtrip hooks when env vars are set"); peerCommand.environment.put("ENABLE_FORY_DEBUG_OUTPUT", "1"); break; + case "scala": + workDir = idlRoot.resolve("scala"); + command = + Arrays.asList( + "sbt", + "--batch", + "++3.3.1", + "Test/runMain org.apache.fory.idl_tests.ScalaIdlRoundTripPeer"); + peerCommand.environment.put("ENABLE_FORY_DEBUG_OUTPUT", "1"); + break; default: throw new IllegalArgumentException("Unknown peer language: " + peer); } diff --git a/integration_tests/idl_tests/run_scala_tests.sh b/integration_tests/idl_tests/run_scala_tests.sh new file mode 100755 index 0000000000..a1c8f353d5 --- /dev/null +++ b/integration_tests/idl_tests/run_scala_tests.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +python "${SCRIPT_DIR}/generate_idl.py" --lang scala + +cd "${ROOT_DIR}/integration_tests/idl_tests/scala" +ENABLE_FORY_DEBUG_OUTPUT=1 sbt ++3.3.1 clean test + +cd "${ROOT_DIR}/integration_tests/idl_tests" +IDL_PEER_LANG=scala IDL_JAVA_TEST_PATTERN=IdlRoundTripTest ./run_java_tests.sh diff --git a/integration_tests/idl_tests/scala/build.sbt b/integration_tests/idl_tests/scala/build.sbt new file mode 100644 index 0000000000..ff6bd548df --- /dev/null +++ b/integration_tests/idl_tests/scala/build.sbt @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +ThisBuild / scalaVersion := "3.3.1" +ThisBuild / organization := "org.apache.fory" + +lazy val foryScala = ProjectRef(file("../../../scala"), "fory-scala") + +lazy val root = (project in file(".")) + .dependsOn(foryScala) + .settings( + name := "fory-scala-idl-tests", + publish / skip := true, + libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.19" % Test + ) diff --git a/integration_tests/idl_tests/scala/project/build.properties b/integration_tests/idl_tests/scala/project/build.properties new file mode 100644 index 0000000000..04267b14af --- /dev/null +++ b/integration_tests/idl_tests/scala/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.9.9 diff --git a/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripPeer.scala b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripPeer.scala new file mode 100644 index 0000000000..0caabecd08 --- /dev/null +++ b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripPeer.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.idl_tests + +import addressbook.AddressbookForyRegistration +import auto_id.AutoIdForyRegistration +import collection.CollectionForyRegistration +import complex_fbs.ComplexFbsForyRegistration +import complex_pb.ComplexPbForyRegistration +import example.ExampleForyRegistration +import graph.GraphForyRegistration +import monster.MonsterForyRegistration +import optional_types.OptionalTypesForyRegistration +import org.apache.fory.Fory +import org.apache.fory.serializer.scala.ScalaSerializers +import tree.TreeForyRegistration + +import java.nio.file.{Files, Path} + +object ScalaIdlRoundTripPeer { + def main(args: Array[String]): Unit = { + val compatible = sys.env.get("IDL_COMPATIBLE").forall(_.toBoolean) + roundTrip("DATA_FILE", compatible, refTracking = false)(AddressbookForyRegistration.register) + roundTrip("DATA_FILE_AUTO_ID", compatible, refTracking = false)(AutoIdForyRegistration.register) + roundTrip("DATA_FILE_PRIMITIVES", compatible, refTracking = false) { fory => + AddressbookForyRegistration.register(fory) + ComplexPbForyRegistration.register(fory) + } + roundTrip("DATA_FILE_COLLECTION", compatible, refTracking = false)( + CollectionForyRegistration.register) + roundTrip("DATA_FILE_COLLECTION_UNION", compatible, refTracking = false)( + CollectionForyRegistration.register) + roundTrip("DATA_FILE_COLLECTION_ARRAY", compatible, refTracking = false)( + CollectionForyRegistration.register) + roundTrip("DATA_FILE_COLLECTION_ARRAY_UNION", compatible, refTracking = false)( + CollectionForyRegistration.register) + roundTrip("DATA_FILE_EXAMPLE", compatible, refTracking = false)(ExampleForyRegistration.register) + roundTrip("DATA_FILE_EXAMPLE_UNION", compatible, refTracking = false)( + ExampleForyRegistration.register) + roundTrip("DATA_FILE_OPTIONAL_TYPES", compatible, refTracking = false)( + OptionalTypesForyRegistration.register) + roundTrip("DATA_FILE_TREE", compatible, refTracking = true)(TreeForyRegistration.register) + roundTrip("DATA_FILE_GRAPH", compatible, refTracking = true)(GraphForyRegistration.register) + roundTrip("DATA_FILE_FLATBUFFERS_MONSTER", compatible, refTracking = false)( + MonsterForyRegistration.register) + roundTrip("DATA_FILE_FLATBUFFERS_TEST2", compatible, refTracking = false) { fory => + MonsterForyRegistration.register(fory) + ComplexFbsForyRegistration.register(fory) + } + } + + private def roundTrip( + envName: String, + compatible: Boolean, + refTracking: Boolean)(register: Fory => Unit): Unit = { + sys.env.get(envName).foreach { file => + val fory = Fory.builder() + .withXlang(true) + .withCompatible(compatible) + .withRefTracking(refTracking) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + ScalaSerializers.registerSerializers(fory) + register(fory) + val path = Path.of(file) + val value = fory.deserialize(Files.readAllBytes(path)) + Files.write(path, fory.serialize(value)) + } + } +} diff --git a/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala new file mode 100644 index 0000000000..5deb278663 --- /dev/null +++ b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.idl_tests + +import basic.{BasicEnvelope, BasicForyRegistration, BasicValue, Money} +import collection.{ + CollectionForyRegistration, + NumericCollectionArrayUnion, + NumericCollectionUnion, + NumericCollections, + NumericCollectionsArray +} +import example.{ExampleForyRegistration, ExampleMessage, ExampleState} +import org.apache.fory.Fory +import org.apache.fory.meta.FieldTypes +import org.apache.fory.scala.{ForyScalaEnum, ForySerializer} +import org.apache.fory.serializer.StaticGeneratedStructSerializerFactory +import org.apache.fory.`type`.{TypeUtils, Types} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import tree.{TreeForyRegistration, TreeNode} + +import java.math.BigDecimal +import scala.jdk.CollectionConverters._ + +final class ScalaIdlRoundTripTest extends AnyWordSpec with Matchers { + "generated Scala IDL models" should { + "round trip case classes, Option fields, and ADT union cases" in { + val fory = Fory.builder() + .withXlang(true) + .withCompatible(true) + .withRefTracking(true) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + BasicForyRegistration.register(fory) + + val envelope = BasicEnvelope( + Some(Money(new BigDecimal("12.34"), "USD")), + BasicValue.MoneyCase(Money(new BigDecimal("56.78"), "EUR")), + None) + + fory.deserialize(fory.serialize(envelope)) shouldEqual envelope + } + + "round trip generated Scala collection metadata" in { + val fory = Fory.builder() + .withXlang(true) + .withCompatible(true) + .withRefTracking(true) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + CollectionForyRegistration.register(fory) + + val collections = NumericCollections( + int8Values = List(1.toByte, (-2).toByte), + int16Values = List(3.toShort, (-4).toShort), + int32Values = List(5, -6), + int64Values = List(7L, -8L), + uint8Values = List(9, 255), + uint16Values = List(10, 65535), + uint32Values = List(11L, 4294967295L), + uint64Values = List(12L, -1L), + float32Values = List(1.5f, -2.5f), + float64Values = List(3.5d, -4.5d)) + val union = NumericCollectionUnion.Uint32ValuesCase(List(1L, 4294967295L)) + val arrays = NumericCollectionsArray( + int8Values = Array[Byte](1, -2), + int16Values = Array[Short](3, -4), + int32Values = Array[Int](5, -6), + int64Values = Array[Long](7L, -8L), + uint8Values = Array[Byte](9, -1), + uint16Values = Array[Short](10, -1), + uint32Values = Array[Int](11, -1), + uint64Values = Array[Long](12L, -1L), + float32Values = Array[Float](1.5f, -2.5f), + float64Values = Array[Double](3.5d, -4.5d)) + val arrayUnion = NumericCollectionArrayUnion.Uint32ValuesCase(Array[Int](1, -1)) + + fory.deserialize(fory.serialize(collections)) shouldEqual collections + fory.deserialize(fory.serialize(union)) shouldEqual union + val arraysRoundTrip = fory.deserialize(fory.serialize(arrays)).asInstanceOf[NumericCollectionsArray] + arraysRoundTrip.int8Values.sameElements(arrays.int8Values) shouldBe true + arraysRoundTrip.uint32Values.sameElements(arrays.uint32Values) shouldBe true + val arrayUnionRoundTrip = + fory.deserialize(fory.serialize(arrayUnion)).asInstanceOf[NumericCollectionArrayUnion] + arrayUnionRoundTrip.asInstanceOf[NumericCollectionArrayUnion.Uint32ValuesCase].value + .sameElements(arrayUnion.asInstanceOf[NumericCollectionArrayUnion.Uint32ValuesCase].value) shouldBe true + } + + "preserve generated Scala enum metadata in nested descriptors" in { + classOf[ForyScalaEnum].isAssignableFrom(classOf[ExampleState]) shouldBe true + val factory = + summon[ForySerializer[ExampleMessage]] + .asInstanceOf[StaticGeneratedStructSerializerFactory[ExampleMessage]] + val descriptors = factory.getGeneratedDescriptors.asScala + val enumValue = descriptors.find(_.getName == "enumValue").get + val enumList = descriptors.find(_.getName == "enumList").get + val enumMap = descriptors.find(_.getName == "enumValuesByName").get + val uint8ArrayList = descriptors.find(_.getName == "uint8ArrayList").get + val uint8ArrayMap = descriptors.find(_.getName == "uint8ArrayValuesByName").get + + enumValue.getTypeRef.getTypeExtMeta.typeId() shouldBe Types.ENUM + TypeUtils.getElementType(enumList.getTypeRef).getTypeExtMeta.typeId() shouldBe Types.ENUM + TypeUtils.getMapKeyValueType(enumMap.getTypeRef).f1.getTypeExtMeta.typeId() shouldBe Types.ENUM + TypeUtils.getElementType(uint8ArrayList.getTypeRef).getComponentType.getTypeExtMeta + .typeId() shouldBe Types.UINT8 + TypeUtils.getMapKeyValueType(uint8ArrayMap.getTypeRef).f1.getComponentType.getTypeExtMeta + .typeId() shouldBe Types.UINT8 + + val fory = Fory.builder() + .withXlang(true) + .withCompatible(false) + .withRefTracking(false) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + ExampleForyRegistration.register(fory) + val serializer = factory.newSerializer(fory.getTypeResolver, classOf[ExampleMessage], null) + val fieldGroups = serializer.buildLocalFieldGroups(factory.getGeneratedDescriptors) + val enumListInfo = fieldGroups.allFields.find(_.descriptor.getName == "enumList").get + enumListInfo.genericType.getTypeParameter0.getTypeRef.getTypeExtMeta.typeId() shouldBe + Types.ENUM + enumListInfo.genericType.getTypeParameter0.isMonomorphic shouldBe true + FieldTypes + .buildFieldType(fory.getTypeResolver, uint8ArrayList) + .asInstanceOf[FieldTypes.CollectionFieldType] + .getElementType + .getTypeId shouldBe Types.UINT8_ARRAY + FieldTypes + .buildFieldType(fory.getTypeResolver, uint8ArrayMap) + .asInstanceOf[FieldTypes.MapFieldType] + .getValueType + .getTypeId shouldBe Types.UINT8_ARRAY + } + + "round trip generated cycle-owned normal classes" in { + val fory = Fory.builder() + .withXlang(true) + .withCompatible(true) + .withRefTracking(true) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + TreeForyRegistration.register(fory) + + val root = new TreeNode() + root.id = "root" + root.name = "Root" + val child = new TreeNode() + child.id = "child" + child.name = "Child" + child.parent = Some(root) + root.children = List(child) + + val roundTrip = fory.deserialize(fory.serialize(root)).asInstanceOf[TreeNode] + roundTrip.id shouldEqual "root" + roundTrip.children.head.id shouldEqual "child" + roundTrip.children.head.parent.get shouldBe theSameInstanceAs(roundTrip) + } + } +} diff --git a/java/fory-core/src/main/java/org/apache/fory/annotation/ForyCase.java b/java/fory-core/src/main/java/org/apache/fory/annotation/ForyCase.java new file mode 100644 index 0000000000..8b3fbd39a9 --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/annotation/ForyCase.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Declares a stable xlang union case ID. + * + *

Case ID {@code 0} is reserved for unknown-case carriers. Schema-defined union cases must use + * positive IDs. + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ + ElementType.TYPE, + ElementType.FIELD, + ElementType.METHOD, + ElementType.PARAMETER, + ElementType.CONSTRUCTOR +}) +public @interface ForyCase { + int id(); +} diff --git a/java/fory-core/src/main/java/org/apache/fory/annotation/ForyField.java b/java/fory-core/src/main/java/org/apache/fory/annotation/ForyField.java index f4f63c8bde..50815b6051 100644 --- a/java/fory-core/src/main/java/org/apache/fory/annotation/ForyField.java +++ b/java/fory-core/src/main/java/org/apache/fory/annotation/ForyField.java @@ -25,7 +25,7 @@ import java.lang.annotation.Target; @Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.FIELD, ElementType.METHOD}) +@Target({ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER}) public @interface ForyField { /** Controls polymorphism behavior for struct fields in cross-language serialization. */ diff --git a/java/fory-core/src/main/java/org/apache/fory/annotation/ForyUnion.java b/java/fory-core/src/main/java/org/apache/fory/annotation/ForyUnion.java new file mode 100644 index 0000000000..e354553432 --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/annotation/ForyUnion.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** Marker annotation for generated or derived JVM xlang union types. */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface ForyUnion {} diff --git a/java/fory-core/src/main/java/org/apache/fory/annotation/Ref.java b/java/fory-core/src/main/java/org/apache/fory/annotation/Ref.java index 1221305882..b90f40f85b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/annotation/Ref.java +++ b/java/fory-core/src/main/java/org/apache/fory/annotation/Ref.java @@ -26,7 +26,12 @@ /** Type-use annotation to explicitly enable/disable reference tracking for generic elements. */ @Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.TYPE_USE, ElementType.TYPE_PARAMETER}) +@Target({ + ElementType.FIELD, + ElementType.PARAMETER, + ElementType.TYPE_USE, + ElementType.TYPE_PARAMETER +}) public @interface Ref { /** Whether to enable reference tracking for the annotated type. */ boolean enable() default true; diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java index c07d2ecffb..ef6eb51c49 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java @@ -119,6 +119,10 @@ private static FieldType buildFieldType( && !primitiveList && TypeAnnotationUtils.isBoxedListArrayType(descriptor); TypeExtMeta typeExtMeta = genericType.getTypeRef().getTypeExtMeta(); + int primitiveArrayTypeIdFromComponentMeta = + rawType.isArray() && rawType.getComponentType().isPrimitive() + ? primitiveArrayTypeIdFromComponentMeta(genericType.getTypeRef()) + : Types.UNKNOWN; TypeExtMeta primitiveListArgumentMeta = primitiveList ? primitiveListArgumentMeta(genericType.getTypeRef()) : null; TypeExtMeta primitiveListInlineMeta = @@ -156,6 +160,8 @@ private static FieldType buildFieldType( // For primitive arrays with type annotations, use getDescriptorTypeId to parse annotation. // This allows @UInt8Type etc. to override the default byte[] bytes schema. typeId = Types.getDescriptorTypeId(resolver, descriptor); + } else if (primitiveArrayTypeIdFromComponentMeta != Types.UNKNOWN) { + typeId = primitiveArrayTypeIdFromComponentMeta; } else if (typeAnnotation != null && rawType.isArray() && descriptor != null) { typeId = Types.getDescriptorTypeId(resolver, descriptor); } else { @@ -254,7 +260,8 @@ private static FieldType buildFieldType( elementNullable, elementTrackingRef, primitiveListElementTypeId, -1)); } - if (COLLECTION_TYPE.isSupertypeOf(genericType.getTypeRef())) { + if (COLLECTION_TYPE.isSupertypeOf(genericType.getTypeRef()) + || (isXlang && (resolver.isCollection(rawType) || resolver.isSet(rawType)))) { return new CollectionFieldType( typeId, nullable, @@ -265,7 +272,8 @@ private static FieldType buildFieldType( genericType.getTypeParameter0() == null ? GenericType.build(Object.class) : genericType.getTypeParameter0())); - } else if (MAP_TYPE.isSupertypeOf(genericType.getTypeRef())) { + } else if (MAP_TYPE.isSupertypeOf(genericType.getTypeRef()) + || (isXlang && resolver.isMap(rawType))) { Tuple2, TypeRef> mapKeyValueType = getMapKeyValueType(genericType); return new MapFieldType( typeId, @@ -285,6 +293,8 @@ private static FieldType buildFieldType( : resolver.buildGenericType(mapKeyValueType.f1))); } else if (isUnionType || Union.class.isAssignableFrom(rawType)) { return new UnionFieldType(nullable, trackingRef); + } else if (Types.isEnumType(typeId)) { + return new EnumFieldType(nullable, Types.ENUM, -1); } else if (TypeUtils.unwrap(rawType).isPrimitive()) { // unified basic types for xlang and native mode return new RegisteredFieldType(nullable, trackingRef, typeId, -1); @@ -353,6 +363,18 @@ private static TypeExtMeta primitiveListArgumentMeta(TypeRef typeRef) { return elementMeta != null && Types.isPrimitiveType(elementMeta.typeId()) ? elementMeta : null; } + private static int primitiveArrayTypeIdFromComponentMeta(TypeRef typeRef) { + TypeRef componentType = typeRef.getComponentType(); + if (componentType == null) { + return Types.UNKNOWN; + } + TypeExtMeta componentMeta = componentType.getTypeExtMeta(); + if (componentMeta == null || componentMeta.typeId() == Types.UNKNOWN) { + return Types.UNKNOWN; + } + return TypeAnnotationUtils.getArrayTypeIdFromElementType(componentType); + } + public abstract static class FieldType implements Serializable { private static final int KIND_OBJECT = 0; private static final int KIND_MAP = 1; @@ -886,8 +908,13 @@ public TypeRef toTypeToken(TypeResolver resolver, TypeRef declared) { if (declElementType.equals(elementType)) { return declared; } - return collectionOf( - declaredClass, elementType, new TypeExtMeta(typeId, nullable, trackingRef)); + TypeExtMeta extMeta = new TypeExtMeta(typeId, nullable, trackingRef); + if (!java.util.Collection.class.isAssignableFrom(declaredClass) + && resolver.isCollection(declaredClass)) { + return TypeRef.of( + declaredClass, extMeta, java.util.Collections.singletonList(elementType), null); + } + return collectionOf(declaredClass, elementType, extMeta); } // Build array type from element type // elementType could be base type (int) or intermediate array (int[]) @@ -984,13 +1011,18 @@ public TypeRef toTypeToken(TypeResolver classResolver, TypeRef declared) { } if (valueDecl.hasWildcard()) { // handle generic bound - valueDecl = keyDecl.resolveAllWildcards(); + valueDecl = valueDecl.resolveAllWildcards(); } - return mapOf( - declared.getRawType(), - keyType.toTypeToken(classResolver, keyDecl), - valueType.toTypeToken(classResolver, valueDecl), - new TypeExtMeta(typeId, nullable, trackingRef)); + TypeExtMeta extMeta = new TypeExtMeta(typeId, nullable, trackingRef); + TypeRef keyTypeRef = keyType.toTypeToken(classResolver, keyDecl); + TypeRef valueTypeRef = valueType.toTypeToken(classResolver, valueDecl); + Class declaredClass = declared.getRawType(); + if (!java.util.Map.class.isAssignableFrom(declaredClass) + && classResolver.isMap(declaredClass)) { + return TypeRef.of( + declaredClass, extMeta, java.util.Arrays.asList(keyTypeRef, valueTypeRef), null); + } + return mapOf(declaredClass, keyTypeRef, valueTypeRef, extMeta); } return mapOf( keyType.toTypeToken(classResolver, keyDecl), @@ -1040,8 +1072,8 @@ public EnumFieldType(boolean nullable, int typeId, int userTypeId) { @Override public TypeRef toTypeToken(TypeResolver classResolver, TypeRef declared) { - if (declared != null && declared.getRawType().isEnum()) { - return declared; + if (declared != null) { + return TypeRef.of(declared.getRawType(), new TypeExtMeta(Types.ENUM, nullable, false)); } return TypeRef.of(UnknownClass.UnknownEnum.class); } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index 1c671c5ab2..7aa11ffeae 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -595,6 +595,70 @@ public void registerUnion(Class cls, String namespace, String name, Serialize registerGraalvmClass(cls); } + @Override + public void registerUnionCase(Class unionType, Class caseType) { + checkRegisterAllowed(); + TypeInfo typeInfo = classInfoMap.get(unionType); + Preconditions.checkArgument( + typeInfo != null && Types.isUnionType(typeInfo.typeId), + "Union type %s must be registered before case type %s", + unionType, + caseType); + TypeInfo existingInfo = classInfoMap.get(caseType); + Preconditions.checkArgument( + existingInfo == null || existingInfo == typeInfo, + "Union case type %s has been registered as %s", + caseType, + existingInfo); + classInfoMap.put(caseType, typeInfo); + extRegistry.registeredClasses.put(caseType.getName(), caseType); + registerGraalvmClass(caseType); + } + + @Override + public void registerEnum(Class cls, long userId, Serializer serializer) { + checkRegisterAllowed(); + int checkedUserId = toUserTypeId(userId); + Preconditions.checkNotNull(serializer); + checkRegistration(cls, checkedUserId, cls.getName(), false); + extRegistry.registeredClassIdMap.put(cls, checkedUserId); + TypeInfo typeInfo = classInfoMap.get(cls); + if (typeInfo == null) { + typeInfo = new TypeInfo(this, cls, serializer, Types.ENUM, checkedUserId); + } else { + typeInfo = typeInfo.copy(Types.ENUM, checkedUserId); + typeInfo.setSerializer(this, serializer); + } + updateTypeInfo(cls, typeInfo); + extRegistry.registeredClasses.put(cls.getName(), cls); + registerGraalvmClass(cls); + } + + @Override + public void registerEnum(Class cls, String namespace, String name, Serializer serializer) { + checkRegisterAllowed(); + Preconditions.checkNotNull(serializer); + Preconditions.checkArgument(!Functions.isLambda(cls)); + Preconditions.checkArgument(!ReflectionUtils.isJdkProxy(cls)); + Preconditions.checkArgument(!cls.isArray()); + String fullname = name; + if (namespace == null) { + namespace = ""; + } + if (!StringUtils.isBlank(namespace)) { + fullname = namespace + "." + name; + } + checkRegistration(cls, -1, fullname, false); + EncodedMetaString nsBytes = sharedRegistry.getPackageEncodedMetaString(namespace); + EncodedMetaString nameBytes = sharedRegistry.getTypeNameEncodedMetaString(name); + TypeInfo typeInfo = new TypeInfo(cls, nsBytes, nameBytes, serializer, Types.NAMED_ENUM, -1); + typeInfo.setSerializer(this, serializer); + classInfoMap.put(cls, typeInfo); + compositeNameBytes2TypeInfo.put(new TypeNameBytes(nsBytes, nameBytes), typeInfo); + extRegistry.registeredClasses.put(fullname, cls); + registerGraalvmClass(cls); + } + /** * Registers multiple classes for internal use with auto-assigned internal IDs. * diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java b/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java index ca516748d3..6725a38699 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java @@ -26,6 +26,7 @@ import org.apache.fory.exception.ForyException; import org.apache.fory.meta.TypeDef; import org.apache.fory.serializer.StaticGeneratedStructSerializer; +import org.apache.fory.serializer.StaticGeneratedStructSerializerFactory; import org.apache.fory.type.Descriptor; /** Shared registry of build-time generated static serializer mappings. */ @@ -90,11 +91,48 @@ List getGeneratedDescriptors() { private final ConcurrentHashMap, Entry> xlangSerializers = new ConcurrentHashMap<>(); private final ConcurrentHashMap, Entry> nativeSerializers = new ConcurrentHashMap<>(); + private final ConcurrentHashMap, StaticGeneratedStructSerializerFactory> + xlangFactories = new ConcurrentHashMap<>(); + private final ConcurrentHashMap, StaticGeneratedStructSerializerFactory> + nativeFactories = new ConcurrentHashMap<>(); private final ConcurrentHashMap, Boolean> missingXlangSerializers = new ConcurrentHashMap<>(); private final ConcurrentHashMap, Boolean> missingNativeSerializers = new ConcurrentHashMap<>(); + void registerFactory( + Class targetType, boolean xlang, StaticGeneratedStructSerializerFactory factory) { + ConcurrentHashMap, StaticGeneratedStructSerializerFactory> factories = + xlang ? xlangFactories : nativeFactories; + StaticGeneratedStructSerializerFactory existing = factories.putIfAbsent(targetType, factory); + if (existing != null && existing != factory && existing.getClass() != factory.getClass()) { + throw new IllegalArgumentException( + "Conflicting static generated serializer factory for " + targetType.getName()); + } + } + + StaticGeneratedStructSerializer newRegisteredSerializer( + TypeResolver resolver, Class targetType, TypeDef typeDef) { + StaticGeneratedStructSerializerFactory factory = + getRegisteredFactory(targetType, resolver.isCrossLanguage()); + if (factory == null) { + return null; + } + return factory.newSerializer(resolver, targetType, typeDef); + } + + List getRegisteredDescriptors(Class targetType, boolean xlang) { + StaticGeneratedStructSerializerFactory factory = getRegisteredFactory(targetType, xlang); + return factory == null ? null : factory.getGeneratedDescriptors(); + } + + private StaticGeneratedStructSerializerFactory getRegisteredFactory( + Class targetType, boolean xlang) { + ConcurrentHashMap, StaticGeneratedStructSerializerFactory> factories = + xlang ? xlangFactories : nativeFactories; + return factories.get(targetType); + } + Class getSerializerClass( Class targetType, boolean xlang) { Entry entry = getEntry(targetType, xlang); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 3532f3a09e..69e0cbf648 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -82,12 +82,14 @@ import org.apache.fory.serializer.CodegenSerializer; import org.apache.fory.serializer.CodegenSerializer.LazyInitBeanSerializer; import org.apache.fory.serializer.CompatibleSerializer; +import org.apache.fory.serializer.DeferedLazySerializer; import org.apache.fory.serializer.ObjectSerializer; import org.apache.fory.serializer.PrimitiveSerializers; import org.apache.fory.serializer.Serializer; import org.apache.fory.serializer.SerializerFactory; import org.apache.fory.serializer.Serializers; import org.apache.fory.serializer.StaticGeneratedStructSerializer; +import org.apache.fory.serializer.StaticGeneratedStructSerializerFactory; import org.apache.fory.serializer.UnknownClass; import org.apache.fory.serializer.UnknownClass.UnknownEmptyStruct; import org.apache.fory.serializer.UnknownClass.UnknownStruct; @@ -289,6 +291,26 @@ public void register(String className, String namespace, String typeName) { public abstract void registerUnion( Class type, String namespace, String typeName, Serializer serializer); + /** + * Registers {@code caseType} as a runtime class alias for an already registered union type. + * + *

Some JVM languages compile a union value to concrete case subclasses even though the wire + * type is owned by the sealed union base. This method makes runtime dispatch for those case + * subclasses use the base union {@link TypeInfo}; it must not create another wire type name or + * user type ID. + */ + @Internal + public abstract void registerUnionCase(Class unionType, Class caseType); + + /** Registers a non-Java enum type with a user-specified ID and serializer. */ + @Internal + public abstract void registerEnum(Class type, long id, Serializer serializer); + + /** Registers a non-Java enum type with a namespace, type name, and serializer. */ + @Internal + public abstract void registerEnum( + Class type, String namespace, String typeName, Serializer serializer); + /** * Registers a custom serializer for a type. * @@ -428,6 +450,10 @@ public boolean isCollectionDescriptor(Descriptor descriptor) { return isCollection(descriptor.getRawType()); } + public boolean isMapDescriptor(Descriptor descriptor) { + return isMap(descriptor.getRawType()); + } + public abstract boolean isMonomorphic(Descriptor descriptor); public abstract boolean isMonomorphic(Class clz); @@ -1032,6 +1058,13 @@ private TypeInfo getMetaSharedTypeInfo(TypeDef typeDef, Class clz) { // type metadata or a concrete target-class transformation. return typeInfo; } + StaticGeneratedStructSerializer registeredStaticSerializer = + sharedRegistry.staticGeneratedSerializerRegistry.newRegisteredSerializer( + this, cls, typeDef); + if (registeredStaticSerializer != null) { + typeInfo.setSerializer(this, registeredStaticSerializer); + return typeInfo; + } Class sc = getCompatibleDeserializerClassFromGraalvmRegistry(cls, typeDef); if (sc == null) { @@ -1283,6 +1316,21 @@ private Serializer getNativeTypedValueSerializer(int typeId, Class rawType public abstract void setSerializerIfAbsent(Class cls, Serializer serializer); + @Internal + @SuppressWarnings({"rawtypes", "unchecked"}) + public final void registerStaticGeneratedStructSerializerFactory( + Class cls, StaticGeneratedStructSerializerFactory factory) { + if (!isRegistered(cls)) { + register(cls); + } + sharedRegistry.staticGeneratedSerializerRegistry.registerFactory( + cls, isCrossLanguage(), factory); + Serializer serializer = + new DeferedLazySerializer.DeferredLazyObjectSerializer( + this, cls, () -> Tuple2.of(true, factory.newSerializer(this, cls, null))); + setSerializer(cls, serializer); + } + /** * Reset serializer if {@code serializer} is not null, otherwise clear serializer for {@code cls}. */ @@ -1539,6 +1587,7 @@ private DescriptorGrouper buildDescriptorGrouper( this::usesPrimitiveFieldOrdering, this::isBuildIn, this::isCollectionDescriptor, + this::isMapDescriptor, descriptors, descriptorsGroupedOrdered, descriptorUpdator, @@ -1556,6 +1605,12 @@ public final DescriptorGrouper groupDescriptors( } private List buildFieldDescriptors(Class clz, boolean searchParent) { + List registeredStaticDescriptors = + sharedRegistry.staticGeneratedSerializerRegistry.getRegisteredDescriptors( + clz, isCrossLanguage()); + if (registeredStaticDescriptors != null) { + return normalizeFieldDescriptors(clz, searchParent, registeredStaticDescriptors); + } if (shouldPreferStaticGeneratedSerializer(clz)) { List staticDescriptors = getStaticGeneratedStructDescriptors(clz); if (staticDescriptors != null) { @@ -1861,9 +1916,11 @@ private int getPrimitiveFieldSize(Descriptor descriptor) { *

  • Otherwise: return true only for Optional types, false for all other non-primitives * * - *

    For native mode: reflected value fields are nullable by default. Descriptors without a - * backing field already carry schema-owned nullability, for example TypeDef descriptors and - * annotation-processor generated native descriptors. + *

    Descriptors without a backing Java {@link Field} already carry schema-owned nullability, for + * example TypeDef descriptors and static descriptors emitted by annotation processors or Scala + * macro derivation. + * + *

    For native reflected descriptors: value fields are nullable by default. * *

    Important: this must match the TypeDef metadata for the same descriptor source. Xlang local * descriptors use xlang defaults, native reflected descriptors use native nullable-by-default @@ -1878,15 +1935,15 @@ private boolean isFieldNullable(Descriptor descriptor) { if (typeExtMeta != null) { return typeExtMeta.nullable(); } + if (descriptor.getField() == null) { + return descriptor.isNullable(); + } if (isCrossLanguage()) { // For xlang mode: apply xlang defaults // This must match what TypeDefEncoder.buildFieldType uses for TypeDef metadata // Default for xlang: false for all non-primitives, except Optional types return TypeUtils.isOptionalType(rawType); } - if (descriptor.getField() == null) { - return descriptor.isNullable(); - } return descriptor.hasForyField() ? descriptor.isNullable() : true; } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java index 422c76cbf9..b6e990d035 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java @@ -89,6 +89,7 @@ import org.apache.fory.serializer.PrimitiveArraySerializers; import org.apache.fory.serializer.PrimitiveSerializers; import org.apache.fory.serializer.Serializer; +import org.apache.fory.serializer.SerializerFactory; import org.apache.fory.serializer.Serializers; import org.apache.fory.serializer.Shareable; import org.apache.fory.serializer.SqlTimeSerializers; @@ -260,7 +261,7 @@ public void register(Class type, String namespace, String typeName) { if (typeInfo.typeName != null) { String prevNamespace = typeInfo.decodeNamespace(); String prevTypeName = typeInfo.decodeTypeName(); - if (!namespace.equals(prevNamespace) || typeName.equals(prevTypeName)) { + if (!namespace.equals(prevNamespace) || !typeName.equals(prevTypeName)) { throw new IllegalArgumentException( String.format( "Type %s has been registered with namespace %s type %s", @@ -376,7 +377,7 @@ public void registerUnion( if (typeInfo != null && typeInfo.typeName != null) { String prevNamespace = typeInfo.decodeNamespace(); String prevTypeName = typeInfo.decodeTypeName(); - if (!namespace.equals(prevNamespace) || typeName.equals(prevTypeName)) { + if (!namespace.equals(prevNamespace) || !typeName.equals(prevTypeName)) { throw new IllegalArgumentException( String.format( "Type %s has been registered with namespace %s type %s", @@ -387,6 +388,73 @@ public void registerUnion( register(type, serializer, namespace, typeName, xtypeId, -1); } + @Override + public void registerUnionCase(Class unionType, Class caseType) { + checkRegisterAllowed(); + TypeInfo typeInfo = classInfoMap.get(unionType); + Preconditions.checkArgument( + typeInfo != null && Types.isUnionType(typeInfo.typeId), + "Union type %s must be registered before case type %s", + unionType, + caseType); + TypeInfo existingInfo = classInfoMap.get(caseType); + Preconditions.checkArgument( + existingInfo == null || existingInfo == typeInfo, + "Union case type %s has been registered as %s", + caseType, + existingInfo); + classInfoMap.put(caseType, typeInfo); + extRegistry.registeredClasses.put(caseType.getName(), caseType); + registerGraalvmClass(caseType); + } + + @Override + public void registerEnum(Class type, long userTypeId, Serializer serializer) { + checkRegisterAllowed(); + Preconditions.checkNotNull(serializer); + int checkedUserTypeId = toUserTypeId(userTypeId); + Preconditions.checkArgument( + !containsUserTypeId(checkedUserTypeId), "Type id %s has been registered", userTypeId); + TypeInfo typeInfo = classInfoMap.get(type); + if (typeInfo != null && typeInfo.typeId != 0) { + throw new IllegalArgumentException( + String.format("Type %s has been registered with id %s", type, typeInfo.typeId)); + } + register( + type, + serializer, + ReflectionUtils.getPackage(type), + ReflectionUtils.getClassNameWithoutPackage(type), + Types.ENUM, + checkedUserTypeId); + } + + @Override + public void registerEnum( + Class type, String namespace, String typeName, Serializer serializer) { + checkRegisterAllowed(); + Preconditions.checkNotNull(serializer); + if (namespace == null) { + namespace = ""; + } + Preconditions.checkArgument( + !typeName.contains("."), + "Typename %s should not contains `.`, please put it into namespace", + typeName); + TypeInfo typeInfo = classInfoMap.get(type); + if (typeInfo != null && typeInfo.typeName != null) { + String prevNamespace = typeInfo.decodeNamespace(); + String prevTypeName = typeInfo.decodeTypeName(); + if (!namespace.equals(prevNamespace) || !typeName.equals(prevTypeName)) { + throw new IllegalArgumentException( + String.format( + "Type %s has been registered with namespace %s type %s", + type, prevNamespace, prevTypeName)); + } + } + register(type, serializer, namespace, typeName, Types.NAMED_ENUM, -1); + } + /** * Register type with given type id and serializer for type in fory type system. * @@ -651,6 +719,12 @@ public boolean isMonomorphic(Descriptor descriptor) { return false; } byte typeIdByte = getInternalTypeId(rawType); + if (Types.isUnionType(typeIdByte)) { + return true; + } + if (Types.isEnumType(typeIdByte)) { + return true; + } if (isCompatible()) { return !Types.isUserDefinedType(typeIdByte) && typeIdByte != Types.UNKNOWN; } @@ -680,6 +754,9 @@ public boolean isMonomorphic(Class clz) { } TypeInfo typeInfo = getTypeInfo(clz, false); if (typeInfo != null) { + if (Types.isEnumType(typeInfo.typeId) || Types.isUnionType(typeInfo.typeId)) { + return true; + } Serializer s = typeInfo.serializer; if (s instanceof TimeSerializers.TimeSerializer || s instanceof MapLikeSerializer @@ -801,15 +878,7 @@ private TypeInfo buildTypeInfo(Class cls) { cls = HashMap.class; serializer = new HashMapSerializer(this); } else { - TypeInfo cachedTypeInfo = classInfoMap.get(cls); - if (cachedTypeInfo != null - && cachedTypeInfo.serializer != null - && cachedTypeInfo.serializer instanceof MapLikeSerializer - && ((MapLikeSerializer) cachedTypeInfo.serializer).supportCodegenHook()) { - serializer = cachedTypeInfo.serializer; - } else { - serializer = new MapSerializer(this, cls); - } + serializer = getMapSerializer(cls); } typeId = Types.MAP; } else if (UnknownClass.class.isAssignableFrom(cls)) { @@ -845,9 +914,33 @@ private Serializer getCollectionSerializer(Class cls) { && ((CollectionLikeSerializer) (typeInfo.serializer)).supportCodegenHook()) { return typeInfo.serializer; } + Serializer serializer = createSerializerFromFactory(cls); + if (serializer != null) { + return serializer; + } return new CollectionSerializer(this, cls); } + private Serializer getMapSerializer(Class cls) { + TypeInfo typeInfo = classInfoMap.get(cls); + if (typeInfo != null + && typeInfo.serializer != null + && typeInfo.serializer instanceof MapLikeSerializer + && ((MapLikeSerializer) typeInfo.serializer).supportCodegenHook()) { + return typeInfo.serializer; + } + Serializer serializer = createSerializerFromFactory(cls); + if (serializer != null) { + return serializer; + } + return new MapSerializer(this, cls); + } + + private Serializer createSerializerFromFactory(Class cls) { + SerializerFactory serializerFactory = getSerializerFactory(); + return serializerFactory == null ? null : serializerFactory.createSerializer(this, cls); + } + private void registerDefaultTypes() { Config config = this.config; // Boolean types diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java index c33ab80a57..14fdc817f2 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java @@ -32,6 +32,7 @@ import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.meta.FieldInfo; import org.apache.fory.meta.TypeDef; +import org.apache.fory.resolver.TypeInfo; import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.FieldGroups.SerializationFieldInfo; import org.apache.fory.serializer.converter.FieldConverters; @@ -65,7 +66,7 @@ public StaticGeneratedStructSerializer(TypeResolver typeResolver, Class type) } @SuppressWarnings("unchecked") - protected StaticGeneratedStructSerializer( + public StaticGeneratedStructSerializer( TypeResolver typeResolver, Class type, TypeDef typeDef, List descriptors) { super(typeResolver, (Class) type); setSerializerIfAbsent(typeResolver, (Class) type); @@ -77,12 +78,17 @@ protected StaticGeneratedStructSerializer( } private void setSerializerIfAbsent(TypeResolver typeResolver, Class type) { - if (!typeResolver.isCrossLanguage() || typeResolver.getTypeInfo(type, false) != null) { + TypeInfo typeInfo = typeResolver.getTypeInfo(type, false); + if (!typeResolver.isCrossLanguage() || typeInfo != null) { // Field-group construction resolves monomorphic field serializers. A generated serializer can // therefore encounter its own type before the subclass constructor has finished, just like // ObjectSerializer. Install this instance early so recursive fields reuse it instead of // constructing another serializer for the same type. - typeResolver.setSerializerIfAbsent(type, this); + if (typeInfo != null && typeInfo.getSerializer() instanceof DeferedLazySerializer) { + typeResolver.setSerializer(type, this); + } else { + typeResolver.setSerializerIfAbsent(type, this); + } } } @@ -101,9 +107,13 @@ public final List getDescriptors() { return getGeneratedDescriptors(); } + public final List getRemoteFields() { + return remoteFields; + } + public abstract T readCompatible(ReadContext readContext); - protected final FieldGroups buildFieldGroups(List descriptors) { + public final FieldGroups buildFieldGroups(List descriptors) { descriptors = runtimeDescriptors(descriptors); DescriptorGrouper grouper = FieldGroups.buildDescriptorGrouper( @@ -111,7 +121,7 @@ protected final FieldGroups buildFieldGroups(List descriptors) { return FieldGroups.buildFieldInfos(typeResolver, grouper); } - protected final FieldGroups buildLocalFieldGroups(List descriptors) { + public final FieldGroups buildLocalFieldGroups(List descriptors) { if (!typeResolver.isShareMeta()) { return buildFieldGroups(descriptors); } @@ -127,7 +137,7 @@ protected final List runtimeDescriptors(List descriptors return typeResolver.normalizeFieldDescriptors(type, true, descriptors); } - protected final int[] localFieldIds( + public final int[] localFieldIds( SerializationFieldInfo[] fieldInfos, List descriptors) { Map localIds = new HashMap<>(); for (int i = 0; i < descriptors.size(); i++) { @@ -187,23 +197,64 @@ protected final void writeOtherFieldValue( fieldValue); } - protected final void writeFieldValue( + public final void writeFieldValue( WriteContext writeContext, SerializationFieldInfo fieldInfo, Object fieldValue) { + writeFieldValue(typeResolver, writeContext, fieldInfo, fieldValue); + } + + public static void writeFieldValue( + TypeResolver typeResolver, + WriteContext writeContext, + SerializationFieldInfo fieldInfo, + Object fieldValue) { switch (fieldInfo.codecCategory) { case BUILD_IN: - writeBuildInFieldValue(writeContext, fieldInfo, fieldValue); + // Some schema-built-in fields still use container-shaped Java accessors, such as + // @ArrayType List. The override owns the accessor-to-payload conversion. + if (fieldInfo.containerSerializerOverride != null) { + writeContainerFieldValue(typeResolver, writeContext, fieldInfo, fieldValue); + return; + } + AbstractObjectSerializer.writeBuildInFieldValue( + writeContext, + typeResolver, + writeContext.getRefWriter(), + fieldInfo, + writeContext.getBuffer(), + fieldValue); return; case CONTAINER: - writeContainerFieldValue(writeContext, fieldInfo, fieldValue); + writeContainerFieldValue(typeResolver, writeContext, fieldInfo, fieldValue); return; case OTHER: - writeOtherFieldValue(writeContext, fieldInfo, fieldValue); + AbstractObjectSerializer.writeField( + writeContext, + typeResolver, + writeContext.getRefWriter(), + fieldInfo, + writeContext.getBuffer(), + fieldValue); return; default: throw new IllegalStateException("Unknown field codec category " + fieldInfo.codecCategory); } } + private static void writeContainerFieldValue( + TypeResolver typeResolver, + WriteContext writeContext, + SerializationFieldInfo fieldInfo, + Object fieldValue) { + AbstractObjectSerializer.writeContainerFieldValue( + writeContext, + typeResolver, + writeContext.getRefWriter(), + writeContext.getGenerics(), + fieldInfo, + writeContext.getBuffer(), + fieldValue); + } + protected final Object readBuildInFieldValue( ReadContext readContext, SerializationFieldInfo fieldInfo) { // See writeBuildInFieldValue: built-in schema groups can still need container conversion. @@ -232,18 +283,48 @@ protected final Object readOtherFieldValue( } protected final Object readFieldValue(ReadContext readContext, SerializationFieldInfo fieldInfo) { + return readFieldValue(typeResolver, readContext, fieldInfo); + } + + public static Object readFieldValue( + TypeResolver typeResolver, ReadContext readContext, SerializationFieldInfo fieldInfo) { switch (fieldInfo.codecCategory) { case BUILD_IN: - return readBuildInFieldValue(readContext, fieldInfo); + // See writeFieldValue: built-in schema groups can still need container conversion. + if (fieldInfo.containerSerializerOverride != null) { + return readContainerFieldValue(typeResolver, readContext, fieldInfo); + } + return AbstractObjectSerializer.readBuildInFieldValue( + readContext, + typeResolver, + readContext.getRefReader(), + fieldInfo, + readContext.getBuffer()); case CONTAINER: - return readContainerFieldValue(readContext, fieldInfo); + return readContainerFieldValue(typeResolver, readContext, fieldInfo); case OTHER: - return readOtherFieldValue(readContext, fieldInfo); + return AbstractObjectSerializer.readField( + readContext, + typeResolver, + readContext.getRefReader(), + fieldInfo, + readContext.getBuffer()); default: throw new IllegalStateException("Unknown field codec category " + fieldInfo.codecCategory); } } + private static Object readContainerFieldValue( + TypeResolver typeResolver, ReadContext readContext, SerializationFieldInfo fieldInfo) { + return AbstractObjectSerializer.readContainerFieldValue( + readContext, + typeResolver, + readContext.getRefReader(), + readContext.getGenerics(), + fieldInfo, + readContext.getBuffer()); + } + protected final Object readRemoteField(ReadContext readContext, RemoteFieldInfo remoteField) { if (remoteField.compatibleCollectionArrayReadAction != null) { return CompatibleCollectionArrayReader.read( @@ -254,7 +335,7 @@ protected final Object readRemoteField(ReadContext readContext, RemoteFieldInfo return readField(readContext, remoteField.serializationFieldInfo); } - protected final void skipField(ReadContext readContext, RemoteFieldInfo remoteField) { + public final void skipField(ReadContext readContext, RemoteFieldInfo remoteField) { try { FieldSkipper.skipField( readContext, @@ -273,7 +354,7 @@ protected final SerializationFieldInfo localFieldInfo(int matchedId) { return localFieldsById[matchedId]; } - protected final boolean canReadRemoteField( + public final boolean canReadRemoteField( RemoteFieldInfo remoteField, SerializationFieldInfo localFieldInfo) { if (remoteField.incompatibleCollectionArrayMatch) { throw new DeserializationException( @@ -295,7 +376,7 @@ protected final boolean canReadRemoteField( return FieldConverters.canConvert(remoteType, localType); } - protected final Object readCompatibleFieldValue( + public final Object readCompatibleFieldValue( ReadContext readContext, RemoteFieldInfo remoteField, SerializationFieldInfo localFieldInfo) { Object fieldValue = readRemoteField(readContext, remoteField); if (remoteField.compatibleCollectionArrayReadAction != null) { @@ -385,15 +466,16 @@ protected final Object copyFieldValue( return copyContext.copyObject(fieldValue, fieldInfo.dispatchId); } - protected final int computeClassVersionHash(List descriptors) { - descriptors = runtimeDescriptors(descriptors); - return ObjectSerializer.computeStructHash( - typeResolver, - FieldGroups.buildDescriptorGrouper( - typeResolver, descriptors, false, descriptor -> descriptor)); + public final int computeClassVersionHash(List descriptors) { + DescriptorGrouper grouper = + typeResolver.isShareMeta() + ? typeResolver.createDescriptorGrouper(typeResolver.getTypeDef(type, true), type) + : FieldGroups.buildDescriptorGrouper( + typeResolver, runtimeDescriptors(descriptors), false, descriptor -> descriptor); + return ObjectSerializer.computeStructHash(typeResolver, grouper); } - protected final void checkClassVersion(int readHash, int classVersionHash) { + public final void checkClassVersion(int readHash, int classVersionHash) { ObjectSerializer.checkClassVersion(type, readHash, classVersionHash); } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializerFactory.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializerFactory.java new file mode 100644 index 0000000000..3882e3ceea --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializerFactory.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer; + +import java.util.List; +import org.apache.fory.annotation.Internal; +import org.apache.fory.meta.TypeDef; +import org.apache.fory.resolver.TypeResolver; +import org.apache.fory.type.Descriptor; + +/** + * Factory for statically shaped struct serializers whose implementation is generated by a non-Java + * compiler. + * + *

    Named generated serializer classes are still discovered by {@link + * org.apache.fory.resolver.StaticGeneratedSerializerRegistry}. This factory path is for language + * frontends such as Scala 3 macro derivation where the serializer code is emitted at the typeclass + * call site rather than as a separately named JVM class. + */ +@Internal +public interface StaticGeneratedStructSerializerFactory { + /** Descriptor metadata generated from the source-language type model. */ + List getGeneratedDescriptors(); + + /** + * Create a serializer for {@code type}. + * + * @param typeDef remote TypeDef for compatible reads, or {@code null} for local schema reads. + */ + StaticGeneratedStructSerializer newSerializer( + TypeResolver typeResolver, Class type, TypeDef typeDef); +} diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java index 28f50bb0fc..817f9106b6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java @@ -197,6 +197,140 @@ public Union copy(CopyContext copyContext, Union union) { return factory.apply(union.getIndex(), copiedValue); } + /** + * Writes a schema-defined union case for generated non-Java union carriers. + * + *

    The union envelope is owned here so generated serializers for Scala/Kotlin/etc. do not + * duplicate protocol details or allocate temporary Java {@link Union} carriers. + */ + public static void writeCaseValue( + TypeResolver resolver, + WriteContext writeContext, + FieldGroups.SerializationFieldInfo fieldInfo, + Object value, + int caseId) { + writeContext.getBuffer().writeVarUInt32(caseId); + writeKnownCasePayload(resolver, writeContext, fieldInfo, value); + } + + /** + * Writes an unknown union case payload using dynamic value metadata. + * + *

    Unknown cases preserve the original peer case id; id {@code 0} remains the local unknown + * carrier tag and is not written for schema-defined cases. + */ + public static void writeUnknownCaseValue( + WriteContext writeContext, Object value, int originalCaseId) { + Preconditions.checkArgument( + originalCaseId > 0, "Unknown union case id must preserve a positive original id"); + writeContext.getBuffer().writeVarUInt32(originalCaseId); + writeContext.writeRef(value); + } + + /** Reads a schema-defined union case payload for generated non-Java union carriers. */ + public static Object readCaseValue( + TypeResolver resolver, + ReadContext readContext, + FieldGroups.SerializationFieldInfo fieldInfo) { + int typeId = Types.getDescriptorTypeId(resolver, fieldInfo.descriptor); + if (typeId == Types.UNKNOWN) { + return readContext.readRef(); + } + int nextReadRefId = readContext.tryPreserveRefId(); + if (nextReadRefId >= Fory.NOT_NULL_VALUE_FLAG) { + TypeInfo declared = getDeclaredCaseTypeInfo(fieldInfo, typeId); + TypeInfo readTypeInfo = resolver.readTypeInfo(readContext, declared); + Serializer serializer = getCaseSerializer(fieldInfo, readTypeInfo.getTypeId(), readTypeInfo); + Object caseValue = + readCaseValue( + readContext, serializer, getCaseGenericType(fieldInfo, readTypeInfo.getTypeId())); + readContext.setReadRef(nextReadRefId, caseValue); + return caseValue; + } + return readContext.getReadRef(); + } + + private static void writeKnownCasePayload( + TypeResolver resolver, + WriteContext writeContext, + FieldGroups.SerializationFieldInfo fieldInfo, + Object value) { + int typeId = Types.getDescriptorTypeId(resolver, fieldInfo.descriptor); + if (typeId == Types.UNKNOWN) { + writeContext.writeRef(value); + return; + } + MemoryBuffer buffer = writeContext.getBuffer(); + if (value == null) { + buffer.writeByte(Fory.NULL_FLAG); + return; + } + TypeInfo typeInfo = getCaseTypeInfo(resolver, fieldInfo, value, typeId); + Serializer serializer = getCaseSerializer(fieldInfo, typeId, typeInfo); + if (serializer != null && serializer.needToWriteRef()) { + if (writeContext.writeRefOrNull(value)) { + return; + } + } else { + buffer.writeByte(Fory.NOT_NULL_VALUE_FLAG); + } + if (!Types.isUserDefinedType(typeId)) { + buffer.writeUInt8(typeId); + } else { + resolver.writeTypeInfo(writeContext, typeInfo); + } + writeValue(writeContext, value, typeId, serializer, getCaseGenericType(fieldInfo, typeId)); + } + + private static TypeInfo getCaseTypeInfo( + TypeResolver resolver, + FieldGroups.SerializationFieldInfo fieldInfo, + Object value, + int typeId) { + if (Types.isPrimitiveType(typeId)) { + return resolver.getTypeInfoByTypeId(typeId); + } + TypeInfo declared = getDeclaredCaseTypeInfo(fieldInfo, typeId); + if (declared != null) { + return declared; + } + if (!Types.isUserDefinedType(typeId)) { + return resolver.getTypeInfoByTypeId(typeId); + } + return resolver.getTypeInfo(value.getClass()); + } + + private static TypeInfo getDeclaredCaseTypeInfo( + FieldGroups.SerializationFieldInfo fieldInfo, int typeId) { + if (Types.isPrimitiveType(typeId)) { + return null; + } + if (fieldInfo.containerTypeInfo != null) { + return fieldInfo.containerTypeInfo; + } + return fieldInfo.typeInfo; + } + + private static Serializer getCaseSerializer( + FieldGroups.SerializationFieldInfo fieldInfo, int typeId, TypeInfo fallbackTypeInfo) { + if (fieldInfo.containerSerializerOverride != null + && (typeId == Types.LIST + || typeId == Types.SET + || typeId == Types.MAP + || Types.isPrimitiveArray(typeId))) { + return fieldInfo.containerSerializerOverride; + } + return fallbackTypeInfo.getSerializer(); + } + + private static GenericType getCaseGenericType( + FieldGroups.SerializationFieldInfo fieldInfo, int typeId) { + if (typeId != Types.LIST && typeId != Types.SET && typeId != Types.MAP) { + return null; + } + return fieldInfo.genericType; + } + private void writeCaseValue(WriteContext writeContext, Object value, int typeId, int caseId) { MemoryBuffer buffer = writeContext.getBuffer(); byte internalTypeId = (byte) typeId; @@ -233,7 +367,7 @@ private void writeCaseValue(WriteContext writeContext, Object value, int typeId, writeValue(writeContext, value, typeId, serializer, getCaseGenericType(caseId, typeId)); } - private void writeCaseValue( + private static void writeCaseValue( WriteContext writeContext, Serializer serializer, GenericType genericType, Object value) { if (genericType == null) { Serializers.write(writeContext, serializer, value); @@ -249,7 +383,7 @@ private void writeCaseValue( } } - private void writeValue( + private static void writeValue( WriteContext writeContext, Object value, int typeId, @@ -317,7 +451,7 @@ private void writeValue( throw new IllegalStateException("Missing serializer for union type id " + typeId); } - private Object readCaseValue( + private static Object readCaseValue( ReadContext readContext, Serializer serializer, GenericType genericType) { if (genericType == null) { return Serializers.read(readContext, serializer); diff --git a/java/fory-core/src/main/java/org/apache/fory/type/DescriptorGrouper.java b/java/fory-core/src/main/java/org/apache/fory/type/DescriptorGrouper.java index ae828ca427..004d3b6dbc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/DescriptorGrouper.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/DescriptorGrouper.java @@ -60,6 +60,7 @@ public class DescriptorGrouper { private final Predicate usesPrimitiveFieldOrdering; private final Predicate isBuildIn; private final Predicate isCollection; + private final Predicate isMap; private final Function descriptorUpdater; private final boolean descriptorsGroupedOrdered; private boolean sorted = false; @@ -89,6 +90,7 @@ private DescriptorGrouper( Predicate usesPrimitiveFieldOrdering, Predicate isBuildIn, Predicate isCollection, + Predicate isMap, Collection descriptors, boolean descriptorsGroupedOrdered, Function descriptorUpdater, @@ -98,6 +100,7 @@ private DescriptorGrouper( this.descriptors = descriptors; this.isBuildIn = isBuildIn; this.isCollection = isCollection; + this.isMap = isMap; this.descriptorUpdater = descriptorUpdater; this.descriptorsGroupedOrdered = descriptorsGroupedOrdered; this.primitiveDescriptors = @@ -166,7 +169,7 @@ public DescriptorGrouper sort() { } } else if (isCollection.test(descriptor)) { collectionDescriptors.add(descriptorUpdater.apply(descriptor)); - } else if (TypeUtils.isMap(descriptor.getRawType())) { + } else if (isMap.test(descriptor)) { mapDescriptors.add(descriptorUpdater.apply(descriptor)); } else if (isBuildIn.test(descriptor)) { buildInDescriptors.add(descriptorUpdater.apply(descriptor)); @@ -260,6 +263,7 @@ public static DescriptorGrouper createDescriptorGrouper( || TypeUtils.isBoxed(descriptor.getRawType()), isBuildIn, DescriptorGrouper::isDefaultCollectionDescriptor, + descriptor -> TypeUtils.isMap(descriptor.getRawType()), descriptors, descriptorsGroupedOrdered, descriptorUpdator, @@ -279,6 +283,7 @@ public static DescriptorGrouper createDescriptorGrouper( usesPrimitiveFieldOrdering, isBuildIn, DescriptorGrouper::isDefaultCollectionDescriptor, + descriptor -> TypeUtils.isMap(descriptor.getRawType()), descriptors, descriptorsGroupedOrdered, descriptorUpdator, @@ -290,6 +295,7 @@ public static DescriptorGrouper createDescriptorGrouper( Predicate usesPrimitiveFieldOrdering, Predicate isBuildIn, Predicate isCollection, + Predicate isMap, Collection descriptors, boolean descriptorsGroupedOrdered, Function descriptorUpdator, @@ -299,6 +305,7 @@ public static DescriptorGrouper createDescriptorGrouper( usesPrimitiveFieldOrdering, isBuildIn, isCollection, + isMap, descriptors, descriptorsGroupedOrdered, descriptorUpdator == null ? DescriptorGrouper::createDescriptor : descriptorUpdator, diff --git a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java index 96ed500566..05f13829ba 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java @@ -576,7 +576,7 @@ public static TypeRef getElementType(TypeRef typeRef) { List> typeArguments = typeRef.getTypeArguments(); if (typeArguments.size() == 1) { Class rawType = getRawType(typeRef); - if (Iterable.class.isAssignableFrom(rawType)) { + if (Iterable.class.isAssignableFrom(rawType) || isScalaCollectionClass(rawType)) { return typeArguments.get(0); } } @@ -612,7 +612,8 @@ public static Tuple2, TypeRef> getMapKeyValueType(TypeRef typeR List> typeArguments = typeRef.getTypeArguments(); if (typeArguments.size() == 2) { Class rawType = getRawType(typeRef); - if (Map.class.isAssignableFrom(rawType) && rawType.getTypeParameters().length == 2) { + if ((Map.class.isAssignableFrom(rawType) || isScalaCollectionClass(rawType)) + && rawType.getTypeParameters().length == 2) { return Tuple2.of(typeArguments.get(0), typeArguments.get(1)); } } @@ -640,6 +641,10 @@ public static Tuple2, TypeRef> getMapKeyValueType(TypeRef typeR return Tuple2.of(keyType, valueType); } + private static boolean isScalaCollectionClass(Class rawType) { + return rawType.getName().startsWith("scala.collection"); + } + public static void applyRefTrackingOverride( GenericType genericType, Object typeUse, boolean globalTrackingRef) { if (genericType == null || typeUse == null) { diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/ScalaXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/ScalaXlangTest.java new file mode 100644 index 0000000000..217b68ad0b --- /dev/null +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/ScalaXlangTest.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.xlang; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Method; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import lombok.Data; +import org.apache.fory.Fory; +import org.apache.fory.annotation.ForyField; +import org.apache.fory.annotation.ForyStruct; +import org.apache.fory.annotation.Nullable; +import org.apache.fory.config.Language; +import org.apache.fory.memory.MemoryBuffer; +import org.apache.fory.memory.MemoryUtils; +import org.apache.fory.serializer.UnionSerializer; +import org.apache.fory.test.TestUtils; +import org.apache.fory.type.union.Union; +import org.testng.Assert; +import org.testng.SkipException; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +/** Executes Java-driven Scala 3 macro xlang serializer tests. */ +@Test +public class ScalaXlangTest extends XlangTestBase { + private static final String DERIVED_CASE = "derived_struct_round_trip"; + private static final String KNOWN_UNION_CASE = "known_union_case_round_trip"; + private static final String UNKNOWN_UNION_CASE = "unknown_union_case_round_trip"; + private static final String NAMESPACE = "scala_peer"; + private static final File SCALA_DIR = new File("../../scala"); + + @BeforeMethod(alwaysRun = true) + public void skipInheritedXlangCases(Method method) { + if (method.getDeclaringClass() != ScalaXlangTest.class) { + throw new SkipException( + "Scala xlang phase 1 validates macro-derived Scala 3 peer serializers"); + } + } + + @Override + protected void ensurePeerReady() { + String enabled = System.getenv("FORY_SCALA_JAVA_CI"); + if (!"1".equals(enabled)) { + throw new SkipException("Skipping ScalaXlangTest: FORY_SCALA_JAVA_CI not set to 1"); + } + boolean buildSuccess = + TestUtils.executeCommand( + Arrays.asList("sbt", "--batch", "++3.3.1", "Test/compile"), + 240, + Collections.emptyMap(), + SCALA_DIR); + if (!buildSuccess) { + throw new AssertionError("Failed to compile Scala xlang peer"); + } + } + + @Override + protected CommandContext buildCommandContext(String caseName, Path dataFile) { + return new CommandContext( + Arrays.asList( + "sbt", + "--batch", + "++3.3.1", + "Test/runMain org.apache.fory.serializer.scala.ScalaXlangPeer " + + caseName + + " " + + dataFile.toAbsolutePath()), + envBuilder(dataFile), + SCALA_DIR); + } + + @Override + protected ExecutionContext prepareExecution(String caseName, byte[] payload) throws IOException { + if (!DERIVED_CASE.equals(caseName) + && !KNOWN_UNION_CASE.equals(caseName) + && !UNKNOWN_UNION_CASE.equals(caseName)) { + throw new SkipException( + "Scala xlang phase 1 validates macro-derived Scala 3 peer serializers"); + } + return super.prepareExecution(caseName, payload); + } + + @Test(groups = "xlang") + public void testDerivedStructRoundTrip() throws IOException { + Fory fory = newFory(); + registerScalaPeerTypes(fory); + + ScalaPeerUserMirror request = new ScalaPeerUserMirror(); + request.id = 41; + request.name = "java"; + request.email = "java@example.com"; + + ExecutionContext context = executePeer(DERIVED_CASE, fory, request); + ScalaPeerUserMirror response = + (ScalaPeerUserMirror) fory.deserialize(readBuffer(context.dataFile())); + Assert.assertEquals(response.id, 42); + Assert.assertEquals(response.name, "scala-java"); + Assert.assertNull(response.email); + } + + @Test(groups = "xlang") + public void testKnownUnionCaseRoundTrip() throws IOException { + Fory fory = newFory(); + registerScalaPeerTypes(fory); + + ScalaPeerUserMirror user = new ScalaPeerUserMirror(); + user.id = 41; + user.name = "java"; + user.email = "java@example.com"; + + ExecutionContext context = + executePeer(KNOWN_UNION_CASE, fory, new ScalaPeerTargetMirror(1, user)); + ScalaPeerTargetMirror response = + (ScalaPeerTargetMirror) fory.deserialize(readBuffer(context.dataFile())); + Assert.assertEquals(response.getIndex(), 1); + Assert.assertTrue(response.getValue() instanceof ScalaPeerUserMirror); + ScalaPeerUserMirror responseUser = (ScalaPeerUserMirror) response.getValue(); + Assert.assertEquals(responseUser.id, 42); + Assert.assertEquals(responseUser.name, "scala-java"); + Assert.assertNull(responseUser.email); + } + + @Test(groups = "xlang") + public void testUnknownUnionCaseRoundTrip() throws IOException { + Fory fory = newFory(); + registerScalaPeerTypes(fory); + + ScalaPeerUserMirror unknownPayload = new ScalaPeerUserMirror(); + unknownPayload.id = 99; + unknownPayload.name = "future"; + unknownPayload.email = "future@example.com"; + + ExecutionContext context = + executePeer(UNKNOWN_UNION_CASE, fory, new ScalaPeerTargetMirror(99, unknownPayload)); + ScalaPeerTargetMirror response = + (ScalaPeerTargetMirror) fory.deserialize(readBuffer(context.dataFile())); + Assert.assertEquals(response.getIndex(), 99); + Assert.assertTrue(response.getValue() instanceof ScalaPeerUserMirror); + ScalaPeerUserMirror responsePayload = (ScalaPeerUserMirror) response.getValue(); + Assert.assertEquals(responsePayload.id, 100); + Assert.assertEquals(responsePayload.name, "scala-future"); + Assert.assertNull(responsePayload.email); + } + + private ExecutionContext executePeer(String caseName, Fory fory, Object request) + throws IOException { + MemoryBuffer buffer = MemoryUtils.buffer(128); + fory.serialize(buffer, request); + ExecutionContext context = prepareExecution(caseName, buffer.getBytes(0, buffer.writerIndex())); + runPeer(context, 180); + return context; + } + + private static Fory newFory() { + return Fory.builder() + .withLanguage(Language.XLANG) + .withCompatible(true) + .requireClassRegistration(true) + .build(); + } + + private static void registerScalaPeerTypes(Fory fory) { + fory.register(ScalaPeerUserMirror.class, NAMESPACE, "ScalaPeerUser"); + fory.registerUnion( + ScalaPeerTargetMirror.class, + NAMESPACE, + "ScalaPeerTarget", + new UnionSerializer(fory.getTypeResolver(), ScalaPeerTargetMirror.class)); + } + + @Data + @ForyStruct + public static class ScalaPeerUserMirror { + @ForyField(id = 1) + public int id; + + @ForyField(id = 2) + public String name; + + @Nullable + @ForyField(id = 3) + public String email; + } + + public static final class ScalaPeerTargetMirror extends Union { + public ScalaPeerTargetMirror(int index, Object value) { + super(index, value); + } + } +} diff --git a/scala/src/main/java/org/apache/fory/scala/ForyScalaEnum.java b/scala/src/main/java/org/apache/fory/scala/ForyScalaEnum.java new file mode 100644 index 0000000000..a69e7c5635 --- /dev/null +++ b/scala/src/main/java/org/apache/fory/scala/ForyScalaEnum.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.scala; + +/** Marker interface for Scala 3 enums generated from Fory schema enum definitions. */ +public interface ForyScalaEnum { + int getForyId(); +} diff --git a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaDispatcher.java b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaDispatcher.java index e2af2a2efa..6f9ca18a96 100644 --- a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaDispatcher.java +++ b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaDispatcher.java @@ -20,6 +20,7 @@ package org.apache.fory.serializer.scala; import org.apache.fory.resolver.TypeResolver; +import org.apache.fory.scala.ForyScalaEnum; import org.apache.fory.serializer.JavaSerializer; import org.apache.fory.serializer.Serializer; import org.apache.fory.serializer.SerializerFactory; @@ -42,6 +43,23 @@ public class ScalaDispatcher implements SerializerFactory { */ @Override public Serializer createSerializer(TypeResolver typeResolver, Class clz) { + if (ForyScalaEnum.class.isAssignableFrom(clz)) { + return new ScalaEnumSerializer(typeResolver, clz); + } + if (scala.Option.class.isAssignableFrom(clz)) { + return new ScalaOptionSerializer(typeResolver, clz); + } + if (typeResolver.isCrossLanguage()) { + if (scala.collection.Map.class.isAssignableFrom(clz)) { + return new ScalaXlangMapSerializer(typeResolver, clz); + } else if (scala.collection.Set.class.isAssignableFrom(clz)) { + return new ScalaXlangSetSerializer(typeResolver, clz); + } else if (scala.collection.Seq.class.isAssignableFrom(clz)) { + return new ScalaXlangSeqSerializer(typeResolver, clz); + } else if (scala.collection.Iterable.class.isAssignableFrom(clz)) { + return new ScalaXlangCollectionSerializer(typeResolver, clz); + } + } // Many map/seq/set types doesn't extends DefaultSerializable. if (scala.collection.SortedMap.class.isAssignableFrom(clz)) { return new ScalaSortedMapSerializer(typeResolver, clz); diff --git a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java new file mode 100644 index 0000000000..c1b6d4b5eb --- /dev/null +++ b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer.scala; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Arrays; +import org.apache.fory.collection.LongMap; +import org.apache.fory.config.Config; +import org.apache.fory.context.ReadContext; +import org.apache.fory.context.WriteContext; +import org.apache.fory.scala.ForyScalaEnum; +import org.apache.fory.serializer.ImmutableSerializer; +import org.apache.fory.serializer.Shareable; +import org.apache.fory.util.Preconditions; + +/** Serializer for Scala 3 enums generated by the Fory Scala schema IDL target. */ +@SuppressWarnings({"unchecked", "rawtypes"}) +public final class ScalaEnumSerializer extends ImmutableSerializer implements Shareable { + private static final int MAX_ENUM_ID_ARRAY_SIZE = 2048; + + private final Config config; + private final Object[] enumConstants; + private final Object[] enumConstantByTagArray; + private final LongMap enumConstantByTagMap; + + public ScalaEnumSerializer(org.apache.fory.resolver.TypeResolver resolver, Class cls) { + super(resolver.getConfig(), (Class) cls, false); + config = resolver.getConfig(); + Preconditions.checkArgument( + ForyScalaEnum.class.isAssignableFrom(cls), + "Scala enum %s must implement %s", + cls, + ForyScalaEnum.class.getName()); + enumConstants = loadValues(cls); + LongMap constantsByTag = new LongMap<>(enumConstants.length); + int maxTag = 0; + for (Object enumConstant : enumConstants) { + int tag = ((ForyScalaEnum) enumConstant).getForyId(); + Object previous = constantsByTag.put(tag, enumConstant); + Preconditions.checkArgument( + previous == null, + "Scala enum %s reuses Fory enum id %s for %s and %s", + cls.getName(), + tag, + previous, + enumConstant); + if (tag > maxTag) { + maxTag = tag; + } + } + if (maxTag < MAX_ENUM_ID_ARRAY_SIZE) { + enumConstantByTagArray = new Object[maxTag + 1]; + constantsByTag.forEach((tag, value) -> enumConstantByTagArray[tag.intValue()] = value); + enumConstantByTagMap = null; + } else { + enumConstantByTagArray = null; + enumConstantByTagMap = constantsByTag; + } + } + + @Override + public void write(WriteContext writeContext, Object value) { + writeContext.getBuffer().writeVarUInt32Small7(((ForyScalaEnum) value).getForyId()); + } + + @Override + public Object read(ReadContext readContext) { + int tag = readContext.getBuffer().readVarUInt32Small7(); + Object value = null; + if (enumConstantByTagArray != null && tag < enumConstantByTagArray.length) { + value = enumConstantByTagArray[tag]; + } else if (enumConstantByTagMap != null) { + value = enumConstantByTagMap.get(tag); + } + if (value != null) { + return value; + } + return handleUnknownEnumValue(tag); + } + + private Object handleUnknownEnumValue(int tag) { + switch (config.getUnknownEnumValueStrategy()) { + case RETURN_NULL: + return null; + case RETURN_FIRST_VARIANT: + return enumConstants[0]; + case RETURN_LAST_VARIANT: + return enumConstants[enumConstants.length - 1]; + default: + throw new IllegalArgumentException( + String.format("Scala enum tag %s not in %s", tag, Arrays.toString(enumConstants))); + } + } + + private static Object[] loadValues(Class cls) { + try { + Method values = cls.getMethod("values"); + Object result = values.invoke(null); + Preconditions.checkArgument( + result instanceof Object[], "Scala enum %s values() did not return an array", cls); + return (Object[]) result; + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalArgumentException("Failed to load Scala enum values for " + cls.getName(), e); + } + } +} diff --git a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java index 64a7c83942..9699d56062 100644 --- a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java +++ b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java @@ -41,6 +41,9 @@ public static void registerSerializers(ThreadSafeFory fory) { public static void registerSerializers(Fory fory) { TypeResolver resolver = setSerializerFactory(fory); + if (resolver.isCrossLanguage()) { + return; + } Config config = resolver.getConfig(); resolver.registerSerializer( @@ -171,6 +174,16 @@ public static void registerSerializers(Fory fory) { resolver.register(scala.collection.mutable.BitSet$.class); } + public static void registerEnum(Fory fory, Class cls, long typeId) { + TypeResolver resolver = fory.getTypeResolver(); + resolver.registerEnum(cls, typeId, new ScalaEnumSerializer(resolver, cls)); + } + + public static void registerEnum(Fory fory, Class cls, String namespace, String typeName) { + TypeResolver resolver = fory.getTypeResolver(); + resolver.registerEnum(cls, namespace, typeName, new ScalaEnumSerializer(resolver, cls)); + } + private static TypeResolver setSerializerFactory(Fory fory) { TypeResolver resolver = fory.getTypeResolver(); ScalaDispatcher dispatcher = new ScalaDispatcher(); diff --git a/scala/src/main/scala-3/org/apache/fory/scala/ForySerializer.scala b/scala/src/main/scala-3/org/apache/fory/scala/ForySerializer.scala new file mode 100644 index 0000000000..613fa6be8b --- /dev/null +++ b/scala/src/main/scala-3/org/apache/fory/scala/ForySerializer.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.scala + +import org.apache.fory.{Fory, ThreadSafeFory} +import org.apache.fory.meta.TypeDef +import org.apache.fory.resolver.TypeResolver +import org.apache.fory.serializer.{ + Serializer, + StaticGeneratedStructSerializerFactory +} + +trait ForySerializer[T] { + def createSerializer(typeResolver: TypeResolver): Serializer[T] = + createSerializer(typeResolver, null) + + def createSerializer(typeResolver: TypeResolver, typeDef: TypeDef): Serializer[T] + + def isUnion: Boolean = false + + def registrationClasses(cls: Class[T]): Array[Class[_]] = Array(cls) +} + +object ForySerializer { + inline def derived[T]: ForySerializer[T] = + ${ org.apache.fory.scala.internal.ForySerializerMacros.derive[T] } + + def register[T](fory: Fory, cls: Class[T])(using serializer: ForySerializer[T]): Unit = { + register(fory, cls, null, null) + } + + def register[T]( + fory: Fory, + cls: Class[T], + typeId: Long)(using serializer: ForySerializer[T]): Unit = { + register(fory, cls, java.lang.Long.valueOf(typeId), null, null) + } + + def register[T]( + fory: Fory, + cls: Class[T], + namespace: String, + typeName: String)(using serializer: ForySerializer[T]): Unit = { + register(fory, cls, null, namespace, typeName) + } + + private def register[T]( + fory: Fory, + cls: Class[T], + typeId: java.lang.Long, + namespace: String, + typeName: String)(using serializer: ForySerializer[T]): Unit = { + val resolver = fory.getTypeResolver + serializer match { + case factory: StaticGeneratedStructSerializerFactory[T] @unchecked => + registerType(fory, cls, typeId, namespace, typeName) + resolver.registerStaticGeneratedStructSerializerFactory(cls, factory) + case _ if serializer.isUnion => + val unionSerializer = serializer.createSerializer(resolver) + if typeId != null then { + resolver.registerUnion(cls, typeId.longValue(), unionSerializer) + } else { + val unionNamespace = + if namespace != null then namespace else Option(cls.getPackage).map(_.getName).orNull + val unionTypeName = if typeName != null then typeName else cls.getSimpleName + fory.registerUnion( + cls, + if unionNamespace == null then "" else unionNamespace, + unionTypeName, + unionSerializer) + } + serializer.registrationClasses(cls).foreach { registrationClass => + if registrationClass != cls then { + resolver.registerUnionCase(cls, registrationClass) + } + } + case _ => + registerType(fory, cls, typeId, namespace, typeName) + resolver.setSerializer(cls, serializer.createSerializer(resolver)) + } + } + + def register[T]( + fory: ThreadSafeFory, + cls: Class[T])(using serializer: ForySerializer[T]): Unit = { + fory.registerCallback((runtime: Fory) => register(runtime, cls)(using serializer)) + } + + def register[T]( + fory: ThreadSafeFory, + cls: Class[T], + typeId: Long)(using serializer: ForySerializer[T]): Unit = { + fory.registerCallback((runtime: Fory) => register(runtime, cls, typeId)(using serializer)) + } + + def register[T]( + fory: ThreadSafeFory, + cls: Class[T], + namespace: String, + typeName: String)(using serializer: ForySerializer[T]): Unit = { + fory.registerCallback((runtime: Fory) => + register(runtime, cls, namespace, typeName)(using serializer)) + } + + private def registerType[T]( + fory: Fory, + cls: Class[T], + typeId: java.lang.Long, + namespace: String, + typeName: String): Unit = { + if typeId != null then { + fory.getTypeResolver.register(cls, typeId.longValue()) + } else if namespace == null || typeName == null then { + fory.register(cls) + } else { + fory.register(cls, namespace, typeName) + } + } +} diff --git a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala new file mode 100644 index 0000000000..1f8a081eba --- /dev/null +++ b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala @@ -0,0 +1,1209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.scala.internal + +import org.apache.fory.annotation.{ForyCase, ForyField, ForyStruct, ForyUnion} +import org.apache.fory.meta.{TypeDef => ForyTypeDef, TypeExtMeta} +import org.apache.fory.resolver.TypeResolver +import org.apache.fory.scala.ForyScalaEnum +import org.apache.fory.scala.ForySerializer +import org.apache.fory.serializer.{ + FieldGroups, + Serializer, + StaticGeneratedStructSerializer, + StaticGeneratedStructSerializerFactory, + UnionSerializer +} +import org.apache.fory.`type`.{Descriptor, Types} + +import java.lang.reflect.Modifier +import scala.quoted.* + +object ForySerializerMacros { + def derive[T: Type](using q: Quotes): Expr[ForySerializer[T]] = { + import q.reflect.* + val symbol = TypeRepr.of[T].typeSymbol + if hasAnnotation[ForyUnion](symbol) then deriveUnion[T](symbol) + else deriveStruct[T](symbol) + } + + private def deriveStruct[T: Type](using q: Quotes)( + owner: q.reflect.Symbol): Expr[ForySerializer[T]] = { + import q.reflect.* + + final case class FieldMeta( + symbol: Symbol, + name: String, + index: Int, + fieldId: Int, + sourceType: TypeRepr, + wireType: TypeRepr, + option: Boolean, + nullable: Boolean, + trackingRef: Boolean, + constructorOwned: Boolean) + + if !hasAnnotation[ForyStruct](owner) then { + report.errorAndAbort( + s"${owner.fullName} must be annotated with @ForyStruct to derive ForySerializer") + } + val ownerClassName = owner.fullName.replace("$.", "$") + + def optionElement(tpe: TypeRepr): Option[TypeRepr] = { + peelAnnotations(tpe)._1.dealias match { + case AppliedType(base, List(arg)) if base.typeSymbol.fullName == "scala.Option" => + Some(arg) + case _ => None + } + } + + def peelAnnotations(tpe: TypeRepr): (TypeRepr, List[Term]) = { + tpe match { + case AnnotatedType(underlying, annotation) => + val (base, annotations) = peelAnnotations(underlying) + (base, annotation :: annotations) + case other => + other.dealias match { + case AnnotatedType(underlying, annotation) => + val (base, annotations) = peelAnnotations(underlying) + (base, annotation :: annotations) + case dealiased => (dealiased, Nil) + } + } + } + + def boxedIfPrimitive(tpe: TypeRepr): TypeRepr = { + val (base, annotations) = peelAnnotations(tpe) + val boxed = + if base =:= TypeRepr.of[Boolean] then TypeRepr.of[java.lang.Boolean] + else if base =:= TypeRepr.of[Byte] then TypeRepr.of[java.lang.Byte] + else if base =:= TypeRepr.of[Short] then TypeRepr.of[java.lang.Short] + else if base =:= TypeRepr.of[Int] then TypeRepr.of[java.lang.Integer] + else if base =:= TypeRepr.of[Long] then TypeRepr.of[java.lang.Long] + else if base =:= TypeRepr.of[Float] then TypeRepr.of[java.lang.Float] + else if base =:= TypeRepr.of[Double] then TypeRepr.of[java.lang.Double] + else base + annotations.foldRight(boxed)((annotation, current) => AnnotatedType(current, annotation)) + } + + def classFor(tpe: TypeRepr): Expr[Class[?]] = { + val normalized = peelAnnotations(tpe.widen)._1.dealias + val fullName = normalized.typeSymbol.fullName + if normalized =:= TypeRepr.of[Boolean] then '{ java.lang.Boolean.TYPE } + else if normalized =:= TypeRepr.of[Byte] then '{ java.lang.Byte.TYPE } + else if normalized =:= TypeRepr.of[Short] then '{ java.lang.Short.TYPE } + else if normalized =:= TypeRepr.of[Int] then '{ java.lang.Integer.TYPE } + else if normalized =:= TypeRepr.of[Long] then '{ java.lang.Long.TYPE } + else if normalized =:= TypeRepr.of[Float] then '{ java.lang.Float.TYPE } + else if normalized =:= TypeRepr.of[Double] then '{ java.lang.Double.TYPE } + else if normalized =:= TypeRepr.of[Char] then '{ java.lang.Character.TYPE } + else if normalized =:= TypeRepr.of[String] || + normalized.typeSymbol == TypeRepr.of[String].typeSymbol || + fullName == "scala.Predef.String" || + fullName == "scala.Predef$.String" || + fullName.endsWith("Predef.String") || + fullName.endsWith("Predef$.String") + then '{ classOf[String] } + else if fullName == "scala.Array" then { + '{ Class.forName(${ Expr(arrayClassName(normalized)) }) } + } else '{ Class.forName(${ Expr(fullName.replace("$.", "$")) }) } + } + + def arrayClassName(tpe: TypeRepr): String = { + tpe.dealias match { + case AppliedType(arrayType, List(componentType)) + if arrayType.typeSymbol.fullName == "scala.Array" => + "[" + arrayComponentDescriptor(componentType) + case _ => + report.errorAndAbort(s"Expected Scala Array type, got ${tpe.show}") + } + } + + def arrayComponentDescriptor(tpe: TypeRepr): String = { + val normalized = peelAnnotations(tpe.widen)._1.dealias + if normalized =:= TypeRepr.of[Boolean] then "Z" + else if normalized =:= TypeRepr.of[Byte] then "B" + else if normalized =:= TypeRepr.of[Short] then "S" + else if normalized =:= TypeRepr.of[Int] then "I" + else if normalized =:= TypeRepr.of[Long] then "J" + else if normalized =:= TypeRepr.of[Float] then "F" + else if normalized =:= TypeRepr.of[Double] then "D" + else if normalized =:= TypeRepr.of[Char] then "C" + else if normalized.typeSymbol.fullName == "scala.Array" then arrayClassName(normalized) + else "L" + normalized.typeSymbol.fullName.replace("$.", "$") + ";" + } + + val constructorFields = { + val params = owner.primaryConstructor.paramSymss.flatten.filter(_.isValDef) + if params.nonEmpty then params else owner.caseFields + } + val constructorFieldSet = constructorFields.toSet + val bodyFields = + owner.fieldMembers.filter(field => + !constructorFieldSet.contains(field) && annotationIntArg[ForyField](field, "id").nonEmpty) + val serializableFields = constructorFields ++ bodyFields + if serializableFields.isEmpty then { + report.errorAndAbort(s"${owner.fullName} has no serializable fields") + } + + def declaredType(symbol: Symbol): TypeRepr = { + symbol.tree match { + case ValDef(_, tpt, _) => tpt.tpe + case _ => symbol.termRef.widen + } + } + + val fields = serializableFields.zipWithIndex.map { (field, index) => + val sourceType = declaredType(field) + val (wireType, option, nullable) = optionElement(sourceType) match { + case Some(inner) => (boxedIfPrimitive(inner), true, true) + case None => (sourceType, false, false) + } + FieldMeta( + field, + field.name, + index, + annotationIntArg[ForyField](field, "id").getOrElse(-1), + sourceType, + wireType, + option, + nullable, + hasRef(field) || topLevelTypeHasRef(sourceType), + constructorFieldSet.contains(field)) + } + + def generatedType(tpe: TypeRepr): Expr[Descriptor.GeneratedType] = { + val (outer, outerAnnotations) = peelAnnotations(tpe) + val option = optionElement(outer) + val fieldSource = option.map(boxedIfPrimitive).getOrElse(outer) + val (base, baseAnnotations) = peelAnnotations(fieldSource) + val annotations = outerAnnotations ++ baseAnnotations + val argumentSource = fieldSource + def appliedType(tpe: TypeRepr): Option[(TypeRepr, List[TypeRepr])] = { + val directArgs = tpe.typeArgs + if directArgs.nonEmpty then { + tpe match { + case AppliedType(typeConstructor, _) => Some((typeConstructor, directArgs)) + case other => + other.dealias match { + case AppliedType(typeConstructor, _) => Some((typeConstructor, directArgs)) + case _ => Some((tpe, directArgs)) + } + } + } else { + tpe match { + case AppliedType(typeConstructor, typeArgs) => Some((typeConstructor, typeArgs)) + case other => + other.dealias match { + case AppliedType(typeConstructor, typeArgs) => Some((typeConstructor, typeArgs)) + case _ => None + } + } + } + } + val component = appliedType(argumentSource) match { + case Some((arrayType, List(componentType))) + if arrayType.typeSymbol.fullName == "scala.Array" => + Some(generatedType(componentType)) + case _ => None + } + val args = appliedType(argumentSource) match { + case Some((arrayType, List(_))) if arrayType.typeSymbol.fullName == "scala.Array" => + Nil + case Some((_, typeArgs)) => typeArgs + case _ => Nil + } + val argExprs = args.map(generatedType) + val argList: Expr[java.util.List[Descriptor.GeneratedType]] = + '{ + import scala.jdk.CollectionConverters.* + java.util.Collections.unmodifiableList(${ Expr.ofList(argExprs) }.asJava) + } + val componentExpr: Expr[Descriptor.GeneratedType] = + component.getOrElse('{ null.asInstanceOf[Descriptor.GeneratedType] }) + val typeId = + annotations.flatMap(typeIdForAnnotation).headOption + .orElse { + if hasAnnotation[ForyUnion](base.typeSymbol) then Some(Types.UNION) else None + } + .orElse { + if isScalaEnumType(base) then Some(Types.ENUM) else None + } + .getOrElse(Types.UNKNOWN) + val rawClass = classFor(base) + val typeExtMeta = generatedTypeExtMeta( + typeId, + nullable = option.nonEmpty, + trackingRef = annotations.exists(isRefAnnotation), + rawClass = Some(rawClass)) + '{ Descriptor.generatedType($rawClass, $typeExtMeta, $argList, $componentExpr) } + } + + def descriptor(field: FieldMeta): Expr[Descriptor] = { + '{ + new Descriptor( + ${ generatedType(field.sourceType) }, + ${ Expr(field.sourceType.show) }, + ${ Expr(field.name) }, + ${ Expr(Modifier.PRIVATE | Modifier.FINAL) }, + ${ Expr(ownerClassName) }, + true, + ${ Expr(field.fieldId) }, + ${ Expr(field.nullable) }, + ${ Expr(field.trackingRef) }, + ForyField.Dynamic.AUTO, + false + ) + } + } + + val descriptorsExpr: Expr[java.util.List[Descriptor]] = { + val exprs = fields.map(descriptor) + '{ + import scala.jdk.CollectionConverters.* + java.util.Collections.unmodifiableList(${ Expr.ofList(exprs) }.asJava) + } + } + + def selectValue(valueExpr: Expr[T], field: FieldMeta): Expr[Any] = { + Select.unique(valueExpr.asTerm, field.name).asExpr + } + + def writeDispatch( + valueExpr: Expr[T], + fieldIdExpr: Expr[Int], + fieldInfoExpr: Expr[FieldGroups.SerializationFieldInfo], + writeContextExpr: Expr[org.apache.fory.context.WriteContext], + resolverExpr: Expr[TypeResolver]): Expr[Unit] = { + fields.foldRight( + '{ + throw new IllegalStateException("Unknown generated Scala field id " + $fieldIdExpr) + }: Expr[Unit]) { (field, next) => + val fieldValue = selectValue(valueExpr, field) + val wireValue = + if field.option then '{ $fieldValue.asInstanceOf[Option[Any]].orNull } + else fieldValue + '{ + if $fieldIdExpr == ${ Expr(field.index) } then { + StaticGeneratedStructSerializer.writeFieldValue( + $resolverExpr, + $writeContextExpr, + $fieldInfoExpr, + $wireValue) + } else { + $next + } + } + } + } + + def decodeValue(raw: Expr[Any], field: FieldMeta): Expr[Any] = { + if field.option then { + field.sourceType.asType match { + case '[a] => '{ Option($raw).asInstanceOf[a] } + } + } else { + field.sourceType.asType match { + case '[a] => '{ $raw.asInstanceOf[a] } + } + } + } + + def valueArg(valuesExpr: Expr[Array[Any]], field: FieldMeta): Expr[Any] = + decodeValue('{ $valuesExpr(${ Expr(field.index) }) }, field) + + def assignRawValue(objExpr: Expr[T], field: FieldMeta, raw: Expr[Any]): Expr[Unit] = + Assign(Select.unique(objExpr.asTerm, field.name), decodeValue(raw, field).asTerm) + .asExprOf[Unit] + + def assignValueById(objExpr: Expr[T], fieldIdExpr: Expr[Int], raw: Expr[Any]): Expr[Unit] = { + fields.foldRight( + '{ + throw new IllegalStateException("Unknown generated Scala field id " + $fieldIdExpr) + }: Expr[Unit]) { (field, next) => + '{ + if $fieldIdExpr == ${ Expr(field.index) } then { + ${ assignRawValue(objExpr, field, raw) } + } else { + $next + } + } + } + } + + def readAndAssignDispatch( + objExpr: Expr[T], + fieldIdExpr: Expr[Int], + fieldInfoExpr: Expr[FieldGroups.SerializationFieldInfo], + readContextExpr: Expr[org.apache.fory.context.ReadContext], + resolverExpr: Expr[TypeResolver]): Expr[Unit] = { + '{ + val fieldValue = + StaticGeneratedStructSerializer.readFieldValue($resolverExpr, $readContextExpr, $fieldInfoExpr) + ${ assignValueById(objExpr, fieldIdExpr, 'fieldValue) } + } + } + + def constructFromValues(valuesExpr: Expr[Array[Any]]): Expr[T] = { + val constructorOwned = fields.filter(_.constructorOwned) + val postConstruction = fields.filterNot(_.constructorOwned) + if postConstruction.isEmpty then { + val args = constructorOwned.map { field => + valueArg(valuesExpr, field).asTerm + } + Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args).asExprOf[T] + } else if constructorOwned.isEmpty then { + val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) + val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), Nil) + val assignments = postConstruction.map { field => + Assign(Select.unique(Ref(obj), field.name), valueArg(valuesExpr, field).asTerm) + } + Block(ValDef(obj, Some(construct)) :: assignments, Ref(obj)).asExprOf[T] + } else { + val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) + val args = constructorOwned.map { field => + valueArg(valuesExpr, field).asTerm + } + val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args) + val assignments = postConstruction.map { field => + Assign(Select.unique(Ref(obj), field.name), valueArg(valuesExpr, field).asTerm) + } + Block(ValDef(obj, Some(construct)) :: assignments, Ref(obj)).asExprOf[T] + } + } + + def copyValue(valueExpr: Expr[T]): Expr[T] = { + val constructorOwned = fields.filter(_.constructorOwned) + val postConstruction = fields.filterNot(_.constructorOwned) + if postConstruction.isEmpty then { + val args = constructorOwned.map { field => + val selected = selectValue(valueExpr, field) + field.sourceType.asType match { + case '[a] => '{ $selected.asInstanceOf[a] }.asTerm + } + } + Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args).asExprOf[T] + } else if constructorOwned.isEmpty then { + val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) + val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), Nil) + val assignments = postConstruction.map { field => + val selected = selectValue(valueExpr, field) + val copied = field.sourceType.asType match { + case '[a] => '{ $selected.asInstanceOf[a] } + } + Assign(Select.unique(Ref(obj), field.name), copied.asTerm) + } + Block(ValDef(obj, Some(construct)) :: assignments, Ref(obj)).asExprOf[T] + } else { + val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) + val args = constructorOwned.map { field => + val selected = selectValue(valueExpr, field) + field.sourceType.asType match { + case '[a] => '{ $selected.asInstanceOf[a] }.asTerm + } + } + val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args) + val assignments = postConstruction.map { field => + val selected = selectValue(valueExpr, field) + val copied = field.sourceType.asType match { + case '[a] => '{ $selected.asInstanceOf[a] } + } + Assign(Select.unique(Ref(obj), field.name), copied.asTerm) + } + Block(ValDef(obj, Some(construct)) :: assignments, Ref(obj)).asExprOf[T] + } + } + + def constructRead(valuesExpr: Expr[Array[Any]], readContextExpr: Expr[org.apache.fory.context.ReadContext]): Expr[T] = { + val constructorOwned = fields.filter(_.constructorOwned) + if constructorOwned.isEmpty then { + report.errorAndAbort( + s"${owner.fullName} cycle-owned generated classes must use generated mutable read paths") + } else constructFromValues(valuesExpr) + } + + def readSchemaConsistentBody( + serializerExpr: Expr[StaticGeneratedStructSerializer[T]], + resolverExpr: Expr[TypeResolver], + descriptorsExpr: Expr[java.util.List[Descriptor]], + classVersionHashExpr: Expr[Int], + allFieldsExpr: Expr[Array[FieldGroups.SerializationFieldInfo]], + allFieldIdsExpr: Expr[Array[Int]], + readContextExpr: Expr[org.apache.fory.context.ReadContext]): Expr[T] = { + val constructorOwned = fields.filter(_.constructorOwned) + if constructorOwned.isEmpty then { + '{ + val buffer = $readContextExpr.getBuffer + if $resolverExpr.checkClassVersion() then { + $serializerExpr.checkClassVersion(buffer.readInt32(), $classVersionHashExpr) + } + val obj = ${ Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), Nil).asExprOf[T] } + $readContextExpr.reference(obj) + var i = 0 + while i < $allFieldsExpr.length do { + val fieldInfo = $allFieldsExpr(i) + val fieldId = $allFieldIdsExpr(i) + ${ readAndAssignDispatch('obj, 'fieldId, 'fieldInfo, readContextExpr, resolverExpr) } + i += 1 + } + obj + } + } else { + '{ + val buffer = $readContextExpr.getBuffer + if $resolverExpr.checkClassVersion() then { + $serializerExpr.checkClassVersion(buffer.readInt32(), $classVersionHashExpr) + } + val values = new Array[Any]($descriptorsExpr.size()) + var i = 0 + while i < $allFieldsExpr.length do { + val fieldInfo = $allFieldsExpr(i) + values($allFieldIdsExpr(i)) = + StaticGeneratedStructSerializer.readFieldValue($resolverExpr, $readContextExpr, fieldInfo) + i += 1 + } + ${ constructRead('values, readContextExpr) } + } + } + } + + def readCompatibleBody( + serializerExpr: Expr[StaticGeneratedStructSerializer[T]], + descriptorsExpr: Expr[java.util.List[Descriptor]], + fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]], + sameSchemaCompatibleExpr: Expr[Boolean], + readContextExpr: Expr[org.apache.fory.context.ReadContext]): Expr[T] = { + val constructorOwned = fields.filter(_.constructorOwned) + if constructorOwned.isEmpty then { + '{ + if $sameSchemaCompatibleExpr then { + $serializerExpr.read($readContextExpr) + } else { + val obj = ${ Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), Nil).asExprOf[T] } + $readContextExpr.reference(obj) + val remoteFields = $serializerExpr.getRemoteFields() + var i = 0 + while i < remoteFields.size() do { + val remoteField = remoteFields.get(i) + val matchedId = remoteField.matchedId + if matchedId >= 0 then { + val localField = $fieldsByIdExpr(matchedId) + if $serializerExpr.canReadRemoteField(remoteField, localField) then { + val fieldValue = + $serializerExpr.readCompatibleFieldValue($readContextExpr, remoteField, localField) + ${ assignValueById('obj, 'matchedId, 'fieldValue) } + } else { + $serializerExpr.skipField($readContextExpr, remoteField) + } + } else { + $serializerExpr.skipField($readContextExpr, remoteField) + } + i += 1 + } + obj + } + } + } else { + '{ + if $sameSchemaCompatibleExpr then { + $serializerExpr.read($readContextExpr) + } else { + val values = new Array[Any]($descriptorsExpr.size()) + val remoteFields = $serializerExpr.getRemoteFields() + var i = 0 + while i < remoteFields.size() do { + val remoteField = remoteFields.get(i) + val matchedId = remoteField.matchedId + if matchedId >= 0 then { + val localField = $fieldsByIdExpr(matchedId) + if $serializerExpr.canReadRemoteField(remoteField, localField) then { + values(matchedId) = + $serializerExpr.readCompatibleFieldValue($readContextExpr, remoteField, localField) + } else { + $serializerExpr.skipField($readContextExpr, remoteField) + } + } else { + $serializerExpr.skipField($readContextExpr, remoteField) + } + i += 1 + } + ${ constructRead('values, readContextExpr) } + } + } + } + } + + val classExpr: Expr[Class[T]] = + '{ Class.forName(${ Expr(ownerClassName) }).asInstanceOf[Class[T]] } + + '{ + new ForySerializer[T] with StaticGeneratedStructSerializerFactory[T] { + private val descriptors: java.util.List[Descriptor] = $descriptorsExpr + + override def getGeneratedDescriptors(): java.util.List[Descriptor] = descriptors + + override def createSerializer( + resolver: TypeResolver, + remoteTypeDef: ForyTypeDef): Serializer[T] = { + newSerializer(resolver, $classExpr, remoteTypeDef) + } + + override def newSerializer( + resolver: TypeResolver, + cls: Class[?], + remoteTypeDef: ForyTypeDef): StaticGeneratedStructSerializer[T] = { + new StaticGeneratedStructSerializer[T](resolver, cls, remoteTypeDef, descriptors) { + private val generatedSerializer: StaticGeneratedStructSerializer[T] = this + private val fieldGroups: FieldGroups = + buildLocalFieldGroups(descriptors) + private val allFields: Array[FieldGroups.SerializationFieldInfo] = + fieldGroups.allFields + private val allFieldIds: Array[Int] = localFieldIds(allFields, descriptors) + private val fieldsById: Array[FieldGroups.SerializationFieldInfo] = { + val result = new Array[FieldGroups.SerializationFieldInfo](descriptors.size()) + var i = 0 + while i < allFields.length do { + result(allFieldIds(i)) = allFields(i) + i += 1 + } + result + } + private val classVersionHash: Int = + if resolver.checkClassVersion() then computeClassVersionHash(descriptors) else 0 + private val sameSchemaCompatible: Boolean = + remoteTypeDef != null && + remoteTypeDef.getId == ForyTypeDef.buildTypeDef(resolver, cls).getId + + override def getGeneratedDescriptors(): java.util.List[Descriptor] = descriptors + + override def write( + writeContext: org.apache.fory.context.WriteContext, + value: T): Unit = { + val buffer = writeContext.getBuffer + if resolver.checkClassVersion() then { + buffer.writeInt32(classVersionHash) + } + var i = 0 + while i < allFields.length do { + val fieldInfo = allFields(i) + val fieldId = allFieldIds(i) + ${ writeDispatch('value, 'fieldId, 'fieldInfo, 'writeContext, 'resolver) } + i += 1 + } + } + + override def read(readContext: org.apache.fory.context.ReadContext): T = { + if remoteTypeDef != null && !sameSchemaCompatible then readCompatible(readContext) + else readSchemaConsistent(readContext) + } + + private def readSchemaConsistent( + readContext: org.apache.fory.context.ReadContext): T = + ${ + readSchemaConsistentBody( + 'generatedSerializer, + 'resolver, + 'descriptors, + 'classVersionHash, + 'allFields, + 'allFieldIds, + 'readContext) + } + + override def readCompatible(readContext: org.apache.fory.context.ReadContext): T = { + ${ + readCompatibleBody( + 'generatedSerializer, + 'descriptors, + 'fieldsById, + 'sameSchemaCompatible, + 'readContext) + } + } + + override def copy( + copyContext: org.apache.fory.context.CopyContext, + value: T): T = ${ copyValue('value) } + } + } + } + } + } + + private def deriveUnion[T: Type](using q: Quotes)( + owner: q.reflect.Symbol): Expr[ForySerializer[T]] = { + import q.reflect.* + + final case class CaseMeta( + symbol: Symbol, + id: Int, + payloadType: TypeRepr, + payloadName: String, + unknownIdName: String, + unknown: Boolean, + fieldIndex: Int) + + def payloadMeta(child: Symbol, id: Int): (TypeRepr, String, String) = { + val params = child.primaryConstructor.paramSymss.flatten + if id == 0 then { + if params.size != 2 then { + report.errorAndAbort( + s"${child.fullName} is the unknown union case and must have (caseId: Int, value: Any)") + } + val caseIdType = params.head.tree match { + case ValDef(_, tpt, _) => tpt.tpe + case _ => params.head.termRef.widen + } + val tpe = params(1).tree match { + case ValDef(_, tpt, _) => tpt.tpe + case _ => TypeRepr.of[Any] + } + if !(caseIdType =:= TypeRepr.of[Int]) || !(tpe =:= TypeRepr.of[Any]) then { + report.errorAndAbort( + s"${child.fullName} is the unknown union case and must have (caseId: Int, value: Any)") + } + (tpe, params(1).name, params.head.name) + } else { + if params.size != 1 then { + report.errorAndAbort(s"${child.fullName} must have exactly one payload parameter") + } + val tpe = params.head.tree match { + case ValDef(_, tpt, _) => tpt.tpe + case _ => params.head.termRef.widen + } + (tpe, params.head.name, "") + } + } + + val rawCases = owner.children.flatMap { child => + annotationIntArg[ForyCase](child, "id").map { id => + if id < 0 then report.errorAndAbort(s"${child.fullName} @ForyCase id must be >= 0") + val (tpe, payloadName, unknownIdName) = payloadMeta(child, id) + CaseMeta(child, id, tpe, payloadName, unknownIdName, id == 0, -1) + } + } + var nextFieldIndex = 0 + val cases = rawCases.map { unionCase => + if unionCase.unknown then unionCase + else { + val indexed = unionCase.copy(fieldIndex = nextFieldIndex) + nextFieldIndex += 1 + indexed + } + } + val knownCases = cases.filterNot(_.unknown) + if cases.count(_.unknown) > 1 then { + report.errorAndAbort(s"${owner.fullName} must define exactly one @ForyCase(id = 0) unknown case") + } + if cases.filterNot(_.unknown).groupBy(_.id).exists(_._2.size > 1) then { + report.errorAndAbort(s"${owner.fullName} has duplicate @ForyCase ids") + } + val unknown = + cases.find(_.unknown).getOrElse( + report.errorAndAbort(s"${owner.fullName} must define @ForyCase(id = 0) unknown case")) + + def peelAnnotations(tpe: TypeRepr): (TypeRepr, List[Term]) = { + tpe match { + case AnnotatedType(underlying, annotation) => + val (base, annotations) = peelAnnotations(underlying) + (base, annotation :: annotations) + case other => + other.dealias match { + case AnnotatedType(underlying, annotation) => + val (base, annotations) = peelAnnotations(underlying) + (base, annotation :: annotations) + case dealiased => (dealiased, Nil) + } + } + } + + def optionElement(tpe: TypeRepr): Option[TypeRepr] = { + peelAnnotations(tpe)._1.dealias match { + case AppliedType(base, List(arg)) if base.typeSymbol.fullName == "scala.Option" => + Some(arg) + case _ => None + } + } + + def boxedIfPrimitive(tpe: TypeRepr): TypeRepr = { + val (base, annotations) = peelAnnotations(tpe) + val boxed = + if base =:= TypeRepr.of[Boolean] then TypeRepr.of[java.lang.Boolean] + else if base =:= TypeRepr.of[Byte] then TypeRepr.of[java.lang.Byte] + else if base =:= TypeRepr.of[Short] then TypeRepr.of[java.lang.Short] + else if base =:= TypeRepr.of[Int] then TypeRepr.of[java.lang.Integer] + else if base =:= TypeRepr.of[Long] then TypeRepr.of[java.lang.Long] + else if base =:= TypeRepr.of[Float] then TypeRepr.of[java.lang.Float] + else if base =:= TypeRepr.of[Double] then TypeRepr.of[java.lang.Double] + else base + annotations.foldRight(boxed)((annotation, current) => AnnotatedType(current, annotation)) + } + + def classFor(tpe: TypeRepr): Expr[Class[?]] = { + val normalized = peelAnnotations(tpe.widen)._1.dealias + val fullName = normalized.typeSymbol.fullName + if normalized =:= TypeRepr.of[Boolean] then '{ java.lang.Boolean.TYPE } + else if normalized =:= TypeRepr.of[Byte] then '{ java.lang.Byte.TYPE } + else if normalized =:= TypeRepr.of[Short] then '{ java.lang.Short.TYPE } + else if normalized =:= TypeRepr.of[Int] then '{ java.lang.Integer.TYPE } + else if normalized =:= TypeRepr.of[Long] then '{ java.lang.Long.TYPE } + else if normalized =:= TypeRepr.of[Float] then '{ java.lang.Float.TYPE } + else if normalized =:= TypeRepr.of[Double] then '{ java.lang.Double.TYPE } + else if normalized =:= TypeRepr.of[Char] then '{ java.lang.Character.TYPE } + else if normalized =:= TypeRepr.of[String] || + normalized.typeSymbol == TypeRepr.of[String].typeSymbol || + fullName == "scala.Predef.String" || + fullName == "scala.Predef$.String" || + fullName.endsWith("Predef.String") || + fullName.endsWith("Predef$.String") + then '{ classOf[String] } + else if fullName == "scala.Array" then { + '{ Class.forName(${ Expr(arrayClassName(normalized)) }) } + } else '{ Class.forName(${ Expr(fullName.replace("$.", "$")) }) } + } + + def arrayClassName(tpe: TypeRepr): String = { + tpe.dealias match { + case AppliedType(arrayType, List(componentType)) + if arrayType.typeSymbol.fullName == "scala.Array" => + "[" + arrayComponentDescriptor(componentType) + case _ => + report.errorAndAbort(s"Expected Scala Array type, got ${tpe.show}") + } + } + + def arrayComponentDescriptor(tpe: TypeRepr): String = { + val normalized = peelAnnotations(tpe.widen)._1.dealias + if normalized =:= TypeRepr.of[Boolean] then "Z" + else if normalized =:= TypeRepr.of[Byte] then "B" + else if normalized =:= TypeRepr.of[Short] then "S" + else if normalized =:= TypeRepr.of[Int] then "I" + else if normalized =:= TypeRepr.of[Long] then "J" + else if normalized =:= TypeRepr.of[Float] then "F" + else if normalized =:= TypeRepr.of[Double] then "D" + else if normalized =:= TypeRepr.of[Char] then "C" + else if normalized.typeSymbol.fullName == "scala.Array" then arrayClassName(normalized) + else "L" + normalized.typeSymbol.fullName.replace("$.", "$") + ";" + } + + def generatedType(tpe: TypeRepr): Expr[Descriptor.GeneratedType] = { + val (outer, outerAnnotations) = peelAnnotations(tpe) + val option = optionElement(outer) + val fieldSource = option.map(boxedIfPrimitive).getOrElse(outer) + val (base, baseAnnotations) = peelAnnotations(fieldSource) + val annotations = outerAnnotations ++ baseAnnotations + val argumentSource = fieldSource + def appliedType(tpe: TypeRepr): Option[(TypeRepr, List[TypeRepr])] = { + val directArgs = tpe.typeArgs + if directArgs.nonEmpty then { + tpe match { + case AppliedType(typeConstructor, _) => Some((typeConstructor, directArgs)) + case other => + other.dealias match { + case AppliedType(typeConstructor, _) => Some((typeConstructor, directArgs)) + case _ => Some((tpe, directArgs)) + } + } + } else { + tpe match { + case AppliedType(typeConstructor, typeArgs) => Some((typeConstructor, typeArgs)) + case other => + other.dealias match { + case AppliedType(typeConstructor, typeArgs) => Some((typeConstructor, typeArgs)) + case _ => None + } + } + } + } + val component = appliedType(argumentSource) match { + case Some((arrayType, List(componentType))) + if arrayType.typeSymbol.fullName == "scala.Array" => + Some(generatedType(componentType)) + case _ => None + } + val args = appliedType(argumentSource) match { + case Some((arrayType, List(_))) if arrayType.typeSymbol.fullName == "scala.Array" => + Nil + case Some((_, typeArgs)) => typeArgs + case _ => Nil + } + val argExprs = args.map(generatedType) + val argList: Expr[java.util.List[Descriptor.GeneratedType]] = + '{ + import scala.jdk.CollectionConverters.* + java.util.Collections.unmodifiableList(${ Expr.ofList(argExprs) }.asJava) + } + val componentExpr: Expr[Descriptor.GeneratedType] = + component.getOrElse('{ null.asInstanceOf[Descriptor.GeneratedType] }) + val typeId = + annotations.flatMap(typeIdForAnnotation).headOption + .orElse { + if hasAnnotation[ForyUnion](base.typeSymbol) then Some(Types.UNION) else None + } + .orElse { + if isScalaEnumType(base) then Some(Types.ENUM) else None + } + .getOrElse(Types.UNKNOWN) + val rawClass = classFor(base) + val typeExtMeta = generatedTypeExtMeta( + typeId, + nullable = option.nonEmpty, + trackingRef = annotations.exists(isRefAnnotation), + rawClass = Some(rawClass)) + '{ Descriptor.generatedType($rawClass, $typeExtMeta, $argList, $componentExpr) } + } + + def caseDescriptor(unionCase: CaseMeta): Expr[Descriptor] = { + '{ + new Descriptor( + ${ generatedType(unionCase.payloadType) }, + ${ Expr(unionCase.payloadType.show) }, + ${ Expr(unionCase.symbol.name + ".value") }, + ${ Expr(Modifier.PRIVATE | Modifier.FINAL) }, + ${ Expr(owner.fullName.replace("$.", "$")) }, + true, + ${ Expr(unionCase.id) }, + false, + false, + ForyField.Dynamic.AUTO, + false + ) + } + } + + val caseDescriptorsExpr: Expr[java.util.List[Descriptor]] = + '{ + import scala.jdk.CollectionConverters.* + java.util.Collections.unmodifiableList(${ Expr.ofList(knownCases.map(caseDescriptor)) }.asJava) + } + + def writeDispatch( + valueExpr: Expr[T], + writeContextExpr: Expr[org.apache.fory.context.WriteContext], + resolverExpr: Expr[TypeResolver], + caseFieldInfosExpr: Expr[Array[FieldGroups.SerializationFieldInfo]]): Expr[Unit] = { + cases.foldRight( + '{ + throw new IllegalStateException("Unknown Scala union case " + $valueExpr) + }: Expr[Unit]) { (unionCase, next) => + unionCase.symbol.typeRef.asType match { + case '[c] => + if unionCase.unknown then { + val originalId = + Select.unique( + '{ $valueExpr.asInstanceOf[c] }.asTerm, + unionCase.unknownIdName).asExprOf[Int] + val payload = + Select.unique( + '{ $valueExpr.asInstanceOf[c] }.asTerm, + unionCase.payloadName).asExpr + '{ + if $valueExpr.isInstanceOf[c] then { + UnionSerializer.writeUnknownCaseValue($writeContextExpr, $payload, $originalId) + } else { + $next + } + } + } else { + val payload = + Select.unique( + '{ $valueExpr.asInstanceOf[c] }.asTerm, + unionCase.payloadName).asExpr + '{ + if $valueExpr.isInstanceOf[c] then { + UnionSerializer.writeCaseValue( + $resolverExpr, + $writeContextExpr, + $caseFieldInfosExpr(${ Expr(unionCase.fieldIndex) }), + $payload, + ${ Expr(unionCase.id) }) + } else { + $next + } + } + } + } + } + } + + def construct(unionCase: CaseMeta, args: List[Term]): Expr[T] = { + Apply(Select(New(TypeTree.ref(unionCase.symbol)), unionCase.symbol.primaryConstructor), args) + .asExprOf[T] + } + + def readDispatch( + caseIdExpr: Expr[Int], + readContextExpr: Expr[org.apache.fory.context.ReadContext], + resolverExpr: Expr[TypeResolver], + caseFieldInfosExpr: Expr[Array[FieldGroups.SerializationFieldInfo]]): Expr[T] = { + val unknownPayload = '{ $readContextExpr.readRef() } + val unknownExpr = construct(unknown, List(caseIdExpr.asTerm, unknownPayload.asTerm)) + knownCases.foldRight(unknownExpr) { (unionCase, next) => + unionCase.payloadType.asType match { + case '[p] => + val rawPayload = + '{ + UnionSerializer.readCaseValue( + $resolverExpr, + $readContextExpr, + $caseFieldInfosExpr(${ Expr(unionCase.fieldIndex) })) + } + val payload = coercePayload[p](rawPayload, unionCase.payloadType) + val current = construct(unionCase, List(payload.asTerm)) + '{ + if $caseIdExpr == ${ Expr(unionCase.id) } then $current else $next + } + } + } + } + + def coercePayload[P: Type]( + payloadExpr: Expr[Any], + payloadType: TypeRepr): Expr[P] = { + val rawTypeName = payloadType.dealias match { + case AppliedType(base, _) => base.typeSymbol.fullName + case other => other.typeSymbol.fullName + } + val renderedType = payloadType.show + if rawTypeName == "scala.collection.immutable.List" || + rawTypeName == "scala.collection.Seq" || + rawTypeName == "scala.collection.immutable.Seq" || + renderedType.startsWith("scala.List[") || + renderedType.startsWith("List[") + then { + '{ + $payloadExpr match { + case value: scala.collection.immutable.List[?] => value.asInstanceOf[P] + case value: java.util.List[?] => + import scala.jdk.CollectionConverters.* + value.asScala.toList.asInstanceOf[P] + case value => value.asInstanceOf[P] + } + } + } else if rawTypeName == "scala.collection.immutable.Set" || + rawTypeName == "scala.collection.Set" || + renderedType.startsWith("scala.Set[") || + renderedType.startsWith("Set[") + then { + '{ + $payloadExpr match { + case value: scala.collection.immutable.Set[?] => value.asInstanceOf[P] + case value: java.util.Set[?] => + import scala.jdk.CollectionConverters.* + value.asScala.toSet.asInstanceOf[P] + case value => value.asInstanceOf[P] + } + } + } else if rawTypeName == "scala.collection.immutable.Map" || + rawTypeName == "scala.collection.Map" || + renderedType.startsWith("scala.Map[") || + renderedType.startsWith("Map[") + then { + '{ + $payloadExpr match { + case value: scala.collection.immutable.Map[?, ?] => value.asInstanceOf[P] + case value: java.util.Map[?, ?] => + import scala.jdk.CollectionConverters.* + value.asScala.toMap.asInstanceOf[P] + case value => value.asInstanceOf[P] + } + } + } else '{ $payloadExpr.asInstanceOf[P] } + } + + val ownerClassName = owner.fullName.replace("$.", "$") + val classExpr: Expr[Class[T]] = + '{ Class.forName(${ Expr(ownerClassName) }).asInstanceOf[Class[T]] } + val caseClassesExpr: Expr[List[Class[_]]] = + Expr.ofList(cases.map(unionCase => + '{ Class.forName(${ Expr(ownerClassName + "$" + unionCase.symbol.name) }) })) + + '{ + new ForySerializer[T] { + override def isUnion: Boolean = true + + override def registrationClasses(cls: Class[T]): Array[Class[_]] = + (cls :: $caseClassesExpr).toArray + + override def createSerializer( + resolver: TypeResolver, + remoteTypeDef: ForyTypeDef): Serializer[T] = { + new Serializer[T](resolver.getConfig, $classExpr) { + private val caseFieldInfos: Array[FieldGroups.SerializationFieldInfo] = { + val descriptors = $caseDescriptorsExpr + val result = new Array[FieldGroups.SerializationFieldInfo](descriptors.size()) + var i = 0 + while i < descriptors.size() do { + result(i) = new FieldGroups.SerializationFieldInfo(resolver, descriptors.get(i)) + i += 1 + } + result + } + + override def write( + writeContext: org.apache.fory.context.WriteContext, + value: T): Unit = { + ${ writeDispatch('value, 'writeContext, 'resolver, 'caseFieldInfos) } + } + + override def read(readContext: org.apache.fory.context.ReadContext): T = { + val buffer = readContext.getBuffer + val caseId = buffer.readVarUInt32() + ${ readDispatch('caseId, 'readContext, 'resolver, 'caseFieldInfos) } + } + + override def copy(copyContext: org.apache.fory.context.CopyContext, value: T): T = + value + } + } + } + } + } + + private def annotationIntArg[A: Type](using q: Quotes)( + symbol: q.reflect.Symbol, + name: String): Option[Int] = { + import q.reflect.* + symbol.annotations + .find(_.tpe <:< TypeRepr.of[A]) + .flatMap { + case Apply(_, args) => + args.collectFirst { + case NamedArg(`name`, Literal(IntConstant(value))) => value + case Literal(IntConstant(value)) => value + } + case _ => None + } + } + + private def hasAnnotation[A: Type](using q: Quotes)(symbol: q.reflect.Symbol): Boolean = { + import q.reflect.* + symbol.annotations.exists(_.tpe <:< TypeRepr.of[A]) + } + + private def hasRef(using q: Quotes)(symbol: q.reflect.Symbol): Boolean = { + import q.reflect.* + symbol.annotations.exists(_.tpe.typeSymbol.fullName == "org.apache.fory.annotation.Ref") + } + + private def topLevelTypeHasRef(using q: Quotes)(tpe: q.reflect.TypeRepr): Boolean = { + import q.reflect.* + + def peelAnnotations(tpe: TypeRepr): (TypeRepr, List[Term]) = { + tpe match { + case AnnotatedType(underlying, annotation) => + val (base, annotations) = peelAnnotations(underlying) + (base, annotation :: annotations) + case other => + other.dealias match { + case AnnotatedType(underlying, annotation) => + val (base, annotations) = peelAnnotations(underlying) + (base, annotation :: annotations) + case dealiased => (dealiased, Nil) + } + } + } + + def isRef(annotation: Term): Boolean = + annotation.tpe.typeSymbol.fullName == "org.apache.fory.annotation.Ref" + + val (base, annotations) = peelAnnotations(tpe) + base.dealias match { + case AppliedType(optionType, List(inner)) + if optionType.typeSymbol.fullName == "scala.Option" => + peelAnnotations(inner)._2.exists(isRef) + case _ => annotations.exists(isRef) + } + } + + private def generatedTypeExtMeta(using q: Quotes)( + typeId: Int, + nullable: Boolean, + trackingRef: Boolean, + rawClass: Option[Expr[Class[?]]] = None): Expr[TypeExtMeta] = { + if typeId == Types.UNKNOWN && rawClass.nonEmpty then { + val raw = rawClass.get + '{ + val resolvedTypeId = + if classOf[ForyScalaEnum].isAssignableFrom($raw) then Types.ENUM else Types.UNKNOWN + if resolvedTypeId == Types.UNKNOWN && !${ Expr(nullable) } && !${ Expr(trackingRef) } then { + null.asInstanceOf[TypeExtMeta] + } else { + TypeExtMeta.of(resolvedTypeId, ${ Expr(nullable) }, ${ Expr(trackingRef) }) + } + } + } else if typeId == Types.UNKNOWN && !nullable && !trackingRef then { + '{ null.asInstanceOf[TypeExtMeta] } + } else { + '{ TypeExtMeta.of(${ Expr(typeId) }, ${ Expr(nullable) }, ${ Expr(trackingRef) }) } + } + } + + private def isScalaEnumType(using q: Quotes)(tpe: q.reflect.TypeRepr): Boolean = { + import q.reflect.* + tpe.typeSymbol.flags.is(Flags.Enum) || + tpe <:< TypeRepr.of[ForyScalaEnum] || + tpe.baseClasses.exists(_.fullName == "org.apache.fory.scala.ForyScalaEnum") + } + + private def typeIdForAnnotation(using q: Quotes)(annotation: q.reflect.Term): Option[Int] = { + import q.reflect.* + val annotationName = annotation.tpe.typeSymbol.fullName + annotationName match { + case "org.apache.fory.annotation.Int8Type" => Some(Types.INT8) + case "org.apache.fory.annotation.UInt8Type" => Some(Types.UINT8) + case "org.apache.fory.annotation.UInt16Type" => Some(Types.UINT16) + case "org.apache.fory.annotation.Float16Type" => Some(Types.FLOAT16) + case "org.apache.fory.annotation.BFloat16Type" => Some(Types.BFLOAT16) + case "org.apache.fory.annotation.Int32Type" => + Some(if annotationEncoding(annotation).contains("FIXED") then Types.INT32 else Types.VARINT32) + case "org.apache.fory.annotation.UInt32Type" => + Some(if annotationEncoding(annotation).contains("FIXED") then Types.UINT32 else Types.VAR_UINT32) + case "org.apache.fory.annotation.Int64Type" => + annotationEncoding(annotation) match { + case Some("FIXED") => Some(Types.INT64) + case Some("TAGGED") => Some(Types.TAGGED_INT64) + case _ => Some(Types.VARINT64) + } + case "org.apache.fory.annotation.UInt64Type" => + annotationEncoding(annotation) match { + case Some("FIXED") => Some(Types.UINT64) + case Some("TAGGED") => Some(Types.TAGGED_UINT64) + case _ => Some(Types.VAR_UINT64) + } + case _ => None + } + } + + private def annotationEncoding(using q: Quotes)(annotation: q.reflect.Term): Option[String] = { + import q.reflect.* + annotation match { + case Apply(_, args) => args.collectFirst { + case NamedArg("encoding", Select(_, name)) => name + case NamedArg("encoding", Ident(name)) => name + } + case _ => None + } + } + + private def isRefAnnotation(using q: Quotes)(annotation: q.reflect.Term): Boolean = { + annotation.tpe.typeSymbol.fullName == "org.apache.fory.annotation.Ref" + } +} diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala new file mode 100644 index 0000000000..a9014103b6 --- /dev/null +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer.scala + +import org.apache.fory.context.{CopyContext, ReadContext, WriteContext} +import org.apache.fory.resolver.TypeResolver +import org.apache.fory.serializer.Serializer +import org.apache.fory.serializer.collection.{CollectionLikeSerializer, MapLikeSerializer} + +import java.util +import scala.collection.mutable +import scala.collection.{immutable => simmutable} + +abstract class AbstractScalaXlangCollectionSerializer[A, T <: scala.collection.Iterable[A]]( + typeResolver: TypeResolver, + cls: Class[T]) + extends CollectionLikeSerializer[T](typeResolver, cls) { + + override def onCollectionWrite(writeContext: WriteContext, value: T): util.Collection[_] = { + writeContext.getBuffer.writeVarUInt32Small7(value.size) + new XlangCollectionAdapter[A](value) + } + + override def newCollection(readContext: ReadContext): util.Collection[_] = { + val numElements = readCollectionSize(readContext.getBuffer) + setNumElements(numElements) + new XlangCollectionBuilder[A, T](newBuilder(numElements)) + } + + protected def newBuilder(numElements: Int): mutable.Builder[A, T] + + override def onCollectionRead(collection: util.Collection[_]): T = { + collection.asInstanceOf[XlangCollectionBuilder[A, T]].builder.result() + } + + override def copy(copyContext: CopyContext, value: T): T = { + if (isImmutable) { + value + } else { + super.copy(copyContext, value) + } + } +} + +class ScalaXlangSeqSerializer[A, T <: scala.collection.Seq[A]]( + typeResolver: TypeResolver, + cls: Class[T]) + extends AbstractScalaXlangCollectionSerializer[A, T](typeResolver, cls) { + override protected def newBuilder(numElements: Int): mutable.Builder[A, T] = { + val builder = simmutable.List.newBuilder[A] + builder.sizeHint(numElements) + builder.asInstanceOf[mutable.Builder[A, T]] + } +} + +class ScalaXlangSetSerializer[A, T <: scala.collection.Set[A]]( + typeResolver: TypeResolver, + cls: Class[T]) + extends AbstractScalaXlangCollectionSerializer[A, T](typeResolver, cls) { + override protected def newBuilder(numElements: Int): mutable.Builder[A, T] = { + val builder = simmutable.Set.newBuilder[A] + builder.sizeHint(numElements) + builder.asInstanceOf[mutable.Builder[A, T]] + } +} + +class ScalaXlangCollectionSerializer[A, T <: scala.collection.Iterable[A]]( + typeResolver: TypeResolver, + cls: Class[T]) + extends AbstractScalaXlangCollectionSerializer[A, T](typeResolver, cls) { + override protected def newBuilder(numElements: Int): mutable.Builder[A, T] = { + val builder = simmutable.List.newBuilder[A] + builder.sizeHint(numElements) + builder.asInstanceOf[mutable.Builder[A, T]] + } +} + +private final class XlangCollectionAdapter[A](coll: scala.collection.Iterable[A]) + extends util.AbstractCollection[A] { + override def iterator(): util.Iterator[A] = new util.Iterator[A] { + private val it = coll.iterator + + override def hasNext: Boolean = it.hasNext + + override def next(): A = it.next() + } + + override def size(): Int = coll.size +} + +private final class XlangCollectionBuilder[A, T](val builder: mutable.Builder[A, T]) + extends util.AbstractCollection[A] { + override def add(e: A): Boolean = { + builder.addOne(e) + true + } + + override def iterator(): util.Iterator[A] = + throw new UnsupportedOperationException("Scala xlang collection builder is write-only") + + override def size(): Int = + throw new UnsupportedOperationException("Scala xlang collection builder is write-only") +} + +abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K, V]]( + typeResolver: TypeResolver, + cls: Class[T]) + extends MapLikeSerializer[T](typeResolver, cls) { + + override def onMapWrite(writeContext: WriteContext, value: T): util.Map[_, _] = { + writeContext.getBuffer.writeVarUInt32Small7(value.size) + new XlangMapAdapter[K, V](value) + } + + override def newMap(readContext: ReadContext): util.Map[_, _] = { + val numElements = readMapSize(readContext.getBuffer) + setNumElements(numElements) + val builder = simmutable.Map.newBuilder[K, V] + builder.sizeHint(numElements) + new XlangMapBuilder[K, V, T](builder.asInstanceOf[mutable.Builder[(K, V), T]]) + } + + override def onMapRead(map: util.Map[_, _]): T = { + map.asInstanceOf[XlangMapBuilder[K, V, T]].builder.result() + } + + override def onMapCopy(map: util.Map[_, _]): T = onMapRead(map) +} + +class ScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K, V]]( + typeResolver: TypeResolver, + cls: Class[T]) + extends AbstractScalaXlangMapSerializer[K, V, T](typeResolver, cls) + +private final class XlangMapAdapter[K, V](map: scala.collection.Map[K, V]) + extends util.AbstractMap[K, V] { + override def entrySet(): util.Set[util.Map.Entry[K, V]] = + new util.AbstractSet[util.Map.Entry[K, V]] { + override def size(): Int = map.size + + override def iterator(): util.Iterator[util.Map.Entry[K, V]] = + new util.Iterator[util.Map.Entry[K, V]] { + private val it = map.iterator + + override def hasNext: Boolean = it.hasNext + + override def next(): util.Map.Entry[K, V] = { + val entry = it.next() + new org.apache.fory.collection.MapEntry[K, V](entry._1, entry._2) + } + } + } +} + +private final class XlangMapBuilder[K, V, T](val builder: mutable.Builder[(K, V), T]) + extends util.AbstractMap[K, V] { + override def entrySet(): util.Set[util.Map.Entry[K, V]] = + throw new UnsupportedOperationException("Scala xlang map builder is write-only") + + override def put(key: K, value: V): V = { + builder.addOne((key, value)) + value + } +} + +final class ScalaOptionSerializer(typeResolver: TypeResolver, cls: Class[_]) + extends Serializer[Option[Any]](typeResolver.getConfig, cls.asInstanceOf[Class[Option[Any]]]) { + override def write(writeContext: WriteContext, value: Option[Any]): Unit = { + writeContext.writeRef(value.orNull) + } + + override def read(readContext: ReadContext): Option[Any] = { + Option(readContext.readRef()) + } + + override def copy(copyContext: CopyContext, value: Option[Any]): Option[Any] = { + value.map(copyContext.copyObject(_)) + } +} diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala new file mode 100644 index 0000000000..bd4b4ce77d --- /dev/null +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer.scala + +import org.apache.fory.Fory +import org.apache.fory.annotation.{ForyCase, ForyField, ForyStruct, ForyUnion, Ref} +import org.apache.fory.scala.ForySerializer +import org.apache.fory.serializer.StaticGeneratedStructSerializerFactory +import org.apache.fory.`type`.TypeUtils +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.jdk.CollectionConverters._ + +object ForySerializerDerivationTest { + @ForyStruct + final case class Person( + @ForyField(id = 1) name: String, + @ForyField(id = 2) age: Int, + @ForyField(id = 3) email: Option[String]) + derives ForySerializer + + @ForyStruct + final case class SearchUser(@ForyField(id = 1) name: String) derives ForySerializer + + @ForyStruct + final case class CollectionBox( + @ForyField(id = 1) names: List[String], + @ForyField(id = 2) tags: Set[String], + @ForyField(id = 3) scores: Map[String, Int]) + derives ForySerializer + + @ForyStruct + final class RefNode() derives ForySerializer { + @ForyField(id = 1) + var children: List[RefNode @Ref] = List.empty + + @Ref + @ForyField(id = 2) + var parent: Option[RefNode @Ref] = None + } + + @ForyStruct + final class MixedRecord(@ForyField(id = 1) val id: Int) derives ForySerializer { + @ForyField(id = 2) + var name: String = "" + } + + @ForyUnion + enum SearchTarget derives ForySerializer { + @ForyCase(id = 0) + case UnknownCase(caseId: Int, value: Any) + + @ForyCase(id = 1) + case UserCase(value: SearchUser) + + @ForyCase(id = 2) + case FixedIdCase(value: Int) + } + + def xlangFory(): Fory = { + val fory = Fory.builder() + .withXlang(true) + .withRefTracking(true) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .suppressClassRegistrationWarnings(false) + .build() + ScalaSerializers.registerSerializers(fory) + ForySerializer.register(fory, classOf[Person], "scala_test", "Person") + ForySerializer.register(fory, classOf[SearchUser], "scala_test", "SearchUser") + ForySerializer.register(fory, classOf[CollectionBox], "scala_test", "CollectionBox") + ForySerializer.register(fory, classOf[MixedRecord], "scala_test", "MixedRecord") + ForySerializer.register(fory, classOf[SearchTarget], "scala_test", "SearchTarget") + fory + } +} + +class ForySerializerDerivationTest extends AnyWordSpec with Matchers { + import ForySerializerDerivationTest._ + + "Scala 3 ForySerializer derivation" should { + "serialize derived case classes with Option fields" in { + val fory = xlangFory() + fory.deserialize(fory.serialize(Person("Ada", 36, Some("ada@example.com")))) shouldEqual + Person("Ada", 36, Some("ada@example.com")) + fory.deserialize(fory.serialize(Person("Grace", 85, None))) shouldEqual + Person("Grace", 85, None) + } + + "serialize derived union enum cases" in { + val fory = xlangFory() + val user = SearchTarget.UserCase(SearchUser("Ada")) + val fixed = SearchTarget.FixedIdCase(7) + fory.deserialize(fory.serialize(user)) shouldEqual user + fory.deserialize(fory.serialize(fixed)) shouldEqual fixed + } + + "serialize derived case classes with Scala collection fields" in { + val fory = xlangFory() + val box = CollectionBox(List("a", "b"), Set("x", "y"), Map("a" -> 1, "b" -> 2)) + fory.deserialize(fory.serialize(box)) shouldEqual box + } + + "serialize mixed constructor and mutable field classes" in { + val fory = xlangFory() + val record = new MixedRecord(7) + record.name = "Ada" + val restored = fory.deserialize(fory.serialize(record)).asInstanceOf[MixedRecord] + restored.id shouldBe 7 + restored.name shouldBe "Ada" + } + + "preserve nested reference metadata in generated descriptors" in { + val factory = + summon[ForySerializer[RefNode]] + .asInstanceOf[StaticGeneratedStructSerializerFactory[RefNode]] + val descriptors = factory.getGeneratedDescriptors.asScala + val children = descriptors.find(_.getName == "children").get + val parent = descriptors.find(_.getName == "parent").get + + children.isTrackingRef shouldBe false + TypeUtils.getElementType(children.getTypeRef).getTypeExtMeta.trackingRef() shouldBe true + parent.isNullable shouldBe true + parent.isTrackingRef shouldBe true + } + + "serialize derived union unknown cases with original ids" in { + val fory = xlangFory() + val unknown = SearchTarget.UnknownCase(99, SearchUser("Future")) + fory.deserialize(fory.serialize(unknown)) shouldEqual unknown + } + } +} diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala new file mode 100644 index 0000000000..a782060345 --- /dev/null +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer.scala + +import org.apache.fory.Fory +import org.apache.fory.annotation.{ForyCase, ForyField, ForyStruct, ForyUnion} +import org.apache.fory.config.Language +import org.apache.fory.scala.ForySerializer + +import java.nio.file.{Files, Path} + +@ForyStruct +final case class ScalaPeerUser( + @ForyField(id = 1) id: Int, + @ForyField(id = 2) name: String, + @ForyField(id = 3) email: Option[String]) + derives ForySerializer + +@ForyUnion +enum ScalaPeerTarget derives ForySerializer { + @ForyCase(id = 0) + case UnknownCase(caseId: Int, value: Any) + + @ForyCase(id = 1) + case UserCase(value: ScalaPeerUser) +} + +object ScalaXlangPeer { + private val Namespace = "scala_peer" + + def main(args: Array[String]): Unit = { + require(args.length == 2, "Usage: ScalaXlangPeer ") + args(0) match { + case "derived_struct_round_trip" => + roundTripUser(Path.of(args(1))) + case "known_union_case_round_trip" => + roundTripTarget(Path.of(args(1)), preserveUnknownCase = false) + case "unknown_union_case_round_trip" => + roundTripTarget(Path.of(args(1)), preserveUnknownCase = true) + case other => + throw new IllegalArgumentException(s"Unknown Scala xlang peer case: $other") + } + } + + private def roundTripUser(dataFile: Path): Unit = { + val fory = newFory() + val request = fory.deserialize(Files.readAllBytes(dataFile)).asInstanceOf[ScalaPeerUser] + Files.write( + dataFile, + fory.serialize(request.copy(id = request.id + 1, name = "scala-" + request.name, email = None))) + } + + private def roundTripTarget(dataFile: Path, preserveUnknownCase: Boolean): Unit = { + val fory = newFory() + val request = fory.deserialize(Files.readAllBytes(dataFile)).asInstanceOf[ScalaPeerTarget] + val response = request match { + case ScalaPeerTarget.UserCase(user) => + ScalaPeerTarget.UserCase( + user.copy(id = user.id + 1, name = "scala-" + user.name, email = None)) + case ScalaPeerTarget.UnknownCase(caseId, value: ScalaPeerUser) if preserveUnknownCase => + ScalaPeerTarget.UnknownCase( + caseId, + value.copy(id = value.id + 1, name = "scala-" + value.name, email = None)) + case ScalaPeerTarget.UnknownCase(caseId, value) => + ScalaPeerTarget.UnknownCase(caseId, value) + } + Files.write(dataFile, fory.serialize(response)) + } + + private def newFory(): Fory = { + val fory = Fory.builder() + .withLanguage(Language.XLANG) + .withCompatible(true) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + ScalaSerializers.registerSerializers(fory) + ForySerializer.register(fory, classOf[ScalaPeerUser], Namespace, "ScalaPeerUser") + ForySerializer.register(fory, classOf[ScalaPeerTarget], Namespace, "ScalaPeerTarget") + fory + } +} diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala new file mode 100644 index 0000000000..0e414613a5 --- /dev/null +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer.scala + +import org.apache.fory.Fory +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.jdk.CollectionConverters.* + +class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { + def fory: Fory = { + val runtime = Fory.builder() + .withXlang(true) + .withRefTracking(true) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(false) + .suppressClassRegistrationWarnings(false) + .build() + ScalaSerializers.registerSerializers(runtime) + runtime + } + + "fory scala xlang support" should { + "serialize collections with canonical xlang serializers" in { + val runtime = fory + val list = List("a", "b", "c") + val set = Set("a", "b", "c") + val map = Map("a" -> 1, "b" -> 2) + runtime + .deserialize(runtime.serialize(list)) + .asInstanceOf[java.util.List[String]] + .asScala + .toList shouldEqual list + runtime + .deserialize(runtime.serialize(set)) + .asInstanceOf[java.util.Set[String]] + .asScala + .toSet shouldEqual set + runtime + .deserialize(runtime.serialize(map)) + .asInstanceOf[java.util.Map[String, Int]] + .asScala + .toMap shouldEqual map + } + + } +} From 6ee39a169388120f7981eb9801f55d76e1ea4c4f Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 15 May 2026 08:00:00 +0800 Subject: [PATCH 3/9] fix(scala): cover inherited xlang cases --- .../apache/fory/resolver/ClassResolver.java | 20 + .../apache/fory/resolver/TypeResolver.java | 10 + .../apache/fory/resolver/XtypeResolver.java | 20 + .../org/apache/fory/xlang/ScalaXlangTest.java | 315 ++++++- scala/build.sbt | 10 + .../serializer/scala/ScalaEnumSerializer.java | 2 +- .../serializer/scala/ScalaSerializers.java | 11 + .../scala/internal/ForySerializerMacros.scala | 24 +- .../serializer/scala/ScalaXlangPeer.scala | 865 +++++++++++++++++- 9 files changed, 1218 insertions(+), 59 deletions(-) diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index 7aa11ffeae..238e8dec6b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -615,6 +615,26 @@ public void registerUnionCase(Class unionType, Class caseType) { registerGraalvmClass(caseType); } + @Override + public void registerEnumCase(Class enumType, Class caseType) { + checkRegisterAllowed(); + TypeInfo typeInfo = classInfoMap.get(enumType); + Preconditions.checkArgument( + typeInfo != null && Types.isEnumType(typeInfo.typeId), + "Enum type %s must be registered before case type %s", + enumType, + caseType); + TypeInfo existingInfo = classInfoMap.get(caseType); + Preconditions.checkArgument( + existingInfo == null || existingInfo == typeInfo, + "Enum case type %s has been registered as %s", + caseType, + existingInfo); + classInfoMap.put(caseType, typeInfo); + extRegistry.registeredClasses.put(caseType.getName(), caseType); + registerGraalvmClass(caseType); + } + @Override public void registerEnum(Class cls, long userId, Serializer serializer) { checkRegisterAllowed(); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 69e0cbf648..8f767dafb6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -302,6 +302,16 @@ public abstract void registerUnion( @Internal public abstract void registerUnionCase(Class unionType, Class caseType); + /** + * Registers {@code caseType} as a runtime class alias for an already registered enum type. + * + *

    Some JVM languages compile enum cases to concrete singleton subclasses even though the wire + * type is owned by the enum base. This method makes runtime dispatch for those case subclasses + * use the base enum {@link TypeInfo}; it must not create another wire type name or user type ID. + */ + @Internal + public abstract void registerEnumCase(Class enumType, Class caseType); + /** Registers a non-Java enum type with a user-specified ID and serializer. */ @Internal public abstract void registerEnum(Class type, long id, Serializer serializer); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java index b6e990d035..087aefb893 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java @@ -408,6 +408,26 @@ public void registerUnionCase(Class unionType, Class caseType) { registerGraalvmClass(caseType); } + @Override + public void registerEnumCase(Class enumType, Class caseType) { + checkRegisterAllowed(); + TypeInfo typeInfo = classInfoMap.get(enumType); + Preconditions.checkArgument( + typeInfo != null && Types.isEnumType(typeInfo.typeId), + "Enum type %s must be registered before case type %s", + enumType, + caseType); + TypeInfo existingInfo = classInfoMap.get(caseType); + Preconditions.checkArgument( + existingInfo == null || existingInfo == typeInfo, + "Enum case type %s has been registered as %s", + caseType, + existingInfo); + classInfoMap.put(caseType, typeInfo); + extRegistry.registeredClasses.put(caseType.getName(), caseType); + registerGraalvmClass(caseType); + } + @Override public void registerEnum(Class type, long userTypeId, Serializer serializer) { checkRegisterAllowed(); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/ScalaXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/ScalaXlangTest.java index 217b68ad0b..98370f710d 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/ScalaXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/ScalaXlangTest.java @@ -21,7 +21,8 @@ import java.io.File; import java.io.IOException; -import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; @@ -38,7 +39,7 @@ import org.apache.fory.type.union.Union; import org.testng.Assert; import org.testng.SkipException; -import org.testng.annotations.BeforeMethod; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; /** Executes Java-driven Scala 3 macro xlang serializer tests. */ @@ -49,13 +50,15 @@ public class ScalaXlangTest extends XlangTestBase { private static final String UNKNOWN_UNION_CASE = "unknown_union_case_round_trip"; private static final String NAMESPACE = "scala_peer"; private static final File SCALA_DIR = new File("../../scala"); + private static final File SCALA_CLASSPATH_FILE = + new File(SCALA_DIR, "target/scala-xlang-test-classpath"); + private static final String JAVA_EXECUTABLE = + new File(new File(System.getProperty("java.home"), "bin"), "java").getAbsolutePath(); + private static volatile String scalaPeerClasspath; - @BeforeMethod(alwaysRun = true) - public void skipInheritedXlangCases(Method method) { - if (method.getDeclaringClass() != ScalaXlangTest.class) { - throw new SkipException( - "Scala xlang phase 1 validates macro-derived Scala 3 peer serializers"); - } + @DataProvider(name = "enableCodegenParallel") + public static Object[][] enableCodegenParallel() { + return enableCodegen(); } @Override @@ -66,41 +69,39 @@ protected void ensurePeerReady() { } boolean buildSuccess = TestUtils.executeCommand( - Arrays.asList("sbt", "--batch", "++3.3.1", "Test/compile"), + Arrays.asList("sbt", "--batch", "++3.3.1", "writeTestClasspath"), 240, Collections.emptyMap(), SCALA_DIR); if (!buildSuccess) { throw new AssertionError("Failed to compile Scala xlang peer"); } + try { + scalaPeerClasspath = + new String(Files.readAllBytes(SCALA_CLASSPATH_FILE.toPath()), StandardCharsets.UTF_8) + .trim(); + } catch (IOException e) { + throw new AssertionError("Failed to read Scala xlang peer classpath", e); + } } @Override protected CommandContext buildCommandContext(String caseName, Path dataFile) { + if (scalaPeerClasspath == null || scalaPeerClasspath.isEmpty()) { + throw new IllegalStateException("Scala xlang peer classpath is not initialized"); + } return new CommandContext( Arrays.asList( - "sbt", - "--batch", - "++3.3.1", - "Test/runMain org.apache.fory.serializer.scala.ScalaXlangPeer " - + caseName - + " " - + dataFile.toAbsolutePath()), + JAVA_EXECUTABLE, + "-cp", + scalaPeerClasspath, + "org.apache.fory.serializer.scala.ScalaXlangPeer", + caseName, + dataFile.toAbsolutePath().toString()), envBuilder(dataFile), SCALA_DIR); } - @Override - protected ExecutionContext prepareExecution(String caseName, byte[] payload) throws IOException { - if (!DERIVED_CASE.equals(caseName) - && !KNOWN_UNION_CASE.equals(caseName) - && !UNKNOWN_UNION_CASE.equals(caseName)) { - throw new SkipException( - "Scala xlang phase 1 validates macro-derived Scala 3 peer serializers"); - } - return super.prepareExecution(caseName, payload); - } - @Test(groups = "xlang") public void testDerivedStructRoundTrip() throws IOException { Fory fory = newFory(); @@ -208,4 +209,264 @@ public ScalaPeerTargetMirror(int index, Object value) { super(index, value); } } + + // ============================================================================ + // Test methods - duplicated from XlangTestBase for Maven Surefire discovery + // ============================================================================ + + @Test(groups = "xlang") + public void testBuffer() throws IOException { + super.testBuffer(); + } + + @Test(groups = "xlang") + public void testBufferVar() throws IOException { + super.testBufferVar(); + } + + @Test(groups = "xlang") + public void testMurmurHash3() throws IOException { + super.testMurmurHash3(); + } + + @Test(groups = "xlang") + public void testStringSerializer() throws Exception { + super.testStringSerializer(); + } + + @Test(groups = "xlang") + public void testCrossLanguageSerializer() throws Exception { + super.testCrossLanguageSerializer(); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testSimpleStruct(boolean enableCodegen) throws IOException { + super.testSimpleStruct(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testSimpleNamedStruct(boolean enableCodegen) throws IOException { + super.testSimpleNamedStruct(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testStructEvolvingOverride(boolean enableCodegen) throws IOException { + super.testStructEvolvingOverride(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testList(boolean enableCodegen) throws IOException { + super.testList(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testMap(boolean enableCodegen) throws IOException { + super.testMap(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testInteger(boolean enableCodegen) throws IOException { + super.testInteger(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testDecimal(boolean enableCodegen) throws IOException { + super.testDecimal(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testItem(boolean enableCodegen) throws IOException { + super.testItem(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testColor(boolean enableCodegen) throws IOException { + super.testColor(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testStructWithList(boolean enableCodegen) throws IOException { + super.testStructWithList(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testStructWithMap(boolean enableCodegen) throws IOException { + super.testStructWithMap(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNestedAnnotatedContainerSchemaConsistent(boolean enableCodegen) + throws IOException { + super.testNestedAnnotatedContainerSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNestedAnnotatedContainerCompatible(boolean enableCodegen) throws IOException { + super.testNestedAnnotatedContainerCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testCollectionElementRefOverride(boolean enableCodegen) throws IOException { + super.testCollectionElementRefOverride(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testCollectionElementRefRemoteTracking(boolean enableCodegen) throws IOException { + super.testCollectionElementRefRemoteTracking(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testSkipIdCustom(boolean enableCodegen) throws IOException { + super.testSkipIdCustom(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testSkipNameCustom(boolean enableCodegen) throws IOException { + super.testSkipNameCustom(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testConsistentNamed(boolean enableCodegen) throws IOException { + super.testConsistentNamed(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testStructVersionCheck(boolean enableCodegen) throws IOException { + super.testStructVersionCheck(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStruct(boolean enableCodegen) throws IOException { + super.testReducedPrecisionFloatStruct(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStructCompatibleFieldSkip(boolean enableCodegen) + throws IOException { + super.testReducedPrecisionFloatStructCompatibleFieldSkip(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testPolymorphicList(boolean enableCodegen) throws IOException { + super.testPolymorphicList(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testPolymorphicMap(boolean enableCodegen) throws IOException { + super.testPolymorphicMap(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testOneStringFieldSchemaConsistent(boolean enableCodegen) throws IOException { + super.testOneStringFieldSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testOneStringFieldCompatible(boolean enableCodegen) throws IOException { + super.testOneStringFieldCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testTwoStringFieldCompatible(boolean enableCodegen) throws IOException { + super.testTwoStringFieldCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testSchemaEvolutionCompatible(boolean enableCodegen) throws IOException { + super.testSchemaEvolutionCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testOneEnumFieldSchemaConsistent(boolean enableCodegen) throws IOException { + super.testOneEnumFieldSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testOneEnumFieldCompatible(boolean enableCodegen) throws IOException { + super.testOneEnumFieldCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testTwoEnumFieldCompatible(boolean enableCodegen) throws IOException { + super.testTwoEnumFieldCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testEnumSchemaEvolutionCompatible(boolean enableCodegen) throws IOException { + super.testEnumSchemaEvolutionCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNullableFieldSchemaConsistentNotNull(boolean enableCodegen) throws IOException { + super.testNullableFieldSchemaConsistentNotNull(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNullableFieldSchemaConsistentNull(boolean enableCodegen) throws IOException { + super.testNullableFieldSchemaConsistentNull(enableCodegen); + } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNullableFieldCompatibleNotNull(boolean enableCodegen) throws IOException { + super.testNullableFieldCompatibleNotNull(enableCodegen); + } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNullableFieldCompatibleNull(boolean enableCodegen) throws IOException { + super.testNullableFieldCompatibleNull(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testUnionXlang(boolean enableCodegen) throws IOException { + super.testUnionXlang(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testRefSchemaConsistent(boolean enableCodegen) throws IOException { + super.testRefSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testRefCompatible(boolean enableCodegen) throws IOException { + super.testRefCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testCircularRefSchemaConsistent(boolean enableCodegen) throws IOException { + super.testCircularRefSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testCircularRefCompatible(boolean enableCodegen) throws IOException { + super.testCircularRefCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testUnsignedSchemaConsistent(boolean enableCodegen) throws IOException { + super.testUnsignedSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testUnsignedSchemaConsistentSimple(boolean enableCodegen) throws IOException { + super.testUnsignedSchemaConsistentSimple(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testUnsignedSchemaCompatible(boolean enableCodegen) throws IOException { + super.testUnsignedSchemaCompatible(enableCodegen); + } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testManualSchemaKindStruct(boolean enableCodegen) throws IOException { + super.testManualSchemaKindStruct(enableCodegen); + } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testListArrayCompatibleRead(boolean enableCodegen) throws IOException { + super.testListArrayCompatibleRead(enableCodegen); + } } diff --git a/scala/build.sbt b/scala/build.sbt index ed279c2fd8..ebc18659a6 100644 --- a/scala/build.sbt +++ b/scala/build.sbt @@ -47,6 +47,16 @@ libraryDependencies ++= Seq( "dev.zio" %% "zio" % "2.1.7" % Test, ) +lazy val writeTestClasspath = taskKey[File]("Writes the Scala test runtime classpath") + +writeTestClasspath := { + val output = target.value / "scala-xlang-test-classpath" + IO.write( + output, + (Test / fullClasspath).value.map(_.data.getAbsolutePath).mkString(java.io.File.pathSeparator)) + output +} + // Exclude sonatypeRelease and sonatypeBundleRelease commands because we // don't want to release this project to Maven Central without having // to complete the release using the repository.apache.org web site. diff --git a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java index c1b6d4b5eb..3a591cb3c3 100644 --- a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java +++ b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java @@ -110,7 +110,7 @@ private Object handleUnknownEnumValue(int tag) { } } - private static Object[] loadValues(Class cls) { + static Object[] loadValues(Class cls) { try { Method values = cls.getMethod("values"); Object result = values.invoke(null); diff --git a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java index 9699d56062..4541d3d3f3 100644 --- a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java +++ b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java @@ -177,11 +177,22 @@ public static void registerSerializers(Fory fory) { public static void registerEnum(Fory fory, Class cls, long typeId) { TypeResolver resolver = fory.getTypeResolver(); resolver.registerEnum(cls, typeId, new ScalaEnumSerializer(resolver, cls)); + registerEnumCases(resolver, cls); } public static void registerEnum(Fory fory, Class cls, String namespace, String typeName) { TypeResolver resolver = fory.getTypeResolver(); resolver.registerEnum(cls, namespace, typeName, new ScalaEnumSerializer(resolver, cls)); + registerEnumCases(resolver, cls); + } + + private static void registerEnumCases(TypeResolver resolver, Class cls) { + for (Object enumConstant : ScalaEnumSerializer.loadValues(cls)) { + Class caseClass = enumConstant.getClass(); + if (caseClass != cls) { + resolver.registerEnumCase(cls, caseClass); + } + } } private static TypeResolver setSerializerFactory(Fory fory) { diff --git a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala index 1f8a081eba..042969a439 100644 --- a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala +++ b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala @@ -155,13 +155,25 @@ object ForySerializerMacros { if params.nonEmpty then params else owner.caseFields } val constructorFieldSet = constructorFields.toSet - val bodyFields = - owner.fieldMembers.filter(field => - !constructorFieldSet.contains(field) && annotationIntArg[ForyField](field, "id").nonEmpty) - val serializableFields = constructorFields ++ bodyFields - if serializableFields.isEmpty then { - report.errorAndAbort(s"${owner.fullName} has no serializable fields") + val bodyFields = { + val candidates = owner.fieldMembers.filter { field => + !constructorFieldSet.contains(field) && + !field.flags.is(Flags.Private) && + !field.flags.is(Flags.Synthetic) && + !field.name.contains("$") + } + val selected = + if constructorFields.isEmpty then candidates + else candidates.filter(field => annotationIntArg[ForyField](field, "id").nonEmpty) + selected.foreach { field => + if !field.flags.is(Flags.Mutable) then { + report.errorAndAbort( + s"${owner.fullName}.${field.name} is a post-construction field and must be a mutable var") + } + } + selected } + val serializableFields = constructorFields ++ bodyFields def declaredType(symbol: Symbol): TypeRepr = { symbol.tree match { diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala index a782060345..60689af816 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala @@ -20,11 +20,318 @@ package org.apache.fory.serializer.scala import org.apache.fory.Fory -import org.apache.fory.annotation.{ForyCase, ForyField, ForyStruct, ForyUnion} -import org.apache.fory.config.Language -import org.apache.fory.scala.ForySerializer +import org.apache.fory.annotation.{ + ArrayType, + ForyCase, + ForyField, + ForyStruct, + ForyUnion, + Int32Type, + Int64Type, + Ref, + UInt16Type, + UInt32Type, + UInt64Type, + UInt8Type +} +import org.apache.fory.annotation.ForyStruct.Evolution +import org.apache.fory.collection.{BFloat16List, Float16List, Int32List} +import org.apache.fory.config.{Int32Encoding, Int64Encoding, Language} +import org.apache.fory.context.{ReadContext, WriteContext} +import org.apache.fory.memory.{MemoryBuffer, MemoryUtils} +import org.apache.fory.meta.MetaCompressor +import org.apache.fory.resolver.TypeResolver +import org.apache.fory.scala.{ForyScalaEnum, ForySerializer} +import org.apache.fory.serializer.Serializer +import org.apache.fory.`type`.{BFloat16, Float16} +import org.apache.fory.`type`.union.Union2 +import org.apache.fory.util.MurmurHash3 +import java.math.BigDecimal as JBigDecimal +import java.nio.charset.StandardCharsets import java.nio.file.{Files, Path} +import java.time.{Instant, LocalDate} +import java.util +import scala.jdk.CollectionConverters.* + +enum Color(val foryId: Int) extends ForyScalaEnum { + case Green extends Color(0) + case Red extends Color(1) + case Blue extends Color(2) + case White extends Color(3) + + override def getForyId(): Int = foryId +} + +enum TestEnum(val foryId: Int) extends ForyScalaEnum { + case VALUE_A extends TestEnum(0) + case VALUE_B extends TestEnum(1) + case VALUE_C extends TestEnum(2) + + override def getForyId(): Int = foryId +} + +@ForyStruct +final case class Item(name: String) derives ForySerializer + +@ForyStruct +final case class SimpleStruct( + f1: util.HashMap[Integer, java.lang.Double], + f2: Int, + f3: Item, + f4: String, + f5: Color, + f6: util.List[String], + f7: Int, + f8: Int, + last: Int) + derives ForySerializer + +@ForyStruct(evolution = Evolution.ENABLED) +final case class EvolvingOverrideStruct(f1: String) derives ForySerializer + +@ForyStruct(evolution = Evolution.DISABLED) +final case class FixedOverrideStruct(f1: String) derives ForySerializer + +@ForyStruct +final case class Item1(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int, f6: Int) + derives ForySerializer + +@ForyStruct +final case class StructWithUnion2(union: Union2[String, java.lang.Long]) derives ForySerializer + +@ForyStruct +final case class StructWithList(items: util.List[String]) derives ForySerializer + +@ForyStruct +final case class StructWithMap(data: util.Map[String, String]) derives ForySerializer + +@ForyStruct +final case class NestedAnnotatedContainerSchemaConsistent( + values: util.Map[ + Long @UInt32Type(encoding = Int32Encoding.FIXED), + util.List[Long @UInt64Type(encoding = Int64Encoding.TAGGED)]]) + derives ForySerializer + +@ForyStruct +final case class NestedAnnotatedContainerCompatible( + values: util.Map[ + Long @UInt32Type(encoding = Int32Encoding.FIXED), + util.List[Long @UInt64Type(encoding = Int64Encoding.TAGGED)]]) + derives ForySerializer + +@ForyStruct +final case class MyStruct(id: Int) derives ForySerializer + +final case class MyExt(id: Int) + +final class MyExtSerializer(resolver: TypeResolver, cls: Class[MyExt]) + extends Serializer[MyExt](resolver.getConfig, cls) { + override def write(writeContext: WriteContext, value: MyExt): Unit = + writeContext.getBuffer.writeVarInt32(value.id) + + override def read(readContext: ReadContext): MyExt = + MyExt(readContext.getBuffer.readVarInt32()) +} + +@ForyStruct +final case class EmptyWrapper() derives ForySerializer + +@ForyStruct +final case class VersionCheckStruct(f1: Int, f2: Option[String], f3: Double) + derives ForySerializer + +trait Animal + +@ForyStruct +final case class Dog(age: Int, name: Option[String]) extends Animal derives ForySerializer + +@ForyStruct +final case class Cat(age: Int, lives: Int) extends Animal derives ForySerializer + +@ForyStruct +final case class AnimalListHolder(animals: util.List[Animal]) derives ForySerializer + +@ForyStruct +final case class AnimalMapHolder(animal_map: util.Map[String, Animal]) derives ForySerializer + +@ForyStruct +final case class EmptyStruct() derives ForySerializer + +@ForyStruct +final case class OneStringFieldStruct(f1: Option[String]) derives ForySerializer + +@ForyStruct +final case class TwoStringFieldStruct(f1: String, f2: String) derives ForySerializer + +@ForyStruct +final case class ReducedPrecisionFloatStruct( + float16Value: Float16, + bfloat16Value: BFloat16, + float16Array: Float16List, + bfloat16Array: BFloat16List) + derives ForySerializer + +@ForyStruct +final case class OneEnumFieldStruct(f1: TestEnum) derives ForySerializer + +@ForyStruct +final case class TwoEnumFieldStruct(f1: TestEnum, f2: TestEnum) derives ForySerializer + +@ForyStruct +final case class NullableComprehensiveSchemaConsistent( + byteField: Byte, + shortField: Short, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + boolField: Boolean, + stringField: String, + listField: util.List[String], + setField: util.Set[String], + mapField: util.Map[String, String], + nullableInt: Option[Int], + nullableLong: Option[Long], + nullableFloat: Option[Float], + nullableDouble: Option[Double], + nullableBool: Option[Boolean], + nullableString: Option[String], + nullableList: Option[util.List[String]], + nullableSet: Option[util.Set[String]], + nullableMap: Option[util.Map[String, String]]) + derives ForySerializer + +@ForyStruct +final case class NullableComprehensiveCompatible( + byteField: Byte, + shortField: Short, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + boolField: Boolean, + boxedInt: Int, + boxedLong: Long, + boxedFloat: Float, + boxedDouble: Double, + boxedBool: Boolean, + stringField: String, + listField: util.List[String], + setField: util.Set[String], + mapField: util.Map[String, String], + nullableInt1: Option[Int], + nullableLong1: Option[Long], + nullableFloat1: Option[Float], + nullableDouble1: Option[Double], + nullableBool1: Option[Boolean], + nullableString2: Option[String], + nullableList2: Option[util.List[String]], + nullableSet2: Option[util.Set[String]], + nullableMap2: Option[util.Map[String, String]]) + derives ForySerializer + +@ForyStruct +final case class RefInnerSchemaConsistent(id: Int, name: String) derives ForySerializer + +@ForyStruct +final case class RefOuterSchemaConsistent( + inner1: Option[RefInnerSchemaConsistent @Ref], + inner2: Option[RefInnerSchemaConsistent @Ref]) + derives ForySerializer + +@ForyStruct +final case class RefInnerCompatible(id: Int, name: String) derives ForySerializer + +@ForyStruct +final case class RefOuterCompatible( + inner1: Option[RefInnerCompatible @Ref], + inner2: Option[RefInnerCompatible @Ref]) + derives ForySerializer + +@ForyStruct +final case class RefOverrideElement(id: Int, name: String) derives ForySerializer + +@ForyStruct +final case class RefOverrideContainer( + listField: util.List[RefOverrideElement @Ref(enable = false)], + setField: util.Set[RefOverrideElement @Ref(enable = false)], + mapField: util.Map[String, RefOverrideElement @Ref(enable = false)]) + derives ForySerializer + +@ForyStruct +final class CircularRefStruct derives ForySerializer { + var name: String = "" + + var selfRef: Option[CircularRefStruct @Ref] = None +} + +@ForyStruct +final case class UnsignedSchemaConsistent( + u8Field: Int @UInt8Type, + u16Field: Int @UInt16Type, + u32VarField: Long @UInt32Type, + u32FixedField: Long @UInt32Type(encoding = Int32Encoding.FIXED), + u64VarField: Long @UInt64Type(encoding = Int64Encoding.VARINT), + u64FixedField: Long @UInt64Type(encoding = Int64Encoding.FIXED), + u64TaggedField: Long @UInt64Type(encoding = Int64Encoding.TAGGED), + u8NullableField: Option[Int @UInt8Type], + u16NullableField: Option[Int @UInt16Type], + u32VarNullableField: Option[Long @UInt32Type], + u32FixedNullableField: Option[Long @UInt32Type(encoding = Int32Encoding.FIXED)], + u64VarNullableField: Option[Long @UInt64Type(encoding = Int64Encoding.VARINT)], + u64FixedNullableField: Option[Long @UInt64Type(encoding = Int64Encoding.FIXED)], + u64TaggedNullableField: Option[Long @UInt64Type(encoding = Int64Encoding.TAGGED)]) + derives ForySerializer + +@ForyStruct +final case class UnsignedSchemaConsistentSimple( + u64Tagged: Long @UInt64Type(encoding = Int64Encoding.TAGGED), + u64TaggedNullable: Option[Long @UInt64Type(encoding = Int64Encoding.TAGGED)]) + derives ForySerializer + +@ForyStruct +final case class UnsignedSchemaCompatible( + u8Field1: Int @UInt8Type, + u16Field1: Int @UInt16Type, + u32VarField1: Long @UInt32Type, + u32FixedField1: Long @UInt32Type(encoding = Int32Encoding.FIXED), + u64VarField1: Long @UInt64Type(encoding = Int64Encoding.VARINT), + u64FixedField1: Long @UInt64Type(encoding = Int64Encoding.FIXED), + u64TaggedField1: Long @UInt64Type(encoding = Int64Encoding.TAGGED), + u8Field2: Option[Int @UInt8Type], + u16Field2: Option[Int @UInt16Type], + u32VarField2: Option[Long @UInt32Type], + u32FixedField2: Option[Long @UInt32Type(encoding = Int32Encoding.FIXED)], + u64VarField2: Option[Long @UInt64Type(encoding = Int64Encoding.VARINT)], + u64FixedField2: Option[Long @UInt64Type(encoding = Int64Encoding.FIXED)], + u64TaggedField2: Option[Long @UInt64Type(encoding = Int64Encoding.TAGGED)]) + derives ForySerializer + +@ForyStruct +final case class XlangCompatibleInt32ListField( + @ForyField(id = 1) + values: Int32List @Int32Type(encoding = Int32Encoding.FIXED)) + derives ForySerializer + +@ForyStruct +final case class XlangCompatibleNullableInt32ListField( + @ForyField(id = 1) values: util.List[Integer]) + derives ForySerializer + +@ForyStruct +final case class XlangCompatibleInt32ArrayField(@ForyField(id = 1) values: Array[Int]) + derives ForySerializer + +@ForyStruct +final case class ManualSchemaKindStruct( + orderedValues: util.List[Integer], + denseValues: util.List[Integer] @ArrayType, + primitiveValues: Array[Int], + payload: Array[Byte], + signedBytes: Array[Byte], + unsignedBytes: Array[Byte] @UInt8Type) + derives ForySerializer @ForyStruct final case class ScalaPeerUser( @@ -43,37 +350,545 @@ enum ScalaPeerTarget derives ForySerializer { } object ScalaXlangPeer { - private val Namespace = "scala_peer" + private val ScalaPeerNamespace = "scala_peer" def main(args: Array[String]): Unit = { require(args.length == 2, "Usage: ScalaXlangPeer ") + val dataFile = Path.of(args(1)) args(0) match { - case "derived_struct_round_trip" => - roundTripUser(Path.of(args(1))) - case "known_union_case_round_trip" => - roundTripTarget(Path.of(args(1)), preserveUnknownCase = false) - case "unknown_union_case_round_trip" => - roundTripTarget(Path.of(args(1)), preserveUnknownCase = true) - case other => - throw new IllegalArgumentException(s"Unknown Scala xlang peer case: $other") + case "test_buffer" => testBuffer(dataFile) + case "test_buffer_var" => testBufferVar(dataFile) + case "test_murmurhash3" => testMurmurHash3(dataFile) + case "test_string_serializer" => roundTripValues(dataFile, newFory()) + case "test_cross_language_serializer" => roundTripValues(dataFile, crossLanguageFory()) + case "test_simple_struct" => roundTripValues(dataFile, simpleStructFory(false)) + case "test_named_simple_struct" => roundTripValues(dataFile, simpleStructFory(true)) + case "test_struct_evolving_override" => roundTripValues(dataFile, evolvingOverrideFory()) + case "test_list" | "test_map" | "test_item" => roundTripValues(dataFile, itemFory()) + case "test_integer" => roundTripValues(dataFile, integerFory()) + case "test_decimal" => roundTripValues(dataFile, newFory()) + case "test_color" => roundTripValues(dataFile, colorFory()) + case "test_union_xlang" => roundTripValues(dataFile, structWithUnionFory()) + case "test_struct_with_list" => roundTripValues(dataFile, structWithListFory()) + case "test_struct_with_map" => roundTripValues(dataFile, structWithMapFory()) + case "test_nested_annotated_container_schema_consistent" => + roundTripValues(dataFile, nestedAnnotatedSchemaFory()) + case "test_nested_annotated_container_compatible" => + roundTripValues(dataFile, nestedAnnotatedCompatibleFory()) + case "test_skip_id_custom" => roundTripValues(dataFile, emptyWrapperFory(Left(104))) + case "test_skip_name_custom" => roundTripValues(dataFile, emptyWrapperFory(Right(("", "my_wrapper")))) + case "test_consistent_named" => roundTripValues(dataFile, consistentNamedFory()) + case "test_struct_version_check" => roundTripValues(dataFile, versionCheckFory()) + case "test_polymorphic_list" => roundTripValues(dataFile, polymorphicListFory()) + case "test_polymorphic_map" => roundTripValues(dataFile, polymorphicMapFory()) + case "test_one_string_field_schema" => roundTripValues(dataFile, oneStringSchemaFory()) + case "test_one_string_field_compatible" => roundTripValues(dataFile, oneStringCompatibleFory()) + case "test_two_string_field_compatible" => roundTripValues(dataFile, twoStringCompatibleFory(201)) + case "test_schema_evolution_compatible" => schemaEvolutionToEmpty(dataFile) + case "test_schema_evolution_compatible_reverse" => schemaEvolutionToTwoString(dataFile) + case "test_reduced_precision_float_struct" => + roundTripValues(dataFile, reducedPrecisionFory(false)) + case "test_reduced_precision_float_struct_compatible_skip" => + reducedPrecisionToEmpty(dataFile) + case "test_one_enum_field_schema" => roundTripValues(dataFile, oneEnumSchemaFory()) + case "test_one_enum_field_compatible" => roundTripValues(dataFile, oneEnumCompatibleFory()) + case "test_two_enum_field_compatible" => roundTripValues(dataFile, twoEnumCompatibleFory(212)) + case "test_enum_schema_evolution_compatible" => enumEvolutionToEmpty(dataFile) + case "test_enum_schema_evolution_compatible_reverse" => enumEvolutionToTwoEnum(dataFile) + case "test_nullable_field_schema_consistent_not_null" | + "test_nullable_field_schema_consistent_null" => + roundTripValues(dataFile, nullableSchemaFory()) + case "test_nullable_field_compatible_not_null" => + roundTripValues(dataFile, nullableCompatibleFory()) + case "test_nullable_field_compatible_null" => nullableCompatibleNull(dataFile) + case "test_ref_schema_consistent" => roundTripValues(dataFile, refSchemaFory()) + case "test_ref_compatible" => roundTripValues(dataFile, refCompatibleFory()) + case "test_collection_element_ref_override" => collectionElementRefOverride(dataFile) + case "test_collection_element_ref_remote_tracking" => collectionElementRefRemoteTracking(dataFile) + case "test_circular_ref_schema_consistent" => roundTripValues(dataFile, circularRefFory(601, false)) + case "test_circular_ref_compatible" => roundTripValues(dataFile, circularRefFory(602, true)) + case "test_unsigned_schema_consistent_simple" => + roundTripValues(dataFile, unsignedSimpleFory()) + case "test_unsigned_schema_consistent" => roundTripValues(dataFile, unsignedSchemaFory()) + case "test_unsigned_schema_compatible" => roundTripValues(dataFile, unsignedCompatibleFory()) + case "test_list_array_compatible_list_to_array" => listArrayListToArray(dataFile) + case "test_list_array_compatible_array_to_list" => listArrayArrayToList(dataFile) + case "test_list_array_compatible_nullable_list_to_array_error" => + roundTripValues(dataFile, nullableInt32ListFory()) + case "derived_struct_round_trip" => roundTripUser(dataFile) + case "known_union_case_round_trip" => roundTripTarget(dataFile, preserveUnknownCase = false) + case "unknown_union_case_round_trip" => roundTripTarget(dataFile, preserveUnknownCase = true) + case other => throw new IllegalArgumentException(s"Unknown Scala xlang peer case: $other") } } - private def roundTripUser(dataFile: Path): Unit = { + private object NoOpMetaCompressor extends MetaCompressor { + override def compress(data: Array[Byte], offset: Int, size: Int): Array[Byte] = { + val result = new Array[Byte](size + 1) + System.arraycopy(data, offset, result, 0, size) + result + } + + override def decompress(data: Array[Byte], offset: Int, size: Int): Array[Byte] = { + val result = new Array[Byte](size) + System.arraycopy(data, offset, result, 0, size) + result + } + } + + private def newFory( + compatible: Boolean = true, + refTracking: Boolean = false, + classVersionCheck: Boolean = false, + noCompression: Boolean = false): Fory = { + val builder = Fory.builder() + .withLanguage(Language.XLANG) + .withCompatible(compatible) + .withRefTracking(refTracking) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(false) + if classVersionCheck then builder.withClassVersionCheck(true) + if noCompression then builder.withMetaCompressor(NoOpMetaCompressor) + val fory = builder.build() + ScalaSerializers.registerSerializers(fory) + fory + } + + private def registerStruct[T](fory: Fory, cls: Class[T], id: Long)(using ForySerializer[T]): Unit = + ForySerializer.register(fory, cls, id) + + private def registerStruct[T]( + fory: Fory, + cls: Class[T], + namespace: String, + typeName: String)(using ForySerializer[T]): Unit = + ForySerializer.register(fory, cls, namespace, typeName) + + private def roundTripValues(dataFile: Path, fory: Fory): Unit = { + val bytes = Files.readAllBytes(dataFile) + val input = MemoryUtils.wrap(bytes) + val output = MemoryBuffer.newHeapBuffer(Math.max(256, bytes.length * 2 + 64)) + while input.readerIndex() < bytes.length do { + val value = fory.deserialize(input) + fory.serialize(output, value) + } + writeBuffer(dataFile, output) + } + + private def readOne[T](dataFile: Path, fory: Fory): T = + fory.deserialize(Files.readAllBytes(dataFile)).asInstanceOf[T] + + private def writeOne(dataFile: Path, fory: Fory, value: Any): Unit = + Files.write(dataFile, fory.serialize(value)) + + private def writeBuffer(dataFile: Path, buffer: MemoryBuffer): Unit = + Files.write(dataFile, buffer.getBytes(0, buffer.writerIndex())) + + private def testBuffer(dataFile: Path): Unit = { + val input = MemoryUtils.wrap(Files.readAllBytes(dataFile)) + val boolValue = input.readBoolean() + val byteValue = input.readByte() + val shortValue = input.readInt16() + val intValue = input.readInt32() + val longValue = input.readInt64() + val floatValue = input.readFloat32() + val doubleValue = input.readFloat64() + val varUIntValue = input.readVarUInt32() + val bytes = input.readBytes(input.readInt32()) + + val output = MemoryUtils.buffer(32) + output.writeBoolean(boolValue) + output.writeByte(byteValue) + output.writeInt16(shortValue) + output.writeInt32(intValue) + output.writeInt64(longValue) + output.writeFloat32(floatValue) + output.writeFloat64(doubleValue) + output.writeVarUInt32(varUIntValue) + output.writeInt32(bytes.length) + output.writeBytes(bytes) + writeBuffer(dataFile, output) + } + + private def testBufferVar(dataFile: Path): Unit = { + val input = MemoryUtils.wrap(Files.readAllBytes(dataFile)) + val output = MemoryUtils.buffer(256) + var i = 0 + while i < 18 do { + output.writeVarInt32(input.readVarInt32()) + i += 1 + } + i = 0 + while i < 12 do { + output.writeVarUInt32(input.readVarUInt32()) + i += 1 + } + i = 0 + while i < 19 do { + output.writeVarUInt64(input.readVarUInt64()) + i += 1 + } + i = 0 + while i < 15 do { + output.writeVarInt64(input.readVarInt64()) + i += 1 + } + writeBuffer(dataFile, output) + } + + private def testMurmurHash3(dataFile: Path): Unit = { + val bytes = Files.readAllBytes(dataFile) + if bytes.length == 16 then { + val expected = MurmurHash3.murmurhash3_x64_128(Array[Byte](1, 2, 8), 0, 3, 47) + val buffer = MemoryUtils.wrap(bytes) + require(buffer.readInt64() == expected(0), "Unexpected MurmurHash3 first word") + require(buffer.readInt64() == expected(1), "Unexpected MurmurHash3 second word") + } + Files.write(dataFile, bytes) + } + + private def crossLanguageFory(): Fory = { val fory = newFory() - val request = fory.deserialize(Files.readAllBytes(dataFile)).asInstanceOf[ScalaPeerUser] - Files.write( - dataFile, - fory.serialize(request.copy(id = request.id + 1, name = "scala-" + request.name, email = None))) + ScalaSerializers.registerEnum(fory, classOf[Color], 101L) + fory } - private def roundTripTarget(dataFile: Path, preserveUnknownCase: Boolean): Unit = { + private def colorFory(): Fory = crossLanguageFory() + + private def simpleStructFory(named: Boolean): Fory = { + val fory = newFory() + if named then { + ScalaSerializers.registerEnum(fory, classOf[Color], "demo", "color") + registerStruct(fory, classOf[Item], "demo", "item") + registerStruct(fory, classOf[SimpleStruct], "demo", "simple_struct") + } else { + ScalaSerializers.registerEnum(fory, classOf[Color], 101L) + registerStruct(fory, classOf[Item], 102L) + registerStruct(fory, classOf[SimpleStruct], 103L) + } + fory + } + + private def evolvingOverrideFory(): Fory = { + val fory = newFory() + registerStruct(fory, classOf[EvolvingOverrideStruct], "test", "evolving_yes") + registerStruct(fory, classOf[FixedOverrideStruct], "test", "evolving_off") + fory + } + + private def itemFory(): Fory = { + val fory = newFory() + registerStruct(fory, classOf[Item], 102L) + fory + } + + private def integerFory(): Fory = { + val fory = newFory() + registerStruct(fory, classOf[Item1], 101L) + fory + } + + private def structWithUnionFory(): Fory = { + val fory = newFory() + registerStruct(fory, classOf[StructWithUnion2], 301L) + fory + } + + private def structWithListFory(): Fory = { + val fory = newFory() + registerStruct(fory, classOf[StructWithList], 201L) + fory + } + + private def structWithMapFory(): Fory = { + val fory = newFory() + registerStruct(fory, classOf[StructWithMap], 202L) + fory + } + + private def nestedAnnotatedSchemaFory(): Fory = { + val fory = newFory(compatible = false) + registerStruct(fory, classOf[NestedAnnotatedContainerSchemaConsistent], 801L) + fory + } + + private def nestedAnnotatedCompatibleFory(): Fory = { + val fory = newFory(compatible = true, noCompression = true) + registerStruct(fory, classOf[NestedAnnotatedContainerCompatible], 802L) + fory + } + + private def emptyWrapperFory(registration: Either[Long, (String, String)]): Fory = { val fory = newFory() - val request = fory.deserialize(Files.readAllBytes(dataFile)).asInstanceOf[ScalaPeerTarget] + registration match { + case Left(id) => + fory.register(classOf[MyExt], 103) + fory.registerSerializer(classOf[MyExt], classOf[MyExtSerializer]) + registerStruct(fory, classOf[EmptyWrapper], id) + case Right((namespace, typeName)) => + fory.register(classOf[MyExt], "my_ext") + fory.registerSerializer(classOf[MyExt], classOf[MyExtSerializer]) + registerStruct(fory, classOf[EmptyWrapper], namespace, typeName) + } + fory + } + + private def consistentNamedFory(): Fory = { + val fory = newFory(compatible = false, classVersionCheck = true) + ScalaSerializers.registerEnum(fory, classOf[Color], "", "color") + registerStruct(fory, classOf[MyStruct], "", "my_struct") + fory.register(classOf[MyExt], "my_ext") + fory.registerSerializer(classOf[MyExt], classOf[MyExtSerializer]) + fory + } + + private def versionCheckFory(): Fory = { + val fory = newFory(compatible = false, classVersionCheck = true) + registerStruct(fory, classOf[VersionCheckStruct], 201L) + fory + } + + private def polymorphicListFory(): Fory = { + val fory = newFory() + registerStruct(fory, classOf[Dog], 302L) + registerStruct(fory, classOf[Cat], 303L) + registerStruct(fory, classOf[AnimalListHolder], 304L) + fory + } + + private def polymorphicMapFory(): Fory = { + val fory = newFory() + registerStruct(fory, classOf[Dog], 302L) + registerStruct(fory, classOf[Cat], 303L) + registerStruct(fory, classOf[AnimalMapHolder], 305L) + fory + } + + private def oneStringSchemaFory(): Fory = { + val fory = newFory(compatible = false) + registerStruct(fory, classOf[OneStringFieldStruct], 200L) + fory + } + + private def oneStringCompatibleFory(): Fory = { + val fory = newFory() + registerStruct(fory, classOf[OneStringFieldStruct], 200L) + fory + } + + private def twoStringCompatibleFory(id: Long): Fory = { + val fory = newFory() + registerStruct(fory, classOf[TwoStringFieldStruct], id) + fory + } + + private def schemaEvolutionToEmpty(dataFile: Path): Unit = { + val inFory = twoStringCompatibleFory(200L) + val outFory = newFory() + registerStruct(outFory, classOf[EmptyStruct], 200L) + readOne[TwoStringFieldStruct](dataFile, inFory) + writeOne(dataFile, outFory, EmptyStruct()) + } + + private def schemaEvolutionToTwoString(dataFile: Path): Unit = { + val inFory = oneStringCompatibleFory() + val outFory = twoStringCompatibleFory(200L) + val value = readOne[OneStringFieldStruct](dataFile, inFory) + writeOne(dataFile, outFory, TwoStringFieldStruct(value.f1.orNull, "")) + } + + private def reducedPrecisionFory(compatible: Boolean): Fory = { + val fory = newFory(compatible = compatible, noCompression = compatible) + registerStruct(fory, classOf[ReducedPrecisionFloatStruct], 213L) + fory + } + + private def reducedPrecisionToEmpty(dataFile: Path): Unit = { + val inFory = reducedPrecisionFory(compatible = true) + val outFory = newFory(compatible = true, noCompression = true) + registerStruct(outFory, classOf[EmptyStruct], 213L) + readOne[ReducedPrecisionFloatStruct](dataFile, inFory) + writeOne(dataFile, outFory, EmptyStruct()) + } + + private def oneEnumSchemaFory(): Fory = { + val fory = newFory(compatible = false) + ScalaSerializers.registerEnum(fory, classOf[TestEnum], 210L) + registerStruct(fory, classOf[OneEnumFieldStruct], 211L) + fory + } + + private def oneEnumCompatibleFory(): Fory = { + val fory = newFory() + ScalaSerializers.registerEnum(fory, classOf[TestEnum], 210L) + registerStruct(fory, classOf[OneEnumFieldStruct], 211L) + fory + } + + private def twoEnumCompatibleFory(id: Long): Fory = { + val fory = newFory() + ScalaSerializers.registerEnum(fory, classOf[TestEnum], 210L) + registerStruct(fory, classOf[TwoEnumFieldStruct], id) + fory + } + + private def enumEvolutionToEmpty(dataFile: Path): Unit = { + val inFory = twoEnumCompatibleFory(211L) + val outFory = newFory() + ScalaSerializers.registerEnum(outFory, classOf[TestEnum], 210L) + registerStruct(outFory, classOf[EmptyStruct], 211L) + readOne[TwoEnumFieldStruct](dataFile, inFory) + writeOne(dataFile, outFory, EmptyStruct()) + } + + private def enumEvolutionToTwoEnum(dataFile: Path): Unit = { + val inFory = oneEnumCompatibleFory() + val outFory = twoEnumCompatibleFory(211L) + val value = readOne[OneEnumFieldStruct](dataFile, inFory) + writeOne(dataFile, outFory, TwoEnumFieldStruct(value.f1, TestEnum.VALUE_A)) + } + + private def nullableSchemaFory(): Fory = { + val fory = newFory(compatible = false) + registerStruct(fory, classOf[NullableComprehensiveSchemaConsistent], 401L) + fory + } + + private def nullableCompatibleFory(): Fory = { + val fory = newFory(compatible = true, noCompression = true) + registerStruct(fory, classOf[NullableComprehensiveCompatible], 402L) + fory + } + + private def nullableCompatibleNull(dataFile: Path): Unit = { + val fory = nullableCompatibleFory() + val value = readOne[NullableComprehensiveCompatible](dataFile, fory) + val defaults = value.copy( + nullableInt1 = Some(0), + nullableLong1 = Some(0L), + nullableFloat1 = Some(0.0f), + nullableDouble1 = Some(0.0d), + nullableBool1 = Some(false), + nullableString2 = Some(""), + nullableList2 = Some(new util.ArrayList[String]()), + nullableSet2 = Some(new util.HashSet[String]()), + nullableMap2 = Some(new util.HashMap[String, String]())) + writeOne(dataFile, fory, defaults) + } + + private def refSchemaFory(): Fory = { + val fory = newFory(compatible = false, refTracking = true) + registerStruct(fory, classOf[RefInnerSchemaConsistent], 501L) + registerStruct(fory, classOf[RefOuterSchemaConsistent], 502L) + fory + } + + private def refCompatibleFory(): Fory = { + val fory = newFory(compatible = true, refTracking = true, noCompression = true) + registerStruct(fory, classOf[RefInnerCompatible], 503L) + registerStruct(fory, classOf[RefOuterCompatible], 504L) + fory + } + + private def refOverrideFory(): Fory = { + val fory = newFory(compatible = false, refTracking = true) + registerStruct(fory, classOf[RefOverrideElement], 701L) + registerStruct(fory, classOf[RefOverrideContainer], 702L) + fory + } + + private def sharedRefOverrideContainer(): RefOverrideContainer = { + val element = RefOverrideElement(7, "shared_element") + val list = new util.ArrayList[RefOverrideElement]() + list.add(element) + list.add(element) + val set = new util.HashSet[RefOverrideElement]() + set.add(element) + val map = new util.HashMap[String, RefOverrideElement]() + map.put("k1", element) + map.put("k2", element) + RefOverrideContainer(list, set, map) + } + + private def collectionElementRefOverride(dataFile: Path): Unit = { + val fory = refOverrideFory() + readOne[RefOverrideContainer](dataFile, fory) + writeOne(dataFile, fory, sharedRefOverrideContainer()) + } + + private def collectionElementRefRemoteTracking(dataFile: Path): Unit = + writeOne(dataFile, refOverrideFory(), sharedRefOverrideContainer()) + + private def circularRefFory(id: Long, compatible: Boolean): Fory = { + val fory = newFory(compatible = compatible, refTracking = true, noCompression = compatible) + registerStruct(fory, classOf[CircularRefStruct], id) + fory + } + + private def unsignedSimpleFory(): Fory = { + val fory = newFory(compatible = false) + registerStruct(fory, classOf[UnsignedSchemaConsistentSimple], 1L) + fory + } + + private def unsignedSchemaFory(): Fory = { + val fory = newFory(compatible = false) + registerStruct(fory, classOf[UnsignedSchemaConsistent], 501L) + fory + } + + private def unsignedCompatibleFory(): Fory = { + val fory = newFory(compatible = true, noCompression = true) + registerStruct(fory, classOf[UnsignedSchemaCompatible], 502L) + fory + } + + private def int32ListFory(): Fory = { + val fory = newFory(compatible = true) + registerStruct(fory, classOf[XlangCompatibleInt32ListField], 901L) + fory + } + + private def nullableInt32ListFory(): Fory = { + val fory = newFory(compatible = true) + registerStruct(fory, classOf[XlangCompatibleNullableInt32ListField], 901L) + fory + } + + private def int32ArrayFory(): Fory = { + val fory = newFory(compatible = true) + registerStruct(fory, classOf[XlangCompatibleInt32ArrayField], 901L) + fory + } + + private def listArrayListToArray(dataFile: Path): Unit = { + val value = readOne[XlangCompatibleInt32ListField](dataFile, int32ListFory()) + val values = new Array[Int](value.values.size()) + var i = 0 + while i < value.values.size() do { + values(i) = value.values.get(i) + i += 1 + } + writeOne(dataFile, int32ArrayFory(), XlangCompatibleInt32ArrayField(values)) + } + + private def listArrayArrayToList(dataFile: Path): Unit = { + val value = readOne[XlangCompatibleInt32ArrayField](dataFile, int32ArrayFory()) + writeOne(dataFile, int32ListFory(), XlangCompatibleInt32ListField(new Int32List(value.values))) + } + + private def roundTripUser(dataFile: Path): Unit = { + val fory = scalaPeerFory() + val request = readOne[ScalaPeerUser](dataFile, fory) + writeOne(dataFile, fory, request.copy(id = request.id + 1, name = "scala-" + request.name, email = None)) + } + + private def roundTripTarget(dataFile: Path, preserveUnknownCase: Boolean): Unit = { + val fory = scalaPeerFory() + val request = readOne[ScalaPeerTarget](dataFile, fory) val response = request match { case ScalaPeerTarget.UserCase(user) => - ScalaPeerTarget.UserCase( - user.copy(id = user.id + 1, name = "scala-" + user.name, email = None)) + ScalaPeerTarget.UserCase(user.copy(id = user.id + 1, name = "scala-" + user.name, email = None)) case ScalaPeerTarget.UnknownCase(caseId, value: ScalaPeerUser) if preserveUnknownCase => ScalaPeerTarget.UnknownCase( caseId, @@ -81,10 +896,10 @@ object ScalaXlangPeer { case ScalaPeerTarget.UnknownCase(caseId, value) => ScalaPeerTarget.UnknownCase(caseId, value) } - Files.write(dataFile, fory.serialize(response)) + writeOne(dataFile, fory, response) } - private def newFory(): Fory = { + private def scalaPeerFory(): Fory = { val fory = Fory.builder() .withLanguage(Language.XLANG) .withCompatible(true) @@ -92,8 +907,8 @@ object ScalaXlangPeer { .requireClassRegistration(true) .build() ScalaSerializers.registerSerializers(fory) - ForySerializer.register(fory, classOf[ScalaPeerUser], Namespace, "ScalaPeerUser") - ForySerializer.register(fory, classOf[ScalaPeerTarget], Namespace, "ScalaPeerTarget") + ForySerializer.register(fory, classOf[ScalaPeerUser], ScalaPeerNamespace, "ScalaPeerUser") + ForySerializer.register(fory, classOf[ScalaPeerTarget], ScalaPeerNamespace, "ScalaPeerTarget") fory } } From cff4a415c4ea4a593331dd06f155e2090e1d5b73 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 15 May 2026 08:26:39 +0800 Subject: [PATCH 4/9] fix(jvm): preserve xlang metadata in ci --- compiler/fory_compiler/generators/scala.py | 54 ++++-- compiler/fory_compiler/ir/construction.py | 5 +- .../tests/test_scala_generator.py | 2 +- .../java/org/apache/fory/meta/FieldTypes.java | 2 +- .../StaticGeneratedStructSerializer.java | 52 +++--- .../fory/serializer/UnionSerializer.java | 158 +++++++++--------- .../apache/fory/type/TypeAnnotationUtils.java | 6 +- .../kotlin/ksp/ForyKotlinSymbolProcessor.kt | 20 ++- .../scala/ScalaXlangSerializerTest.scala | 2 +- 9 files changed, 172 insertions(+), 129 deletions(-) diff --git a/compiler/fory_compiler/generators/scala.py b/compiler/fory_compiler/generators/scala.py index 6211014c02..76984653e2 100644 --- a/compiler/fory_compiler/generators/scala.py +++ b/compiler/fory_compiler/generators/scala.py @@ -222,14 +222,18 @@ def _collect_imported_registrations(self) -> List[tuple[str, str]]: if self.schema.source_file: base_dir = Path(self.schema.source_file).resolve().parent for imp in self.schema.imports: - candidate = self._normalize_import_path(str((base_dir / imp.path).resolve())) + candidate = self._normalize_import_path( + str((base_dir / imp.path).resolve()) + ) schema = self._load_schema(candidate) if schema is None: continue package = self._scala_package_for_schema(schema) if not package or package in used: continue - ordered.append((package, self._registration_class_name_for_schema(schema))) + ordered.append( + (package, self._registration_class_name_for_schema(schema)) + ) used.add(package) for package, registration in sorted(packages.items()): if package not in used: @@ -255,7 +259,10 @@ def generate(self) -> List[GeneratedFile]: def generate_enum_file(self, enum: Enum) -> GeneratedFile: lines = self.source_header( - {"org.apache.fory.annotation.ForyEnumId", "org.apache.fory.scala.ForyScalaEnum"} + { + "org.apache.fory.annotation.ForyEnumId", + "org.apache.fory.scala.ForyScalaEnum", + } ) comment = self.format_type_id_comment(enum, "//") if comment: @@ -313,7 +320,9 @@ def generate_enum(self, enum: Enum, indent: int = 0) -> List[str]: case_name = self.safe_identifier( self.to_pascal_case(self.strip_enum_prefix(enum.name, value.name)) ) - lines.append(f"{ind} case {case_name} extends {enum.name}({value.value})") + lines.append( + f"{ind} case {case_name} extends {enum.name}({value.value})" + ) lines.append("") lines.append(f"{ind} @ForyEnumId") lines.append(f"{ind} def getForyId: Int = foryId") @@ -328,7 +337,10 @@ def generate_union( parent_stack: Optional[List[Message]] = None, ) -> List[str]: ind = self.indent_str * indent - lines = [f"{ind}@ForyUnion", f"{ind}enum {union.name} derives ForySerializer {{"] + lines = [ + f"{ind}@ForyUnion", + f"{ind}enum {union.name} derives ForySerializer {{", + ] lines.append(f"{ind} @ForyCase(id = 0)") lines.append(f"{ind} case UnknownCase(caseId: Int, value: Any)") lines.append("") @@ -355,9 +367,10 @@ def generate_message( indent: int = 0, parent_stack: Optional[List[Message]] = None, ) -> List[str]: - if self._construction_shapes.get(message.name, None) and self._construction_shapes[ - message.name - ].cycle_owned: + if ( + self._construction_shapes.get(message.name, None) + and self._construction_shapes[message.name].cycle_owned + ): return self.generate_normal_class(message, indent, parent_stack) return self.generate_case_class(message, indent, parent_stack) @@ -388,7 +401,10 @@ def generate_normal_class( ) -> List[str]: ind = self.indent_str * indent current_stack = self.current_stack(parent_stack, message) - lines = [f"{ind}@ForyStruct", f"{ind}final class {message.name}() derives ForySerializer {{"] + lines = [ + f"{ind}@ForyStruct", + f"{ind}final class {message.name}() derives ForySerializer {{", + ] for field in message.fields: field_type = self.generate_type( field.field_type, @@ -530,7 +546,9 @@ def resolve_scala_type_name( for index in range(len(parent_stack) - 1, -1, -1): owner = parent_stack[index] if owner.get_nested_type(name) is not None: - return ".".join([message.name for message in parent_stack[: index + 1]] + [name]) + return ".".join( + [message.name for message in parent_stack[: index + 1]] + [name] + ) return name def apply_type_annotation(self, scala_type: str, annotation: str) -> str: @@ -673,9 +691,7 @@ def collect_type_imports(self, field_type: FieldType, imports: Set[str]) -> None self.collect_type_imports(field_type.key_type, imports) self.collect_type_imports(field_type.value_type, imports) - def collect_integer_imports( - self, field_type: FieldType, imports: Set[str] - ) -> None: + def collect_integer_imports(self, field_type: FieldType, imports: Set[str]) -> None: if not isinstance(field_type, PrimitiveType): return kind = field_type.kind @@ -721,9 +737,13 @@ def collect_array_element_imports( def field_type_has_ref(self, field_type: FieldType) -> bool: if isinstance(field_type, ListType): - return field_type.element_ref or self.field_type_has_ref(field_type.element_type) + return field_type.element_ref or self.field_type_has_ref( + field_type.element_type + ) if isinstance(field_type, MapType): - return field_type.value_ref or self.field_type_has_ref(field_type.value_type) + return field_type.value_ref or self.field_type_has_ref( + field_type.value_type + ) return False def is_ref_target_type(self, field_type: FieldType) -> bool: @@ -756,7 +776,9 @@ def generate_registration_file(self) -> GeneratedFile: lines.append(" register(fory)") lines.append(" })") else: - lines.append(" runtime.registerCallback((fory: Fory) => register(fory))") + lines.append( + " runtime.registerCallback((fory: Fory) => register(fory))" + ) lines.append(" runtime") lines.append(" }") lines.append("") diff --git a/compiler/fory_compiler/ir/construction.py b/compiler/fory_compiler/ir/construction.py index 5612245d76..16aafc1ea8 100644 --- a/compiler/fory_compiler/ir/construction.py +++ b/compiler/fory_compiler/ir/construction.py @@ -52,7 +52,10 @@ def analyze_message_construction_shapes( """ messages = {message.name: message for message in schema.messages} - graph = {name: set(_message_dependencies(message, messages)) for name, message in messages.items()} + graph = { + name: set(_message_dependencies(message, messages)) + for name, message in messages.items() + } cycle_owned = _cycle_nodes(graph) return { name: MessageConstructionShape(cycle_owned=name in cycle_owned) diff --git a/compiler/fory_compiler/tests/test_scala_generator.py b/compiler/fory_compiler/tests/test_scala_generator.py index d072aef89c..40a5470da3 100644 --- a/compiler/fory_compiler/tests/test_scala_generator.py +++ b/compiler/fory_compiler/tests/test_scala_generator.py @@ -94,7 +94,7 @@ def test_scala_generator_uses_mutable_normal_class_for_construction_cycles(): node = files["graph/Node.scala"] assert "final class Node() derives ForySerializer" in node - assert "var id: String = \"\"" in node + assert 'var id: String = ""' in node assert "var parent: Option[Node @Ref] = None" in node diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java index ef6eb51c49..7fc0b46867 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java @@ -372,7 +372,7 @@ private static int primitiveArrayTypeIdFromComponentMeta(TypeRef typeRef) { if (componentMeta == null || componentMeta.typeId() == Types.UNKNOWN) { return Types.UNKNOWN; } - return TypeAnnotationUtils.getArrayTypeIdFromElementType(componentType); + return TypeAnnotationUtils.getArrayTypeIdFromDenseElementType(componentType); } public abstract static class FieldType implements Serializable { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java index 14fdc817f2..0fbf7cf22d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java @@ -186,6 +186,21 @@ protected final void writeContainerFieldValue( fieldValue); } + private static void writeContainerFieldValue( + TypeResolver typeResolver, + WriteContext writeContext, + SerializationFieldInfo fieldInfo, + Object fieldValue) { + AbstractObjectSerializer.writeContainerFieldValue( + writeContext, + typeResolver, + writeContext.getRefWriter(), + writeContext.getGenerics(), + fieldInfo, + writeContext.getBuffer(), + fieldValue); + } + protected final void writeOtherFieldValue( WriteContext writeContext, SerializationFieldInfo fieldInfo, Object fieldValue) { AbstractObjectSerializer.writeField( @@ -240,21 +255,6 @@ public static void writeFieldValue( } } - private static void writeContainerFieldValue( - TypeResolver typeResolver, - WriteContext writeContext, - SerializationFieldInfo fieldInfo, - Object fieldValue) { - AbstractObjectSerializer.writeContainerFieldValue( - writeContext, - typeResolver, - writeContext.getRefWriter(), - writeContext.getGenerics(), - fieldInfo, - writeContext.getBuffer(), - fieldValue); - } - protected final Object readBuildInFieldValue( ReadContext readContext, SerializationFieldInfo fieldInfo) { // See writeBuildInFieldValue: built-in schema groups can still need container conversion. @@ -276,6 +276,17 @@ protected final Object readContainerFieldValue( readContext.getBuffer()); } + private static Object readContainerFieldValue( + TypeResolver typeResolver, ReadContext readContext, SerializationFieldInfo fieldInfo) { + return AbstractObjectSerializer.readContainerFieldValue( + readContext, + typeResolver, + readContext.getRefReader(), + readContext.getGenerics(), + fieldInfo, + readContext.getBuffer()); + } + protected final Object readOtherFieldValue( ReadContext readContext, SerializationFieldInfo fieldInfo) { return AbstractObjectSerializer.readField( @@ -314,17 +325,6 @@ public static Object readFieldValue( } } - private static Object readContainerFieldValue( - TypeResolver typeResolver, ReadContext readContext, SerializationFieldInfo fieldInfo) { - return AbstractObjectSerializer.readContainerFieldValue( - readContext, - typeResolver, - readContext.getRefReader(), - readContext.getGenerics(), - fieldInfo, - readContext.getBuffer()); - } - protected final Object readRemoteField(ReadContext readContext, RemoteFieldInfo remoteField) { if (remoteField.compatibleCollectionArrayReadAction != null) { return CompatibleCollectionArrayReader.read( diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java index 817f9106b6..991b5955b5 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java @@ -213,6 +213,58 @@ public static void writeCaseValue( writeKnownCasePayload(resolver, writeContext, fieldInfo, value); } + private void writeCaseValue(WriteContext writeContext, Object value, int typeId, int caseId) { + MemoryBuffer buffer = writeContext.getBuffer(); + byte internalTypeId = (byte) typeId; + boolean primitiveArray = Types.isPrimitiveArray(internalTypeId); + Serializer serializer; + TypeInfo typeInfo; + if (value == null) { + buffer.writeByte(Fory.NULL_FLAG); + return; + } + typeInfo = getFinalCaseTypeInfo(caseId); + if (typeInfo == null) { + Preconditions.checkArgument(!primitiveArray); + if (!Types.isUserDefinedType(internalTypeId)) { + typeInfo = resolver.getTypeInfoByTypeId(internalTypeId); + } else { + typeInfo = resolver.getTypeInfo(value.getClass()); + } + } + Preconditions.checkArgument(typeInfo != null); + serializer = getCaseSerializer(caseId, typeId, typeInfo); + if (serializer != null && serializer.needToWriteRef()) { + if (writeContext.writeRefOrNull(value)) { + return; + } + } else { + buffer.writeByte(Fory.NOT_NULL_VALUE_FLAG); + } + if (!Types.isUserDefinedType(internalTypeId)) { + buffer.writeUInt8(typeId); + } else { + resolver.writeTypeInfo(writeContext, typeInfo); + } + writeValue(writeContext, value, typeId, serializer, getCaseGenericType(caseId, typeId)); + } + + private static void writeCaseValue( + WriteContext writeContext, Serializer serializer, GenericType genericType, Object value) { + if (genericType == null) { + Serializers.write(writeContext, serializer, value); + return; + } + writeContext.getGenerics().pushGenericType(genericType, writeContext.getDepth()); + writeContext.increaseDepth(); + try { + Serializers.write(writeContext, serializer, value); + } finally { + writeContext.decreaseDepth(); + writeContext.getGenerics().popGenericType(writeContext.getDepth()); + } + } + /** * Writes an unknown union case payload using dynamic value metadata. * @@ -250,6 +302,21 @@ public static Object readCaseValue( return readContext.getReadRef(); } + private static Object readCaseValue( + ReadContext readContext, Serializer serializer, GenericType genericType) { + if (genericType == null) { + return Serializers.read(readContext, serializer); + } + readContext.getGenerics().pushGenericType(genericType, readContext.getDepth()); + readContext.increaseDepth(); + try { + return Serializers.read(readContext, serializer); + } finally { + readContext.decreaseDepth(); + readContext.getGenerics().popGenericType(readContext.getDepth()); + } + } + private static void writeKnownCasePayload( TypeResolver resolver, WriteContext writeContext, @@ -323,6 +390,14 @@ private static Serializer getCaseSerializer( return fallbackTypeInfo.getSerializer(); } + private Serializer getCaseSerializer(int caseId, int typeId, TypeInfo fallbackTypeInfo) { + Serializer serializer = finalCaseSerializers.get(caseId); + if (serializer != null && typeId == Types.LIST) { + return serializer; + } + return fallbackTypeInfo.getSerializer(); + } + private static GenericType getCaseGenericType( FieldGroups.SerializationFieldInfo fieldInfo, int typeId) { if (typeId != Types.LIST && typeId != Types.SET && typeId != Types.MAP) { @@ -331,56 +406,11 @@ private static GenericType getCaseGenericType( return fieldInfo.genericType; } - private void writeCaseValue(WriteContext writeContext, Object value, int typeId, int caseId) { - MemoryBuffer buffer = writeContext.getBuffer(); - byte internalTypeId = (byte) typeId; - boolean primitiveArray = Types.isPrimitiveArray(internalTypeId); - Serializer serializer; - TypeInfo typeInfo; - if (value == null) { - buffer.writeByte(Fory.NULL_FLAG); - return; - } - typeInfo = getFinalCaseTypeInfo(caseId); - if (typeInfo == null) { - Preconditions.checkArgument(!primitiveArray); - if (!Types.isUserDefinedType(internalTypeId)) { - typeInfo = resolver.getTypeInfoByTypeId(internalTypeId); - } else { - typeInfo = resolver.getTypeInfo(value.getClass()); - } - } - Preconditions.checkArgument(typeInfo != null); - serializer = getCaseSerializer(caseId, typeId, typeInfo); - if (serializer != null && serializer.needToWriteRef()) { - if (writeContext.writeRefOrNull(value)) { - return; - } - } else { - buffer.writeByte(Fory.NOT_NULL_VALUE_FLAG); - } - if (!Types.isUserDefinedType(internalTypeId)) { - buffer.writeUInt8(typeId); - } else { - resolver.writeTypeInfo(writeContext, typeInfo); - } - writeValue(writeContext, value, typeId, serializer, getCaseGenericType(caseId, typeId)); - } - - private static void writeCaseValue( - WriteContext writeContext, Serializer serializer, GenericType genericType, Object value) { - if (genericType == null) { - Serializers.write(writeContext, serializer, value); - return; - } - writeContext.getGenerics().pushGenericType(genericType, writeContext.getDepth()); - writeContext.increaseDepth(); - try { - Serializers.write(writeContext, serializer, value); - } finally { - writeContext.decreaseDepth(); - writeContext.getGenerics().popGenericType(writeContext.getDepth()); + private GenericType getCaseGenericType(int caseId, int typeId) { + if (typeId != Types.LIST && typeId != Types.MAP) { + return null; } + return finalCaseGenericTypes.get(caseId); } private static void writeValue( @@ -451,36 +481,6 @@ private static void writeValue( throw new IllegalStateException("Missing serializer for union type id " + typeId); } - private static Object readCaseValue( - ReadContext readContext, Serializer serializer, GenericType genericType) { - if (genericType == null) { - return Serializers.read(readContext, serializer); - } - readContext.getGenerics().pushGenericType(genericType, readContext.getDepth()); - readContext.increaseDepth(); - try { - return Serializers.read(readContext, serializer); - } finally { - readContext.decreaseDepth(); - readContext.getGenerics().popGenericType(readContext.getDepth()); - } - } - - private Serializer getCaseSerializer(int caseId, int typeId, TypeInfo fallbackTypeInfo) { - Serializer serializer = finalCaseSerializers.get(caseId); - if (serializer != null && typeId == Types.LIST) { - return serializer; - } - return fallbackTypeInfo.getSerializer(); - } - - private GenericType getCaseGenericType(int caseId, int typeId) { - if (typeId != Types.LIST && typeId != Types.MAP) { - return null; - } - return finalCaseGenericTypes.get(caseId); - } - private TypeInfo getFinalCaseTypeInfo(int caseId) { if (!finalCaseSerializersResolved) { resolveFinalCaseTypeInfo(); diff --git a/java/fory-core/src/main/java/org/apache/fory/type/TypeAnnotationUtils.java b/java/fory-core/src/main/java/org/apache/fory/type/TypeAnnotationUtils.java index a296209671..3b37ba7616 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/TypeAnnotationUtils.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/TypeAnnotationUtils.java @@ -345,7 +345,7 @@ public static int getBoxedListArrayTypeId(Descriptor descriptor) { // Fieldless static descriptors already carry the dense array contract on the parent // descriptor. Their element TypeExtMeta is the dense element domain, not a Java source-level // scalar encoding annotation. - int typeId = getArrayTypeIdFromElementType(elementTypeRef, true); + int typeId = getArrayTypeIdFromDenseElementType(elementTypeRef); if (typeId == Types.UNKNOWN) { throw new IllegalArgumentException( "@ArrayType List field " @@ -391,6 +391,10 @@ private static int getArrayTypeIdFromElementType( return Types.UNKNOWN; } + public static int getArrayTypeIdFromDenseElementType(TypeRef elementTypeRef) { + return getArrayTypeIdFromElementType(elementTypeRef, true); + } + public static int getArrayTypeIdFromElementTypeId(int elementTypeId) { switch (elementTypeId) { case Types.BOOL: diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt index 02baf96533..afa0c9cb7c 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt @@ -33,6 +33,7 @@ import com.google.devtools.ksp.symbol.KSDeclaration import com.google.devtools.ksp.symbol.KSPropertyDeclaration import com.google.devtools.ksp.symbol.KSType import com.google.devtools.ksp.symbol.KSTypeArgument +import com.google.devtools.ksp.symbol.KSValueParameter import com.google.devtools.ksp.symbol.Modifier import com.google.devtools.ksp.symbol.Nullability import java.nio.charset.StandardCharsets @@ -212,7 +213,7 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso ) return null } - val fieldMeta = resolveForyField(property) ?: return null + val fieldMeta = resolveForyField(property, parameter) ?: return null if (fieldMeta.id < -1) { logger.error("@ForyField id must be -1 or a non-negative value", property) return null @@ -333,8 +334,12 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso ) } - private fun resolveForyField(property: KSPropertyDeclaration): ForyFieldMeta? { + private fun resolveForyField( + property: KSPropertyDeclaration, + parameter: KSValueParameter, + ): ForyFieldMeta? { val propertyMeta = foryFieldMeta(property.annotations) + val parameterMeta = foryFieldMeta(parameter.annotations) val getterHasFory = property.getter?.annotations?.any { isAnnotation(it, FORY_FIELD) } == true val setterHasFory = property.setter?.annotations?.any { isAnnotation(it, FORY_FIELD) } == true if (getterHasFory || setterHasFory) { @@ -344,7 +349,16 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso ) return null } - return propertyMeta ?: ForyFieldMeta.NONE + if (propertyMeta != null && parameterMeta != null && propertyMeta != parameterMeta) { + logger.error( + "@ForyField metadata on Kotlin property and constructor parameter must match", + property, + ) + return null + } + // Java annotations on primary-constructor properties commonly land on the constructor + // parameter when PARAMETER is an allowed target; KSP must preserve schema IDs from that site. + return propertyMeta ?: parameterMeta ?: ForyFieldMeta.NONE } private fun hasFieldAnnotation(property: KSPropertyDeclaration, qualifiedName: String): Boolean { diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala index 0e414613a5..8769df203b 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala @@ -23,7 +23,7 @@ import org.apache.fory.Fory import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import scala.jdk.CollectionConverters.* +import scala.jdk.CollectionConverters._ class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { def fory: Fory = { From a320059b1f7fa698f92eec834614673e8906bbad Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 15 May 2026 09:02:23 +0800 Subject: [PATCH 5/9] fix(scala): add idl sbt license header --- .../idl_tests/scala/project/build.properties | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/integration_tests/idl_tests/scala/project/build.properties b/integration_tests/idl_tests/scala/project/build.properties index 04267b14af..c12725b419 100644 --- a/integration_tests/idl_tests/scala/project/build.properties +++ b/integration_tests/idl_tests/scala/project/build.properties @@ -1 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + sbt.version=1.9.9 From 8e64c37ec629e353b13be7db4c379e374b1addb0 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 15 May 2026 12:38:17 +0800 Subject: [PATCH 6/9] fix(scala): harden xlang idl cycles --- compiler/fory_compiler/generators/java.py | 218 ++++++++++---- compiler/fory_compiler/generators/scala.py | 53 ++-- compiler/fory_compiler/ir/construction.py | 105 +++++-- .../tests/test_generated_code.py | 74 +++++ .../tests/test_package_options.py | 23 ++ .../tests/test_scala_generator.py | 202 +++++++++++++ docs/compiler/schema-idl.md | 34 ++- docs/guide/scala/schema-idl.md | 11 + docs/specification/xlang_type_mapping.md | 48 +-- integration_tests/idl_tests/generate_idl.py | 20 +- .../idl_tests/idl/nested_name.fdl | 42 +++ .../fory/idl_tests/IdlRoundTripTest.java | 77 +++++ .../idl_tests/ScalaIdlRoundTripPeer.scala | 3 + .../idl_tests/ScalaIdlRoundTripTest.scala | 34 +++ .../org/apache/fory/context/CopyContext.java | 32 +- .../StaticGeneratedStructSerializer.java | 12 +- .../fory/serializer/UnionSerializer.java | 6 + .../scala/internal/ForySerializerMacros.scala | 279 +++++++++++++----- .../scala/XlangCollectionSerializer.scala | 132 ++++++++- .../scala/ForySerializerDerivationTest.scala | 139 +++++++++ .../scala/ScalaXlangSerializerTest.scala | 60 ++++ 21 files changed, 1402 insertions(+), 202 deletions(-) create mode 100644 integration_tests/idl_tests/idl/nested_name.fdl diff --git a/compiler/fory_compiler/generators/java.py b/compiler/fory_compiler/generators/java.py index ff6190d421..e55847ce75 100644 --- a/compiler/fory_compiler/generators/java.py +++ b/compiler/fory_compiler/generators/java.py @@ -465,7 +465,7 @@ def generate_message_file(self, message: Message) -> GeneratedFile: # Fields for field in message.fields: - field_lines = self.generate_field(field) + field_lines = self.generate_field(field, [message]) for line in field_lines: lines.append(f" {line}") @@ -478,7 +478,7 @@ def generate_message_file(self, message: Message) -> GeneratedFile: # Getters and setters for field in message.fields: - getter_setter = self.generate_getter_setter(field) + getter_setter = self.generate_getter_setter(field, [message]) for line in getter_setter: lines.append(f" {line}") @@ -588,10 +588,16 @@ def generate_outer_class_file(self, outer_classname: str) -> GeneratedFile: return GeneratedFile(path=path, content="\n".join(lines)) - def collect_message_imports(self, message: Message, imports: Set[str]): + def collect_message_imports( + self, + message: Message, + imports: Set[str], + parent_stack: Optional[List[Message]] = None, + ): """Collect imports for a message and all its nested types recursively.""" + lineage = (parent_stack or []) + [message] for field in message.fields: - self.collect_field_imports(field, imports) + self.collect_field_imports(field, imports, lineage) imports.add("org.apache.fory.annotation.ForyStruct") if not self.get_effective_evolving(message): @@ -607,15 +613,20 @@ def collect_message_imports(self, message: Message, imports: Set[str]): # Collect imports from nested messages for nested_msg in message.nested_messages: - self.collect_message_imports(nested_msg, imports) + self.collect_message_imports(nested_msg, imports, lineage) for nested_union in message.nested_unions: - self.collect_union_imports(nested_union, imports) + self.collect_union_imports(nested_union, imports, lineage) def collect_enum_imports(self, imports: Set[str]): """Collect imports required by generated Java enums.""" imports.add("org.apache.fory.annotation.ForyEnumId") - def collect_union_imports(self, union: Union, imports: Set[str]): + def collect_union_imports( + self, + union: Union, + imports: Set[str], + parent_stack: Optional[List[Message]] = None, + ): """Collect imports for a union and its cases.""" imports.add("org.apache.fory.type.union.Union") imports.add("org.apache.fory.type.Types") @@ -631,6 +642,7 @@ def collect_union_imports(self, union: Union, imports: Set[str]): field.element_optional, field.element_ref, field, + parent_stack=parent_stack, ) def has_array_field_recursive(self, message: Message) -> bool: @@ -743,7 +755,7 @@ def generate_union_class( for field in union.fields: case_name = self.to_pascal_case(field.name) case_enum_name = self.to_upper_snake_case(field.name) - case_type = self.get_union_case_type(field) + case_type = self.get_union_case_type(field, parent_stack) lines.append( f"{ind} public static {union.name} of{case_name}({case_type} v) {{" ) @@ -774,8 +786,8 @@ def generate_union_class( for field in union.fields: case_name = self.to_pascal_case(field.name) case_enum_name = self.to_upper_snake_case(field.name) - case_type = self.get_union_case_type(field) - cast_type = self.get_union_case_cast_type(field) + case_type = self.get_union_case_type(field, parent_stack) + cast_type = self.get_union_case_cast_type(field, parent_stack) wrap_array_type: Optional[str] = None wrap_list_type: Optional[str] = None if ( @@ -872,7 +884,9 @@ def generate_union_class( lines.append("") return lines - def get_union_case_type(self, field: Field) -> str: + def get_union_case_type( + self, field: Field, parent_stack: Optional[List[Message]] = None + ) -> str: """Return the Java type for a union case.""" return self.generate_type( field.field_type, @@ -880,16 +894,19 @@ def get_union_case_type(self, field: Field) -> str: field.element_optional, field.element_ref, field, + parent_stack=parent_stack, ) - def get_union_case_cast_type(self, field: Field) -> str: + def get_union_case_cast_type( + self, field: Field, parent_stack: Optional[List[Message]] = None + ) -> str: """Return the Java cast type for a union case value.""" if isinstance(field.field_type, PrimitiveType): boxed = self.BOXED_MAP.get(field.field_type.kind) if boxed is not None: return boxed return self.PRIMITIVE_MAP[field.field_type.kind] - return self.get_union_case_type(field) + return self.get_union_case_type(field, parent_stack) def get_union_case_type_id_expr( self, field: Field, parent_stack: Optional[List[Message]] @@ -1076,7 +1093,7 @@ def generate_nested_message( # Fields for field in message.fields: - field_lines = self.generate_field(field) + field_lines = self.generate_field(field, lineage) for line in field_lines: lines.append(f" {line}") @@ -1089,7 +1106,7 @@ def generate_nested_message( # Getters and setters for field in message.fields: - getter_setter = self.generate_getter_setter(field) + getter_setter = self.generate_getter_setter(field, lineage) for line in getter_setter: lines.append(f" {line}") @@ -1117,7 +1134,9 @@ def generate_nested_message( lines.append("") return lines - def generate_field(self, field: Field) -> List[str]: + def generate_field( + self, field: Field, parent_stack: Optional[List[Message]] + ) -> List[str]: """Generate field declaration with annotations.""" lines = [] @@ -1157,6 +1176,7 @@ def generate_field(self, field: Field) -> List[str]: field.element_ref, field, type_use=use_type_annotation, + parent_stack=parent_stack, ) if nullable: java_type = self.apply_top_level_type_use_annotation(java_type, "@Nullable") @@ -1166,7 +1186,9 @@ def generate_field(self, field: Field) -> List[str]: return lines - def generate_getter_setter(self, field: Field) -> List[str]: + def generate_getter_setter( + self, field: Field, parent_stack: Optional[List[Message]] + ) -> List[str]: """Generate getter and setter for a field.""" lines = [] is_any = ( @@ -1180,6 +1202,7 @@ def generate_getter_setter(self, field: Field) -> List[str]: field.element_optional, field.element_ref, field, + parent_stack=parent_stack, ) field_name = self.to_camel_case(field.name) pascal_name = self.to_pascal_case(field.name) @@ -1206,6 +1229,7 @@ def generate_type( element_ref: bool = False, field: Optional[Field] = None, type_use: bool = False, + parent_stack: Optional[List[Message]] = None, ) -> str: """Generate Java type string.""" if isinstance(field_type, PrimitiveType): @@ -1220,6 +1244,13 @@ def generate_type( return java_type elif isinstance(field_type, NamedType): + if "." not in field_type.name and parent_stack: + for index in range(len(parent_stack), 0, -1): + if ( + parent_stack[index - 1].get_nested_type(field_type.name) + is not None + ): + return field_type.name named_type = self.schema.get_type(field_type.name) if named_type is not None and self.is_imported_type(named_type): java_package = self._java_package_for_type(named_type) @@ -1255,9 +1286,12 @@ def generate_type( ) return java_type element_type = self.generate_type( - field_type.element_type, True, type_use=True + field_type.element_type, + True, + type_use=True, + parent_stack=parent_stack, ) - if self.is_ref_target_type(field_type.element_type): + if self.is_ref_target_type(field_type.element_type, parent_stack): ref_annotation = "@Ref" if child_ref else "@Ref(enable=false)" element_type = f"{ref_annotation} {element_type}" return f"List<{element_type}>" @@ -1274,9 +1308,19 @@ def generate_type( return java_type elif isinstance(field_type, MapType): - key_type = self.generate_type(field_type.key_type, True, type_use=True) - value_type = self.generate_type(field_type.value_type, True, type_use=True) - if self.is_ref_target_type(field_type.value_type): + key_type = self.generate_type( + field_type.key_type, + True, + type_use=True, + parent_stack=parent_stack, + ) + value_type = self.generate_type( + field_type.value_type, + True, + type_use=True, + parent_stack=parent_stack, + ) + if self.is_ref_target_type(field_type.value_type, parent_stack): ref_annotation = ( "@Ref" if field_type.value_ref else "@Ref(enable=false)" ) @@ -1293,6 +1337,7 @@ def collect_type_imports( element_ref: bool = False, field: Optional[Field] = None, type_use: bool = False, + parent_stack: Optional[List[Message]] = None, ): """Collect required imports for a field type.""" if isinstance(field_type, PrimitiveType): @@ -1330,12 +1375,21 @@ def collect_type_imports( PrimitiveKind.FLOAT16, PrimitiveKind.BFLOAT16, ): - self.collect_type_imports(field_type.element_type, imports) + self.collect_type_imports( + field_type.element_type, + imports, + parent_stack=parent_stack, + ) return imports.add("java.util.List") - if self.is_ref_target_type(field_type.element_type): + if self.is_ref_target_type(field_type.element_type, parent_stack): imports.add("org.apache.fory.annotation.Ref") - self.collect_type_imports(field_type.element_type, imports, type_use=True) + self.collect_type_imports( + field_type.element_type, + imports, + type_use=True, + parent_stack=parent_stack, + ) elif isinstance(field_type, ArrayType): kind = field_type.element_type.kind @@ -1348,17 +1402,35 @@ def collect_type_imports( self.collect_array_type_use_imports(field_type, imports, type_use) if kind not in (PrimitiveKind.FLOAT16, PrimitiveKind.BFLOAT16): self.collect_type_imports( - field_type.element_type, imports, type_use=True + field_type.element_type, + imports, + type_use=True, + parent_stack=parent_stack, ) elif isinstance(field_type, MapType): imports.add("java.util.Map") - if self.is_ref_target_type(field_type.value_type): + if self.is_ref_target_type(field_type.value_type, parent_stack): imports.add("org.apache.fory.annotation.Ref") - self.collect_type_imports(field_type.key_type, imports, type_use=True) - self.collect_type_imports(field_type.value_type, imports, type_use=True) + self.collect_type_imports( + field_type.key_type, + imports, + type_use=True, + parent_stack=parent_stack, + ) + self.collect_type_imports( + field_type.value_type, + imports, + type_use=True, + parent_stack=parent_stack, + ) - def collect_field_imports(self, field: Field, imports: Set[str]): + def collect_field_imports( + self, + field: Field, + imports: Set[str], + parent_stack: Optional[List[Message]] = None, + ): """Collect imports for a field, including list modifiers.""" is_any = ( isinstance(field.field_type, PrimitiveType) @@ -1371,6 +1443,7 @@ def collect_field_imports(self, field: Field, imports: Set[str]): field.element_optional, field.element_ref, field, + parent_stack=parent_stack, ) self.collect_integer_imports(field.field_type, imports) self.collect_array_imports(field, imports) @@ -1379,11 +1452,18 @@ def collect_field_imports(self, field: Field, imports: Set[str]): if field.ref or field.tag_id is not None: imports.add("org.apache.fory.annotation.ForyField") - def is_ref_target_type(self, field_type: FieldType) -> bool: + def is_ref_target_type( + self, field_type: FieldType, parent_stack: Optional[List[Message]] = None + ) -> bool: if not isinstance(field_type, NamedType): return False - resolved = self.schema.get_type(field_type.name) - return isinstance(resolved, (Message, Union)) + if "." in field_type.name or not parent_stack: + return isinstance(self.schema.get_type(field_type.name), (Message, Union)) + for index in range(len(parent_stack), 0, -1): + resolved = parent_stack[index - 1].get_nested_type(field_type.name) + if resolved is not None: + return isinstance(resolved, (Message, Union)) + return isinstance(self.schema.get_type(field_type.name), (Message, Union)) def java_array(self, field: Optional[Field]) -> bool: if field is None: @@ -2045,8 +2125,10 @@ def generate_registration_file( lines.append(" private static final ThreadSafeFory FORY = createFory();") lines.append(" }") lines.append("") - # When outer_classname is set, all top-level types become inner classes - type_prefix = outer_classname if outer_classname else "" + # When outer_classname is set, all top-level types become inner classes. + # The outer class is a JVM code-shape owner only; it is not part of the + # schema namespace used for name registration. + class_prefix = outer_classname if outer_classname else "" local_enums = [e for e in self.schema.enums if not self.is_imported_type(e)] local_unions = [u for u in self.schema.unions if not self.is_imported_type(u)] @@ -2060,15 +2142,15 @@ def generate_registration_file( # Register enums (top-level) for enum in local_enums: - self.generate_enum_registration(lines, enum, type_prefix) + self.generate_enum_registration(lines, enum, class_prefix) # Register unions (top-level) for union in local_unions: - self.generate_union_registration(lines, union, type_prefix) + self.generate_union_registration(lines, union, class_prefix) # Register messages (top-level and nested) for message in local_messages: - self.generate_message_registration(lines, message, type_prefix) + self.generate_message_registration(lines, message, class_prefix) lines.append(" }") lines.append("}") @@ -2084,12 +2166,18 @@ def generate_registration_file( return GeneratedFile(path=path, content="\n".join(lines)) def generate_enum_registration( - self, lines: List[str], enum: Enum, parent_path: str + self, + lines: List[str], + enum: Enum, + class_parent_path: str, + schema_parent_path: str = "", ): """Generate registration code for an enum.""" # In Java, nested class references use OuterClass.InnerClass - class_ref = f"{parent_path}.{enum.name}" if parent_path else enum.name - type_name = class_ref if parent_path else enum.name + class_ref = ( + f"{class_parent_path}.{enum.name}" if class_parent_path else enum.name + ) + type_name = enum.name if self.should_register_by_id(enum): lines.append( @@ -2098,17 +2186,25 @@ def generate_enum_registration( else: # Use FDL package for namespace (consistent across languages) ns = self.schema.package or "default" + if schema_parent_path: + ns = f"{ns}.{schema_parent_path}" lines.append( f' resolver.register({class_ref}.class, "{ns}", "{type_name}");' ) def generate_message_registration( - self, lines: List[str], message: Message, parent_path: str + self, + lines: List[str], + message: Message, + class_parent_path: str, + schema_parent_path: str = "", ): """Generate registration code for a message and its nested types.""" # In Java, nested class references use OuterClass.InnerClass - class_ref = f"{parent_path}.{message.name}" if parent_path else message.name - type_name = class_ref if parent_path else message.name + class_ref = ( + f"{class_parent_path}.{message.name}" if class_parent_path else message.name + ) + type_name = message.name if self.should_register_by_id(message): lines.append( @@ -2117,27 +2213,47 @@ def generate_message_registration( else: # Use FDL package for namespace (consistent across languages) ns = self.schema.package or "default" + if schema_parent_path: + ns = f"{ns}.{schema_parent_path}" lines.append( f' resolver.register({class_ref}.class, "{ns}", "{type_name}");' ) + nested_schema_parent_path = ( + f"{schema_parent_path}.{message.name}" + if schema_parent_path + else message.name + ) + # Register nested enums for nested_enum in message.nested_enums: - self.generate_enum_registration(lines, nested_enum, class_ref) + self.generate_enum_registration( + lines, nested_enum, class_ref, nested_schema_parent_path + ) # Register nested unions for nested_union in message.nested_unions: - self.generate_union_registration(lines, nested_union, class_ref) + self.generate_union_registration( + lines, nested_union, class_ref, nested_schema_parent_path + ) # Register nested messages for nested_msg in message.nested_messages: - self.generate_message_registration(lines, nested_msg, class_ref) + self.generate_message_registration( + lines, nested_msg, class_ref, nested_schema_parent_path + ) def generate_union_registration( - self, lines: List[str], union: Union, parent_path: str + self, + lines: List[str], + union: Union, + class_parent_path: str, + schema_parent_path: str = "", ): """Generate registration code for a union.""" - class_ref = f"{parent_path}.{union.name}" if parent_path else union.name + class_ref = ( + f"{class_parent_path}.{union.name}" if class_parent_path else union.name + ) type_name = union.name serializer_ref = f"new org.apache.fory.serializer.UnionSerializer(resolver, {class_ref}.class)" @@ -2147,8 +2263,8 @@ def generate_union_registration( ) else: ns = self.schema.package or "default" - if parent_path: - ns = f"{ns}.{parent_path}" + if schema_parent_path: + ns = f"{ns}.{schema_parent_path}" lines.append( f' resolver.registerUnion({class_ref}.class, "{ns}", "{type_name}", {serializer_ref});' ) diff --git a/compiler/fory_compiler/generators/scala.py b/compiler/fory_compiler/generators/scala.py index 76984653e2..15126a0a85 100644 --- a/compiler/fory_compiler/generators/scala.py +++ b/compiler/fory_compiler/generators/scala.py @@ -367,10 +367,10 @@ def generate_message( indent: int = 0, parent_stack: Optional[List[Message]] = None, ) -> List[str]: - if ( - self._construction_shapes.get(message.name, None) - and self._construction_shapes[message.name].cycle_owned - ): + shape = self._construction_shapes.get( + self.construction_key(parent_stack, message), None + ) + if shape is not None and shape.cycle_owned: return self.generate_normal_class(message, indent, parent_stack) return self.generate_case_class(message, indent, parent_stack) @@ -414,7 +414,7 @@ def generate_normal_class( top_level_ref=field.ref, parent_stack=current_stack, ) - if field.ref and self.is_ref_target_type(field.field_type): + if field.ref and self.is_ref_target_type(field.field_type, current_stack): lines.append(f"{ind} @Ref") lines.append(f"{ind} @ForyField(id = {field.number})") lines.append( @@ -456,7 +456,9 @@ def generate_parameter(self, field: Field, parent_stack: List[Message]) -> str: parent_stack=parent_stack, ) ref_annotation = ( - "@Ref " if field.ref and self.is_ref_target_type(field.field_type) else "" + "@Ref " + if field.ref and self.is_ref_target_type(field.field_type, parent_stack) + else "" ) return f"{ref_annotation}@ForyField(id = {field.number}) {field_name}: {field_type}" @@ -472,7 +474,7 @@ def generate_type( base = self._generate_non_optional_type( field_type, element_optional, element_ref, parent_stack ) - if top_level_ref and self.is_ref_target_type(field_type): + if top_level_ref and self.is_ref_target_type(field_type, parent_stack): base = self.apply_type_annotation(base, "Ref") return f"Option[{base}]" if nullable else base @@ -526,6 +528,11 @@ def current_stack( ) -> List[Message]: return [*(parent_stack or []), message] + def construction_key( + self, parent_stack: Optional[List[Message]], message: Message + ) -> str: + return ".".join([*[owner.name for owner in parent_stack or []], message.name]) + def resolve_scala_type_name( self, name: str, parent_stack: Optional[List[Message]] ) -> str: @@ -537,11 +544,6 @@ def resolve_scala_type_name( if package and package != self.get_scala_package(): return f"{package}.{name}" return name - named_type = self.schema.get_type(name) - if named_type is not None and self.is_imported_type(named_type): - package = self._scala_package_for_type(named_type) - if package and package != self.get_scala_package(): - return f"{package}.{name}" if parent_stack: for index in range(len(parent_stack) - 1, -1, -1): owner = parent_stack[index] @@ -549,6 +551,11 @@ def resolve_scala_type_name( return ".".join( [message.name for message in parent_stack[: index + 1]] + [name] ) + named_type = self.schema.get_type(name) + if named_type is not None and self.is_imported_type(named_type): + package = self._scala_package_for_type(named_type) + if package and package != self.get_scala_package(): + return f"{package}.{name}" return name def apply_type_annotation(self, scala_type: str, annotation: str) -> str: @@ -746,10 +753,18 @@ def field_type_has_ref(self, field_type: FieldType) -> bool: ) return False - def is_ref_target_type(self, field_type: FieldType) -> bool: + def is_ref_target_type( + self, field_type: FieldType, parent_stack: Optional[List[Message]] = None + ) -> bool: if not isinstance(field_type, NamedType): return False - return self.schema.get_type(field_type.name) is not None + if "." in field_type.name or not parent_stack: + return isinstance(self.schema.get_type(field_type.name), (Message, Union)) + for index in range(len(parent_stack) - 1, -1, -1): + resolved = parent_stack[index].get_nested_type(field_type.name) + if resolved is not None: + return isinstance(resolved, (Message, Union)) + return isinstance(self.schema.get_type(field_type.name), (Message, Union)) def generate_registration_file(self) -> GeneratedFile: imports = { @@ -817,15 +832,18 @@ def generate_type_registration( self, lines: List[str], type_def, owner_path: Optional[str] = None ) -> None: class_ref = f"{owner_path}.{type_def.name}" if owner_path else type_def.name + namespace = self.schema.package or "default" + type_name = type_def.name + if owner_path: + namespace = f"{namespace}.{owner_path}" if isinstance(type_def, Enum): if self.should_register_by_id(type_def): lines.append( f" ScalaSerializers.registerEnum(fory, classOf[{class_ref}], {type_def.type_id}L)" ) else: - namespace = self.schema.package or "default" lines.append( - f' ScalaSerializers.registerEnum(fory, classOf[{class_ref}], "{namespace}", "{type_def.name}")' + f' ScalaSerializers.registerEnum(fory, classOf[{class_ref}], "{namespace}", "{type_name}")' ) return if self.should_register_by_id(type_def): @@ -833,9 +851,8 @@ def generate_type_registration( f" ForySerializer.register(fory, classOf[{class_ref}], {type_def.type_id}L)" ) else: - namespace = self.schema.package or "default" lines.append( - f' ForySerializer.register(fory, classOf[{class_ref}], "{namespace}", "{type_def.name}")' + f' ForySerializer.register(fory, classOf[{class_ref}], "{namespace}", "{type_name}")' ) def safe_identifier(self, name: str) -> str: diff --git a/compiler/fory_compiler/ir/construction.py b/compiler/fory_compiler/ir/construction.py index 16aafc1ea8..4120935e2b 100644 --- a/compiler/fory_compiler/ir/construction.py +++ b/compiler/fory_compiler/ir/construction.py @@ -20,10 +20,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, Iterable, List, Set +from typing import Dict, Iterable, List, Optional, Set, Tuple from fory_compiler.ir.ast import ( ArrayType, + Enum, + Field, FieldType, ListType, MapType, @@ -31,6 +33,7 @@ NamedType, PrimitiveType, Schema, + Union, ) @@ -51,11 +54,19 @@ def analyze_message_construction_shapes( ``any`` field does not force this shape by itself. """ - messages = {message.name: message for message in schema.messages} - graph = { - name: set(_message_dependencies(message, messages)) - for name, message in messages.items() - } + message_entries, union_entries, types = _collect_types( + schema.messages, schema.unions, schema.enums + ) + messages = {name: message for name, message, _ in message_entries} + graph = {} + for name, message, parent_paths in message_entries: + graph[name] = set( + _field_dependencies(message.fields, types, (*parent_paths, message.name)) + ) + for name, union, parent_paths in union_entries: + graph[name] = set( + _field_dependencies(union.fields, types, (*parent_paths, union.name)) + ) cycle_owned = _cycle_nodes(graph) return { name: MessageConstructionShape(cycle_owned=name in cycle_owned) @@ -63,32 +74,90 @@ def analyze_message_construction_shapes( } -def _message_dependencies( - message: Message, messages: Dict[str, Message] +def _collect_types( + messages: Iterable[Message], + unions: Iterable[Union], + enums: Iterable[Enum], + parent_paths: Optional[Tuple[str, ...]] = None, +) -> Tuple[ + List[Tuple[str, Message, Tuple[str, ...]]], + List[Tuple[str, Union, Tuple[str, ...]]], + Dict[str, object], +]: + parent_paths = parent_paths or () + message_entries: List[Tuple[str, Message, Tuple[str, ...]]] = [] + union_entries: List[Tuple[str, Union, Tuple[str, ...]]] = [] + types: Dict[str, object] = {} + for union in unions: + name = ".".join((*parent_paths, union.name)) + union_entries.append((name, union, parent_paths)) + types[name] = union + for enum in enums: + name = ".".join((*parent_paths, enum.name)) + types[name] = enum + for message in messages: + name = ".".join((*parent_paths, message.name)) + message_entries.append((name, message, parent_paths)) + types[name] = message + nested_messages, nested_unions, nested_types = _collect_types( + message.nested_messages, + message.nested_unions, + message.nested_enums, + (*parent_paths, message.name), + ) + message_entries.extend(nested_messages) + union_entries.extend(nested_unions) + types.update(nested_types) + return message_entries, union_entries, types + + +def _field_dependencies( + fields: Iterable[Field], types: Dict[str, object], current_path: Tuple[str, ...] ) -> Iterable[str]: - for field in message.fields: - yield from _field_type_dependencies(field.field_type, messages) + for field in fields: + yield from _field_type_dependencies(field.field_type, types, current_path) def _field_type_dependencies( - field_type: FieldType, messages: Dict[str, Message] + field_type: FieldType, types: Dict[str, object], current_path: Tuple[str, ...] ) -> Iterable[str]: if isinstance(field_type, PrimitiveType): return if isinstance(field_type, NamedType): - root_name = field_type.name.split(".", 1)[0] - if root_name in messages: - yield root_name + resolved = _resolve_type_name(field_type.name, types, current_path) + if resolved is not None and isinstance(resolved[1], (Message, Union)): + yield resolved[0] return if isinstance(field_type, ListType): - yield from _field_type_dependencies(field_type.element_type, messages) + yield from _field_type_dependencies( + field_type.element_type, types, current_path + ) return if isinstance(field_type, ArrayType): - yield from _field_type_dependencies(field_type.element_type, messages) + yield from _field_type_dependencies( + field_type.element_type, types, current_path + ) return if isinstance(field_type, MapType): - yield from _field_type_dependencies(field_type.key_type, messages) - yield from _field_type_dependencies(field_type.value_type, messages) + yield from _field_type_dependencies(field_type.key_type, types, current_path) + yield from _field_type_dependencies(field_type.value_type, types, current_path) + + +def _resolve_type_name( + name: str, types: Dict[str, object], parent_paths: Tuple[str, ...] +) -> Optional[Tuple[str, object]]: + if "." in name: + resolved = types.get(name) + return (name, resolved) if resolved is not None else None + for index in range(len(parent_paths), 0, -1): + candidate = ".".join((*parent_paths[:index], name)) + resolved = types.get(candidate) + if resolved is not None: + return candidate, resolved + resolved = types.get(name) + if resolved is not None: + return name, resolved + return None def _cycle_nodes(graph: Dict[str, Set[str]]) -> Set[str]: diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index 468682e48c..af840808c8 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -449,6 +449,80 @@ def test_generated_code_nested_messages_equivalent(): assert_all_languages_equal(schemas) +def test_java_nested_name_registration_uses_owner_namespace(): + schema = parse_fdl( + """ + option enable_auto_type_id = false; + package demo; + + message Envelope { + message Payload { + int32 value = 1; + } + + enum Kind { + UNKNOWN = 0; + ACTIVE = 1; + } + + union Choice { + Payload payload = 1; + string note = 2; + } + + Payload payload = 1; + Kind kind = 2; + Choice choice = 3; + list payloads = 4; + } + """ + ) + output = render_files(generate_files(schema, JavaGenerator)) + + assert "import org.apache.fory.annotation.Ref;" in output + assert "private List<@Ref Payload> payloads;" in output + assert 'resolver.register(Envelope.class, "demo", "Envelope");' in output + assert ( + 'resolver.register(Envelope.Payload.class, "demo.Envelope", "Payload");' + in output + ) + assert 'resolver.register(Envelope.Kind.class, "demo.Envelope", "Kind");' in output + assert ( + 'resolver.registerUnion(Envelope.Choice.class, "demo.Envelope", "Choice"' + in output + ) + + +def test_java_nested_enum_shadowing_does_not_emit_ref_annotations(): + schema = parse_fdl( + """ + package demo; + + message Node { + Envelope owner = 1; + } + + message Envelope { + enum Node { + UNKNOWN = 0; + ACTIVE = 1; + } + + Node kind = 1; + list kinds = 2; + map indexed = 3; + } + """ + ) + output = render_files(generate_files(schema, JavaGenerator)) + + assert "import org.apache.fory.annotation.Ref;" not in output + assert "private Node kind;" in output + assert "private List kinds;" in output + assert "private Map indexed;" in output + assert "@Ref" not in output + + def test_generated_code_tree_ref_options_equivalent(): fdl = dedent( """ diff --git a/compiler/fory_compiler/tests/test_package_options.py b/compiler/fory_compiler/tests/test_package_options.py index 3bbf20d583..3cddcd1bb5 100644 --- a/compiler/fory_compiler/tests/test_package_options.py +++ b/compiler/fory_compiler/tests/test_package_options.py @@ -510,10 +510,21 @@ def test_outer_classname_registration_uses_prefix(self): """Test that registration uses outer class as prefix.""" source = """ package myapp; + option enable_auto_type_id = false; option java_outer_classname = "DescriptorProtos"; message User { + message Payload { + string value = 1; + } + + union Choice { + Payload payload = 1; + } + string name = 1; + Payload payload = 2; + Choice choice = 3; } """ lexer = Lexer(source) @@ -529,6 +540,18 @@ def test_outer_classname_registration_uses_prefix(self): # Should reference types with outer class prefix assert "DescriptorProtos.User.class" in registration_file.content + assert ( + 'resolver.register(DescriptorProtos.User.class, "myapp", "User");' + in registration_file.content + ) + assert ( + 'resolver.register(DescriptorProtos.User.Payload.class, "myapp.User", "Payload");' + in registration_file.content + ) + assert ( + 'resolver.registerUnion(DescriptorProtos.User.Choice.class, "myapp.User", "Choice"' + in registration_file.content + ) def test_outer_classname_with_nested_types(self): """Test java_outer_classname with nested types.""" diff --git a/compiler/fory_compiler/tests/test_scala_generator.py b/compiler/fory_compiler/tests/test_scala_generator.py index 40a5470da3..5c9a21e48d 100644 --- a/compiler/fory_compiler/tests/test_scala_generator.py +++ b/compiler/fory_compiler/tests/test_scala_generator.py @@ -98,6 +98,208 @@ def test_scala_generator_uses_mutable_normal_class_for_construction_cycles(): assert "var parent: Option[Node @Ref] = None" in node +def test_scala_generator_uses_mutable_normal_class_for_nested_construction_cycles(): + files = generate_scala( + """ + package graph; + + message Envelope [id=120] { + message Node [id=121] { + string id = 1; + ref Node parent = 2; + } + + Node root = 1; + } + """ + ) + + envelope = files["graph/Envelope.scala"] + assert "final case class Envelope(" in envelope + assert "object Envelope {" in envelope + assert "final class Node() derives ForySerializer" in envelope + assert 'var id: String = ""' in envelope + assert "var parent: Option[Envelope.Node @Ref] = None" in envelope + + +def test_scala_generator_marks_nested_owner_child_cycles_mutable(): + files = generate_scala( + """ + package graph; + + message Envelope [id=130] { + message Node [id=131] { + string id = 1; + ref Envelope owner = 2; + } + + Node root = 1; + } + """ + ) + + envelope = files["graph/Envelope.scala"] + assert "final class Envelope() derives ForySerializer" in envelope + assert "var root: Option[Envelope.Node] = None" in envelope + assert "final class Node() derives ForySerializer" in envelope + assert "var owner: Option[Envelope @Ref] = None" in envelope + + +def test_scala_generator_marks_union_mediated_cycles_mutable(): + files = generate_scala( + """ + package graph; + + message Node [id=140] { + string id = 1; + ref Choice choice = 2; + } + + union Choice [id=141] { + Node node = 1; + } + """ + ) + + node = files["graph/Node.scala"] + assert "final class Node() derives ForySerializer" in node + assert 'var id: String = ""' in node + assert "var choice: Choice @Ref = null" in node + + +def test_scala_generator_marks_nested_union_mediated_cycles_mutable(): + files = generate_scala( + """ + package graph; + + message Envelope [id=150] { + union Choice [id=151] { + Node node = 1; + } + + message Node [id=152] { + string id = 1; + ref Choice choice = 2; + } + + Node root = 1; + } + """ + ) + + envelope = files["graph/Envelope.scala"] + assert "final case class Envelope(" in envelope + assert "@ForyField(id = 1) root: Option[Envelope.Node]" in envelope + assert "enum Choice derives ForySerializer" in envelope + assert "case NodeCase(value: Envelope.Node)" in envelope + assert "final class Node() derives ForySerializer" in envelope + assert 'var id: String = ""' in envelope + assert "var choice: Envelope.Choice @Ref = null" in envelope + + +def test_scala_generator_resolves_shadowed_nested_types_before_top_level_types(): + files = generate_scala( + """ + package graph; + + message Node { + string label = 1; + } + + message Envelope { + message Node { + string id = 1; + ref Node parent = 2; + } + + Node root = 1; + } + """ + ) + + envelope = files["graph/Envelope.scala"] + assert "final class Node() derives ForySerializer" in envelope + assert "var parent: Option[Envelope.Node @Ref] = None" in envelope + assert "@ForyField(id = 1) root: Option[Envelope.Node]" in envelope + + +def test_scala_generator_does_not_make_cycles_from_shadowed_nested_enums(): + files = generate_scala( + """ + package graph; + + message Node [id=160] { + Envelope owner = 1; + } + + message Envelope [id=161] { + enum Node { + UNKNOWN = 0; + ACTIVE = 1; + } + + Node kind = 1; + list kinds = 2; + } + """ + ) + + node = files["graph/Node.scala"] + envelope = files["graph/Envelope.scala"] + assert "final case class Node(" in node + assert "final case class Envelope(" in envelope + assert "@ForyField(id = 1) kind: Envelope.Node" in envelope + assert "@ForyField(id = 2) kinds: List[Envelope.Node]" in envelope + assert "List[Envelope.Node @Ref]" not in envelope + + +def test_scala_generator_uses_jvm_nested_names_for_name_registration(): + files = generate_scala( + """ + option enable_auto_type_id = false; + package demo; + + message Envelope { + message Payload { + int32 value = 1; + } + + enum Kind { + UNKNOWN = 0; + ACTIVE = 1; + } + + union Choice { + Payload payload = 1; + string note = 2; + } + + Payload payload = 1; + Kind kind = 2; + Choice choice = 3; + } + """ + ) + + registration = files["demo/DemoForyRegistration.scala"] + assert ( + 'ForySerializer.register(fory, classOf[Envelope], "demo", "Envelope")' + in registration + ) + assert ( + 'ForySerializer.register(fory, classOf[Envelope.Payload], "demo.Envelope", "Payload")' + in registration + ) + assert ( + 'ScalaSerializers.registerEnum(fory, classOf[Envelope.Kind], "demo.Envelope", "Kind")' + in registration + ) + assert ( + 'ForySerializer.register(fory, classOf[Envelope.Choice], "demo.Envelope", "Choice")' + in registration + ) + + def test_scala_generator_keeps_imported_types_in_owner_package(): repo_root = Path(__file__).resolve().parents[3] idl_dir = repo_root / "integration_tests" / "idl_tests" / "idl" diff --git a/docs/compiler/schema-idl.md b/docs/compiler/schema-idl.md index e0a9dc6bfe..f585863a13 100644 --- a/docs/compiler/schema-idl.md +++ b/docs/compiler/schema-idl.md @@ -846,8 +846,7 @@ field_type field_name = field_number; ```protobuf optional list tags = 1; // Nullable list list tags = 2; // Elements may be null -ref list nodes = 3; // Collection tracked as a reference -list nodes = 4; // Elements tracked as references +list nodes = 3; // Elements tracked as references ``` **Grammar:** @@ -860,8 +859,9 @@ list_type := 'list' '<' { 'optional' | 'ref' | scalar_encoding } field_type ' array_type := 'array' '<' array_element_type '>' ``` -Modifiers apply to the field/collection. Use `list<...>` to describe element -modifiers. `repeated` is accepted as an alias for `list`. +`optional` before `list` applies to the collection field. `ref` is only valid +for named message/union fields; for collection contents, use `list` or +`map`. `repeated` is accepted as an alias for `list`. ### Field Modifiers @@ -962,23 +962,22 @@ Modifiers can be combined: message Example { optional list tags = 1; // Nullable list list aliases = 2; // Elements may be null - ref list nodes = 3; // Collection tracked as a reference - list children = 4; // Elements tracked as references - optional ref User owner = 5; // Nullable tracked reference + list children = 3; // Elements tracked as references + optional ref User owner = 4; // Nullable tracked reference } ``` -Modifiers before `list` apply to the field/collection. Modifiers after `list` -apply to elements. `repeated` is accepted as an alias for `list`. +`optional` before `list` applies to the field/collection. `ref` before `list` or +`map` is invalid; put `ref` inside the element/value type instead. `repeated` is +accepted as an alias for `list`. **List modifier mapping:** -| Fory IDL | Java | Python | Go | Rust | C++ | Dart | Scala | -| ----------------------- | --------------------------------------- | --------------------------------------- | ----------------------- | --------------------- | ----------------------------------------- | ------------------------------------------------------------- | ---------------------- | -| `optional list` | `@Nullable List` | `Optional[List[str]]` | `[]string` + `nullable` | `Option>` | `std::optional>` | `List?` | `Option[List[String]]` | -| `list` | `List` (nullable elements) | `List[Optional[str]]` | `[]*string` | `Vec>` | `std::vector>` | `List` | `List[Option[String]]` | -| `ref list` | `List` + `@ForyField(ref = true)` | `List[User]` + `pyfory.field(ref=True)` | `[]User` + `ref` | `Arc>` | `std::shared_ptr>` | `List` + `@ForyField(ref: true)` | `List[User] @Ref` | -| `list` | `List` | `List[User]` | `[]*User` + `ref=false` | `Vec>` | `std::vector>` | `List` + `@ListField(element: DeclaredType(ref: true))` | `List[User @Ref]` | +| Fory IDL | Java | Python | Go | Rust | C++ | Dart | Scala | +| ----------------------- | ---------------------------------- | --------------------- | ----------------------- | --------------------- | ----------------------------------------- | ------------------------------------------------------------- | ---------------------- | +| `optional list` | `@Nullable List` | `Optional[List[str]]` | `[]string` + `nullable` | `Option>` | `std::optional>` | `List?` | `Option[List[String]]` | +| `list` | `List` (nullable elements) | `List[Optional[str]]` | `[]*string` | `Vec>` | `std::vector>` | `List` | `List[Option[String]]` | +| `list` | `List` | `List[User]` | `[]*User` + `ref=false` | `Vec>` | `std::vector>` | `List` + `@ListField(element: DeclaredType(ref: true))` | `List[User @Ref]` | Use `ref(thread_safe=false)` in Fory IDL (or `[(fory).thread_safe_pointer = false]` in protobuf) to generate `Rc` instead of `Arc` in Rust. @@ -1363,6 +1362,11 @@ code. When `enable_auto_type_id = false`, types without explicit IDs are registered by namespace and name instead. Collisions are detected at compile-time across the current file and all imports; when a collision occurs, the compiler raises an error and asks for an explicit `id` or an `alias`. +For Java and Scala generated code, nested name registration appends the parent +path to the namespace and keeps the nested type's simple name. For example, +`package demo; message Envelope { message Payload { ... } }` registers +`Payload` as namespace `demo.Envelope` and type name `Payload` in those JVM +targets. ```protobuf enum Color [id=100] { ... } diff --git a/docs/guide/scala/schema-idl.md b/docs/guide/scala/schema-idl.md index 1b002434dd..b18a63d54d 100644 --- a/docs/guide/scala/schema-idl.md +++ b/docs/guide/scala/schema-idl.md @@ -68,6 +68,10 @@ Messages in compiler-detected construction cycles generate normal classes with mutable serialized fields so the deserializer can allocate and register the object before reading fields that can point back to it. A top-level `ref Foo`, nested `list`, or `any` field does not by itself force this shape. +The compiler analyzes message and union dependencies together, so +message-to-union-to-message cycles also make the participating messages normal +classes. Acyclic owner messages that only contain a cyclic nested type remain +case classes. Reference tracking is expressed with the shared `@Ref` annotation, including type-use positions: @@ -154,3 +158,10 @@ direct assignments for mutable post-construction fields. It builds descriptor metadata from Scala compile-time types, including nested generics, `Option`, arrays, scalar encoding annotations, nullability, and `@Ref` metadata. Java reflection is not the source of truth for generated Scala metadata. + +During copy, cyclic graphs are supported when the copied root can be allocated +and registered before cyclic fields are copied, which is the normal-class shape +used by schema IDL for construction cycles. If a copy starts at an immutable +constructor-owned value that participates in the cycle, such as a Scala enum +case or case class, the serializer fails with a clear error because no copied +identity can be published until construction has completed. diff --git a/docs/specification/xlang_type_mapping.md b/docs/specification/xlang_type_mapping.md index 2a5fdcebe4..2863c738ed 100644 --- a/docs/specification/xlang_type_mapping.md +++ b/docs/specification/xlang_type_mapping.md @@ -163,30 +163,30 @@ Notes: The Scala schema IDL target emits Scala 3 source only. The `fory-scala` runtime artifact remains cross-built for Scala 2.13 and Scala 3. -| Fory schema kind | Scala generated carrier | -| ------------------------------------- | -------------------------------------------------------------------------- | -| `optional T` | `Option[T]` | -| `bool` | `Boolean` | -| `int8`, `int16`, `int32`, `int64` | `Byte`, `Short`, `Int`, `Long` | -| `uint8`, `uint16`, `uint32`, `uint64` | `Int`, `Int`, `Long`, `Long` plus unsigned Fory type metadata | -| `float16`, `bfloat16` | JVM half-float and bfloat16 carriers | -| `float32`, `float64` | `Float`, `Double` | -| `string` | `String` | -| `binary` | `Array[Byte]` | -| `list`, `set`, `map` | `List[T]`, `Set[T]`, `Map[K, V]` | -| `array` | `Array[Boolean]` | -| `array`, `array` | `Array[Byte]` with signed/unsigned descriptor metadata | -| `array`, `array` | `Array[Short]` with signed/unsigned descriptor metadata | -| `array`, `array` | `Array[Int]` with signed/unsigned descriptor metadata | -| `array`, `array` | `Array[Long]` with signed/unsigned descriptor metadata | -| `array`, `array` | `Array[Short]` with reduced-precision descriptor metadata | -| `array`, `array` | `Array[Float]`, `Array[Double]` | -| `date`, `timestamp`, `duration` | `java.time.LocalDate`, `java.time.Instant`, `java.time.Duration` | -| `decimal` | `java.math.BigDecimal` | -| `message` | Scala 3 `case class` by default; normal class only for construction cycles | -| `enum` | Scala 3 `enum` with stable Fory enum IDs | -| `union` | Scala 3 ADT `enum derives ForySerializer` | -| `any` | `AnyRef` | +| Fory schema kind | Scala generated carrier | +| ------------------------------------- | ---------------------------------------------------------------------------------------- | +| `optional T` | `Option[T]` | +| `bool` | `Boolean` | +| `int8`, `int16`, `int32`, `int64` | `Byte`, `Short`, `Int`, `Long` | +| `uint8`, `uint16`, `uint32`, `uint64` | `Int`, `Int`, `Long`, `Long` plus unsigned Fory type metadata | +| `float16`, `bfloat16` | JVM half-float and bfloat16 carriers | +| `float32`, `float64` | `Float`, `Double` | +| `string` | `String` | +| `binary` | `Array[Byte]` | +| `list`, `set`, `map` | `List[T]`, `Set[T]`, `Map[K, V]` | +| `array` | `Array[Boolean]` | +| `array`, `array` | `Array[Byte]` with signed/unsigned descriptor metadata | +| `array`, `array` | `Array[Short]` with signed/unsigned descriptor metadata | +| `array`, `array` | `Array[Int]` with signed/unsigned descriptor metadata | +| `array`, `array` | `Array[Long]` with signed/unsigned descriptor metadata | +| `array`, `array` | `Array[Short]` with reduced-precision descriptor metadata | +| `array`, `array` | `Array[Float]`, `Array[Double]` | +| `date`, `timestamp`, `duration` | `java.time.LocalDate`, `java.time.Instant`, `java.time.Duration` | +| `decimal` | `java.math.BigDecimal` | +| `message` | Scala 3 `case class` by default; normal class only for message/union construction cycles | +| `enum` | Scala 3 `enum` with stable Fory enum IDs | +| `union` | Scala 3 ADT `enum derives ForySerializer` | +| `any` | `AnyRef` | Generated Scala descriptor metadata is produced by Scala 3 macro derivation from Scala compile-time types, including nested generics, `Option`, arrays, diff --git a/integration_tests/idl_tests/generate_idl.py b/integration_tests/idl_tests/generate_idl.py index b53442e997..466b170ef9 100755 --- a/integration_tests/idl_tests/generate_idl.py +++ b/integration_tests/idl_tests/generate_idl.py @@ -42,6 +42,11 @@ IDL_DIR / "idl" / "example.fdl", ] +LANG_EXTRA_SCHEMAS = { + "java": [IDL_DIR / "idl" / "nested_name.fdl"], + "scala": [IDL_DIR / "idl" / "nested_name.fdl"], +} + LANG_OUTPUTS = { "java": REPO_ROOT / "integration_tests/idl_tests/java/src/main/java/generated", "python": REPO_ROOT / "integration_tests/idl_tests/python/idl_tests/generated", @@ -127,7 +132,18 @@ def main() -> int: env=env, ) - for schema in SCHEMAS: + schemas_by_lang = { + lang: [*SCHEMAS, *LANG_EXTRA_SCHEMAS.get(lang, [])] for lang in langs + } + schemas = [] + seen_schemas = set() + for lang in langs: + for schema in schemas_by_lang[lang]: + if schema not in seen_schemas: + schemas.append(schema) + seen_schemas.add(schema) + + for schema in schemas: cmd = [ sys.executable, "-m", @@ -137,6 +153,8 @@ def main() -> int: ] for lang in langs: + if schema not in schemas_by_lang[lang]: + continue out_dir = LANG_OUTPUTS[lang] if lang == "go": out_dir = GO_OUTPUT_OVERRIDES.get(schema.name, out_dir) diff --git a/integration_tests/idl_tests/idl/nested_name.fdl b/integration_tests/idl_tests/idl/nested_name.fdl new file mode 100644 index 0000000000..6b49346c9a --- /dev/null +++ b/integration_tests/idl_tests/idl/nested_name.fdl @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +option enable_auto_type_id = false; + +package nested_name; + +message Envelope { + message Node { + string id = 1; + ref(weak=true) Node parent = 2; + list children = 3; + } + + enum Kind { + UNKNOWN = 0; + ACTIVE = 1; + } + + union Choice { + Node node = 1; + string note = 2; + } + + ref Node root = 1; + Kind kind = 2; + Choice choice = 3; +} diff --git a/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java b/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java index 44d320e50e..797ad18813 100644 --- a/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java +++ b/integration_tests/idl_tests/java/src/test/java/org/apache/fory/idl_tests/IdlRoundTripTest.java @@ -87,6 +87,7 @@ import monster.Monster; import monster.MonsterForyRegistration; import monster.Vec3; +import nested_name.NestedNameForyRegistration; import optional_types.AllOptionalTypes; import optional_types.OptionalHolder; import optional_types.OptionalTypesForyRegistration; @@ -137,6 +138,16 @@ public void testAutoIdRoundTripSchemaConsistent() throws Exception { runAutoIdRoundTrip(false); } + @Test + public void testNestedNameRoundTripCompatible() throws Exception { + runNestedNameRoundTrip(true); + } + + @Test + public void testNestedNameRoundTripSchemaConsistent() throws Exception { + runNestedNameRoundTrip(false); + } + @Test public void testEvolvingRoundTrip() { runEvolvingRoundTrip(); @@ -204,6 +215,34 @@ private void runAutoIdRoundTrip(boolean compatible) throws Exception { } } + private void runNestedNameRoundTrip(boolean compatible) throws Exception { + Fory fory = buildRefFory(compatible); + NestedNameForyRegistration.register(fory); + + nested_name.Envelope envelope = buildNestedNameEnvelope(); + byte[] bytes = fory.serialize(envelope); + Object decoded = fory.deserialize(bytes); + + Assert.assertTrue(decoded instanceof nested_name.Envelope); + assertNestedNameEnvelope((nested_name.Envelope) decoded); + + for (String peer : resolvePeers("scala")) { + Path dataFile = Files.createTempFile("idl-nested-name-" + peer + "-", ".bin"); + dataFile.toFile().deleteOnExit(); + Files.write(dataFile, bytes); + + Map env = new HashMap<>(); + env.put("DATA_FILE_NESTED_NAME", dataFile.toAbsolutePath().toString()); + PeerCommand command = buildPeerCommand(peer, env, compatible); + runPeer(command, peer); + + byte[] peerBytes = Files.readAllBytes(dataFile); + Object peerRoundTrip = fory.deserialize(peerBytes); + Assert.assertTrue(peerRoundTrip instanceof nested_name.Envelope); + assertNestedNameEnvelope((nested_name.Envelope) peerRoundTrip); + } + } + private void runEvolvingRoundTrip() { Fory foryV1 = buildFory(true); Fory foryV2 = buildFory(true); @@ -713,6 +752,15 @@ private List resolvePeers() { return peers; } + private List resolvePeers(String... supportedPeers) { + List peers = resolvePeers(); + if (supportedPeers.length == 0) { + return peers; + } + List supported = Arrays.asList(supportedPeers); + return peers.stream().filter(supported::contains).collect(Collectors.toList()); + } + private PeerCommand buildPeerCommand( String peer, Map environment, boolean compatible) { Path repoRoot = repoRoot(); @@ -953,6 +1001,35 @@ private Wrapper buildAutoIdWrapper(Envelope envelope) { return Wrapper.ofEnvelope(envelope); } + private nested_name.Envelope buildNestedNameEnvelope() { + nested_name.Envelope.Node root = new nested_name.Envelope.Node(); + root.setId("root"); + nested_name.Envelope.Node child = new nested_name.Envelope.Node(); + child.setId("child"); + child.setParent(root); + child.setChildren(Collections.emptyList()); + root.setChildren(Collections.singletonList(child)); + + nested_name.Envelope envelope = new nested_name.Envelope(); + envelope.setRoot(root); + envelope.setKind(nested_name.Envelope.Kind.ACTIVE); + envelope.setChoice(nested_name.Envelope.Choice.ofNode(child)); + return envelope; + } + + private void assertNestedNameEnvelope(nested_name.Envelope envelope) { + Assert.assertEquals(envelope.getKind(), nested_name.Envelope.Kind.ACTIVE); + Assert.assertNotNull(envelope.getRoot()); + nested_name.Envelope.Node root = envelope.getRoot(); + Assert.assertEquals(root.getId(), "root"); + Assert.assertEquals(root.getChildren().size(), 1); + nested_name.Envelope.Node child = root.getChildren().get(0); + Assert.assertEquals(child.getId(), "child"); + Assert.assertSame(child.getParent(), root); + Assert.assertTrue(envelope.getChoice().hasNode()); + Assert.assertSame(envelope.getChoice().getNode(), child); + } + private PrimitiveTypes buildPrimitiveTypes() { PrimitiveTypes types = new PrimitiveTypes(); types.setBoolValue(true); diff --git a/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripPeer.scala b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripPeer.scala index 0caabecd08..3fcdab052a 100644 --- a/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripPeer.scala +++ b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripPeer.scala @@ -27,6 +27,7 @@ import complex_pb.ComplexPbForyRegistration import example.ExampleForyRegistration import graph.GraphForyRegistration import monster.MonsterForyRegistration +import nested_name.NestedNameForyRegistration import optional_types.OptionalTypesForyRegistration import org.apache.fory.Fory import org.apache.fory.serializer.scala.ScalaSerializers @@ -39,6 +40,8 @@ object ScalaIdlRoundTripPeer { val compatible = sys.env.get("IDL_COMPATIBLE").forall(_.toBoolean) roundTrip("DATA_FILE", compatible, refTracking = false)(AddressbookForyRegistration.register) roundTrip("DATA_FILE_AUTO_ID", compatible, refTracking = false)(AutoIdForyRegistration.register) + roundTrip("DATA_FILE_NESTED_NAME", compatible, refTracking = true)( + NestedNameForyRegistration.register) roundTrip("DATA_FILE_PRIMITIVES", compatible, refTracking = false) { fory => AddressbookForyRegistration.register(fory) ComplexPbForyRegistration.register(fory) diff --git a/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala index 5deb278663..9de941711d 100644 --- a/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala +++ b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala @@ -28,6 +28,7 @@ import collection.{ NumericCollectionsArray } import example.{ExampleForyRegistration, ExampleMessage, ExampleState} +import nested_name.NestedNameForyRegistration import org.apache.fory.Fory import org.apache.fory.meta.FieldTypes import org.apache.fory.scala.{ForyScalaEnum, ForySerializer} @@ -176,5 +177,38 @@ final class ScalaIdlRoundTripTest extends AnyWordSpec with Matchers { roundTrip.children.head.id shouldEqual "child" roundTrip.children.head.parent.get shouldBe theSameInstanceAs(roundTrip) } + + "round trip name-registered nested messages, enums, and unions" in { + val fory = Fory.builder() + .withXlang(true) + .withCompatible(true) + .withRefTracking(true) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + NestedNameForyRegistration.register(fory) + + val root = new nested_name.Envelope.Node() + root.id = "root" + val child = new nested_name.Envelope.Node() + child.id = "child" + child.parent = Some(root) + child.children = List.empty + root.children = List(child) + val envelope = nested_name.Envelope( + Some(root), + nested_name.Envelope.Kind.Active, + nested_name.Envelope.Choice.NodeCase(child)) + + val roundTrip = fory.deserialize(fory.serialize(envelope)).asInstanceOf[nested_name.Envelope] + roundTrip.kind shouldBe nested_name.Envelope.Kind.Active + val roundTripRoot = roundTrip.root.get + roundTripRoot.id shouldEqual "root" + val roundTripChild = roundTripRoot.children.head + roundTripChild.id shouldEqual "child" + roundTripChild.parent.get shouldBe theSameInstanceAs(roundTripRoot) + roundTrip.choice.asInstanceOf[nested_name.Envelope.Choice.NodeCase].value shouldBe + theSameInstanceAs(roundTripChild) + } } } diff --git a/java/fory-core/src/main/java/org/apache/fory/context/CopyContext.java b/java/fory-core/src/main/java/org/apache/fory/context/CopyContext.java index 5eea194405..94fe27bc47 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/CopyContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/CopyContext.java @@ -21,6 +21,7 @@ import java.util.Arrays; import org.apache.fory.collection.IdentityMap; +import org.apache.fory.exception.CopyException; import org.apache.fory.resolver.ClassResolver; import org.apache.fory.resolver.TypeInfo; import org.apache.fory.resolver.TypeResolver; @@ -38,6 +39,8 @@ */ @SuppressWarnings("unchecked") public final class CopyContext { + private static final Object COPY_IN_PROGRESS = new Object(); + private final TypeResolver typeResolver; private final boolean copyRefTracking; private final IdentityMap originToCopyMap; @@ -86,9 +89,36 @@ public void reference(T origin, T copied) { } } + /** + * Marks an origin as being copied before the destination value can be constructed. + * + *

    Constructor-owned immutable values cannot publish a copy early. Serializers for those values + * use this marker so recursive copies fail with a clear error instead of recursing until stack + * overflow. + */ + public void markCopyInProgress(T origin) { + if (copyRefTracking && origin != null) { + originToCopyMap.put(origin, COPY_IN_PROGRESS); + } + } + + /** Clears a copy-in-progress marker if no completed copy replaced it. */ + public void clearCopyInProgress(T origin) { + if (copyRefTracking && origin != null && originToCopyMap.get(origin) == COPY_IN_PROGRESS) { + originToCopyMap.remove(origin); + } + } + /** Returns the previously registered copy for {@code origin}, or {@code null} if absent. */ public T getCopyObject(T origin) { - return (T) originToCopyMap.get(origin); + Object copied = originToCopyMap.get(origin); + if (copied == COPY_IN_PROGRESS) { + throw new CopyException( + "Cannot copy cyclic object graph rooted at constructor-owned immutable value " + + origin.getClass().getName() + + " because its copy cannot be referenced before construction completes"); + } + return (T) copied; } /** diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java index 0fbf7cf22d..38061f62a6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java @@ -456,13 +456,23 @@ protected final void debugRemoteReadField( + buffer.readerIndex()); } - protected final Object copyFieldValue( + public static Object copyFieldValue( CopyContext copyContext, Object fieldValue, SerializationFieldInfo fieldInfo) { + if (fieldValue == null) { + return null; + } if (fieldInfo.containerSerializerOverride != null) { @SuppressWarnings("unchecked") Serializer serializer = (Serializer) fieldInfo.containerSerializerOverride; return copyContext.copyObject(fieldValue, serializer); } + if (fieldInfo.codecCategory == FieldGroups.FieldCodecCategory.CONTAINER + && fieldInfo.containerTypeInfo != null) { + @SuppressWarnings("unchecked") + Serializer serializer = + (Serializer) fieldInfo.containerTypeInfo.getSerializer(); + return copyContext.copyObject(fieldValue, serializer); + } return copyContext.copyObject(fieldValue, fieldInfo.dispatchId); } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java index 991b5955b5..9109dcc286 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/UnionSerializer.java @@ -197,6 +197,12 @@ public Union copy(CopyContext copyContext, Union union) { return factory.apply(union.getIndex(), copiedValue); } + /** Copies a schema-defined union case payload for generated non-Java union carriers. */ + public static Object copyCaseValue( + CopyContext copyContext, FieldGroups.SerializationFieldInfo fieldInfo, Object value) { + return StaticGeneratedStructSerializer.copyFieldValue(copyContext, value, fieldInfo); + } + /** * Writes a schema-defined union case for generated non-Java union carriers. * diff --git a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala index 042969a439..867a13e6f4 100644 --- a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala +++ b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala @@ -401,46 +401,92 @@ object ForySerializerMacros { } } - def copyValue(valueExpr: Expr[T]): Expr[T] = { + def copiedValueArg( + valueExpr: Expr[T], + field: FieldMeta, + copyContextExpr: Expr[org.apache.fory.context.CopyContext], + fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]]): Expr[Any] = { + val selected = selectValue(valueExpr, field) + val wireValue = + if field.option then '{ $selected.asInstanceOf[Option[Any]].orNull } + else selected + val copied = + '{ + StaticGeneratedStructSerializer.copyFieldValue( + $copyContextExpr, + $wireValue, + $fieldsByIdExpr(${ Expr(field.index) })) + } + decodeValue(copied, field) + } + + def referenceCopy( + copyContextExpr: Expr[org.apache.fory.context.CopyContext], + sourceExpr: Expr[T], + copiedExpr: Expr[T]): Term = + '{ $copyContextExpr.reference($sourceExpr, $copiedExpr) }.asTerm + + def copyValue( + valueExpr: Expr[T], + copyContextExpr: Expr[org.apache.fory.context.CopyContext], + fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]]): Expr[T] = { val constructorOwned = fields.filter(_.constructorOwned) val postConstruction = fields.filterNot(_.constructorOwned) - if postConstruction.isEmpty then { - val args = constructorOwned.map { field => - val selected = selectValue(valueExpr, field) - field.sourceType.asType match { - case '[a] => '{ $selected.asInstanceOf[a] }.asTerm + + def copyBody(): Expr[T] = + if postConstruction.isEmpty then { + val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) + val args = constructorOwned.map { field => + copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr).asTerm } - } - Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args).asExprOf[T] - } else if constructorOwned.isEmpty then { - val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) - val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), Nil) - val assignments = postConstruction.map { field => - val selected = selectValue(valueExpr, field) - val copied = field.sourceType.asType match { - case '[a] => '{ $selected.asInstanceOf[a] } + val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args) + Block( + ValDef(obj, Some(construct)) :: + referenceCopy(copyContextExpr, valueExpr, Ref(obj).asExprOf[T]) :: + Nil, + Ref(obj)).asExprOf[T] + } else if constructorOwned.isEmpty then { + val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) + val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), Nil) + val assignments = postConstruction.map { field => + val copied = copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr) + Assign(Select.unique(Ref(obj), field.name), copied.asTerm) } - Assign(Select.unique(Ref(obj), field.name), copied.asTerm) - } - Block(ValDef(obj, Some(construct)) :: assignments, Ref(obj)).asExprOf[T] - } else { - val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) - val args = constructorOwned.map { field => - val selected = selectValue(valueExpr, field) - field.sourceType.asType match { - case '[a] => '{ $selected.asInstanceOf[a] }.asTerm + Block( + ValDef(obj, Some(construct)) :: + referenceCopy(copyContextExpr, valueExpr, Ref(obj).asExprOf[T]) :: + assignments, + Ref(obj)).asExprOf[T] + } else { + val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) + val args = constructorOwned.map { field => + copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr).asTerm } - } - val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args) - val assignments = postConstruction.map { field => - val selected = selectValue(valueExpr, field) - val copied = field.sourceType.asType match { - case '[a] => '{ $selected.asInstanceOf[a] } + val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args) + val assignments = postConstruction.map { field => + val copied = copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr) + Assign(Select.unique(Ref(obj), field.name), copied.asTerm) } - Assign(Select.unique(Ref(obj), field.name), copied.asTerm) + Block( + ValDef(obj, Some(construct)) :: + referenceCopy(copyContextExpr, valueExpr, Ref(obj).asExprOf[T]) :: + assignments, + Ref(obj)).asExprOf[T] } - Block(ValDef(obj, Some(construct)) :: assignments, Ref(obj)).asExprOf[T] - } + + if constructorOwned.nonEmpty then { + '{ + if $copyContextExpr.copyTrackingRef() then { + $copyContextExpr.markCopyInProgress($valueExpr) + try ${ copyBody() } + catch { + case throwable: Throwable => + $copyContextExpr.clearCopyInProgress($valueExpr) + throw throwable + } + } else ${ copyBody() } + } + } else copyBody() } def constructRead(valuesExpr: Expr[Array[Any]], readContextExpr: Expr[org.apache.fory.context.ReadContext]): Expr[T] = { @@ -652,7 +698,7 @@ object ForySerializerMacros { override def copy( copyContext: org.apache.fory.context.CopyContext, - value: T): T = ${ copyValue('value) } + value: T): T = ${ copyValue('value, 'copyContext, 'fieldsById) } } } } @@ -667,11 +713,35 @@ object ForySerializerMacros { symbol: Symbol, id: Int, payloadType: TypeRepr, + option: Boolean, payloadName: String, unknownIdName: String, unknown: Boolean, fieldIndex: Int) + def peelAnnotations(tpe: TypeRepr): (TypeRepr, List[Term]) = { + tpe match { + case AnnotatedType(underlying, annotation) => + val (base, annotations) = peelAnnotations(underlying) + (base, annotation :: annotations) + case other => + other.dealias match { + case AnnotatedType(underlying, annotation) => + val (base, annotations) = peelAnnotations(underlying) + (base, annotation :: annotations) + case dealiased => (dealiased, Nil) + } + } + } + + def optionElement(tpe: TypeRepr): Option[TypeRepr] = { + peelAnnotations(tpe)._1.dealias match { + case AppliedType(base, List(arg)) if base.typeSymbol.fullName == "scala.Option" => + Some(arg) + case _ => None + } + } + def payloadMeta(child: Symbol, id: Int): (TypeRepr, String, String) = { val params = child.primaryConstructor.paramSymss.flatten if id == 0 then { @@ -708,7 +778,15 @@ object ForySerializerMacros { annotationIntArg[ForyCase](child, "id").map { id => if id < 0 then report.errorAndAbort(s"${child.fullName} @ForyCase id must be >= 0") val (tpe, payloadName, unknownIdName) = payloadMeta(child, id) - CaseMeta(child, id, tpe, payloadName, unknownIdName, id == 0, -1) + CaseMeta( + child, + id, + tpe, + optionElement(tpe).nonEmpty, + payloadName, + unknownIdName, + id == 0, + -1) } } var nextFieldIndex = 0 @@ -731,29 +809,6 @@ object ForySerializerMacros { cases.find(_.unknown).getOrElse( report.errorAndAbort(s"${owner.fullName} must define @ForyCase(id = 0) unknown case")) - def peelAnnotations(tpe: TypeRepr): (TypeRepr, List[Term]) = { - tpe match { - case AnnotatedType(underlying, annotation) => - val (base, annotations) = peelAnnotations(underlying) - (base, annotation :: annotations) - case other => - other.dealias match { - case AnnotatedType(underlying, annotation) => - val (base, annotations) = peelAnnotations(underlying) - (base, annotation :: annotations) - case dealiased => (dealiased, Nil) - } - } - } - - def optionElement(tpe: TypeRepr): Option[TypeRepr] = { - peelAnnotations(tpe)._1.dealias match { - case AppliedType(base, List(arg)) if base.typeSymbol.fullName == "scala.Option" => - Some(arg) - case _ => None - } - } - def boxedIfPrimitive(tpe: TypeRepr): TypeRepr = { val (base, annotations) = peelAnnotations(tpe) val boxed = @@ -906,6 +961,31 @@ object ForySerializerMacros { java.util.Collections.unmodifiableList(${ Expr.ofList(knownCases.map(caseDescriptor)) }.asJava) } + def wirePayload(payloadExpr: Expr[Any], unionCase: CaseMeta): Expr[Any] = + if unionCase.option then '{ $payloadExpr.asInstanceOf[Option[Any]].orNull } + else payloadExpr + + def decodePayload(payloadExpr: Expr[Any], unionCase: CaseMeta): Expr[Any] = { + optionElement(unionCase.payloadType) match { + case Some(inner) => + inner.asType match { + case '[p] => + unionCase.payloadType.asType match { + case '[a] => + '{ + val rawPayload = $payloadExpr + if rawPayload == null then None.asInstanceOf[a] + else Option(${ coercePayload[p]('rawPayload, inner) }).asInstanceOf[a] + } + } + } + case None => + unionCase.payloadType.asType match { + case '[p] => coercePayload[p](payloadExpr, unionCase.payloadType) + } + } + } + def writeDispatch( valueExpr: Expr[T], writeContextExpr: Expr[org.apache.fory.context.WriteContext], @@ -938,13 +1018,14 @@ object ForySerializerMacros { Select.unique( '{ $valueExpr.asInstanceOf[c] }.asTerm, unionCase.payloadName).asExpr + val payloadValue = wirePayload(payload, unionCase) '{ if $valueExpr.isInstanceOf[c] then { UnionSerializer.writeCaseValue( $resolverExpr, $writeContextExpr, $caseFieldInfosExpr(${ Expr(unionCase.fieldIndex) }), - $payload, + $payloadValue, ${ Expr(unionCase.id) }) } else { $next @@ -968,19 +1049,59 @@ object ForySerializerMacros { val unknownPayload = '{ $readContextExpr.readRef() } val unknownExpr = construct(unknown, List(caseIdExpr.asTerm, unknownPayload.asTerm)) knownCases.foldRight(unknownExpr) { (unionCase, next) => - unionCase.payloadType.asType match { - case '[p] => - val rawPayload = + val rawPayload = + '{ + UnionSerializer.readCaseValue( + $resolverExpr, + $readContextExpr, + $caseFieldInfosExpr(${ Expr(unionCase.fieldIndex) })) + } + val payload = decodePayload(rawPayload, unionCase) + val current = construct(unionCase, List(payload.asTerm)) + '{ + if $caseIdExpr == ${ Expr(unionCase.id) } then $current else $next + } + } + } + + def copyDispatch( + valueExpr: Expr[T], + copyContextExpr: Expr[org.apache.fory.context.CopyContext], + caseFieldInfosExpr: Expr[Array[FieldGroups.SerializationFieldInfo]]): Expr[T] = { + cases.foldRight( + '{ + throw new IllegalStateException("Unknown Scala union case " + $valueExpr) + }: Expr[T]) { (unionCase, next) => + unionCase.symbol.typeRef.asType match { + case '[c] => + val payload = + Select.unique( + '{ $valueExpr.asInstanceOf[c] }.asTerm, + unionCase.payloadName).asExpr + if unionCase.unknown then { + val originalId = + Select.unique( + '{ $valueExpr.asInstanceOf[c] }.asTerm, + unionCase.unknownIdName).asExprOf[Int] + val copiedPayload = '{ $copyContextExpr.copyObject($payload) } + val current = construct(unknown, List(originalId.asTerm, copiedPayload.asTerm)) + '{ + if $valueExpr.isInstanceOf[c] then $current else $next + } + } else { + val payloadValue = wirePayload(payload, unionCase) + val copiedPayload = + '{ + UnionSerializer.copyCaseValue( + $copyContextExpr, + $caseFieldInfosExpr(${ Expr(unionCase.fieldIndex) }), + $payloadValue) + } + val coerced = decodePayload(copiedPayload, unionCase) + val current = construct(unionCase, List(coerced.asTerm)) '{ - UnionSerializer.readCaseValue( - $resolverExpr, - $readContextExpr, - $caseFieldInfosExpr(${ Expr(unionCase.fieldIndex) })) + if $valueExpr.isInstanceOf[c] then $current else $next } - val payload = coercePayload[p](rawPayload, unionCase.payloadType) - val current = construct(unionCase, List(payload.asTerm)) - '{ - if $caseIdExpr == ${ Expr(unionCase.id) } then $current else $next } } } @@ -1081,8 +1202,22 @@ object ForySerializerMacros { ${ readDispatch('caseId, 'readContext, 'resolver, 'caseFieldInfos) } } - override def copy(copyContext: org.apache.fory.context.CopyContext, value: T): T = - value + override def copy(copyContext: org.apache.fory.context.CopyContext, value: T): T = { + if copyContext.copyTrackingRef() then { + copyContext.markCopyInProgress(value) + try { + val copied = ${ copyDispatch('value, 'copyContext, 'caseFieldInfos) } + copyContext.reference(value, copied) + copied + } catch { + case throwable: Throwable => + copyContext.clearCopyInProgress(value) + throw throwable + } + } else { + ${ copyDispatch('value, 'copyContext, 'caseFieldInfos) } + } + } } } } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala index a9014103b6..bfd51ce613 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala @@ -53,10 +53,79 @@ abstract class AbstractScalaXlangCollectionSerializer[A, T <: scala.collection.I override def copy(copyContext: CopyContext, value: T): T = { if (isImmutable) { value + } else if ( + value.isInstanceOf[mutable.IndexedSeq[_]] && + !value.isInstanceOf[mutable.Growable[_]]) { + copyIndexedSeq( + copyContext, + value, + value.asInstanceOf[mutable.IndexedSeq[A]]) + } else if (value.isInstanceOf[mutable.Iterable[_]]) { + newMutableCopy(value, value.size) match { + case result: mutable.Iterable[_] with mutable.Growable[_] => + val growable = result.asInstanceOf[mutable.Iterable[A] with mutable.Growable[A]] + copyContext.reference(value, growable.asInstanceOf[T]) + copyElements(copyContext, value, growable) + growable.asInstanceOf[T] + case _ => + copyWithBuilder(copyContext, value, value.iterableFactory.newBuilder[A]) + } } else { - super.copy(copyContext, value) + copyWithBuilder(copyContext, value, newBuilder(value.size)) } } + + protected def newMutableCopy(value: T, numElements: Int): scala.collection.Iterable[A] = { + val builder = value.iterableFactory.newBuilder[A] + builder.sizeHint(numElements) + builder.result() + } + + private def copyElements( + copyContext: CopyContext, + value: T, + result: mutable.Growable[A]): Unit = { + val iterator = value.iterator + while (iterator.hasNext) { + result.addOne(copyContext.copyObject(iterator.next()).asInstanceOf[A]) + } + } + + private def copyWithBuilder( + copyContext: CopyContext, + value: T, + builder: mutable.Builder[A, _ <: scala.collection.Iterable[A]]): T = { + val iterator = value.iterator + while (iterator.hasNext) { + builder.addOne(copyContext.copyObject(iterator.next()).asInstanceOf[A]) + } + val result = builder.result().asInstanceOf[T] + copyContext.reference(value, result) + result + } + + private def copyIndexedSeq( + copyContext: CopyContext, + value: T, + indexed: mutable.IndexedSeq[A]): T = { + val result = indexed match { + case arraySeq: mutable.ArraySeq[_] => + val sourceArray = arraySeq.array.asInstanceOf[AnyRef] + val array = + java.lang.reflect.Array.newInstance(sourceArray.getClass.getComponentType, indexed.size) + mutable.ArraySeq.make(array.asInstanceOf[Array[_]]).asInstanceOf[mutable.IndexedSeq[A]] + case _ => + mutable.ArraySeq.make(new Array[Any](indexed.size)).asInstanceOf[mutable.IndexedSeq[A]] + } + val copied = result.asInstanceOf[T] + copyContext.reference(value, copied) + var i = 0 + while (i < indexed.size) { + result.update(i, copyContext.copyObject(indexed(i)).asInstanceOf[A]) + i += 1 + } + copied + } } class ScalaXlangSeqSerializer[A, T <: scala.collection.Seq[A]]( @@ -68,6 +137,7 @@ class ScalaXlangSeqSerializer[A, T <: scala.collection.Seq[A]]( builder.sizeHint(numElements) builder.asInstanceOf[mutable.Builder[A, T]] } + } class ScalaXlangSetSerializer[A, T <: scala.collection.Set[A]]( @@ -79,6 +149,7 @@ class ScalaXlangSetSerializer[A, T <: scala.collection.Set[A]]( builder.sizeHint(numElements) builder.asInstanceOf[mutable.Builder[A, T]] } + } class ScalaXlangCollectionSerializer[A, T <: scala.collection.Iterable[A]]( @@ -90,6 +161,7 @@ class ScalaXlangCollectionSerializer[A, T <: scala.collection.Iterable[A]]( builder.sizeHint(numElements) builder.asInstanceOf[mutable.Builder[A, T]] } + } private final class XlangCollectionAdapter[A](coll: scala.collection.Iterable[A]) @@ -142,6 +214,64 @@ abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K } override def onMapCopy(map: util.Map[_, _]): T = onMapRead(map) + + override def copy(copyContext: CopyContext, value: T): T = { + if (isImmutable) { + value + } else if (value.isInstanceOf[mutable.Map[_, _]]) { + newMutableMapCopy(value, value.size) match { + case result: mutable.Map[_, _] => + val mutableResult = result.asInstanceOf[mutable.Map[K, V]] + copyContext.reference(value, mutableResult.asInstanceOf[T]) + copyEntries(copyContext, value, mutableResult) + mutableResult.asInstanceOf[T] + case _ => + copyWithBuilder(copyContext, value, value.mapFactory.newBuilder[K, V]) + } + } else { + val builder = simmutable.Map.newBuilder[K, V] + builder.sizeHint(value.size) + copyWithBuilder(copyContext, value, builder) + } + } + + private def newMutableMapCopy(value: T, numElements: Int): scala.collection.Map[K, V] = { + val builder = value.mapFactory.newBuilder[K, V] + builder.sizeHint(numElements) + builder.result() + } + + private def copyEntries( + copyContext: CopyContext, + value: T, + result: mutable.Map[K, V]): Unit = { + val iterator = value.iterator + while (iterator.hasNext) { + val entry = iterator.next() + result.addOne( + ( + copyContext.copyObject(entry._1).asInstanceOf[K], + copyContext.copyObject(entry._2).asInstanceOf[V])) + } + } + + private def copyWithBuilder( + copyContext: CopyContext, + value: T, + builder: mutable.Builder[(K, V), _ <: scala.collection.Map[K, V]]) + : T = { + val iterator = value.iterator + while (iterator.hasNext) { + val entry = iterator.next() + builder.addOne( + ( + copyContext.copyObject(entry._1).asInstanceOf[K], + copyContext.copyObject(entry._2).asInstanceOf[V])) + } + val result = builder.result().asInstanceOf[T] + copyContext.reference(value, result) + result + } } class ScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K, V]]( diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala index bd4b4ce77d..1cfcbe6cbb 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala @@ -47,6 +47,13 @@ object ForySerializerDerivationTest { @ForyField(id = 3) scores: Map[String, Int]) derives ForySerializer + @ForyStruct + final case class CopyBox( + @ForyField(id = 1) user: SearchUser, + @ForyField(id = 2) names: List[String], + @ForyField(id = 3) values: Array[Int]) + derives ForySerializer + @ForyStruct final class RefNode() derives ForySerializer { @ForyField(id = 1) @@ -57,6 +64,16 @@ object ForySerializerDerivationTest { var parent: Option[RefNode @Ref] = None } + @ForyStruct + final class UnionRefNode() derives ForySerializer { + @ForyField(id = 1) + var name: String = "" + + @Ref + @ForyField(id = 2) + var choice: Option[UnionCycle @Ref] = None + } + @ForyStruct final class MixedRecord(@ForyField(id = 1) val id: Int) derives ForySerializer { @ForyField(id = 2) @@ -73,12 +90,25 @@ object ForySerializerDerivationTest { @ForyCase(id = 2) case FixedIdCase(value: Int) + + @ForyCase(id = 3) + case OptionalUserCase(value: Option[SearchUser]) + } + + @ForyUnion + enum UnionCycle derives ForySerializer { + @ForyCase(id = 0) + case UnknownCase(caseId: Int, value: Any) + + @ForyCase(id = 1) + case NodeCase(value: UnionRefNode) } def xlangFory(): Fory = { val fory = Fory.builder() .withXlang(true) .withRefTracking(true) + .withRefCopy(true) .withScalaOptimizationEnabled(true) .requireClassRegistration(true) .suppressClassRegistrationWarnings(false) @@ -87,8 +117,12 @@ object ForySerializerDerivationTest { ForySerializer.register(fory, classOf[Person], "scala_test", "Person") ForySerializer.register(fory, classOf[SearchUser], "scala_test", "SearchUser") ForySerializer.register(fory, classOf[CollectionBox], "scala_test", "CollectionBox") + ForySerializer.register(fory, classOf[CopyBox], "scala_test", "CopyBox") + ForySerializer.register(fory, classOf[RefNode], "scala_test", "RefNode") + ForySerializer.register(fory, classOf[UnionRefNode], "scala_test", "UnionRefNode") ForySerializer.register(fory, classOf[MixedRecord], "scala_test", "MixedRecord") ForySerializer.register(fory, classOf[SearchTarget], "scala_test", "SearchTarget") + ForySerializer.register(fory, classOf[UnionCycle], "scala_test", "UnionCycle") fory } } @@ -147,5 +181,110 @@ class ForySerializerDerivationTest extends AnyWordSpec with Matchers { val unknown = SearchTarget.UnknownCase(99, SearchUser("Future")) fory.deserialize(fory.serialize(unknown)) shouldEqual unknown } + + "serialize and copy derived union Option payloads" in { + val fory = xlangFory() + val some: SearchTarget.OptionalUserCase = + SearchTarget.OptionalUserCase(Some(SearchUser("Ada"))) + val none: SearchTarget.OptionalUserCase = SearchTarget.OptionalUserCase(None) + + fory.deserialize(fory.serialize(some)) shouldEqual some + fory.deserialize(fory.serialize(none)) shouldEqual none + + val copiedSome = fory.copy(some).asInstanceOf[SearchTarget.OptionalUserCase] + val copiedNone = fory.copy(none).asInstanceOf[SearchTarget.OptionalUserCase] + + copiedSome shouldEqual some + copiedSome should not be theSameInstanceAs(some) + copiedSome.value.get should not be theSameInstanceAs(some.value.get) + copiedNone shouldEqual none + copiedNone should not be theSameInstanceAs(none) + } + + "copy derived case classes through field serializers" in { + val fory = xlangFory() + val box = CopyBox(SearchUser("Ada"), List("compiler", "runtime"), Array(1, 2, 3)) + + val copied = fory.copy(box) + + copied should not be theSameInstanceAs(box) + copied.user shouldEqual box.user + copied.user should not be theSameInstanceAs(box.user) + copied.names shouldEqual box.names + copied.names should not be theSameInstanceAs(box.names) + copied.values.sameElements(box.values) shouldBe true + copied.values should not be theSameInstanceAs(box.values) + } + + "copy derived normal classes with ref cycles" in { + val fory = xlangFory() + val root = new RefNode() + val child = new RefNode() + child.parent = Some(root) + root.children = List(child) + + val copied = fory.copy(root) + + copied should not be theSameInstanceAs(root) + copied.children.head should not be theSameInstanceAs(child) + copied.children.head.parent.get shouldBe theSameInstanceAs(copied) + } + + "copy cyclic graphs rooted at mutable classes with union edges" in { + val fory = xlangFory() + val root = new UnionRefNode() + root.name = "root" + val choice = UnionCycle.NodeCase(root) + root.choice = Some(choice) + + val copied = fory.copy(root) + + copied should not be theSameInstanceAs(root) + copied.name shouldBe "root" + copied.choice.get should not be theSameInstanceAs(choice) + copied.choice.get match { + case UnionCycle.NodeCase(value) => value shouldBe theSameInstanceAs(copied) + case other => fail(s"Unexpected copied union case $other") + } + } + + "reject cyclic copies rooted at immutable union values" in { + val fory = xlangFory() + val root = new UnionRefNode() + val choice = UnionCycle.NodeCase(root) + root.choice = Some(choice) + + val error = intercept[org.apache.fory.exception.CopyException] { + fory.copy(choice) + } + + error.getMessage should include("constructor-owned immutable value") + error.getMessage should include(classOf[UnionCycle.NodeCase].getName) + } + + "copy derived union cases through payload serializers" in { + val fory = xlangFory() + val target = SearchTarget.UserCase(SearchUser("Ada")) + + val copied = fory.copy(target) + + copied shouldEqual target + copied should not be theSameInstanceAs(target) + copied.asInstanceOf[SearchTarget.UserCase].value should not be theSameInstanceAs( + target.asInstanceOf[SearchTarget.UserCase].value) + } + + "copy derived union unknown cases" in { + val fory = xlangFory() + val unknown: SearchTarget.UnknownCase = + SearchTarget.UnknownCase(99, SearchUser("Future")) + + val copied = fory.copy(unknown).asInstanceOf[SearchTarget.UnknownCase] + + copied.caseId shouldBe 99 + copied.value.asInstanceOf[SearchUser] shouldEqual unknown.value.asInstanceOf[SearchUser] + copied.value.asInstanceOf[SearchUser] should not be theSameInstanceAs( + unknown.value.asInstanceOf[SearchUser]) + } } } diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala index 8769df203b..d285d12455 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala @@ -30,6 +30,7 @@ class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { val runtime = Fory.builder() .withXlang(true) .withRefTracking(true) + .withRefCopy(true) .withScalaOptimizationEnabled(true) .requireClassRegistration(false) .suppressClassRegistrationWarnings(false) @@ -61,5 +62,64 @@ class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { .toMap shouldEqual map } + "copy mutable collections with cyclic references" in { + val runtime = fory + val list = scala.collection.mutable.ArrayBuffer.empty[AnyRef] + list += list + + val copiedList = runtime.copy(list).asInstanceOf[scala.collection.Seq[AnyRef]] + + copiedList should not be theSameInstanceAs(list) + copiedList.head shouldBe theSameInstanceAs(copiedList) + } + + "copy mutable maps with cyclic references" in { + val runtime = fory + val map = scala.collection.mutable.LinkedHashMap.empty[String, AnyRef] + map.put("self", map) + + val copiedMap = runtime.copy(map).asInstanceOf[scala.collection.Map[String, AnyRef]] + + copiedMap should not be theSameInstanceAs(map) + copiedMap("self") shouldBe theSameInstanceAs(copiedMap) + } + + "copy concrete mutable collection classes" in { + val runtime = fory + val set = scala.collection.mutable.HashSet("a", "b") + val map = scala.collection.mutable.HashMap("a" -> "b", "c" -> "d") + + val copiedSet = runtime.copy(set).asInstanceOf[scala.collection.mutable.HashSet[String]] + val copiedMap = + runtime.copy(map).asInstanceOf[scala.collection.mutable.HashMap[String, String]] + + copiedSet shouldEqual set + copiedSet should not be theSameInstanceAs(set) + copiedMap shouldEqual map + copiedMap should not be theSameInstanceAs(map) + } + + "copy fixed-size mutable collections" in { + val runtime = fory + val arraySeq = scala.collection.mutable.ArraySeq("a", "b") + val intArraySeq = scala.collection.mutable.ArraySeq(1, 2) + val cyclic = scala.collection.mutable.ArraySeq[AnyRef](null) + cyclic.update(0, cyclic) + + val copied = + runtime.copy(arraySeq).asInstanceOf[scala.collection.mutable.ArraySeq[String]] + val copiedIntArraySeq = + runtime.copy(intArraySeq).asInstanceOf[scala.collection.mutable.ArraySeq[Int]] + val copiedCyclic = + runtime.copy(cyclic).asInstanceOf[scala.collection.mutable.ArraySeq[AnyRef]] + + copied shouldEqual arraySeq + copied should not be theSameInstanceAs(arraySeq) + copiedIntArraySeq shouldEqual intArraySeq + copiedIntArraySeq should not be theSameInstanceAs(intArraySeq) + copiedIntArraySeq.getClass shouldBe intArraySeq.getClass + copiedCyclic should not be theSameInstanceAs(cyclic) + copiedCyclic(0) shouldBe theSameInstanceAs(copiedCyclic) + } } } From 4f0b61c5335777eab9f24d4e5d6495474a5c37a5 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 15 May 2026 17:12:25 +0800 Subject: [PATCH 7/9] feat(scala): support xlang schema idl --- compiler/README.md | 4 +- compiler/fory_compiler/generators/java.py | 9 +- compiler/fory_compiler/generators/scala.py | 177 +++++++-- compiler/fory_compiler/ir/construction.py | 82 +++- compiler/fory_compiler/ir/validator.py | 5 + compiler/fory_compiler/tests/test_auto_id.py | 18 + .../tests/test_generated_code.py | 4 + .../tests/test_scala_generator.py | 118 +++++- docs/compiler/generated-code.md | 18 +- docs/compiler/schema-idl.md | 8 +- docs/guide/java/field-configuration.md | 40 +- .../kotlin/static-generated-serializers.md | 4 +- docs/guide/scala/schema-idl.md | 30 +- docs/guide/xlang/field-reference-tracking.md | 11 +- .../specification/xlang_serialization_spec.md | 8 +- docs/specification/xlang_type_mapping.md | 6 +- .../idl_tests/ScalaIdlRoundTripTest.scala | 38 +- .../processing/ForyStructProcessor.java | 32 +- .../annotation/processing/SourceField.java | 3 + .../StaticSerializerSourceWriter.java | 2 + .../processing/ForyStructProcessorTest.java | 57 ++- .../org/apache/fory/annotation/ForyField.java | 8 - .../java/org/apache/fory/annotation/Ref.java | 5 +- .../fory/builder/BaseObjectCodecBuilder.java | 11 +- .../org/apache/fory/context/CopyContext.java | 32 +- .../java/org/apache/fory/meta/FieldInfo.java | 4 +- .../java/org/apache/fory/meta/FieldTypes.java | 80 ++-- .../org/apache/fory/meta/TypeExtMeta.java | 25 +- .../apache/fory/resolver/ClassResolver.java | 40 -- .../StaticGeneratedSerializerRegistry.java | 38 -- .../apache/fory/resolver/TypeResolver.java | 128 +++--- .../apache/fory/resolver/XtypeResolver.java | 40 -- .../apache/fory/serializer/FieldGroups.java | 6 +- .../StaticGeneratedStructSerializer.java | 51 ++- ...taticGeneratedStructSerializerFactory.java | 49 --- .../fory/serializer/struct/Fingerprint.java | 6 +- .../java/org/apache/fory/type/Descriptor.java | 115 +++++- .../apache/fory/type/DescriptorBuilder.java | 9 +- .../java/org/apache/fory/type/ScalaTypes.java | 25 ++ .../ForyFieldSerializationTest.java | 64 +-- .../apache/fory/annotation/ForyFieldTest.java | 71 +++- .../StaticCompatibleCodecBuilderTest.java | 34 +- .../serializer/struct/FingerprintTest.java | 24 ++ .../org/apache/fory/xlang/XlangTestBase.java | 18 +- .../integration_tests/RecordXlangTest.java | 5 +- .../kotlin/ksp/ForyKotlinSymbolProcessor.kt | 116 +++++- .../org/apache/fory/kotlin/ksp/Model.kt | 3 +- .../org/apache/fory/scala/ForyScalaEnum.java | 25 -- .../serializer/scala/ScalaDispatcher.java | 22 +- .../serializer/scala/ScalaEnumSerializer.java | 122 +++++- .../serializer/scala/ScalaSerializers.java | 45 ++- .../apache/fory/scala/ForySerializer.scala | 37 +- .../scala/internal/ForySerializerMacros.scala | 368 +++++++++++++----- .../scala/XlangCollectionSerializer.scala | 259 +++++++++++- .../scala/ForySerializerDerivationTest.scala | 229 ++++++++++- .../fory/serializer/scala/ScalaEnumTest.scala | 33 ++ .../serializer/scala/ScalaXlangPeer.scala | 38 +- 57 files changed, 2059 insertions(+), 800 deletions(-) delete mode 100644 java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializerFactory.java delete mode 100644 scala/src/main/java/org/apache/fory/scala/ForyScalaEnum.java diff --git a/compiler/README.md b/compiler/README.md index 4d78589996..fb443ea4d9 100644 --- a/compiler/README.md +++ b/compiler/README.md @@ -318,12 +318,12 @@ Each generator extends `BaseGenerator` and implements: Generates POJOs with: - Private fields with getters/setters -- `@Nullable` annotations for nullable fields and `@ForyField` annotations for ref fields +- `@Nullable` annotations for nullable fields and `@Ref` annotations for ref fields - Registration helper class ```java public class Cat { - @ForyField(ref = true) + @Ref private Dog friend; @Nullable diff --git a/compiler/fory_compiler/generators/java.py b/compiler/fory_compiler/generators/java.py index e55847ce75..2f119d7d46 100644 --- a/compiler/fory_compiler/generators/java.py +++ b/compiler/fory_compiler/generators/java.py @@ -1149,9 +1149,6 @@ def generate_field( nullable = field.optional or is_any if field.tag_id is not None: annotations.append(f"id = {field.tag_id}") - if field.ref: - annotations.append("ref = true") - if annotations: lines.append(f"@ForyField({', '.join(annotations)})") @@ -1180,6 +1177,8 @@ def generate_field( ) if nullable: java_type = self.apply_top_level_type_use_annotation(java_type, "@Nullable") + if field.ref: + java_type = self.apply_top_level_type_use_annotation(java_type, "@Ref") lines.append(f"private {java_type} {self.to_camel_case(field.name)};") lines.append("") @@ -1449,7 +1448,9 @@ def collect_field_imports( self.collect_array_imports(field, imports) if nullable: imports.add("org.apache.fory.annotation.Nullable") - if field.ref or field.tag_id is not None: + if field.ref: + imports.add("org.apache.fory.annotation.Ref") + if field.tag_id is not None: imports.add("org.apache.fory.annotation.ForyField") def is_ref_target_type( diff --git a/compiler/fory_compiler/generators/scala.py b/compiler/fory_compiler/generators/scala.py index 15126a0a85..39e3858919 100644 --- a/compiler/fory_compiler/generators/scala.py +++ b/compiler/fory_compiler/generators/scala.py @@ -20,7 +20,7 @@ from __future__ import annotations from pathlib import Path -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from fory_compiler.generators.base import BaseGenerator, GeneratedFile from fory_compiler.frontend.utils import parse_idl_file @@ -258,12 +258,7 @@ def generate(self) -> List[GeneratedFile]: return files def generate_enum_file(self, enum: Enum) -> GeneratedFile: - lines = self.source_header( - { - "org.apache.fory.annotation.ForyEnumId", - "org.apache.fory.scala.ForyScalaEnum", - } - ) + lines = self.source_header({"org.apache.fory.annotation.ForyEnumId"}) comment = self.format_type_id_comment(enum, "//") if comment: lines.append(comment) @@ -315,17 +310,13 @@ def source_file(self, type_name: str, lines: List[str]) -> GeneratedFile: def generate_enum(self, enum: Enum, indent: int = 0) -> List[str]: ind = self.indent_str * indent - lines = [f"{ind}enum {enum.name}(val foryId: Int) extends ForyScalaEnum {{"] + lines = [f"{ind}enum {enum.name} {{"] for value in enum.values: case_name = self.safe_identifier( self.to_pascal_case(self.strip_enum_prefix(enum.name, value.name)) ) - lines.append( - f"{ind} case {case_name} extends {enum.name}({value.value})" - ) - lines.append("") - lines.append(f"{ind} @ForyEnumId") - lines.append(f"{ind} def getForyId: Int = foryId") + lines.append(f"{ind} @ForyEnumId({value.value})") + lines.append(f"{ind} case {case_name}") lines.append(f"{ind}}}") lines.append("") return lines @@ -658,9 +649,9 @@ def collect_message_imports(self, message: Message, imports: Set[str]) -> None: imports.add("org.apache.fory.annotation.Ref") for enum in message.nested_enums: imports.add("org.apache.fory.annotation.ForyEnumId") - imports.add("org.apache.fory.scala.ForyScalaEnum") for union in message.nested_unions: imports.add("org.apache.fory.annotation.{ForyCase, ForyUnion}") + self.collect_union_imports(union, imports) for nested in message.nested_messages: self.collect_message_imports(nested, imports) @@ -799,37 +790,134 @@ def generate_registration_file(self) -> GeneratedFile: lines.append("") lines.append(" def register(fory: Fory): Unit = {") lines.append(" ScalaSerializers.registerSerializers(fory)") - for enum in self.schema.enums: - if self.is_imported_type(enum): - continue - self.generate_type_registration(lines, enum) - for message in self.schema.messages: - if self.is_imported_type(message): - continue - self.generate_type_registration(lines, message) - self.generate_nested_registration(lines, message.name, message) - for union in self.schema.unions: - if self.is_imported_type(union): - continue - self.generate_type_registration(lines, union) + registrations = self.registration_order() + for type_def, owner_path in registrations: + if isinstance(type_def, Message): + self.generate_type_registration(lines, type_def, owner_path, type_only=True) + for type_def, owner_path in registrations: + if isinstance(type_def, Message): + self.generate_serializer_registration(lines, type_def, owner_path) + else: + self.generate_type_registration(lines, type_def, owner_path) lines.append(" }") lines.append("}") return self.source_file(class_name, lines) - def generate_nested_registration( - self, lines: List[str], owner_path: str, message: Message + def registration_order(self) -> List[Tuple[object, Optional[str]]]: + entries: List[Tuple[object, Optional[str], List[Message]]] = [] + + def message_path(messages: List[Message]) -> str: + return ".".join(message.name for message in messages) + + def add_message(message: Message, parent_stack: List[Message]) -> None: + owner_path = ".".join(owner.name for owner in parent_stack) or None + entries.append((message, owner_path, parent_stack)) + current_stack = [*parent_stack, message] + for enum in message.nested_enums: + entries.append((enum, message_path(current_stack), current_stack)) + for union in message.nested_unions: + entries.append((union, message_path(current_stack), current_stack)) + for nested in message.nested_messages: + add_message(nested, current_stack) + + for enum in self.schema.enums: + entries.append((enum, None, [])) + for union in self.schema.unions: + entries.append((union, None, [])) + for message in self.schema.messages: + add_message(message, []) + + local_entries: Dict[int, Tuple[object, Optional[str], List[Message]]] = { + id(type_def): (type_def, owner_path, parent_stack) + for type_def, owner_path, parent_stack in entries + if not self.is_imported_type(type_def) + } + ordered: List[Tuple[object, Optional[str]]] = [] + visiting: Set[int] = set() + visited: Set[int] = set() + + def visit(type_def: object) -> None: + key = id(type_def) + if key in visited or key not in local_entries: + return + if key in visiting: + return + visiting.add(key) + _, _, parent_stack = local_entries[key] + for dependency in self.registration_dependencies(type_def, parent_stack): + visit(dependency) + visiting.remove(key) + visited.add(key) + current, owner_path, _ = local_entries[key] + ordered.append((current, owner_path)) + + for type_def, _, _ in entries: + visit(type_def) + return ordered + + def registration_dependencies( + self, type_def: object, parent_stack: List[Message] + ) -> List[object]: + dependencies: List[object] = [] + if isinstance(type_def, Message): + lookup_stack = [*parent_stack, type_def] + for field in type_def.fields: + self.collect_registration_dependencies( + field.field_type, lookup_stack, dependencies + ) + elif isinstance(type_def, Union): + for field in type_def.fields: + self.collect_registration_dependencies( + field.field_type, parent_stack, dependencies + ) + return [dependency for dependency in dependencies if dependency is not type_def] + + def collect_registration_dependencies( + self, + field_type: FieldType, + parent_stack: List[Message], + dependencies: List[object], ) -> None: - for enum in message.nested_enums: - self.generate_type_registration(lines, enum, owner_path) - for nested in message.nested_messages: - nested_path = f"{owner_path}.{nested.name}" - self.generate_type_registration(lines, nested, owner_path) - self.generate_nested_registration(lines, nested_path, nested) - for union in message.nested_unions: - self.generate_type_registration(lines, union, owner_path) + if isinstance(field_type, NamedType): + dependency = self.resolve_named_type(field_type.name, parent_stack) + if dependency is not None and dependency not in dependencies: + dependencies.append(dependency) + return + if isinstance(field_type, ListType): + self.collect_registration_dependencies( + field_type.element_type, parent_stack, dependencies + ) + return + if isinstance(field_type, ArrayType): + self.collect_registration_dependencies( + field_type.element_type, parent_stack, dependencies + ) + return + if isinstance(field_type, MapType): + self.collect_registration_dependencies( + field_type.key_type, parent_stack, dependencies + ) + self.collect_registration_dependencies( + field_type.value_type, parent_stack, dependencies + ) + + def resolve_named_type( + self, name: str, parent_stack: List[Message] + ) -> Optional[object]: + if "." in name: + return self.schema.get_type(name) + for index in range(len(parent_stack) - 1, -1, -1): + nested = parent_stack[index].get_nested_type(name) + if nested is not None: + return nested + return self.schema.get_type(name) def generate_type_registration( - self, lines: List[str], type_def, owner_path: Optional[str] = None + self, + lines: List[str], + type_def, + owner_path: Optional[str] = None, + type_only: bool = False, ) -> None: class_ref = f"{owner_path}.{type_def.name}" if owner_path else type_def.name namespace = self.schema.package or "default" @@ -846,14 +934,21 @@ def generate_type_registration( f' ScalaSerializers.registerEnum(fory, classOf[{class_ref}], "{namespace}", "{type_name}")' ) return + method = "registerType" if type_only else "register" if self.should_register_by_id(type_def): lines.append( - f" ForySerializer.register(fory, classOf[{class_ref}], {type_def.type_id}L)" + f" ForySerializer.{method}(fory, classOf[{class_ref}], {type_def.type_id}L)" ) else: lines.append( - f' ForySerializer.register(fory, classOf[{class_ref}], "{namespace}", "{type_name}")' + f' ForySerializer.{method}(fory, classOf[{class_ref}], "{namespace}", "{type_name}")' ) + def generate_serializer_registration( + self, lines: List[str], type_def, owner_path: Optional[str] = None + ) -> None: + class_ref = f"{owner_path}.{type_def.name}" if owner_path else type_def.name + lines.append(f" ForySerializer.registerSerializer(fory, classOf[{class_ref}])") + def safe_identifier(self, name: str) -> str: return f"`{name}`" if name in self.RESERVED else name diff --git a/compiler/fory_compiler/ir/construction.py b/compiler/fory_compiler/ir/construction.py index 4120935e2b..b9f9125dcb 100644 --- a/compiler/fory_compiler/ir/construction.py +++ b/compiler/fory_compiler/ir/construction.py @@ -44,6 +44,14 @@ class MessageConstructionShape: cycle_owned: bool +@dataclass(frozen=True) +class _Dependency: + """Message/union dependency used by construction-shape analysis.""" + + name: str + constructor_owned: bool + + def analyze_message_construction_shapes( schema: Schema, ) -> Dict[str, MessageConstructionShape]: @@ -58,16 +66,31 @@ def analyze_message_construction_shapes( schema.messages, schema.unions, schema.enums ) messages = {name: message for name, message, _ in message_entries} - graph = {} + graph: Dict[str, Set[str]] = {} + constructor_graph: Dict[str, Set[str]] = {} for name, message, parent_paths in message_entries: - graph[name] = set( - _field_dependencies(message.fields, types, (*parent_paths, message.name)) + dependencies = list( + _field_dependencies( + message.fields, types, (*parent_paths, message.name), False + ) ) + graph[name] = {dependency.name for dependency in dependencies} + constructor_graph[name] = { + dependency.name + for dependency in dependencies + if dependency.constructor_owned + } for name, union, parent_paths in union_entries: - graph[name] = set( - _field_dependencies(union.fields, types, (*parent_paths, union.name)) + dependencies = list( + _field_dependencies(union.fields, types, (*parent_paths, union.name), False) ) - cycle_owned = _cycle_nodes(graph) + graph[name] = {dependency.name for dependency in dependencies} + constructor_graph[name] = { + dependency.name + for dependency in dependencies + if dependency.constructor_owned + } + cycle_owned = _cycle_nodes(graph, constructor_graph) return { name: MessageConstructionShape(cycle_owned=name in cycle_owned) for name in messages @@ -112,35 +135,48 @@ def _collect_types( def _field_dependencies( - fields: Iterable[Field], types: Dict[str, object], current_path: Tuple[str, ...] -) -> Iterable[str]: + fields: Iterable[Field], + types: Dict[str, object], + current_path: Tuple[str, ...], + nested: bool, +) -> Iterable[_Dependency]: for field in fields: - yield from _field_type_dependencies(field.field_type, types, current_path) + yield from _field_type_dependencies( + field.field_type, types, current_path, nested + ) def _field_type_dependencies( - field_type: FieldType, types: Dict[str, object], current_path: Tuple[str, ...] -) -> Iterable[str]: + field_type: FieldType, + types: Dict[str, object], + current_path: Tuple[str, ...], + nested: bool, +) -> Iterable[_Dependency]: if isinstance(field_type, PrimitiveType): return if isinstance(field_type, NamedType): resolved = _resolve_type_name(field_type.name, types, current_path) if resolved is not None and isinstance(resolved[1], (Message, Union)): - yield resolved[0] + yield _Dependency(resolved[0], constructor_owned=not nested) return if isinstance(field_type, ListType): yield from _field_type_dependencies( - field_type.element_type, types, current_path + field_type.element_type, types, current_path, True ) return if isinstance(field_type, ArrayType): yield from _field_type_dependencies( - field_type.element_type, types, current_path + field_type.element_type, types, current_path, True ) return if isinstance(field_type, MapType): - yield from _field_type_dependencies(field_type.key_type, types, current_path) - yield from _field_type_dependencies(field_type.value_type, types, current_path) + yield from _field_type_dependencies( + field_type.key_type, types, current_path, True + ) + yield from _field_type_dependencies( + field_type.value_type, types, current_path, True + ) + return def _resolve_type_name( @@ -160,7 +196,9 @@ def _resolve_type_name( return None -def _cycle_nodes(graph: Dict[str, Set[str]]) -> Set[str]: +def _cycle_nodes( + graph: Dict[str, Set[str]], constructor_graph: Dict[str, Set[str]] +) -> Set[str]: index = 0 stack: List[str] = [] on_stack: Set[str] = set() @@ -194,8 +232,14 @@ def strong_connect(node: str) -> None: if current == node: break if len(component) > 1: - result.update(component) - elif component[0] in graph[component[0]]: + component_set = set(component) + if any( + target in component_set + for current in component + for target in constructor_graph[current] + ): + result.update(component) + elif component[0] in constructor_graph[component[0]]: result.add(component[0]) for node in graph: diff --git a/compiler/fory_compiler/ir/validator.py b/compiler/fory_compiler/ir/validator.py index eccd99ce89..e335ae96d8 100644 --- a/compiler/fory_compiler/ir/validator.py +++ b/compiler/fory_compiler/ir/validator.py @@ -270,6 +270,11 @@ def validate_union(union: Union, parent_path: str = ""): case_numbers = {} case_names = {} for f in union.fields: + if f.number <= 0: + self._error( + f"Union case id {f.number} in {full_name} is reserved; use ids starting from 1", + f.location, + ) if f.number in case_numbers: self._error( f"Duplicate union case id {f.number} in {full_name}: {f.name} and {case_numbers[f.number].name}", diff --git a/compiler/fory_compiler/tests/test_auto_id.py b/compiler/fory_compiler/tests/test_auto_id.py index cee482c5c4..4d196b9879 100644 --- a/compiler/fory_compiler/tests/test_auto_id.py +++ b/compiler/fory_compiler/tests/test_auto_id.py @@ -58,6 +58,24 @@ def test_auto_id_generation_for_message_and_union(): assert union.id_source == "demo.Item" +def test_union_case_ids_must_start_from_one(): + source = """ + package demo; + + union Bad { + string zero = 0; + string negative = -1; + } + """ + schema = parse_schema(source) + validator = SchemaValidator(schema) + + assert not validator.validate() + messages = [issue.message for issue in validator.errors] + assert any("Union case id 0" in message for message in messages) + assert any("Union case id -1" in message for message in messages) + + def test_alias_used_for_auto_id(): source = """ package demo; diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index af840808c8..b4fde196ba 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -583,6 +583,10 @@ def test_generated_code_tree_ref_options_equivalent(): ) assert 'Parent *TreeNode `fory:"id=4,nullable=false,ref"`' in go_output + java_output = render_files(generate_files(schemas["fdl"], JavaGenerator)) + assert "import org.apache.fory.annotation.Ref;" in java_output + assert "private @Ref TreeNode parent;" in java_output + def test_java_float16_equals_hash_contract_generation(): schema = parse_fdl( diff --git a/compiler/fory_compiler/tests/test_scala_generator.py b/compiler/fory_compiler/tests/test_scala_generator.py index 5c9a21e48d..7266e65038 100644 --- a/compiler/fory_compiler/tests/test_scala_generator.py +++ b/compiler/fory_compiler/tests/test_scala_generator.py @@ -64,10 +64,12 @@ def test_scala_generator_emits_case_classes_options_enums_and_unions(): assert "derives ForySerializer" in user status = files["demo/Status.scala"] - assert "enum Status(val foryId: Int)" in status - assert "case Unknown extends Status(0)" in status - assert "case Ok extends Status(7)" in status - assert "@ForyEnumId" in status + assert "enum Status {" in status + assert "@ForyEnumId(0)" in status + assert "case Unknown" in status + assert "@ForyEnumId(7)" in status + assert "case Ok" in status + assert "ForyScalaEnum" not in status union = files["demo/SearchTarget.scala"] assert "@ForyUnion" in union @@ -122,6 +124,51 @@ def test_scala_generator_uses_mutable_normal_class_for_nested_construction_cycle assert "var parent: Option[Envelope.Node @Ref] = None" in envelope +def test_scala_generator_keeps_container_recursive_messages_as_case_classes(): + files = generate_scala( + """ + package graph; + + message Node [id=125] { + string id = 1; + list children = 2; + map lookup = 3; + } + """ + ) + + node = files["graph/Node.scala"] + assert "final case class Node(" in node + assert "children: List[Node @Ref]" in node + assert "lookup: Map[String, Node]" in node + assert "final class Node() derives ForySerializer" not in node + + +def test_scala_generator_marks_container_cycle_with_constructor_edge_mutable(): + files = generate_scala( + """ + package graph; + + message Node [id=126] { + string id = 1; + list edges = 2; + } + + message Edge [id=127] { + string id = 1; + ref Node owner = 2; + } + """ + ) + + node = files["graph/Node.scala"] + edge = files["graph/Edge.scala"] + assert "final class Node() derives ForySerializer" in node + assert "var edges: List[Edge @Ref] = List.empty" in node + assert "final class Edge() derives ForySerializer" in edge + assert "var owner: Option[Node @Ref] = None" in edge + + def test_scala_generator_marks_nested_owner_child_cycles_mutable(): files = generate_scala( """ @@ -167,6 +214,37 @@ def test_scala_generator_marks_union_mediated_cycles_mutable(): assert "var choice: Choice @Ref = null" in node +def test_scala_generator_collects_nested_union_payload_imports(): + files = generate_scala( + """ + package demo; + + message Envelope [id=150] { + message User [id=151] { + string name = 1; + } + + union Target [id=152] { + fixed int32 fixed_id = 1; + list users = 2; + } + + Target target = 1; + } + """ + ) + + envelope = files["demo/Envelope.scala"] + assert "import org.apache.fory.annotation.Int32Type" in envelope + assert "import org.apache.fory.annotation.Ref" in envelope + assert "import org.apache.fory.config.Int32Encoding" in envelope + assert ( + "case FixedIdCase(value: Int @Int32Type(encoding = Int32Encoding.FIXED))" + in envelope + ) + assert "case UsersCase(value: List[Envelope.User @Ref])" in envelope + + def test_scala_generator_marks_nested_union_mediated_cycles_mutable(): files = generate_scala( """ @@ -283,13 +361,14 @@ def test_scala_generator_uses_jvm_nested_names_for_name_registration(): registration = files["demo/DemoForyRegistration.scala"] assert ( - 'ForySerializer.register(fory, classOf[Envelope], "demo", "Envelope")' + 'ForySerializer.registerType(fory, classOf[Envelope], "demo", "Envelope")' in registration ) assert ( - 'ForySerializer.register(fory, classOf[Envelope.Payload], "demo.Envelope", "Payload")' + 'ForySerializer.registerType(fory, classOf[Envelope.Payload], "demo.Envelope", "Payload")' in registration ) + assert "ForySerializer.registerSerializer(fory, classOf[Envelope.Payload])" in registration assert ( 'ScalaSerializers.registerEnum(fory, classOf[Envelope.Kind], "demo.Envelope", "Kind")' in registration @@ -298,6 +377,33 @@ def test_scala_generator_uses_jvm_nested_names_for_name_registration(): 'ForySerializer.register(fory, classOf[Envelope.Choice], "demo.Envelope", "Choice")' in registration ) + assert "ForySerializer.registerSerializer(fory, classOf[Envelope])" in registration + + +def test_scala_generator_pre_registers_message_type_graph_before_serializers(): + files = generate_scala( + """ + package graph; + + message Node { + list edges = 1; + } + + message Edge { + ref Node node = 1; + } + """ + ) + + registration = files["graph/GraphForyRegistration.scala"] + node_type = registration.index("ForySerializer.registerType(fory, classOf[Node]") + edge_type = registration.index("ForySerializer.registerType(fory, classOf[Edge]") + node_serializer = registration.index("ForySerializer.registerSerializer(fory, classOf[Node])") + edge_serializer = registration.index("ForySerializer.registerSerializer(fory, classOf[Edge])") + assert node_type < node_serializer + assert edge_type < node_serializer + assert node_type < edge_serializer + assert edge_type < edge_serializer def test_scala_generator_keeps_imported_types_in_owner_package(): diff --git a/docs/compiler/generated-code.md b/docs/compiler/generated-code.md index 5056d7822f..723a4eb4cf 100644 --- a/docs/compiler/generated-code.md +++ b/docs/compiler/generated-code.md @@ -1095,15 +1095,16 @@ Enums generate Scala 3 enums with stable Fory IDs: ```scala import org.apache.fory.annotation.ForyEnumId -import org.apache.fory.scala.ForyScalaEnum -enum PhoneType(val foryId: Int) extends ForyScalaEnum { - case Mobile extends PhoneType(0) - case Home extends PhoneType(1) - case Work extends PhoneType(2) +enum PhoneType { + @ForyEnumId(0) + case Mobile - @ForyEnumId - def getForyId: Int = foryId + @ForyEnumId(1) + case Home + + @ForyEnumId(2) + case Work } ``` @@ -1127,8 +1128,7 @@ enum Animal derives ForySerializer { } ``` -`optional T` fields generate `Option[T]`. Reference tracking uses `@Ref`; -`@ForyField(ref = true)` is not emitted by the Scala generator. +`optional T` fields generate `Option[T]`. Reference tracking uses `@Ref`. ### Registration diff --git a/docs/compiler/schema-idl.md b/docs/compiler/schema-idl.md index f585863a13..ab98a0e07f 100644 --- a/docs/compiler/schema-idl.md +++ b/docs/compiler/schema-idl.md @@ -817,8 +817,8 @@ message Person [id=100] { ### Rules -- Case IDs must be unique within the union -- Case IDs must be positive. Case ID `0` is reserved for generated unknown-case carriers in languages that expose one. +- Case IDs must be positive and unique within the union +- Case ID `0` is reserved for language runtimes that expose an unknown-case carrier - Cases cannot be `optional` or `ref` - Union cases do not support field options - Case types can be primitives, enums, messages, or other named types @@ -918,7 +918,7 @@ message Node { | Language | Without `ref` | With `ref` | | ---------- | -------------- | ------------------------------------------ | -| Java | `Node parent` | `Node parent` with `@ForyField(ref=true)` | +| Java | `Node parent` | `Node parent` with `@Ref` | | Python | `parent: Node` | `parent: Node = pyfory.field(ref=True)` | | Go | `Parent Node` | `Parent *Node` with `fory:"ref"` | | Rust | `parent: Node` | `parent: Arc` | @@ -977,7 +977,7 @@ accepted as an alias for `list`. | ----------------------- | ---------------------------------- | --------------------- | ----------------------- | --------------------- | ----------------------------------------- | ------------------------------------------------------------- | ---------------------- | | `optional list` | `@Nullable List` | `Optional[List[str]]` | `[]string` + `nullable` | `Option>` | `std::optional>` | `List?` | `Option[List[String]]` | | `list` | `List` (nullable elements) | `List[Optional[str]]` | `[]*string` | `Vec>` | `std::vector>` | `List` | `List[Option[String]]` | -| `list` | `List` | `List[User]` | `[]*User` + `ref=false` | `Vec>` | `std::vector>` | `List` + `@ListField(element: DeclaredType(ref: true))` | `List[User @Ref]` | +| `list` | `List<@Ref User>` | `List[User]` | `[]*User` + `ref=false` | `Vec>` | `std::vector>` | `List` + `@ListField(element: DeclaredType(ref: true))` | `List[User @Ref]` | Use `ref(thread_safe=false)` in Fory IDL (or `[(fory).thread_safe_pointer = false]` in protobuf) to generate `Rc` instead of `Arc` in Rust. diff --git a/docs/guide/java/field-configuration.md b/docs/guide/java/field-configuration.md index 1d119c1fc0..d625eeccad 100644 --- a/docs/guide/java/field-configuration.md +++ b/docs/guide/java/field-configuration.md @@ -25,8 +25,9 @@ This page explains how to configure field-level metadata for serialization in Ja Apache Fory™ provides field-level configuration through annotations: -- **`@ForyField`**: Configure field metadata (id, ref, dynamic) +- **`@ForyField`**: Configure field metadata (id, dynamic) - **`@Nullable`**: Mark a field type or nested type position as nullable +- **`@Ref`**: Enable field or nested-element reference tracking - **`@Ignore`**: Exclude fields from serialization - **Integer type annotations**: Control integer encoding (varint, fixed, tagged, unsigned) @@ -76,8 +77,8 @@ public class User { @ForyField(id = 2) private String email; - @ForyField(id = 3, ref = true) - private List friends; + @ForyField(id = 3) + private List<@Ref User> friends; @ForyField(id = 4, dynamic = ForyField.Dynamic.TRUE) private Object data; @@ -89,11 +90,11 @@ public class User { | Parameter | Type | Default | Description | | --------- | --------- | ------- | -------------------------------------- | | `id` | `int` | `-1` | Non-negative field tag ID, or no ID | -| `ref` | `boolean` | `false` | Enable reference tracking | | `dynamic` | `Dynamic` | `AUTO` | Control polymorphism for struct fields | Use `@Nullable` on the field type or nested type position for nullable schema -metadata. `@ForyField` does not carry nullability. +metadata and `@Ref` for reference tracking. `@ForyField` does not carry either +setting. ## Field ID (`id`) @@ -164,7 +165,7 @@ public class Record { - When a field is non-nullable, Fory skips writing the null flag. - Boxed types (`Integer`, `Long`, etc.) that can be null should use `@Nullable`. -## Reference Tracking (`ref`) +## Reference Tracking (`@Ref`) Enable reference tracking for fields that may be shared or circular: @@ -172,11 +173,13 @@ Enable reference tracking for fields that may be shared or circular: public class RefOuter { // Both fields may point to the same inner object @Nullable - @ForyField(id = 0, ref = true) + @ForyField(id = 0) + @Ref private RefInner inner1; @Nullable - @ForyField(id = 1, ref = true) + @ForyField(id = 1) + @Ref private RefInner inner2; } @@ -186,7 +189,8 @@ public class CircularRef { // Self-referencing field for circular references @Nullable - @ForyField(id = 1, ref = true) + @ForyField(id = 1) + @Ref private CircularRef selfRef; } ``` @@ -198,8 +202,8 @@ public class CircularRef { **Notes**: -- Default is `ref = false` (no reference tracking) -- When `ref = false`, avoids IdentityMap overhead and skips ref tracking flag +- Fields without `@Ref` do not use field-wrapper reference tracking +- Avoid `@Ref` when values are not shared or circular, so Fory can skip the reference flag - Reference tracking only takes effect when global ref tracking is enabled ## Dynamic (Polymorphism Control) @@ -451,7 +455,8 @@ public class Document { // Reference-tracked field for shared/circular references @Nullable - @ForyField(id = 9, ref = true) + @ForyField(id = 9) + @Ref private Document parent; // Ignored field (not serialized) @@ -578,13 +583,13 @@ public class User { Xlang mode has **stricter default values** due to type system differences between languages: - **Nullable**: Fields are non-nullable by default -- **Ref tracking**: Disabled by default (`ref = false`) +- **Ref tracking**: Disabled by default unless the field type uses `@Ref` - **Polymorphism**: Concrete types are non-polymorphic by default In xlang mode, you **need to configure fields** when: - A field can be null (use `@Nullable`) -- A field needs reference tracking for shared/circular objects (use `ref = true`) +- A field needs reference tracking for shared/circular objects (use `@Ref`) - Integer types need specific encoding for cross-language compatibility - You want to reduce metadata size (use field IDs) @@ -602,7 +607,8 @@ public class User { private String email; @Nullable - @ForyField(id = 3, ref = true) // Must declare ref for shared objects + @ForyField(id = 3) + @Ref // Must declare @Ref for shared objects private User friend; } ``` @@ -619,7 +625,7 @@ public class User { 1. **Configure field IDs**: Recommended for compatible mode to reduce serialization cost 2. **Use `@Nullable` for nullable fields**: Required for fields that can be null -3. **Enable ref tracking for shared objects**: Use `ref = true` when objects are shared or circular +3. **Enable ref tracking for shared objects**: Use `@Ref` when objects are shared or circular 4. **Use `@Ignore` or `transient` for sensitive data**: Passwords, tokens, internal state 5. **Choose appropriate encoding**: `varint` for small values, `fixed` for full-range values 6. **Keep IDs stable**: Once assigned, don't change field IDs @@ -631,7 +637,7 @@ public class User { | ----------------------------- | -------------------------------------- | | `@ForyField(id = N)` | Field tag ID to reduce metadata size | | `@Nullable` | Mark field or nested type as nullable | -| `@ForyField(ref = true)` | Enable reference tracking | +| `@Ref` | Enable reference tracking | | `@ForyField(dynamic = ...)` | Control polymorphism for struct fields | | `@Ignore` | Exclude field from serialization | | `@Int32Type(encoding = ...)` | 32-bit signed integer encoding | diff --git a/docs/guide/kotlin/static-generated-serializers.md b/docs/guide/kotlin/static-generated-serializers.md index 8836ac44e1..99eb4edb38 100644 --- a/docs/guide/kotlin/static-generated-serializers.md +++ b/docs/guide/kotlin/static-generated-serializers.md @@ -131,8 +131,8 @@ the schema is always read from Kotlin source nullability. ## References -`@ForyField(ref = true)` is not supported by Kotlin xlang generated -serializers. Generated reads construct Kotlin values through primary +Kotlin xlang generated serializers reject every `@Ref` annotation, including +`@Ref(enable = false)`. Generated reads construct Kotlin values through primary constructors, so they cannot publish partially constructed objects for cyclic back-references. Use non-cyclic schemas for Kotlin xlang structs. diff --git a/docs/guide/scala/schema-idl.md b/docs/guide/scala/schema-idl.md index b18a63d54d..7f7dd6cf33 100644 --- a/docs/guide/scala/schema-idl.md +++ b/docs/guide/scala/schema-idl.md @@ -47,6 +47,13 @@ ExampleForyRegistration.register(fory) For `ThreadSafeFory`, generated registration helpers install a callback so each runtime instance gets the same serializers. +Generated helpers register message type identities before installing message +serializers. This two-phase order lets mutually recursive message graphs build +descriptor metadata through the normal `TypeResolver` path without placeholder +serializers or Scala-specific registration state in Java core. Enums and unions +are registered with their serializers directly because their derived serializers +own case dispatch. + ## Generated Messages Acyclic messages generate case classes: @@ -88,7 +95,13 @@ final class Node() derives ForySerializer { } ``` -`@ForyField(ref = true)` is not the Scala macro or IDL API. +`@Ref` is the JVM reference-tracking annotation for Scala macro and IDL APIs. + +Generated xlang collection fields use immutable Scala collection types: +`List[T]`, `Set[T]`, and `Map[K, V]`. The runtime xlang serializers can also +rebuild supported mutable collection interfaces such as `scala.collection.Seq` +and `scala.collection.Map`, but concrete mutable collection classes are outside +the schema IDL surface unless explicitly generated. ## Generated Enums @@ -96,19 +109,18 @@ IDL enums generate Scala 3 enums only. No Java enum sidecar is emitted. ```scala import org.apache.fory.annotation.ForyEnumId -import org.apache.fory.scala.ForyScalaEnum -enum Status(val foryId: Int) extends ForyScalaEnum { - case Unknown extends Status(0) - case Ok extends Status(1) +enum Status { + @ForyEnumId(0) + case Unknown - @ForyEnumId - def getForyId: Int = foryId + @ForyEnumId(1) + case Ok } ``` -Generated registration uses `ScalaSerializers.registerEnum(...)` so stable Fory -enum IDs are used in xlang mode. +Generated registration uses `ScalaSerializers.registerEnum(...)` so the stable +Fory enum IDs from case-level `@ForyEnumId` metadata are used in xlang mode. ## Generated Unions diff --git a/docs/guide/xlang/field-reference-tracking.md b/docs/guide/xlang/field-reference-tracking.md index 1a5260c167..a580cd2070 100644 --- a/docs/guide/xlang/field-reference-tracking.md +++ b/docs/guide/xlang/field-reference-tracking.md @@ -136,7 +136,7 @@ By default, **most fields do not track references** even when global `refTrackin ### Customizing Per-Field Ref Tracking -#### Java: @ForyField Annotation +#### Java: @Ref Annotation ```java public class Document { @@ -144,12 +144,11 @@ public class Document { String title; // Enable ref tracking for this field - @ForyField(ref = true) + @Ref Author author; // Shared across documents, track refs to avoid duplicates - @ForyField(ref = true) - List tags; + List<@Ref Tag> tags; } ``` @@ -197,8 +196,8 @@ struct Document { #### Scala: @Ref Annotation -Scala schema IDL and Scala 3 macro derivation use the shared JVM `@Ref` -annotation instead of `@ForyField(ref = true)`: +Scala schema IDL and Scala 3 macro derivation use the same shared JVM `@Ref` +annotation: ```scala import org.apache.fory.annotation.{ForyField, ForyStruct, Ref} diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md index e6b901b415..477fcef506 100644 --- a/docs/specification/xlang_serialization_spec.md +++ b/docs/specification/xlang_serialization_spec.md @@ -444,10 +444,10 @@ In xlang mode, for cross-language compatibility: **Annotation examples:** ```java -// Java: use @ForyField annotation +// Java: use @Ref for reference tracking public class MyClass { @Nullable - @ForyField(ref = true) + @Ref private Object refField; private String requiredField; @@ -456,10 +456,10 @@ public class MyClass { ```python # Python: use typing with fory field descriptors -from pyfory import Fory, ForyField +from pyfory import ForyField, Ref class MyClass: - ref_field: ForyField(SomeType, nullable=True, ref=True) + ref_field: ForyField(Ref[SomeType], nullable=True) required_field: ForyField(str, nullable=False) ``` diff --git a/docs/specification/xlang_type_mapping.md b/docs/specification/xlang_type_mapping.md index 2863c738ed..a2a49ebfb8 100644 --- a/docs/specification/xlang_type_mapping.md +++ b/docs/specification/xlang_type_mapping.md @@ -184,7 +184,7 @@ artifact remains cross-built for Scala 2.13 and Scala 3. | `date`, `timestamp`, `duration` | `java.time.LocalDate`, `java.time.Instant`, `java.time.Duration` | | `decimal` | `java.math.BigDecimal` | | `message` | Scala 3 `case class` by default; normal class only for message/union construction cycles | -| `enum` | Scala 3 `enum` with stable Fory enum IDs | +| `enum` | Scala 3 `enum` with stable Fory enum IDs on case-level `@ForyEnumId` annotations | | `union` | Scala 3 ADT `enum derives ForySerializer` | | `any` | `AnyRef` | @@ -192,8 +192,8 @@ Generated Scala descriptor metadata is produced by Scala 3 macro derivation from Scala compile-time types, including nested generics, `Option`, arrays, scalar encoding annotations, nullability, and `@Ref`. Java reflection is not the source of truth for generated Scala TypeDef metadata. Scala `@Ref` metadata is -represented by the shared `org.apache.fory.annotation.Ref` annotation; generated -Scala does not use `@ForyField(ref = true)`. +represented by the shared `org.apache.fory.annotation.Ref` annotation; `@Ref` +is the JVM owner for reference tracking metadata. ## Type info diff --git a/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala index 9de941711d..4f4cb2343a 100644 --- a/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala +++ b/integration_tests/idl_tests/scala/src/test/scala/org/apache/fory/idl_tests/ScalaIdlRoundTripTest.scala @@ -30,10 +30,11 @@ import collection.{ import example.{ExampleForyRegistration, ExampleMessage, ExampleState} import nested_name.NestedNameForyRegistration import org.apache.fory.Fory +import org.apache.fory.annotation.ForyEnumId import org.apache.fory.meta.FieldTypes -import org.apache.fory.scala.{ForyScalaEnum, ForySerializer} -import org.apache.fory.serializer.StaticGeneratedStructSerializerFactory -import org.apache.fory.`type`.{TypeUtils, Types} +import org.apache.fory.scala.ForySerializer +import org.apache.fory.serializer.StaticGeneratedStructSerializer +import org.apache.fory.`type`.{ScalaTypes, TypeUtils, Types} import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec import tree.{TreeForyRegistration, TreeNode} @@ -108,11 +109,23 @@ final class ScalaIdlRoundTripTest extends AnyWordSpec with Matchers { } "preserve generated Scala enum metadata in nested descriptors" in { - classOf[ForyScalaEnum].isAssignableFrom(classOf[ExampleState]) shouldBe true - val factory = + ScalaTypes.isScalaEnumType(classOf[ExampleState]) shouldBe true + val readyCase = + Class.forName("example.ExampleState$").getDeclaredField("Ready").getAnnotation(classOf[ForyEnumId]) + readyCase.value() shouldBe 1 + val fory = Fory.builder() + .withXlang(true) + .withCompatible(false) + .withRefTracking(false) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + ExampleForyRegistration.register(fory) + val serializer = summon[ForySerializer[ExampleMessage]] - .asInstanceOf[StaticGeneratedStructSerializerFactory[ExampleMessage]] - val descriptors = factory.getGeneratedDescriptors.asScala + .createSerializer(fory.getTypeResolver) + .asInstanceOf[StaticGeneratedStructSerializer[ExampleMessage]] + val descriptors = serializer.getGeneratedDescriptors.asScala val enumValue = descriptors.find(_.getName == "enumValue").get val enumList = descriptors.find(_.getName == "enumList").get val enumMap = descriptors.find(_.getName == "enumValuesByName").get @@ -127,16 +140,7 @@ final class ScalaIdlRoundTripTest extends AnyWordSpec with Matchers { TypeUtils.getMapKeyValueType(uint8ArrayMap.getTypeRef).f1.getComponentType.getTypeExtMeta .typeId() shouldBe Types.UINT8 - val fory = Fory.builder() - .withXlang(true) - .withCompatible(false) - .withRefTracking(false) - .withScalaOptimizationEnabled(true) - .requireClassRegistration(true) - .build() - ExampleForyRegistration.register(fory) - val serializer = factory.newSerializer(fory.getTypeResolver, classOf[ExampleMessage], null) - val fieldGroups = serializer.buildLocalFieldGroups(factory.getGeneratedDescriptors) + val fieldGroups = serializer.buildLocalFieldGroups(serializer.getGeneratedDescriptors) val enumListInfo = fieldGroups.allFields.find(_.descriptor.getName == "enumList").get enumListInfo.genericType.getTypeParameter0.getTypeRef.getTypeExtMeta.typeId() shouldBe Types.ENUM diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java index 6f0d12922f..33015c2959 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java @@ -386,6 +386,8 @@ private SourceField buildField( ForyFieldMeta foryField = foryField(field); Object fieldTypeTree = typeTree(field); boolean nullable = fieldNullable(field.asType(), fieldTypeTree, mode); + boolean trackingRef = fieldTrackingRef(field, fieldTypeTree); + boolean hasTrackingRefMetadata = fieldHasTrackingRefMetadata(field, fieldTypeTree); SourceTypeNode typeNode = buildFieldTypeNode(field.asType(), fieldTypeTree, nullable, field); String erasedType = canonicalName(types.erasure(field.asType())); String declaringClass = @@ -437,7 +439,8 @@ private SourceField buildField( foryField.hasForyField, foryField.id, nullable, - foryField.hasForyField && foryField.ref, + trackingRef, + hasTrackingRefMetadata, foryField.dynamic); } @@ -454,6 +457,22 @@ private boolean fieldNullable(TypeMirror type, Object tree, SerializerMode mode) return isOptionalType(type); } + private boolean fieldTrackingRef(VariableElement field, Object tree) { + TypeUseAnnotation ref = typeUseAnnotation(field.asType(), typeTreeInfo(tree).annotations, REF); + if (ref == null) { + AnnotationMirror fieldRef = annotationMirror(field, REF); + ref = fieldRef == null ? null : new TypeUseAnnotation(fieldRef, null); + } + return ref != null && booleanValue(ref, "enable", true); + } + + private boolean fieldHasTrackingRefMetadata(VariableElement field, Object tree) { + if (hasAnnotation(field, REF)) { + return true; + } + return typeUseAnnotation(field.asType(), typeTreeInfo(tree).annotations, REF) != null; + } + private boolean isOptionalType(TypeMirror type) { String erasedType = canonicalName(types.erasure(type)); return erasedType.equals("java.util.Optional") @@ -1169,7 +1188,6 @@ private ForyFieldMeta foryField(VariableElement field) { Map values = elements.getElementValuesWithDefaults(mirror); int id = -1; - boolean ref = false; String dynamic = "AUTO"; for (Map.Entry entry : values.entrySet()) { @@ -1177,8 +1195,6 @@ private ForyFieldMeta foryField(VariableElement field) { Object value = entry.getValue().getValue(); if ("id".equals(name)) { id = ((Number) value).intValue(); - } else if ("ref".equals(name)) { - ref = (Boolean) value; } else if ("dynamic".equals(name)) { dynamic = String.valueOf(value); } @@ -1187,7 +1203,7 @@ private ForyFieldMeta foryField(VariableElement field) { throw new InvalidStructException( "@ForyField id must be -1 (no tag ID) or a non-negative tag ID", field); } - return new ForyFieldMeta(true, id, ref, dynamic); + return new ForyFieldMeta(true, id, dynamic); } private String canonicalName(TypeMirror type) { @@ -1278,17 +1294,15 @@ private static final class InvalidStructException extends RuntimeException { } private static final class ForyFieldMeta { - static final ForyFieldMeta NONE = new ForyFieldMeta(false, -1, false, "AUTO"); + static final ForyFieldMeta NONE = new ForyFieldMeta(false, -1, "AUTO"); final boolean hasForyField; final int id; - final boolean ref; final String dynamic; - ForyFieldMeta(boolean hasForyField, int id, boolean ref, String dynamic) { + ForyFieldMeta(boolean hasForyField, int id, String dynamic) { this.hasForyField = hasForyField; this.id = id; - this.ref = ref; this.dynamic = dynamic; } } diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceField.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceField.java index 005586483d..3d33d37422 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceField.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceField.java @@ -41,6 +41,7 @@ enum AccessKind { final int foryFieldId; final boolean nullable; final boolean trackingRef; + final boolean hasTrackingRefMetadata; final String dynamic; SourceField( @@ -60,6 +61,7 @@ enum AccessKind { int foryFieldId, boolean nullable, boolean trackingRef, + boolean hasTrackingRefMetadata, String dynamic) { this.id = id; this.name = name; @@ -77,6 +79,7 @@ enum AccessKind { this.foryFieldId = foryFieldId; this.nullable = nullable; this.trackingRef = trackingRef; + this.hasTrackingRefMetadata = hasTrackingRefMetadata; this.dynamic = dynamic; } diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java index 00f502bf5d..ce52ad379d 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java @@ -118,6 +118,8 @@ private void writeDescriptors() { .append(field.nullable) .append(", ") .append(field.trackingRef) + .append(", ") + .append(field.hasTrackingRefMetadata) .append(", ForyField.Dynamic.") .append(field.dynamic) .append(", ") diff --git a/java/fory-annotation-processor/src/test/java/org/apache/fory/annotation/processing/ForyStructProcessorTest.java b/java/fory-annotation-processor/src/test/java/org/apache/fory/annotation/processing/ForyStructProcessorTest.java index 7daf9c5fcd..abef8487c6 100644 --- a/java/fory-annotation-processor/src/test/java/org/apache/fory/annotation/processing/ForyStructProcessorTest.java +++ b/java/fory-annotation-processor/src/test/java/org/apache/fory/annotation/processing/ForyStructProcessorTest.java @@ -42,6 +42,8 @@ import org.apache.fory.context.MetaWriteContext; import org.apache.fory.exception.DeserializationException; import org.apache.fory.exception.SerializationException; +import org.apache.fory.meta.FieldInfo; +import org.apache.fory.meta.TypeDef; import org.apache.fory.serializer.StaticGeneratedStructSerializer; import org.apache.fory.type.Descriptor; import org.apache.fory.type.Types; @@ -369,6 +371,49 @@ public void testGeneratedDescriptorsCarryNestedTypeMetadata() throws Exception { } } + @Test + public void testGeneratedDescriptorsPreserveRefMetadataPresence() throws Exception { + CompilationResult result = + compile( + "test.RefMetadataStruct", + "package test;\n" + + "import org.apache.fory.annotation.ForyStruct;\n" + + "import org.apache.fory.annotation.Ref;\n" + + "@ForyStruct public class RefMetadataStruct {\n" + + " public Customer peer;\n" + + " @Ref(enable = false) public Customer localOnly;\n" + + " public RefMetadataStruct() {}\n" + + " public static class Customer { public String id; public Customer() {} }\n" + + "}\n"); + Assert.assertTrue(result.success, result.diagnostics()); + try (URLClassLoader loader = result.classLoader()) { + Class type = loader.loadClass("test.RefMetadataStruct"); + Class serializerType = loader.loadClass("test.RefMetadataStruct_ForyNativeSerializer"); + Fory fory = + Fory.builder() + .withClassLoader(loader) + .withCodegen(false) + .withRefTracking(true) + .requireClassRegistration(false) + .build(); + StaticGeneratedStructSerializer serializer = + (StaticGeneratedStructSerializer) + serializerType + .getConstructor(org.apache.fory.resolver.TypeResolver.class, Class.class) + .newInstance(fory.getTypeResolver(), type); + Descriptor peer = descriptor(serializer.getDescriptors(), "peer"); + Assert.assertFalse(peer.hasTrackingRefMetadata()); + Assert.assertFalse(peer.isTrackingRef()); + Descriptor localOnly = descriptor(serializer.getDescriptors(), "localOnly"); + Assert.assertTrue(localOnly.hasTrackingRefMetadata()); + Assert.assertFalse(localOnly.isTrackingRef()); + + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), type); + Assert.assertTrue(fieldInfo(typeDef, "peer").getFieldType().trackingRef()); + Assert.assertFalse(fieldInfo(typeDef, "localOnly").getFieldType().trackingRef()); + } + } + @Test public void testGeneratedUnsignedScalarWritesValidateRange() throws Exception { CompilationResult result = @@ -547,9 +592,10 @@ public void testStaticSerializerHandlesMonomorphicRecursiveField() throws Except + "import org.apache.fory.annotation.ForyField.Dynamic;\n" + "import org.apache.fory.annotation.ForyStruct;\n" + "import org.apache.fory.annotation.Nullable;\n" + + "import org.apache.fory.annotation.Ref;\n" + "@ForyStruct public class RecursiveStruct {\n" + " public int id;\n" - + " @Nullable @ForyField(ref = true, dynamic = Dynamic.FALSE)\n" + + " @Nullable @Ref @ForyField(dynamic = Dynamic.FALSE)\n" + " public RecursiveStruct next;\n" + " public RecursiveStruct() {}\n" + "}\n"); @@ -1007,6 +1053,15 @@ private static Descriptor descriptor(List descriptors, String name) throw new AssertionError("Missing descriptor " + name); } + private static FieldInfo fieldInfo(TypeDef typeDef, String name) { + for (FieldInfo fieldInfo : typeDef.getFieldsInfo()) { + if (fieldInfo.getFieldName().equals(name)) { + return fieldInfo; + } + } + throw new AssertionError("Missing field info " + name); + } + private static final class CompilationResult { final Path classRoot; final Path generatedRoot; diff --git a/java/fory-core/src/main/java/org/apache/fory/annotation/ForyField.java b/java/fory-core/src/main/java/org/apache/fory/annotation/ForyField.java index 50815b6051..bbec77819d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/annotation/ForyField.java +++ b/java/fory-core/src/main/java/org/apache/fory/annotation/ForyField.java @@ -62,14 +62,6 @@ enum Dynamic { */ int id() default -1; - /** - * Whether to track references for this field. When set to false (default): - Avoids adding the - * object to IdentityMap (saves hash map overhead) - Skips writing ref tracking flag (saves 1 byte - * for non-nullable fields) When set to true, enables reference tracking for shared/circular - * references. Default: false (no reference tracking, aligned with xlang protocol defaults) - */ - boolean ref() default false; - /** * Controls polymorphism behavior for this field in cross-language serialization. * diff --git a/java/fory-core/src/main/java/org/apache/fory/annotation/Ref.java b/java/fory-core/src/main/java/org/apache/fory/annotation/Ref.java index b90f40f85b..f37fe4191d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/annotation/Ref.java +++ b/java/fory-core/src/main/java/org/apache/fory/annotation/Ref.java @@ -24,7 +24,10 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; -/** Type-use annotation to explicitly enable/disable reference tracking for generic elements. */ +/** + * Type-use annotation to explicitly enable or disable reference tracking for a field or nested + * generic element. + */ @Retention(RetentionPolicy.RUNTIME) @Target({ ElementType.FIELD, diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java index 2558bad1d6..0848c5d2db 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java @@ -2184,8 +2184,8 @@ protected Expression deserializeField( Expression value = deserializeForNotNullForField(buffer, descriptor, null); if (serializerCallsReference) { - // When a field explicitly disables ref tracking (@ForyField(ref = false)) - // but global ref tracking is enabled, the serializer will call reference(). + // When field-wrapper ref tracking is disabled but global ref tracking is enabled, the + // serializer will call reference(). // We need to preserve a -1 id so that when the deserializer calls reference(), // it will pop this -1 and skip the setReadRef call. Expression preserveStubRefId = @@ -2218,11 +2218,8 @@ protected Expression deserializeField( protected Expression deserializeCompatibleListArrayField(Descriptor descriptor) { TypeExtMeta extMeta = descriptor.getTypeRef().getTypeExtMeta(); boolean nullable = extMeta == null ? descriptor.isNullable() : extMeta.nullable(); - // A top-level @Nullable TypeExtMeta must not erase @ForyField(ref = true). - boolean trackingRef = - extMeta == null || descriptor.hasForyField() - ? descriptor.isTrackingRef() - : extMeta.trackingRef(); + // Descriptor owns field-wrapper reference tracking through @Ref or generated metadata. + boolean trackingRef = descriptor.isTrackingRef(); Class targetType = descriptor.getField() == null ? descriptor.getRawType() : descriptor.getField().getType(); return new StaticInvoke( diff --git a/java/fory-core/src/main/java/org/apache/fory/context/CopyContext.java b/java/fory-core/src/main/java/org/apache/fory/context/CopyContext.java index 94fe27bc47..5eea194405 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/CopyContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/CopyContext.java @@ -21,7 +21,6 @@ import java.util.Arrays; import org.apache.fory.collection.IdentityMap; -import org.apache.fory.exception.CopyException; import org.apache.fory.resolver.ClassResolver; import org.apache.fory.resolver.TypeInfo; import org.apache.fory.resolver.TypeResolver; @@ -39,8 +38,6 @@ */ @SuppressWarnings("unchecked") public final class CopyContext { - private static final Object COPY_IN_PROGRESS = new Object(); - private final TypeResolver typeResolver; private final boolean copyRefTracking; private final IdentityMap originToCopyMap; @@ -89,36 +86,9 @@ public void reference(T origin, T copied) { } } - /** - * Marks an origin as being copied before the destination value can be constructed. - * - *

    Constructor-owned immutable values cannot publish a copy early. Serializers for those values - * use this marker so recursive copies fail with a clear error instead of recursing until stack - * overflow. - */ - public void markCopyInProgress(T origin) { - if (copyRefTracking && origin != null) { - originToCopyMap.put(origin, COPY_IN_PROGRESS); - } - } - - /** Clears a copy-in-progress marker if no completed copy replaced it. */ - public void clearCopyInProgress(T origin) { - if (copyRefTracking && origin != null && originToCopyMap.get(origin) == COPY_IN_PROGRESS) { - originToCopyMap.remove(origin); - } - } - /** Returns the previously registered copy for {@code origin}, or {@code null} if absent. */ public T getCopyObject(T origin) { - Object copied = originToCopyMap.get(origin); - if (copied == COPY_IN_PROGRESS) { - throw new CopyException( - "Cannot copy cyclic object graph rooted at constructor-owned immutable value " - + origin.getClass().getName() - + " because its copy cannot be referenced before construction completes"); - } - return (T) copied; + return (T) originToCopyMap.get(origin); } /** diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java b/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java index 389fc1c8f7..ec630e7fe1 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java @@ -24,6 +24,7 @@ import java.lang.reflect.Modifier; import java.util.Objects; import org.apache.fory.annotation.ForyField; +import org.apache.fory.annotation.Internal; import org.apache.fory.reflect.TypeRef; import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.converter.FieldConverter; @@ -90,7 +91,8 @@ public FieldTypes.FieldType getFieldType() { * null. Don't invoke this method if class does have fieldName field. In such case, * reflection should be used to get the descriptor. */ - Descriptor toDescriptor(TypeResolver resolver, Descriptor descriptor) { + @Internal + public Descriptor toDescriptor(TypeResolver resolver, Descriptor descriptor) { TypeRef declared = descriptor != null ? descriptor.getTypeRef() : primitiveListCarrierType(); TypeRef typeRef = fieldType.toTypeToken(resolver, declared); String typeName = fieldType.getTypeName(resolver, typeRef); diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java index 7fc0b46867..bcadd42b77 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java @@ -198,22 +198,28 @@ private static FieldType buildFieldType( typeId = Types.UNKNOWN; } } - // For xlang: ref tracking is false by default (no shared ownership like Rust's Rc/Arc) - // For native: use the type's default tracking behavior - boolean descriptorCarriesFieldOptions = descriptor != null && field == null; - boolean trackingRef = - descriptorCarriesFieldOptions - ? descriptor.isTrackingRef() - : isXlang - ? typeExtMeta != null && typeExtMeta.trackingRef() - : genericType.trackingRef(resolver); + // For top-level fields, Descriptor owns explicit wrapper ref metadata from @Ref or generated + // schema descriptors. When no such metadata exists, native mode uses the declared type default + // and xlang keeps the protocol default of no wrapper tracking. + boolean trackingRef; + if (descriptor != null && isXlang) { + trackingRef = descriptor.isTrackingRef(); + } else if (descriptor != null && descriptor.hasTrackingRefMetadata()) { + trackingRef = descriptor.isTrackingRef(); + } else if (descriptor != null) { + trackingRef = resolver.needToWriteRef(TypeRef.of(rawType)); + } else if (isXlang) { + trackingRef = typeExtMeta != null && typeExtMeta.trackingRef(); + } else { + trackingRef = genericType.trackingRef(resolver); + } // For xlang: nullable is false by default for top-level fields. // Nested element types are nullable by default to align with cross-language collection // semantics. // Optional types are nullable (like Rust's Option). // For native: non-primitive types are nullable by default. boolean nullable; - if (descriptorCarriesFieldOptions) { + if (descriptor != null && field == null) { nullable = descriptor.isNullable(); } else if (isXlang) { if (typeExtMeta != null) { @@ -228,11 +234,9 @@ private static FieldType buildFieldType( nullable = !genericType.getCls().isPrimitive(); } - // @ForyField owns wrapper ref tracking and makes Java fields required unless @Nullable or - // type-use metadata marks them nullable. This keeps native TypeDef metadata aligned with the - // normalized Descriptor used by object serializers. + // @ForyField makes Java fields required unless @Nullable or type-use metadata marks them + // nullable. Reference tracking is already resolved above through the Descriptor-owned @Ref bit. if (descriptor != null && descriptor.hasForyField()) { - trackingRef = descriptor.isTrackingRef(); nullable = descriptor.isNullable(); } @@ -635,14 +639,15 @@ public TypeRef toTypeToken(TypeResolver resolver, TypeRef declared) { Class cls; int internalTypeId = typeId; if (declared != null && internalTypeId == Types.ENUM && declared.getRawType().isEnum()) { - return TypeRef.of(declared.getRawType(), new TypeExtMeta(typeId, nullable, trackingRef)); + return TypeRef.of( + declared.getRawType(), typeExtMeta(typeId, nullable, trackingRef, declared)); } if (Types.isPrimitiveType(internalTypeId)) { if (declared != null) { TypeInfo declaredInfo = resolver.getTypeInfo(declared.getRawType(), false); if (declaredInfo != null && declaredInfo.getTypeId() == typeId) { return TypeRef.of( - declared.getRawType(), new TypeExtMeta(typeId, nullable, trackingRef)); + declared.getRawType(), typeExtMeta(typeId, nullable, trackingRef, declared)); } } cls = Types.getClassForTypeId(internalTypeId); @@ -663,17 +668,19 @@ public TypeRef toTypeToken(TypeResolver resolver, TypeRef declared) { cls = declared.getRawType(); } } - return TypeRef.of(cls, new TypeExtMeta(typeId, nullable, trackingRef)); + return TypeRef.of(cls, typeExtMeta(typeId, nullable, trackingRef, declared)); } if (Types.isPrimitiveArray(internalTypeId)) { if (declared != null) { Class declaredRaw = declared.getRawType(); if (declaredRaw.isArray()) { - return TypeRef.of(declaredRaw, new TypeExtMeta(typeId, nullable, trackingRef)); + return TypeRef.of( + declaredRaw, typeExtMeta(typeId, nullable, trackingRef, declared)); } Class listClass = getPrimitiveListClass(internalTypeId); if (listClass != null && listClass.isAssignableFrom(declaredRaw)) { - return TypeRef.of(declaredRaw, new TypeExtMeta(typeId, nullable, trackingRef)); + return TypeRef.of( + declaredRaw, typeExtMeta(typeId, nullable, trackingRef, declared)); } } cls = getPrimitiveArrayClass(internalTypeId); @@ -683,7 +690,8 @@ public TypeRef toTypeToken(TypeResolver resolver, TypeRef declared) { } if (Types.isUserDefinedType((byte) internalTypeId)) { if (declared != null) { - return TypeRef.of(declared.getRawType(), new TypeExtMeta(typeId, nullable, trackingRef)); + return TypeRef.of( + declared.getRawType(), typeExtMeta(typeId, nullable, trackingRef, declared)); } LOG.warn("Class {} not registered, take it as Struct type for deserialization.", typeId); boolean isEnum = internalTypeId == Types.ENUM; @@ -702,7 +710,7 @@ public TypeRef toTypeToken(TypeResolver resolver, TypeRef declared) { boolean isEnum = internalTypeId == Types.ENUM; cls = UnknownClass.getUnknowClass(isEnum, 0, resolver.isShareMeta()); } - return TypeRef.of(cls, new TypeExtMeta(typeId, nullable, trackingRef)); + return TypeRef.of(cls, typeExtMeta(typeId, nullable, trackingRef, declared)); } @Override @@ -882,7 +890,6 @@ public FieldType getElementType() { @Override public TypeRef toTypeToken(TypeResolver resolver, TypeRef declared) { - // TODO support preserve element TypeExtMeta Class declaredClass; TypeRef declElementType; if (declared == null) { @@ -902,13 +909,13 @@ public TypeRef toTypeToken(TypeResolver resolver, TypeRef declared) { } TypeRef elementType = this.elementType.toTypeToken(resolver, declElementType); if (declared == null) { - return collectionOf(elementType, new TypeExtMeta(typeId, nullable, trackingRef)); + return collectionOf(elementType, TypeExtMeta.of(typeId, nullable, trackingRef)); } if (!declaredClass.isArray()) { if (declElementType.equals(elementType)) { return declared; } - TypeExtMeta extMeta = new TypeExtMeta(typeId, nullable, trackingRef); + TypeExtMeta extMeta = typeExtMeta(typeId, nullable, trackingRef, declared); if (!java.util.Collection.class.isAssignableFrom(declaredClass) && resolver.isCollection(declaredClass)) { return TypeRef.of( @@ -929,7 +936,7 @@ public TypeRef toTypeToken(TypeResolver resolver, TypeRef declared) { // Apply field metadata (nullable, trackingRef) to outermost array only TypeExtMeta meta = (i == dimensionsToAdd - 1) - ? new TypeExtMeta(typeId, nullable, trackingRef) + ? typeExtMeta(typeId, nullable, trackingRef, declared) : currentType.getTypeExtMeta(); currentType = TypeRef.of(arrayClass, meta); } @@ -998,7 +1005,6 @@ public FieldType getValueType() { @Override public TypeRef toTypeToken(TypeResolver classResolver, TypeRef declared) { - // TODO support preserve element TypeExtMeta, it will be lost when building other TypeRef TypeRef keyDecl = null; TypeRef valueDecl = null; if (declared != null) { @@ -1013,7 +1019,7 @@ public TypeRef toTypeToken(TypeResolver classResolver, TypeRef declared) { // handle generic bound valueDecl = valueDecl.resolveAllWildcards(); } - TypeExtMeta extMeta = new TypeExtMeta(typeId, nullable, trackingRef); + TypeExtMeta extMeta = typeExtMeta(typeId, nullable, trackingRef, declared); TypeRef keyTypeRef = keyType.toTypeToken(classResolver, keyDecl); TypeRef valueTypeRef = valueType.toTypeToken(classResolver, valueDecl); Class declaredClass = declared.getRawType(); @@ -1027,7 +1033,7 @@ public TypeRef toTypeToken(TypeResolver classResolver, TypeRef declared) { return mapOf( keyType.toTypeToken(classResolver, keyDecl), valueType.toTypeToken(classResolver, valueDecl), - new TypeExtMeta(typeId, nullable, trackingRef)); + TypeExtMeta.of(typeId, nullable, trackingRef)); } @Override @@ -1073,7 +1079,7 @@ public EnumFieldType(boolean nullable, int typeId, int userTypeId) { @Override public TypeRef toTypeToken(TypeResolver classResolver, TypeRef declared) { if (declared != null) { - return TypeRef.of(declared.getRawType(), new TypeExtMeta(Types.ENUM, nullable, false)); + return TypeRef.of(declared.getRawType(), typeExtMeta(Types.ENUM, nullable, false, declared)); } return TypeRef.of(UnknownClass.UnknownEnum.class); } @@ -1121,11 +1127,11 @@ public TypeRef toTypeToken(TypeResolver classResolver, TypeRef declared) { if (UnknownClass.class.isAssignableFrom(componentRawType)) { return TypeRef.of( UnknownClass.getUnknowClass(componentType instanceof EnumFieldType, dimensions, true), - new TypeExtMeta(typeId, nullable, trackingRef)); + typeExtMeta(typeId, nullable, trackingRef, declared)); } else { return TypeRef.of( Array.newInstance(componentRawType, new int[dimensions]).getClass(), - new TypeExtMeta(typeId, nullable, trackingRef)); + typeExtMeta(typeId, nullable, trackingRef, declared)); } } @@ -1191,7 +1197,7 @@ public ObjectFieldType(int typeId, boolean nullable, boolean trackingRef) { @Override public TypeRef toTypeToken(TypeResolver classResolver, TypeRef declared) { Class clz = declared == null ? Object.class : declared.getRawType(); - return TypeRef.of(clz, new TypeExtMeta(typeId, nullable, trackingRef)); + return TypeRef.of(clz, typeExtMeta(typeId, nullable, trackingRef, declared)); } @Override @@ -1214,6 +1220,16 @@ public String toString() { } } + private static TypeExtMeta typeExtMeta( + int typeId, boolean nullable, boolean trackingRef, TypeRef declared) { + TypeExtMeta declaredMeta = declared == null ? null : declared.getTypeExtMeta(); + return TypeExtMeta.of( + typeId, + nullable, + trackingRef, + declaredMeta != null && declaredMeta.nullableWrapper()); + } + /** Class for Union field type. Union types use declared type. */ public static class UnionFieldType extends FieldType { diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeExtMeta.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeExtMeta.java index 26dac4d7fd..d0be968651 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeExtMeta.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeExtMeta.java @@ -25,15 +25,26 @@ public class TypeExtMeta { private final int typeId; private final boolean nullable; private final boolean trackingRef; + private final boolean nullableWrapper; public static TypeExtMeta of(int typeId, boolean nullable, boolean trackingRef) { return new TypeExtMeta(typeId, nullable, trackingRef); } + public static TypeExtMeta of( + int typeId, boolean nullable, boolean trackingRef, boolean nullableWrapper) { + return new TypeExtMeta(typeId, nullable, trackingRef, nullableWrapper); + } + TypeExtMeta(int typeId, boolean nullable, boolean trackingRef) { + this(typeId, nullable, trackingRef, false); + } + + TypeExtMeta(int typeId, boolean nullable, boolean trackingRef, boolean nullableWrapper) { this.typeId = typeId; this.nullable = nullable; this.trackingRef = trackingRef; + this.nullableWrapper = nullableWrapper; } public int typeId() { @@ -48,6 +59,11 @@ public boolean trackingRef() { return trackingRef; } + /** Whether the local source type wraps a nullable value in a language-level container. */ + public boolean nullableWrapper() { + return nullableWrapper; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -57,12 +73,15 @@ public boolean equals(Object o) { return false; } TypeExtMeta that = (TypeExtMeta) o; - return typeId == that.typeId && nullable == that.nullable && trackingRef == that.trackingRef; + return typeId == that.typeId + && nullable == that.nullable + && trackingRef == that.trackingRef + && nullableWrapper == that.nullableWrapper; } @Override public int hashCode() { - return Objects.hash(typeId, nullable, trackingRef); + return Objects.hash(typeId, nullable, trackingRef, nullableWrapper); } @Override @@ -74,6 +93,8 @@ public String toString() { + nullable + ", trackingRef=" + trackingRef + + ", nullableWrapper=" + + nullableWrapper + '}'; } } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index 238e8dec6b..71b44eb6b9 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -595,46 +595,6 @@ public void registerUnion(Class cls, String namespace, String name, Serialize registerGraalvmClass(cls); } - @Override - public void registerUnionCase(Class unionType, Class caseType) { - checkRegisterAllowed(); - TypeInfo typeInfo = classInfoMap.get(unionType); - Preconditions.checkArgument( - typeInfo != null && Types.isUnionType(typeInfo.typeId), - "Union type %s must be registered before case type %s", - unionType, - caseType); - TypeInfo existingInfo = classInfoMap.get(caseType); - Preconditions.checkArgument( - existingInfo == null || existingInfo == typeInfo, - "Union case type %s has been registered as %s", - caseType, - existingInfo); - classInfoMap.put(caseType, typeInfo); - extRegistry.registeredClasses.put(caseType.getName(), caseType); - registerGraalvmClass(caseType); - } - - @Override - public void registerEnumCase(Class enumType, Class caseType) { - checkRegisterAllowed(); - TypeInfo typeInfo = classInfoMap.get(enumType); - Preconditions.checkArgument( - typeInfo != null && Types.isEnumType(typeInfo.typeId), - "Enum type %s must be registered before case type %s", - enumType, - caseType); - TypeInfo existingInfo = classInfoMap.get(caseType); - Preconditions.checkArgument( - existingInfo == null || existingInfo == typeInfo, - "Enum case type %s has been registered as %s", - caseType, - existingInfo); - classInfoMap.put(caseType, typeInfo); - extRegistry.registeredClasses.put(caseType.getName(), caseType); - registerGraalvmClass(caseType); - } - @Override public void registerEnum(Class cls, long userId, Serializer serializer) { checkRegisterAllowed(); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java b/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java index 6725a38699..ca516748d3 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java @@ -26,7 +26,6 @@ import org.apache.fory.exception.ForyException; import org.apache.fory.meta.TypeDef; import org.apache.fory.serializer.StaticGeneratedStructSerializer; -import org.apache.fory.serializer.StaticGeneratedStructSerializerFactory; import org.apache.fory.type.Descriptor; /** Shared registry of build-time generated static serializer mappings. */ @@ -91,48 +90,11 @@ List getGeneratedDescriptors() { private final ConcurrentHashMap, Entry> xlangSerializers = new ConcurrentHashMap<>(); private final ConcurrentHashMap, Entry> nativeSerializers = new ConcurrentHashMap<>(); - private final ConcurrentHashMap, StaticGeneratedStructSerializerFactory> - xlangFactories = new ConcurrentHashMap<>(); - private final ConcurrentHashMap, StaticGeneratedStructSerializerFactory> - nativeFactories = new ConcurrentHashMap<>(); private final ConcurrentHashMap, Boolean> missingXlangSerializers = new ConcurrentHashMap<>(); private final ConcurrentHashMap, Boolean> missingNativeSerializers = new ConcurrentHashMap<>(); - void registerFactory( - Class targetType, boolean xlang, StaticGeneratedStructSerializerFactory factory) { - ConcurrentHashMap, StaticGeneratedStructSerializerFactory> factories = - xlang ? xlangFactories : nativeFactories; - StaticGeneratedStructSerializerFactory existing = factories.putIfAbsent(targetType, factory); - if (existing != null && existing != factory && existing.getClass() != factory.getClass()) { - throw new IllegalArgumentException( - "Conflicting static generated serializer factory for " + targetType.getName()); - } - } - - StaticGeneratedStructSerializer newRegisteredSerializer( - TypeResolver resolver, Class targetType, TypeDef typeDef) { - StaticGeneratedStructSerializerFactory factory = - getRegisteredFactory(targetType, resolver.isCrossLanguage()); - if (factory == null) { - return null; - } - return factory.newSerializer(resolver, targetType, typeDef); - } - - List getRegisteredDescriptors(Class targetType, boolean xlang) { - StaticGeneratedStructSerializerFactory factory = getRegisteredFactory(targetType, xlang); - return factory == null ? null : factory.getGeneratedDescriptors(); - } - - private StaticGeneratedStructSerializerFactory getRegisteredFactory( - Class targetType, boolean xlang) { - ConcurrentHashMap, StaticGeneratedStructSerializerFactory> factories = - xlang ? xlangFactories : nativeFactories; - return factories.get(targetType); - } - Class getSerializerClass( Class targetType, boolean xlang) { Entry entry = getEntry(targetType, xlang); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 8f767dafb6..85a3a6543d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -89,7 +89,6 @@ import org.apache.fory.serializer.SerializerFactory; import org.apache.fory.serializer.Serializers; import org.apache.fory.serializer.StaticGeneratedStructSerializer; -import org.apache.fory.serializer.StaticGeneratedStructSerializerFactory; import org.apache.fory.serializer.UnknownClass; import org.apache.fory.serializer.UnknownClass.UnknownEmptyStruct; import org.apache.fory.serializer.UnknownClass.UnknownStruct; @@ -271,6 +270,48 @@ public void register(String className, String namespace, String typeName) { register(loadClass(className), namespace, typeName); } + /** + * Registers {@code runtimeType} to use the already-registered metadata and serializer of {@code + * canonicalType}. + * + *

    This is only for language runtimes where one source type has multiple JVM runtime classes, + * such as enum or ADT case classes. The resolver owns the runtime class lookup; the caller owns + * deciding which runtime classes are aliases. + */ + @Internal + public final void registerRuntimeTypeAlias(Class runtimeType, Class canonicalType) { + Preconditions.checkNotNull(runtimeType, "runtimeType"); + Preconditions.checkNotNull(canonicalType, "canonicalType"); + if (runtimeType == canonicalType) { + return; + } + checkRegisterAllowed(); + TypeInfo canonicalInfo = classInfoMap.get(canonicalType); + Preconditions.checkArgument( + canonicalInfo != null, + "Canonical type must be registered before registering a runtime type alias: " + + canonicalType.getName()); + TypeInfo existingInfo = classInfoMap.get(runtimeType); + Preconditions.checkArgument( + existingInfo == null || existingInfo == canonicalInfo, + "Runtime type is already registered with different type metadata: " + + runtimeType.getName()); + String runtimeTypeName = runtimeType.getName(); + Class registeredType = extRegistry.registeredClasses.get(runtimeTypeName); + Preconditions.checkArgument( + registeredType == null || registeredType == runtimeType, + "Runtime type alias name is already registered with different class: " + + runtimeTypeName); + String registeredName = extRegistry.registeredClasses.inverse().get(runtimeType); + Preconditions.checkArgument( + registeredName == null || registeredName.equals(runtimeTypeName), + "Runtime type alias is already registered with different name: " + + registeredName); + classInfoMap.put(runtimeType, canonicalInfo); + extRegistry.registeredClasses.put(runtimeTypeName, runtimeType); + registerGraalvmClass(runtimeType); + } + /** * Registers a union type with a user-specified ID and serializer. * @@ -291,27 +332,6 @@ public void register(String className, String namespace, String typeName) { public abstract void registerUnion( Class type, String namespace, String typeName, Serializer serializer); - /** - * Registers {@code caseType} as a runtime class alias for an already registered union type. - * - *

    Some JVM languages compile a union value to concrete case subclasses even though the wire - * type is owned by the sealed union base. This method makes runtime dispatch for those case - * subclasses use the base union {@link TypeInfo}; it must not create another wire type name or - * user type ID. - */ - @Internal - public abstract void registerUnionCase(Class unionType, Class caseType); - - /** - * Registers {@code caseType} as a runtime class alias for an already registered enum type. - * - *

    Some JVM languages compile enum cases to concrete singleton subclasses even though the wire - * type is owned by the enum base. This method makes runtime dispatch for those case subclasses - * use the base enum {@link TypeInfo}; it must not create another wire type name or user type ID. - */ - @Internal - public abstract void registerEnumCase(Class enumType, Class caseType); - /** Registers a non-Java enum type with a user-specified ID and serializer. */ @Internal public abstract void registerEnum(Class type, long id, Serializer serializer); @@ -1068,11 +1088,10 @@ private TypeInfo getMetaSharedTypeInfo(TypeDef typeDef, Class clz) { // type metadata or a concrete target-class transformation. return typeInfo; } - StaticGeneratedStructSerializer registeredStaticSerializer = - sharedRegistry.staticGeneratedSerializerRegistry.newRegisteredSerializer( - this, cls, typeDef); - if (registeredStaticSerializer != null) { - typeInfo.setSerializer(this, registeredStaticSerializer); + StaticGeneratedStructSerializer copiedStaticSerializer = + copyRegisteredStaticGeneratedStructSerializer(cls, typeDef); + if (copiedStaticSerializer != null) { + typeInfo.setSerializer(this, copiedStaticSerializer); return typeInfo; } Class sc = @@ -1326,21 +1345,6 @@ private Serializer getNativeTypedValueSerializer(int typeId, Class rawType public abstract void setSerializerIfAbsent(Class cls, Serializer serializer); - @Internal - @SuppressWarnings({"rawtypes", "unchecked"}) - public final void registerStaticGeneratedStructSerializerFactory( - Class cls, StaticGeneratedStructSerializerFactory factory) { - if (!isRegistered(cls)) { - register(cls); - } - sharedRegistry.staticGeneratedSerializerRegistry.registerFactory( - cls, isCrossLanguage(), factory); - Serializer serializer = - new DeferedLazySerializer.DeferredLazyObjectSerializer( - this, cls, () -> Tuple2.of(true, factory.newSerializer(this, cls, null))); - setSerializer(cls, serializer); - } - /** * Reset serializer if {@code serializer} is not null, otherwise clear serializer for {@code cls}. */ @@ -1616,8 +1620,7 @@ public final DescriptorGrouper groupDescriptors( private List buildFieldDescriptors(Class clz, boolean searchParent) { List registeredStaticDescriptors = - sharedRegistry.staticGeneratedSerializerRegistry.getRegisteredDescriptors( - clz, isCrossLanguage()); + getRegisteredStaticGeneratedStructDescriptors(clz); if (registeredStaticDescriptors != null) { return normalizeFieldDescriptors(clz, searchParent, registeredStaticDescriptors); } @@ -1651,23 +1654,18 @@ private List buildFieldDescriptors( if (!searchParent && !descriptor.getDeclaringClass().equals(clz.getName())) { continue; } - boolean hasForyField = descriptor.hasForyField(); // Compute the final isTrackingRef value: - // For xlang mode: "Reference tracking is disabled by default" (xlang spec) - // - Only enable ref tracking if explicitly set via @ForyField(ref=true) - // For Java mode: - // - If global ref tracking is enabled and no @ForyField, use global setting - // - If @ForyField(ref=true) is set, use that (but can be overridden if global is off) + // For xlang mode: reference tracking is disabled by default and only @Ref enables it. + // For Java mode: @Ref explicitly overrides the type-based default in both directions. boolean ref = globalRefTracking; if (globalRefTracking) { if (isXlang) { - // In xlang mode, only track refs if explicitly annotated with @ForyField(ref=true) - ref = hasForyField && descriptor.isTrackingRef(); + ref = descriptor.isTrackingRef(); } else { - if (hasForyField) { + if (descriptor.hasTrackingRefMetadata()) { ref = descriptor.isTrackingRef(); } else { - ref = needToWriteRef(descriptor.getTypeRef()); + ref = needToWriteRef(TypeRef.of(descriptor.getRawType())); } } } @@ -1677,7 +1675,7 @@ private List buildFieldDescriptors( if (needsUpdate) { Descriptor newDescriptor = - new DescriptorBuilder(descriptor).trackingRef(ref).nullable(nullable).build(); + new DescriptorBuilder(descriptor).inferredTrackingRef(ref).nullable(nullable).build(); result.add(newDescriptor); } else { result.add(descriptor); @@ -1763,6 +1761,26 @@ private List getStaticGeneratedStructDescriptors(Class cls) { cls, isCrossLanguage()); } + private List getRegisteredStaticGeneratedStructDescriptors(Class cls) { + TypeInfo typeInfo = getTypeInfo(cls, false); + if (typeInfo == null + || !(typeInfo.getSerializer() instanceof StaticGeneratedStructSerializer)) { + return null; + } + return ((StaticGeneratedStructSerializer) typeInfo.getSerializer()).getGeneratedDescriptors(); + } + + private StaticGeneratedStructSerializer copyRegisteredStaticGeneratedStructSerializer( + Class cls, TypeDef typeDef) { + TypeInfo typeInfo = getTypeInfo(cls, false); + if (typeInfo == null + || !(typeInfo.getSerializer() instanceof StaticGeneratedStructSerializer)) { + return null; + } + return ((StaticGeneratedStructSerializer) typeInfo.getSerializer()) + .copySerializer(this, cls, typeDef); + } + protected final boolean shouldPreferStaticGeneratedSerializer(Class cls) { return AndroidSupport.IS_ANDROID || isKotlinClass(cls); } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java index 087aefb893..b41fee4052 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java @@ -388,46 +388,6 @@ public void registerUnion( register(type, serializer, namespace, typeName, xtypeId, -1); } - @Override - public void registerUnionCase(Class unionType, Class caseType) { - checkRegisterAllowed(); - TypeInfo typeInfo = classInfoMap.get(unionType); - Preconditions.checkArgument( - typeInfo != null && Types.isUnionType(typeInfo.typeId), - "Union type %s must be registered before case type %s", - unionType, - caseType); - TypeInfo existingInfo = classInfoMap.get(caseType); - Preconditions.checkArgument( - existingInfo == null || existingInfo == typeInfo, - "Union case type %s has been registered as %s", - caseType, - existingInfo); - classInfoMap.put(caseType, typeInfo); - extRegistry.registeredClasses.put(caseType.getName(), caseType); - registerGraalvmClass(caseType); - } - - @Override - public void registerEnumCase(Class enumType, Class caseType) { - checkRegisterAllowed(); - TypeInfo typeInfo = classInfoMap.get(enumType); - Preconditions.checkArgument( - typeInfo != null && Types.isEnumType(typeInfo.typeId), - "Enum type %s must be registered before case type %s", - enumType, - caseType); - TypeInfo existingInfo = classInfoMap.get(caseType); - Preconditions.checkArgument( - existingInfo == null || existingInfo == typeInfo, - "Enum case type %s has been registered as %s", - caseType, - existingInfo); - classInfoMap.put(caseType, typeInfo); - extRegistry.registeredClasses.put(caseType.getName(), caseType); - registerGraalvmClass(caseType); - } - @Override public void registerEnum(Class type, long userTypeId, Serializer serializer) { checkRegisterAllowed(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/FieldGroups.java b/java/fory-core/src/main/java/org/apache/fory/serializer/FieldGroups.java index ed2cbcedfd..3dab3aa492 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/FieldGroups.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/FieldGroups.java @@ -246,9 +246,9 @@ public SerializationFieldInfo(TypeResolver resolver, Descriptor d) { && !primitiveListArray && !primitiveListCollection) { nullable = extMeta.nullable(); - // A top-level @Nullable TypeExtMeta owns only nullability; @ForyField remains the owner of - // field-wrapper reference tracking. - trackingRef = d.hasForyField() ? d.isTrackingRef() : extMeta.trackingRef(); + // Descriptor owns field-wrapper reference tracking through @Ref or generated metadata; a + // top-level TypeExtMeta can also carry only nullability. + trackingRef = d.isTrackingRef(); } else { nullable = d.isNullable(); trackingRef = d.isTrackingRef(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java index 38061f62a6..6ce05338c9 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java @@ -19,6 +19,7 @@ package org.apache.fory.serializer; +import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -29,6 +30,7 @@ import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; import org.apache.fory.exception.DeserializationException; +import org.apache.fory.exception.ForyException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.meta.FieldInfo; import org.apache.fory.meta.TypeDef; @@ -101,6 +103,36 @@ private void setSerializerIfAbsent(TypeResolver typeResolver, Class type) { @Override public abstract T copy(CopyContext copyContext, T value); + /** + * Creates an equivalent serializer for another local/remote TypeDef view of the same generated + * struct. + * + *

    Named Java/Kotlin generated serializers are rediscovered through their generated class. + * Macro-generated serializers may instead override this method so compatible xlang reads can + * reuse the same serializer-owned construction logic without a separate factory object. + */ + @Internal + @SuppressWarnings("unchecked") + public StaticGeneratedStructSerializer copySerializer( + TypeResolver typeResolver, Class type, TypeDef typeDef) { + try { + Constructor constructor = + getClass() + .asSubclass(StaticGeneratedStructSerializer.class) + .getDeclaredConstructor(TypeResolver.class, Class.class, TypeDef.class); + constructor.setAccessible(true); + return (StaticGeneratedStructSerializer) + constructor.newInstance(typeResolver, type, typeDef); + } catch (ReflectiveOperationException e) { + throw new ForyException( + "Failed to copy static generated serializer " + + getClass().getName() + + " for " + + type.getName(), + e); + } + } + public abstract List getGeneratedDescriptors(); public final List getDescriptors() { @@ -122,15 +154,10 @@ public final FieldGroups buildFieldGroups(List descriptors) { } public final FieldGroups buildLocalFieldGroups(List descriptors) { - if (!typeResolver.isShareMeta()) { - return buildFieldGroups(descriptors); - } - // Meta-share writers use the local TypeDef-reified descriptor grouping, matching - // ObjectSerializer. The constructor TypeDef may be a remote schema for compatible reads, so it - // must not own local field access ordering. - DescriptorGrouper grouper = - typeResolver.createDescriptorGrouper(typeResolver.getTypeDef(type, true), type); - return FieldGroups.buildFieldInfos(typeResolver, grouper); + // Generated descriptors carry source-only field metadata such as Scala Option wrappers. A + // schema TypeDef descriptor is the canonical remote contract, but it cannot replace the local + // generated descriptor view used to choose allocation-free field readers and writers. + return buildFieldGroups(descriptors); } protected final List runtimeDescriptors(List descriptors) { @@ -567,6 +594,12 @@ private void appendRemoteFields( int matchedId = matchField(fieldInfo, fieldIds, fields); Descriptor localDescriptor = matchedId == UNKNOWN_FIELD ? null : localDescriptors.get(matchedId); + if (localDescriptor != null) { + Descriptor readDescriptor = fieldInfo.toDescriptor(typeResolver, localDescriptor); + serializationFieldInfo = + new SerializationFieldInfo( + typeResolver, readDescriptor, serializationFieldInfo.codecCategory); + } remoteFields.add( new RemoteFieldInfo( typeResolver, diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializerFactory.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializerFactory.java deleted file mode 100644 index 3882e3ceea..0000000000 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializerFactory.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.fory.serializer; - -import java.util.List; -import org.apache.fory.annotation.Internal; -import org.apache.fory.meta.TypeDef; -import org.apache.fory.resolver.TypeResolver; -import org.apache.fory.type.Descriptor; - -/** - * Factory for statically shaped struct serializers whose implementation is generated by a non-Java - * compiler. - * - *

    Named generated serializer classes are still discovered by {@link - * org.apache.fory.resolver.StaticGeneratedSerializerRegistry}. This factory path is for language - * frontends such as Scala 3 macro derivation where the serializer code is emitted at the typeclass - * call site rather than as a separately named JVM class. - */ -@Internal -public interface StaticGeneratedStructSerializerFactory { - /** Descriptor metadata generated from the source-language type model. */ - List getGeneratedDescriptors(); - - /** - * Create a serializer for {@code type}. - * - * @param typeDef remote TypeDef for compatible reads, or {@code null} for local schema reads. - */ - StaticGeneratedStructSerializer newSerializer( - TypeResolver typeResolver, Class type, TypeDef typeDef); -} diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/struct/Fingerprint.java b/java/fory-core/src/main/java/org/apache/fory/serializer/struct/Fingerprint.java index 689c41cc65..f872da4069 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/struct/Fingerprint.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/struct/Fingerprint.java @@ -108,10 +108,8 @@ public static String computeStructFingerprint( fieldIdentifier = descriptor.getSnakeCaseName(); } - // Get ref flag from @ForyField annotation only (compile-time info) - // If annotation is absent or ref not explicitly set to true, ref is 0 - // This allows fingerprint to be computed at compile time for C++/Rust - char ref = (descriptor.hasForyField() && descriptor.isTrackingRef()) ? '1' : '0'; + // Descriptor owns the normalized top-level ref bit from @Ref or generated metadata. + char ref = descriptor.isTrackingRef() ? '1' : '0'; // Get nullable flag: // - Primitives are always non-nullable diff --git a/java/fory-core/src/main/java/org/apache/fory/type/Descriptor.java b/java/fory-core/src/main/java/org/apache/fory/type/Descriptor.java index 2be8603b17..f28c9bbdef 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/Descriptor.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/Descriptor.java @@ -52,6 +52,7 @@ import org.apache.fory.annotation.Int64Type; import org.apache.fory.annotation.Int8Type; import org.apache.fory.annotation.Internal; +import org.apache.fory.annotation.Ref; import org.apache.fory.annotation.UInt16Type; import org.apache.fory.annotation.UInt32Type; import org.apache.fory.annotation.UInt64Type; @@ -63,6 +64,7 @@ import org.apache.fory.collection.Tuple2; import org.apache.fory.meta.TypeExtMeta; import org.apache.fory.reflect.TypeRef; +import org.apache.fory.reflect.TypeUseMetadata; import org.apache.fory.serializer.converter.FieldConverter; import org.apache.fory.util.ExceptionUtils; import org.apache.fory.util.StringUtils; @@ -107,8 +109,9 @@ public static void clearDescriptorCache() { private final Annotation typeAnnotation; private final boolean arrayType; private boolean nullable; - // trackingRef should only be true if explicitly set to true via @ForyField(ref=true) - // If no annotation or ref not specified, trackingRef stays false and type-based tracking applies + // Ref metadata needs a presence bit because @Ref(enable = false) must override native + // type-based defaults, while unrelated scalar/nullability metadata must not. + private final boolean hasTrackingRefMetadata; private final boolean trackingRef; private FieldConverter fieldConverter; @@ -205,7 +208,8 @@ private Descriptor( nullableOverride == null ? resolveNullable(typeRef, !hasForyField, field, null, readMethod) : nullableOverride; - this.trackingRef = hasForyField && foryField.ref(); + this.hasTrackingRefMetadata = hasTrackingRefMetadata(typeRef, field, readMethod); + this.trackingRef = resolveTrackingRef(typeRef, field, readMethod); } public Descriptor( @@ -220,6 +224,34 @@ public Descriptor( boolean trackingRef, ForyField.Dynamic dynamic, boolean arrayType) { + this( + generatedType, + typeName, + name, + modifier, + declaringClass, + hasForyField, + foryFieldId, + nullable, + trackingRef, + trackingRef, + dynamic, + arrayType); + } + + public Descriptor( + GeneratedType generatedType, + String typeName, + String name, + int modifier, + String declaringClass, + boolean hasForyField, + int foryFieldId, + boolean nullable, + boolean trackingRef, + boolean hasTrackingRefMetadata, + ForyField.Dynamic dynamic, + boolean arrayType) { this( toTypeRef(generatedType), typeName, @@ -230,6 +262,7 @@ public Descriptor( foryFieldId, nullable, trackingRef, + hasTrackingRefMetadata, dynamic, arrayType); } @@ -257,6 +290,7 @@ public Descriptor( typeAnnotation = null; arrayType = false; this.nullable = nullable; + this.hasTrackingRefMetadata = true; this.trackingRef = trackingRef; } @@ -297,6 +331,34 @@ public Descriptor( boolean trackingRef, ForyField.Dynamic dynamic, boolean arrayType) { + this( + typeRef, + typeName, + name, + modifier, + declaringClass, + hasForyField, + foryFieldId, + nullable, + trackingRef, + true, + dynamic, + arrayType); + } + + public Descriptor( + TypeRef typeRef, + String typeName, + String name, + int modifier, + String declaringClass, + boolean hasForyField, + int foryFieldId, + boolean nullable, + boolean trackingRef, + boolean hasTrackingRefMetadata, + ForyField.Dynamic dynamic, + boolean arrayType) { this.field = null; this.typeName = typeName; this.name = name; @@ -320,8 +382,7 @@ public Descriptor( } else { this.nullable = nullable; } - // Synthetic descriptors created from remote TypeDef fields must preserve schema-owned wrapper - // ref tracking even when the field has no tag id/@ForyField metadata. + this.hasTrackingRefMetadata = hasTrackingRefMetadata; this.trackingRef = trackingRef; } @@ -349,7 +410,8 @@ private Descriptor( typeAnnotation = getAnnotation(field); arrayType = field.isAnnotationPresent(ArrayType.class); this.nullable = resolveNullable(typeRef, !hasForyField, field, recordComponent, readMethod); - this.trackingRef = hasForyField && foryField.ref(); + this.hasTrackingRefMetadata = hasTrackingRefMetadata(typeRef, field, readMethod); + this.trackingRef = resolveTrackingRef(typeRef, field, readMethod); } private Descriptor(Method readMethod) { @@ -370,7 +432,8 @@ private Descriptor(Method readMethod) { typeAnnotation = TypeUtils.getMethodReturnTypeUseAnnotation(readMethod, readMethod.getName()); arrayType = readMethod.isAnnotationPresent(ArrayType.class); this.nullable = resolveNullable(typeRef, !hasForyField, null, null, readMethod); - this.trackingRef = hasForyField && foryField.ref(); + this.hasTrackingRefMetadata = hasTrackingRefMetadata(typeRef, null, readMethod); + this.trackingRef = resolveTrackingRef(typeRef, null, readMethod); } public Descriptor(DescriptorBuilder builder) { @@ -382,6 +445,7 @@ public Descriptor(DescriptorBuilder builder) { this.field = builder.field; this.readMethod = builder.readMethod; this.writeMethod = builder.writeMethod; + this.hasTrackingRefMetadata = builder.hasTrackingRefMetadata; this.trackingRef = builder.trackingRef; this.foryField = builder.foryField != null @@ -433,6 +497,39 @@ private static boolean resolveNullable(TypeRef typeRef, boolean defaultNullab return TypeUtils.isNullable(typeRef, defaultNullable); } + private static boolean hasTrackingRefMetadata( + TypeRef typeRef, Field field, Method readMethod) { + if (field != null && field.getAnnotation(Ref.class) != null) { + return true; + } + if (getTypeUseRef(field, readMethod) != null) { + return true; + } + TypeExtMeta typeExtMeta = typeRef.getTypeExtMeta(); + return typeExtMeta != null && typeExtMeta.trackingRef(); + } + + private static boolean resolveTrackingRef(TypeRef typeRef, Field field, Method readMethod) { + Ref ref = field == null ? null : field.getAnnotation(Ref.class); + if (ref != null) { + return ref.enable(); + } + ref = getTypeUseRef(field, readMethod); + if (ref != null) { + return ref.enable(); + } + TypeExtMeta typeExtMeta = typeRef.getTypeExtMeta(); + return typeExtMeta != null && typeExtMeta.trackingRef(); + } + + private static Ref getTypeUseRef(Field field, Method readMethod) { + Object typeUse = + field != null + ? TypeUseMetadata.fieldTypeUse(field) + : (readMethod == null ? null : TypeUseMetadata.methodReturnTypeUse(readMethod)); + return typeUse == null ? null : TypeUseMetadata.typeUseAnnotation(typeUse, Ref.class); + } + private static boolean resolveNullable( TypeRef typeRef, boolean defaultNullable, @@ -470,6 +567,10 @@ public boolean isTrackingRef() { return trackingRef; } + public boolean hasTrackingRefMetadata() { + return hasTrackingRefMetadata; + } + public int getModifier() { return modifier; } diff --git a/java/fory-core/src/main/java/org/apache/fory/type/DescriptorBuilder.java b/java/fory-core/src/main/java/org/apache/fory/type/DescriptorBuilder.java index d7a101521c..bb447a92c1 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/DescriptorBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/DescriptorBuilder.java @@ -42,6 +42,7 @@ public class DescriptorBuilder { ForyField.Dynamic dynamic = ForyField.Dynamic.AUTO; boolean arrayType; boolean nullable; + boolean hasTrackingRefMetadata; boolean trackingRef; FieldConverter fieldConverter; @@ -61,6 +62,7 @@ public DescriptorBuilder(Descriptor descriptor) { this.dynamic = descriptor.getMorphic(); this.arrayType = descriptor.isArrayType(); this.nullable = descriptor.isNullable(); + this.hasTrackingRefMetadata = descriptor.hasTrackingRefMetadata(); this.trackingRef = descriptor.isTrackingRef(); this.fieldConverter = descriptor.getFieldConverter(); } @@ -116,6 +118,12 @@ public DescriptorBuilder nullable(boolean nullable) { } public DescriptorBuilder trackingRef(boolean trackingRef) { + this.hasTrackingRefMetadata = true; + this.trackingRef = trackingRef; + return this; + } + + public DescriptorBuilder inferredTrackingRef(boolean trackingRef) { this.trackingRef = trackingRef; return this; } @@ -129,7 +137,6 @@ public DescriptorBuilder foryField(ForyField foryField) { throw new IllegalArgumentException( "@ForyField id must be -1 (no tag ID) or a non-negative tag ID"); } - this.trackingRef = foryField.ref(); this.dynamic = foryField.dynamic(); } else { this.foryFieldId = -1; diff --git a/java/fory-core/src/main/java/org/apache/fory/type/ScalaTypes.java b/java/fory-core/src/main/java/org/apache/fory/type/ScalaTypes.java index 4c32c1b93e..fb0903be27 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/ScalaTypes.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/ScalaTypes.java @@ -28,6 +28,7 @@ /** Scala types utils using reflection without dependency on scala library. */ @SuppressWarnings({"unchecked", "rawtypes"}) public class ScalaTypes { + private static final String SCALA_ENUM_TYPE_NAME = "scala.reflect.Enum"; private static volatile Class SCALA_MAP_TYPE; private static volatile Class SCALA_SEQ_TYPE; private static volatile Class SCALA_SET_TYPE; @@ -115,4 +116,28 @@ public static Class getScalaProductType() { public static boolean isScalaProductType(Class cls) { return getScalaProductType().isAssignableFrom(cls); } + + public static boolean isScalaEnumType(Class cls) { + return resolveScalaEnumClass(cls) != null; + } + + public static Class resolveScalaEnumClass(Class cls) { + for (Class current = cls; + current != null && current != Object.class; + current = current.getSuperclass()) { + if (implementsInterfaceNamed(current, SCALA_ENUM_TYPE_NAME)) { + return current; + } + } + return null; + } + + private static boolean implementsInterfaceNamed(Class cls, String interfaceName) { + for (Class iface : cls.getInterfaces()) { + if (iface.getName().equals(interfaceName) || implementsInterfaceNamed(iface, interfaceName)) { + return true; + } + } + return false; + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/annotation/ForyFieldSerializationTest.java b/java/fory-core/src/test/java/org/apache/fory/annotation/ForyFieldSerializationTest.java index e8f85b825a..67d5c984d6 100644 --- a/java/fory-core/src/test/java/org/apache/fory/annotation/ForyFieldSerializationTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/annotation/ForyFieldSerializationTest.java @@ -329,23 +329,25 @@ public void testMixedTagIdAndFieldName( "Mixed mode - %s/%s/codegen=%s: %d bytes%n", xlang, compatible, codegen, bytes.length); } - /** Test class for nullable and ref flags */ + /** Test class for nullable and @Ref flags */ @Data @NoArgsConstructor @AllArgsConstructor public static class TestNullableRef { - @ForyField(id = 0, ref = false) + @ForyField(id = 0) String nonNullableNoRef; @Nullable - @ForyField(id = 1, ref = false) + @ForyField(id = 1) String nullableNoRef; - @ForyField(id = 2, ref = true) + @ForyField(id = 2) + @Ref String nonNullableWithRef; @Nullable - @ForyField(id = 3, ref = true) + @ForyField(id = 3) + @Ref String nullableWithRef; } @@ -379,69 +381,75 @@ public void testNullableAndRefFlagsInPayload( xlang, compatible, codegen, bytes.length); } - /** Test class with all fields non-nullable, ref=false for size comparison */ + /** Test class with all fields non-nullable, no-ref for size comparison */ @Data @NoArgsConstructor @AllArgsConstructor public static class AllNonNullableNoRef { - @ForyField(id = 0, ref = false) + @ForyField(id = 0) String field1; - @ForyField(id = 1, ref = false) + @ForyField(id = 1) String field2; - @ForyField(id = 2, ref = false) + @ForyField(id = 2) String field3; } - /** Test class with all fields @Nullable, ref=false for size comparison */ + /** Test class with all fields @Nullable, no-ref for size comparison */ @Data @NoArgsConstructor @AllArgsConstructor public static class AllNullableNoRef { @Nullable - @ForyField(id = 0, ref = false) + @ForyField(id = 0) String field1; @Nullable - @ForyField(id = 1, ref = false) + @ForyField(id = 1) String field2; @Nullable - @ForyField(id = 2, ref = false) + @ForyField(id = 2) String field3; } - /** Test class with all fields non-nullable, ref=true for size comparison */ + /** Test class with all fields non-nullable, @Ref for size comparison */ @Data @NoArgsConstructor @AllArgsConstructor public static class AllNonNullableWithRef { - @ForyField(id = 0, ref = true) + @ForyField(id = 0) + @Ref String field1; - @ForyField(id = 1, ref = true) + @ForyField(id = 1) + @Ref String field2; - @ForyField(id = 2, ref = true) + @ForyField(id = 2) + @Ref String field3; } - /** Test class with all fields @Nullable, ref=true for size comparison */ + /** Test class with all fields @Nullable, @Ref for size comparison */ @Data @NoArgsConstructor @AllArgsConstructor public static class AllNullableWithRef { @Nullable - @ForyField(id = 0, ref = true) + @ForyField(id = 0) + @Ref String field1; @Nullable - @ForyField(id = 1, ref = true) + @ForyField(id = 1) + @Ref String field2; @Nullable - @ForyField(id = 2, ref = true) + @ForyField(id = 2) + @Ref String field3; } @@ -532,8 +540,8 @@ public void testRefFlagReducesPayloadSize( "Ref flag test - %s/%s/codegen=%s/registered=%s - NoRef: %d bytes, WithRef: %d bytes%n", xlang, compatible, codegen, registered, bytesNoRef.length, bytesWithRef.length); - // ref=false should produce smaller or equal payload - // Each ref=true field may add overhead for reference tracking + // no-ref should produce smaller or equal payload + // Each @Ref field may add overhead for reference tracking assertTrue( bytesNoRef.length <= bytesWithRef.length, String.format( @@ -558,9 +566,9 @@ public void testCombinedNullableAndRefFlagsReducePayloadSize( } // Create objects with same data - // Most optimized: non-nullable, ref=false + // Most optimized: non-nullable, no-ref AllNonNullableNoRef optimized = new AllNonNullableNoRef("value1", "value2", "value3"); - // Least optimized: @Nullable, ref=true + // Least optimized: @Nullable, @Ref AllNullableWithRef unoptimized = new AllNullableWithRef("value1", "value2", "value3"); byte[] bytesOptimized = fory.serialize(optimized); @@ -590,12 +598,12 @@ public void testCombinedNullableAndRefFlagsReducePayloadSize( bytesUnoptimized.length - bytesOptimized.length, 100.0 * (bytesUnoptimized.length - bytesOptimized.length) / bytesUnoptimized.length); - // Optimized (non-nullable, ref=false) should be smaller than unoptimized (@Nullable, - // ref=true) + // Optimized (non-nullable, no-ref) should be smaller than unoptimized (@Nullable, + // @Ref) assertTrue( bytesOptimized.length < bytesUnoptimized.length, String.format( - "Expected optimized (non-nullable,ref=false) %d bytes to be < unoptimized (@Nullable,ref=true) %d bytes in mode %s/%s/codegen=%s/registered=%s", + "Expected optimized (non-nullable,no-ref) %d bytes to be < unoptimized (@Nullable,@Ref) %d bytes in mode %s/%s/codegen=%s/registered=%s", bytesOptimized.length, bytesUnoptimized.length, xlang, diff --git a/java/fory-core/src/test/java/org/apache/fory/annotation/ForyFieldTest.java b/java/fory-core/src/test/java/org/apache/fory/annotation/ForyFieldTest.java index f71c3e635f..fef83e95f8 100644 --- a/java/fory-core/src/test/java/org/apache/fory/annotation/ForyFieldTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/annotation/ForyFieldTest.java @@ -20,6 +20,7 @@ package org.apache.fory.annotation; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; @@ -30,6 +31,7 @@ import lombok.NoArgsConstructor; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.type.Descriptor; import org.testng.annotations.Test; public class ForyFieldTest extends ForyTestBase { @@ -39,10 +41,10 @@ public class ForyFieldTest extends ForyTestBase { @NoArgsConstructor @AllArgsConstructor public static class Point { - @ForyField(id = 0, ref = false) + @ForyField(id = 0) public double x; - @ForyField(id = 1, ref = false) + @ForyField(id = 1) public double y; } @@ -61,21 +63,21 @@ public void testSimpleValueObject() { @NoArgsConstructor @AllArgsConstructor public static class User { - @ForyField(id = 0, ref = false) + @ForyField(id = 0) public long userId; - @ForyField(id = 1, ref = false) + @ForyField(id = 1) public String username; @Nullable - @ForyField(id = 2, ref = false) + @ForyField(id = 2) public String email; // Can be null during account creation @Nullable - @ForyField(id = 3, ref = false) + @ForyField(id = 3) public String phoneNumber; // Optional contact method - @ForyField(id = 4, ref = false) + @ForyField(id = 4) public long createdAt; } @@ -108,10 +110,10 @@ public void testEntityWithAllNullOptionalFields() { @NoArgsConstructor @AllArgsConstructor public static class Customer { - @ForyField(id = 0, ref = false) + @ForyField(id = 0) public long customerId; - @ForyField(id = 1, ref = false) + @ForyField(id = 1) public String name; } @@ -119,14 +121,15 @@ public static class Customer { @NoArgsConstructor @AllArgsConstructor public static class Order { - @ForyField(id = 0, ref = false) + @ForyField(id = 0) public long orderId; - @ForyField(id = 1, ref = true) + @ForyField(id = 1) + @Ref public Customer customer; // Same Customer might appear in many orders @Nullable - @ForyField(id = 2, ref = false) + @ForyField(id = 2) public String notes; // Unique per order } @@ -156,7 +159,7 @@ public void testSharedObjectReferences() { assertEquals(deserialized.get(1).orderId, 101L); assertEquals(deserialized.get(0).customer.customerId, 1L); assertEquals(deserialized.get(1).customer.customerId, 1L); - // Both orders should reference the same customer object due to ref=true + // Both orders should reference the same customer object due to @Ref // (though this is more about serialization efficiency than behavior) } @@ -183,15 +186,16 @@ public void testNullableDefaults() { assertNull(deserialized.field2); } - /** Test ref defaults */ + /** Test field-wrapper reference tracking defaults */ @Data @NoArgsConstructor @AllArgsConstructor public static class DefaultRefTest { - @ForyField(id = 0) // ref defaults to false + @ForyField(id = 0) // no @Ref, so field-wrapper ref tracking is disabled public String field1; - @ForyField(id = 1, ref = true) + @ForyField(id = 1) + @Ref public String field2; } @@ -236,6 +240,41 @@ public void testMixedAnnotatedAndRegularFields() { assertNull(deserialized.anotherAnnotatedField); } + @Data + @NoArgsConstructor + @AllArgsConstructor + public static class RefOwnerDefaults { + @ForyField(id = 0) + public Customer foryFieldOnly; + + @Ref(enable = false) + public Customer refDisabled; + } + + @Test + public void testRefAnnotationOwnsNativeRefOverride() { + Fory fory = + Fory.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .build(); + List descriptors = + fory.getTypeResolver().getFieldDescriptors(RefOwnerDefaults.class, true); + + assertTrue(descriptorNamed(descriptors, "foryFieldOnly").isTrackingRef()); + assertFalse(descriptorNamed(descriptors, "refDisabled").isTrackingRef()); + } + + private static Descriptor descriptorNamed(List descriptors, String name) { + for (Descriptor descriptor : descriptors) { + if (descriptor.getName().equals(name)) { + return descriptor; + } + } + throw new AssertionError("Descriptor not found: " + name); + } + /** Test with primitive types */ @Data @NoArgsConstructor diff --git a/java/fory-core/src/test/java/org/apache/fory/builder/StaticCompatibleCodecBuilderTest.java b/java/fory-core/src/test/java/org/apache/fory/builder/StaticCompatibleCodecBuilderTest.java index 7c31e06844..ffb3d3fb4f 100644 --- a/java/fory-core/src/test/java/org/apache/fory/builder/StaticCompatibleCodecBuilderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/builder/StaticCompatibleCodecBuilderTest.java @@ -19,11 +19,6 @@ package org.apache.fory.builder; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.atLeastOnce; - import java.io.IOException; import java.lang.reflect.Field; import java.net.URL; @@ -55,8 +50,6 @@ import org.apache.fory.serializer.Serializer; import org.apache.fory.serializer.StaticGeneratedStructSerializer; import org.apache.fory.serializer.StaticGeneratedStructSerializer.RemoteFieldInfo; -import org.mockito.MockedStatic; -import org.mockito.Mockito; import org.testng.Assert; import org.testng.SkipException; import org.testng.annotations.DataProvider; @@ -295,11 +288,11 @@ public void testStaticCompatibleSkipsUnknownBackReferenceField() throws Exceptio "test.StaticCompatibleRefPayload", "package test;\n" + "import java.util.List;\n" - + "import org.apache.fory.annotation.ForyField;\n" + "import org.apache.fory.annotation.Nullable;\n" + + "import org.apache.fory.annotation.Ref;\n" + "public class StaticCompatibleRefPayload {\n" - + " @Nullable @ForyField(ref = true) public String name;\n" - + " @Nullable @ForyField(ref = true) public String nameAlias;\n" + + " @Nullable @Ref public String name;\n" + + " @Nullable @Ref public String nameAlias;\n" + " public List after;\n" + " public StaticCompatibleRefPayload() {}\n" + "}\n"); @@ -308,10 +301,10 @@ public void testStaticCompatibleSkipsUnknownBackReferenceField() throws Exceptio "test.StaticCompatibleRefPayload", "package test;\n" + "import java.util.List;\n" - + "import org.apache.fory.annotation.ForyField;\n" + "import org.apache.fory.annotation.Nullable;\n" + + "import org.apache.fory.annotation.Ref;\n" + "public class StaticCompatibleRefPayload {\n" - + " @Nullable @ForyField(ref = true) public String name;\n" + + " @Nullable @Ref public String name;\n" + " public List after;\n" + " public StaticCompatibleRefPayload() {}\n" + "}\n"); @@ -460,22 +453,7 @@ private static Object roundTripThroughStaticCompatibleSerializer( writer.setMetaWriteContext(new MetaWriteContext()); byte[] bytes = writer.serialize(writerValue); reader.setMetaReadContext(new MetaReadContext()); - try (MockedStatic codecUtils = - Mockito.mockStatic(CodecUtils.class, Mockito.CALLS_REAL_METHODS)) { - codecUtils - .when( - () -> - CodecUtils.loadOrGenCompatibleCodecClass( - same(reader.getTypeResolver()), eq(readerClass), any(TypeDef.class))) - .thenReturn(compatibleSerializerClass); - Object result = reader.deserialize(bytes); - codecUtils.verify( - () -> - CodecUtils.loadOrGenCompatibleCodecClass( - same(reader.getTypeResolver()), eq(readerClass), any(TypeDef.class)), - atLeastOnce()); - return result; - } + return reader.deserialize(bytes); } private static Map remoteCodecCategories( diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/struct/FingerprintTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/struct/FingerprintTest.java index 351c0d7f3a..ee7e6c1404 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/struct/FingerprintTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/struct/FingerprintTest.java @@ -20,13 +20,18 @@ package org.apache.fory.serializer.struct; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; import java.util.List; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.annotation.ForyStruct; +import org.apache.fory.annotation.Ref; +import org.apache.fory.config.Language; import org.apache.fory.data.AllUnsignedFields; import org.apache.fory.data.UnsignedArrayFields; import org.apache.fory.data.UnsignedScalarFields; +import org.apache.fory.meta.TypeDef; import org.apache.fory.type.Descriptor; import org.apache.fory.type.Types; import org.testng.annotations.Test; @@ -34,6 +39,11 @@ /** Tests for {@link Fingerprint} with unsigned integer types and unsigned integer array types. */ public class FingerprintTest extends ForyTestBase { + @ForyStruct + public static class RefWithoutForyFieldStruct { + @Ref public RefWithoutForyFieldStruct peer; + } + @Test public void testUnsignedScalarFieldsFingerprint() { Fory fory = Fory.builder().build(); @@ -137,4 +147,18 @@ public void testAllUnsignedFieldsFingerprint() { + ",0,1;"; assertEquals(fingerprint, expected); } + + @Test + public void testRefWithoutForyFieldAffectsFingerprintAndTypeDef() { + Fory fory = Fory.builder().withLanguage(Language.XLANG).withRefTracking(true).build(); + fory.register(RefWithoutForyFieldStruct.class, 701); + List descriptors = Descriptor.getDescriptors(RefWithoutForyFieldStruct.class); + + String fingerprint = Fingerprint.computeStructFingerprint(fory, descriptors); + + assertEquals(fingerprint, "peer," + Types.UNKNOWN + ",1,0;"); + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), RefWithoutForyFieldStruct.class); + assertEquals(typeDef.getFieldCount(), 1); + assertTrue(typeDef.getFieldsInfo().get(0).getFieldType().trackingRef()); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java index 0708a10a4a..0b07ff23ef 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java @@ -2707,11 +2707,13 @@ static class RefInnerSchemaConsistent { @ForyStruct static class RefOuterSchemaConsistent { @Nullable - @ForyField(ref = true, dynamic = ForyField.Dynamic.FALSE) + @ForyField(dynamic = ForyField.Dynamic.FALSE) + @Ref RefInnerSchemaConsistent inner1; @Nullable - @ForyField(ref = true, dynamic = ForyField.Dynamic.FALSE) + @ForyField(dynamic = ForyField.Dynamic.FALSE) + @Ref RefInnerSchemaConsistent inner2; } @@ -2787,13 +2789,9 @@ static class RefInnerCompatible { @Data @ForyStruct static class RefOuterCompatible { - @Nullable - @ForyField(ref = true) - RefInnerCompatible inner1; + @Nullable @Ref RefInnerCompatible inner1; - @Nullable - @ForyField(ref = true) - RefInnerCompatible inner2; + @Nullable @Ref RefInnerCompatible inner2; } /** @@ -3017,9 +3015,7 @@ public void testCollectionElementRefRemoteTracking(boolean enableCodegen) static class CircularRefStruct { String name; - @Nullable - @ForyField(ref = true) - CircularRefStruct selfRef; + @Nullable @Ref CircularRefStruct selfRef; } /** diff --git a/java/fory-latest-jdk-tests/src/test/java/org/apache/fory/integration_tests/RecordXlangTest.java b/java/fory-latest-jdk-tests/src/test/java/org/apache/fory/integration_tests/RecordXlangTest.java index 1eaef93a64..0afb6c7a47 100644 --- a/java/fory-latest-jdk-tests/src/test/java/org/apache/fory/integration_tests/RecordXlangTest.java +++ b/java/fory-latest-jdk-tests/src/test/java/org/apache/fory/integration_tests/RecordXlangTest.java @@ -30,6 +30,7 @@ import org.apache.fory.Fory; import org.apache.fory.annotation.ForyField; import org.apache.fory.annotation.Nullable; +import org.apache.fory.annotation.Ref; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.meta.DeflaterMetaCompressor; import org.testng.Assert; @@ -651,8 +652,8 @@ public record RefInnerRecord(int id, String name) {} * same RefInnerRecord instance. */ public record RefOuterRecord( - @Nullable @ForyField(ref = true, dynamic = ForyField.Dynamic.FALSE) RefInnerRecord inner1, - @Nullable @ForyField(ref = true, dynamic = ForyField.Dynamic.FALSE) RefInnerRecord inner2) {} + @Nullable @Ref @ForyField(dynamic = ForyField.Dynamic.FALSE) RefInnerRecord inner1, + @Nullable @Ref @ForyField(dynamic = ForyField.Dynamic.FALSE) RefInnerRecord inner2) {} /** * Test reference tracking with Record in SCHEMA_CONSISTENT mode. Creates an outer struct with two diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt index afa0c9cb7c..b56deeea66 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt @@ -25,6 +25,7 @@ import com.google.devtools.ksp.processing.KSPLogger import com.google.devtools.ksp.processing.Resolver import com.google.devtools.ksp.processing.SymbolProcessor import com.google.devtools.ksp.processing.SymbolProcessorEnvironment +import com.google.devtools.ksp.symbol.AnnotationUseSiteTarget import com.google.devtools.ksp.symbol.ClassKind import com.google.devtools.ksp.symbol.KSAnnotated import com.google.devtools.ksp.symbol.KSAnnotation @@ -41,6 +42,8 @@ import java.util.Locale private const val MAX_CONSTRUCTOR_FIELDS = Long.SIZE_BITS - 1 private const val MAX_DEFAULT_CONSTRUCTOR_FIELDS = 12 +private const val REF_NOT_SUPPORTED_DIAGNOSTIC = + "@Ref is not supported by Kotlin KSP xlang serializers because constructor-based reads cannot publish partially constructed objects" internal fun constructorFieldLimitDiagnostic(fieldCount: Int): String? = if (fieldCount > MAX_CONSTRUCTOR_FIELDS) { @@ -232,14 +235,12 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso ) return null } - val type = parameter.type.resolve() - if (fieldMeta.ref) { - logger.error( - "@ForyField(ref = true) is not supported by Kotlin KSP xlang serializers because constructor-based reads cannot publish partially constructed objects", - property, - ) + val hasFieldRef = resolveFieldRef(property, parameter) ?: return null + if (hasFieldRef) { + logger.error(REF_NOT_SUPPORTED_DIAGNOSTIC, property) return null } + val type = parameter.type.resolve() val typeNode = parseType(type, property, arrayType = hasFieldAnnotation(property, ARRAY_TYPE)) ?: return null @@ -256,7 +257,7 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso type = typeNode, hasForyField = fieldMeta.hasAnnotation, foryFieldId = fieldMeta.id, - trackingRef = fieldMeta.ref, + trackingRef = false, dynamic = fieldMeta.dynamic, arrayType = hasFieldAnnotation(property, ARRAY_TYPE), hasDefault = parameter.hasDefault, @@ -374,20 +375,100 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso return property.annotations.any { isAnnotation(it, qualifiedName) } || false } + private fun resolveFieldRef( + property: KSPropertyDeclaration, + parameter: KSValueParameter, + ): Boolean? { + val getterHasRef = property.getter?.annotations?.any { isAnnotation(it, REF) } == true + val setterHasRef = property.setter?.annotations?.any { isAnnotation(it, REF) } == true + if (getterHasRef || setterHasRef) { + logger.error("@get:Ref and @set:Ref are not valid for Kotlin xlang schema fields", property) + return null + } + val refs = mutableListOf() + if (!appendFieldRefAnnotations(refs, property.annotations, property)) { + return null + } + if (!appendParameterRefAnnotations(refs, parameter.annotations, property)) { + return null + } + return refs.isNotEmpty() + } + + private fun appendFieldRefAnnotations( + refs: MutableList, + annotations: Sequence, + owner: KSAnnotated, + ): Boolean { + for (annotation in annotations) { + if (!isAnnotation(annotation, REF)) { + continue + } + val useSiteTarget = annotation.useSiteTarget + when (useSiteTarget) { + null, + AnnotationUseSiteTarget.PROPERTY, + AnnotationUseSiteTarget.FIELD -> refs.add(annotation) + AnnotationUseSiteTarget.GET, + AnnotationUseSiteTarget.SET -> + logger.error( + "@get:Ref and @set:Ref are not valid for Kotlin xlang schema fields", + owner, + ) + else -> + logger.error( + "@${useSiteTarget.name.lowercase(Locale.ROOT)}:Ref is not valid for Kotlin xlang schema fields", + owner, + ) + } + if ( + useSiteTarget != null && + useSiteTarget != AnnotationUseSiteTarget.PROPERTY && + useSiteTarget != AnnotationUseSiteTarget.FIELD + ) { + return false + } + } + return true + } + + private fun appendParameterRefAnnotations( + refs: MutableList, + annotations: Sequence, + owner: KSAnnotated, + ): Boolean { + for (annotation in annotations) { + if (!isAnnotation(annotation, REF)) { + continue + } + val useSiteTarget = annotation.useSiteTarget + when (useSiteTarget) { + null, + AnnotationUseSiteTarget.PARAM -> refs.add(annotation) + else -> { + logger.error( + "@${useSiteTarget.name.lowercase(Locale.ROOT)}:Ref is not valid for Kotlin constructor parameters", + owner, + ) + return false + } + } + } + return true + } + private fun foryFieldMeta(annotations: Sequence): ForyFieldMeta? { val annotation = annotations.firstOrNull { isAnnotation(it, FORY_FIELD) } ?: return null var id = -1 - var ref = false var dynamic = "AUTO" for (argument in annotation.arguments) { when (argument.name?.asString()) { "id" -> id = argument.value as Int - "ref" -> ref = argument.value as Boolean "dynamic" -> dynamic = argument.value.toString().substringAfterLast('.').uppercase(Locale.ROOT) } } - return ForyFieldMeta(true, id, ref, dynamic) + return ForyFieldMeta(true, id, dynamic) } private fun parseType( @@ -406,6 +487,11 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso return null } val nullable = type.nullability == Nullability.NULLABLE + val hasTypeRef = hasRefAnnotation(type, owner) ?: return null + if (hasTypeRef) { + logger.error(REF_NOT_SUPPORTED_DIAGNOSTIC, owner) + return null + } val encoding = encodingAnnotation(type, owner) if (encoding == Encoding.Invalid) { return null @@ -662,6 +748,15 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso return encodings.firstOrNull() } + private fun hasRefAnnotation(type: KSType, owner: KSAnnotated): Boolean? { + val refs = type.annotations.filter { isAnnotation(it, REF) }.toList() + if (refs.size > 1) { + logger.error("Kotlin xlang field types must not repeat @Ref", owner) + return null + } + return refs.isNotEmpty() + } + private fun scalarType( qualifiedName: String?, encoding: Encoding?, @@ -1162,6 +1257,7 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso const val FORY_FIELD = "org.apache.fory.annotation.ForyField" const val ARRAY_TYPE = "org.apache.fory.annotation.ArrayType" const val NULLABLE = "org.apache.fory.annotation.Nullable" + const val REF = "org.apache.fory.annotation.Ref" const val FIXED = "org.apache.fory.kotlin.Fixed" const val VAR_INT = "org.apache.fory.kotlin.VarInt" const val TAGGED = "org.apache.fory.kotlin.Tagged" diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt index 676f7b6fe0..dc1aba75be 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt @@ -126,10 +126,9 @@ internal enum class CollectionFactory { internal data class ForyFieldMeta( val hasAnnotation: Boolean, val id: Int, - val ref: Boolean, val dynamic: String, ) { companion object { - val NONE: ForyFieldMeta = ForyFieldMeta(false, -1, false, "AUTO") + val NONE: ForyFieldMeta = ForyFieldMeta(false, -1, "AUTO") } } diff --git a/scala/src/main/java/org/apache/fory/scala/ForyScalaEnum.java b/scala/src/main/java/org/apache/fory/scala/ForyScalaEnum.java deleted file mode 100644 index a69e7c5635..0000000000 --- a/scala/src/main/java/org/apache/fory/scala/ForyScalaEnum.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.fory.scala; - -/** Marker interface for Scala 3 enums generated from Fory schema enum definitions. */ -public interface ForyScalaEnum { - int getForyId(); -} diff --git a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaDispatcher.java b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaDispatcher.java index 6f9ca18a96..7dd3a1c3b6 100644 --- a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaDispatcher.java +++ b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaDispatcher.java @@ -19,21 +19,28 @@ package org.apache.fory.serializer.scala; +import java.lang.reflect.Method; import org.apache.fory.resolver.TypeResolver; -import org.apache.fory.scala.ForyScalaEnum; import org.apache.fory.serializer.JavaSerializer; import org.apache.fory.serializer.Serializer; import org.apache.fory.serializer.SerializerFactory; import org.apache.fory.util.Preconditions; import scala.collection.generic.DefaultSerializable; -import java.lang.reflect.Method; - /** * Serializer dispatcher for scala types. */ @SuppressWarnings({"rawtypes", "unchecked"}) public class ScalaDispatcher implements SerializerFactory { + private final SerializerFactory delegate; + + public ScalaDispatcher() { + this(null); + } + + public ScalaDispatcher(SerializerFactory delegate) { + this.delegate = delegate; + } /** * Get Serializer for scala type. @@ -43,7 +50,14 @@ public class ScalaDispatcher implements SerializerFactory { */ @Override public Serializer createSerializer(TypeResolver typeResolver, Class clz) { - if (ForyScalaEnum.class.isAssignableFrom(clz)) { + Serializer serializer; + if (delegate != null) { + serializer = delegate.createSerializer(typeResolver, clz); + if (serializer != null) { + return serializer; + } + } + if (ScalaEnumSerializer.canSerialize(clz)) { return new ScalaEnumSerializer(typeResolver, clz); } if (scala.Option.class.isAssignableFrom(clz)) { diff --git a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java index 3a591cb3c3..f6f6e3a7a7 100644 --- a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java +++ b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaEnumSerializer.java @@ -19,46 +19,56 @@ package org.apache.fory.serializer.scala; +import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.lang.reflect.Modifier; import java.util.Arrays; +import java.util.IdentityHashMap; +import org.apache.fory.annotation.ForyEnumId; +import org.apache.fory.collection.IdentityObjectIntMap; import org.apache.fory.collection.LongMap; import org.apache.fory.config.Config; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; -import org.apache.fory.scala.ForyScalaEnum; import org.apache.fory.serializer.ImmutableSerializer; import org.apache.fory.serializer.Shareable; +import org.apache.fory.type.ScalaTypes; import org.apache.fory.util.Preconditions; -/** Serializer for Scala 3 enums generated by the Fory Scala schema IDL target. */ +/** Serializer for Scala 3 enums with parameterless cases. */ @SuppressWarnings({"unchecked", "rawtypes"}) public final class ScalaEnumSerializer extends ImmutableSerializer implements Shareable { private static final int MAX_ENUM_ID_ARRAY_SIZE = 2048; private final Config config; private final Object[] enumConstants; + private final IdentityObjectIntMap tagByValue; private final Object[] enumConstantByTagArray; private final LongMap enumConstantByTagMap; public ScalaEnumSerializer(org.apache.fory.resolver.TypeResolver resolver, Class cls) { - super(resolver.getConfig(), (Class) cls, false); + super(resolver.getConfig(), resolveSerializerClass(cls), false); config = resolver.getConfig(); - Preconditions.checkArgument( - ForyScalaEnum.class.isAssignableFrom(cls), - "Scala enum %s must implement %s", - cls, - ForyScalaEnum.class.getName()); - enumConstants = loadValues(cls); + Class enumClass = ScalaTypes.resolveScalaEnumClass(cls); + enumConstants = loadValues(enumClass); + IdentityHashMap explicitTags = loadExplicitTags(enumClass); + tagByValue = new IdentityObjectIntMap<>(enumConstants.length, 0.5f); LongMap constantsByTag = new LongMap<>(enumConstants.length); int maxTag = 0; - for (Object enumConstant : enumConstants) { - int tag = ((ForyScalaEnum) enumConstant).getForyId(); + for (int i = 0; i < enumConstants.length; i++) { + Object enumConstant = enumConstants[i]; + int tag = explicitTags == null ? i : explicitTags.getOrDefault(enumConstant, -1); + Preconditions.checkArgument( + tag >= 0, + "Scala enum %s must annotate every case with @ForyEnumId when any case uses it", + enumClass.getName()); + tagByValue.put(enumConstant, tag); Object previous = constantsByTag.put(tag, enumConstant); Preconditions.checkArgument( previous == null, "Scala enum %s reuses Fory enum id %s for %s and %s", - cls.getName(), + enumClass.getName(), tag, previous, enumConstant); @@ -78,7 +88,9 @@ public ScalaEnumSerializer(org.apache.fory.resolver.TypeResolver resolver, Class @Override public void write(WriteContext writeContext, Object value) { - writeContext.getBuffer().writeVarUInt32Small7(((ForyScalaEnum) value).getForyId()); + int tag = tagByValue.get(value, -1); + Preconditions.checkArgument(tag >= 0, "Scala enum value %s is not a registered case", value); + writeContext.getBuffer().writeVarUInt32Small7(tag); } @Override @@ -111,14 +123,92 @@ private Object handleUnknownEnumValue(int tag) { } static Object[] loadValues(Class cls) { + Class enumClass = ScalaTypes.resolveScalaEnumClass(cls); + Preconditions.checkArgument( + enumClass != null, "Scala enum %s must implement scala.reflect.Enum", cls); try { - Method values = cls.getMethod("values"); + Method values = enumClass.getMethod("values"); Object result = values.invoke(null); Preconditions.checkArgument( - result instanceof Object[], "Scala enum %s values() did not return an array", cls); + result instanceof Object[], "Scala enum %s values() did not return an array", enumClass); return (Object[]) result; } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { - throw new IllegalArgumentException("Failed to load Scala enum values for " + cls.getName(), e); + throw new IllegalArgumentException( + "Failed to load Scala enum values for " + enumClass.getName(), e); + } + } + + static boolean canSerialize(Class cls) { + Class enumClass = ScalaTypes.resolveScalaEnumClass(cls); + if (enumClass == null) { + return false; + } + try { + enumClass.getMethod("values"); + return true; + } catch (NoSuchMethodException e) { + return false; + } + } + + private static Class resolveSerializerClass(Class cls) { + Class enumClass = ScalaTypes.resolveScalaEnumClass(cls); + Preconditions.checkArgument( + enumClass != null, "Scala enum %s must implement scala.reflect.Enum", cls); + Preconditions.checkArgument( + canSerialize(enumClass), + "Scala enum %s must define values() to use ScalaEnumSerializer", + enumClass); + return (Class) enumClass; + } + + private static IdentityHashMap loadExplicitTags(Class enumClass) { + Class companion = loadCompanionClass(enumClass); + IdentityHashMap tagsByValue = new IdentityHashMap<>(); + for (Field field : companion.getDeclaredFields()) { + if (!Modifier.isStatic(field.getModifiers()) + || !enumClass.isAssignableFrom(field.getType())) { + continue; + } + ForyEnumId annotation = field.getAnnotation(ForyEnumId.class); + if (annotation == null) { + continue; + } + Preconditions.checkArgument( + annotation.value() >= 0, + "Scala enum %s case %s annotated with @ForyEnumId must declare a non-negative value", + enumClass.getName(), + field.getName()); + field.setAccessible(true); + Object enumConstant = readCaseField(enumClass, field); + Integer previous = tagsByValue.put(enumConstant, annotation.value()); + Preconditions.checkArgument( + previous == null, + "Scala enum %s case %s has multiple @ForyEnumId declarations", + enumClass.getName(), + field.getName()); + } + return tagsByValue.isEmpty() ? null : tagsByValue; + } + + private static Class loadCompanionClass(Class enumClass) { + try { + return Class.forName(enumClass.getName() + "$", false, enumClass.getClassLoader()); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException( + "Failed to load Scala enum companion for " + enumClass.getName(), e); + } + } + + private static Object readCaseField(Class enumClass, Field field) { + try { + return field.get(null); + } catch (IllegalAccessException e) { + throw new IllegalArgumentException( + String.format( + "Failed to read @ForyEnumId case field %s on Scala enum %s", + field.getName(), enumClass.getName()), + e); } } } diff --git a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java index 4541d3d3f3..3b214e44c0 100644 --- a/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java +++ b/scala/src/main/java/org/apache/fory/serializer/scala/ScalaSerializers.java @@ -22,9 +22,9 @@ import org.apache.fory.AbstractThreadSafeFory; import org.apache.fory.Fory; import org.apache.fory.ThreadSafeFory; +import org.apache.fory.annotation.Internal; import org.apache.fory.config.Config; import org.apache.fory.resolver.TypeResolver; -import org.apache.fory.serializer.Serializer; import org.apache.fory.serializer.SerializerFactory; import scala.collection.immutable.NumericRange; import scala.collection.immutable.Range; @@ -40,7 +40,8 @@ public static void registerSerializers(ThreadSafeFory fory) { } public static void registerSerializers(Fory fory) { - TypeResolver resolver = setSerializerFactory(fory); + TypeResolver resolver = fory.getTypeResolver(); + ensureScalaDispatcher(fory); if (resolver.isCrossLanguage()) { return; } @@ -177,40 +178,38 @@ public static void registerSerializers(Fory fory) { public static void registerEnum(Fory fory, Class cls, long typeId) { TypeResolver resolver = fory.getTypeResolver(); resolver.registerEnum(cls, typeId, new ScalaEnumSerializer(resolver, cls)); - registerEnumCases(resolver, cls); + registerEnumRuntimeAliases(fory, cls); } public static void registerEnum(Fory fory, Class cls, String namespace, String typeName) { TypeResolver resolver = fory.getTypeResolver(); resolver.registerEnum(cls, namespace, typeName, new ScalaEnumSerializer(resolver, cls)); - registerEnumCases(resolver, cls); + registerEnumRuntimeAliases(fory, cls); + } + + @Internal + public static void registerRuntimeTypeAlias( + Fory fory, Class runtimeClass, Class canonicalClass) { + fory.getTypeResolver().registerRuntimeTypeAlias(runtimeClass, canonicalClass); } - private static void registerEnumCases(TypeResolver resolver, Class cls) { - for (Object enumConstant : ScalaEnumSerializer.loadValues(cls)) { - Class caseClass = enumConstant.getClass(); - if (caseClass != cls) { - resolver.registerEnumCase(cls, caseClass); + private static void registerEnumRuntimeAliases(Fory fory, Class cls) { + for (Object value : ScalaEnumSerializer.loadValues(cls)) { + Class runtimeClass = value.getClass(); + if (runtimeClass != cls) { + registerRuntimeTypeAlias(fory, runtimeClass, cls); } } } - private static TypeResolver setSerializerFactory(Fory fory) { + private static ScalaDispatcher ensureScalaDispatcher(Fory fory) { TypeResolver resolver = fory.getTypeResolver(); - ScalaDispatcher dispatcher = new ScalaDispatcher(); SerializerFactory factory = resolver.getSerializerFactory(); - if (factory != null) { - SerializerFactory newFactory = (typeResolver, cls) -> { - Serializer serializer = factory.createSerializer(typeResolver, cls); - if (serializer == null) { - serializer = dispatcher.createSerializer(typeResolver, cls); - } - return serializer; - }; - resolver.setSerializerFactory(newFactory); - } else { - resolver.setSerializerFactory(dispatcher); + if (factory instanceof ScalaDispatcher) { + return (ScalaDispatcher) factory; } - return resolver; + ScalaDispatcher dispatcher = new ScalaDispatcher(factory); + resolver.setSerializerFactory(dispatcher); + return dispatcher; } } diff --git a/scala/src/main/scala-3/org/apache/fory/scala/ForySerializer.scala b/scala/src/main/scala-3/org/apache/fory/scala/ForySerializer.scala index 613fa6be8b..e51b5ae66d 100644 --- a/scala/src/main/scala-3/org/apache/fory/scala/ForySerializer.scala +++ b/scala/src/main/scala-3/org/apache/fory/scala/ForySerializer.scala @@ -20,12 +20,11 @@ package org.apache.fory.scala import org.apache.fory.{Fory, ThreadSafeFory} +import org.apache.fory.annotation.Internal import org.apache.fory.meta.TypeDef import org.apache.fory.resolver.TypeResolver -import org.apache.fory.serializer.{ - Serializer, - StaticGeneratedStructSerializerFactory -} +import org.apache.fory.serializer.Serializer +import org.apache.fory.serializer.scala.ScalaSerializers trait ForySerializer[T] { def createSerializer(typeResolver: TypeResolver): Serializer[T] = @@ -35,7 +34,7 @@ trait ForySerializer[T] { def isUnion: Boolean = false - def registrationClasses(cls: Class[T]): Array[Class[_]] = Array(cls) + private[scala] def handledRuntimeClasses(cls: Class[T]): Array[Class[_]] = Array.empty } object ForySerializer { @@ -61,6 +60,25 @@ object ForySerializer { register(fory, cls, null, namespace, typeName) } + @Internal + def registerType[T](fory: Fory, cls: Class[T], typeId: Long): Unit = { + registerType(fory, cls, java.lang.Long.valueOf(typeId), null, null) + } + + @Internal + def registerType[T](fory: Fory, cls: Class[T], namespace: String, typeName: String): Unit = { + registerType(fory, cls, null, namespace, typeName) + } + + @Internal + def registerSerializer[T](fory: Fory, cls: Class[T])(using serializer: ForySerializer[T]): Unit = { + if serializer.isUnion then { + throw new IllegalArgumentException("Use ForySerializer.register for Scala union serializers") + } + val resolver = fory.getTypeResolver + resolver.setSerializer(cls, serializer.createSerializer(resolver)) + } + private def register[T]( fory: Fory, cls: Class[T], @@ -69,9 +87,6 @@ object ForySerializer { typeName: String)(using serializer: ForySerializer[T]): Unit = { val resolver = fory.getTypeResolver serializer match { - case factory: StaticGeneratedStructSerializerFactory[T] @unchecked => - registerType(fory, cls, typeId, namespace, typeName) - resolver.registerStaticGeneratedStructSerializerFactory(cls, factory) case _ if serializer.isUnion => val unionSerializer = serializer.createSerializer(resolver) if typeId != null then { @@ -86,10 +101,8 @@ object ForySerializer { unionTypeName, unionSerializer) } - serializer.registrationClasses(cls).foreach { registrationClass => - if registrationClass != cls then { - resolver.registerUnionCase(cls, registrationClass) - } + serializer.handledRuntimeClasses(cls).foreach { runtimeClass => + ScalaSerializers.registerRuntimeTypeAlias(fory, runtimeClass, cls) } case _ => registerType(fory, cls, typeId, namespace, typeName) diff --git a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala index 867a13e6f4..86534989db 100644 --- a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala +++ b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala @@ -22,16 +22,14 @@ package org.apache.fory.scala.internal import org.apache.fory.annotation.{ForyCase, ForyField, ForyStruct, ForyUnion} import org.apache.fory.meta.{TypeDef => ForyTypeDef, TypeExtMeta} import org.apache.fory.resolver.TypeResolver -import org.apache.fory.scala.ForyScalaEnum import org.apache.fory.scala.ForySerializer import org.apache.fory.serializer.{ FieldGroups, Serializer, StaticGeneratedStructSerializer, - StaticGeneratedStructSerializerFactory, UnionSerializer } -import org.apache.fory.`type`.{Descriptor, Types} +import org.apache.fory.`type`.{Descriptor, ScalaTypes, Types} import java.lang.reflect.Modifier import scala.quoted.* @@ -58,6 +56,7 @@ object ForySerializerMacros { option: Boolean, nullable: Boolean, trackingRef: Boolean, + hasTrackingRefMetadata: Boolean, constructorOwned: Boolean) if !hasAnnotation[ForyStruct](owner) then { @@ -188,6 +187,7 @@ object ForySerializerMacros { case Some(inner) => (boxedIfPrimitive(inner), true, true) case None => (sourceType, false, false) } + val refTracking = refAnnotation(field).orElse(topLevelTypeRefTracking(sourceType)) FieldMeta( field, field.name, @@ -197,16 +197,20 @@ object ForySerializerMacros { wireType, option, nullable, - hasRef(field) || topLevelTypeHasRef(sourceType), + refTracking.getOrElse(false), + refTracking.nonEmpty, constructorFieldSet.contains(field)) } + val hasNestedCompatibleStructFields = + fields.exists(field => hasNestedCompatibleStruct(field.sourceType)) def generatedType(tpe: TypeRepr): Expr[Descriptor.GeneratedType] = { val (outer, outerAnnotations) = peelAnnotations(tpe) val option = optionElement(outer) val fieldSource = option.map(boxedIfPrimitive).getOrElse(outer) val (base, baseAnnotations) = peelAnnotations(fieldSource) - val annotations = outerAnnotations ++ baseAnnotations + val optionInnerAnnotations = option.toList.flatMap(inner => peelAnnotations(inner)._2) + val annotations = outerAnnotations ++ baseAnnotations ++ optionInnerAnnotations val argumentSource = fieldSource def appliedType(tpe: TypeRepr): Option[(TypeRepr, List[TypeRepr])] = { val directArgs = tpe.typeArgs @@ -251,7 +255,10 @@ object ForySerializerMacros { val componentExpr: Expr[Descriptor.GeneratedType] = component.getOrElse('{ null.asInstanceOf[Descriptor.GeneratedType] }) val typeId = - annotations.flatMap(typeIdForAnnotation).headOption + annotations + .flatMap(typeIdForAnnotation) + .headOption + .orElse(option.map(inner => wireTypeId(peelAnnotations(boxedIfPrimitive(inner))._1))) .orElse { if hasAnnotation[ForyUnion](base.typeSymbol) then Some(Types.UNION) else None } @@ -263,11 +270,35 @@ object ForySerializerMacros { val typeExtMeta = generatedTypeExtMeta( typeId, nullable = option.nonEmpty, - trackingRef = annotations.exists(isRefAnnotation), + trackingRef = refTrackingFromAnnotations(annotations).getOrElse(false), + hasTrackingRefMetadata = refTrackingFromAnnotations(annotations).nonEmpty, + nullableWrapper = option.nonEmpty, rawClass = Some(rawClass)) '{ Descriptor.generatedType($rawClass, $typeExtMeta, $argList, $componentExpr) } } + def wireTypeId(tpe: TypeRepr): Int = { + val normalized = peelAnnotations(tpe.widen)._1.dealias + if normalized =:= TypeRepr.of[Boolean] then Types.BOOL + else if normalized =:= TypeRepr.of[Byte] then Types.INT8 + else if normalized =:= TypeRepr.of[Short] then Types.INT16 + else if normalized =:= TypeRepr.of[Int] then Types.INT32 + else if normalized =:= TypeRepr.of[Long] then Types.INT64 + else if normalized =:= TypeRepr.of[Float] then Types.FLOAT32 + else if normalized =:= TypeRepr.of[Double] then Types.FLOAT64 + else { + val fullName = normalized.typeSymbol.fullName + if normalized =:= TypeRepr.of[String] || + normalized.typeSymbol == TypeRepr.of[String].typeSymbol || + fullName == "scala.Predef.String" || + fullName == "scala.Predef$.String" || + fullName.endsWith("Predef.String") || + fullName.endsWith("Predef$.String") + then Types.STRING + else Types.UNKNOWN + } + } + def descriptor(field: FieldMeta): Expr[Descriptor] = { '{ new Descriptor( @@ -280,6 +311,7 @@ object ForySerializerMacros { ${ Expr(field.fieldId) }, ${ Expr(field.nullable) }, ${ Expr(field.trackingRef) }, + ${ Expr(field.hasTrackingRefMetadata) }, ForyField.Dynamic.AUTO, false ) @@ -420,6 +452,21 @@ object ForySerializerMacros { decodeValue(copied, field) } + def failIfCopiedDuringConstructorArgCopy( + valueExpr: Expr[T], + copyContextExpr: Expr[org.apache.fory.context.CopyContext]): Expr[Unit] = { + '{ + if $copyContextExpr.copyTrackingRef() && + $copyContextExpr.getCopyObject($valueExpr) != null + then { + throw new org.apache.fory.exception.CopyException( + "Cannot copy cyclic object graph rooted at constructor-owned immutable value " + + $valueExpr.getClass.getName + + " because its copy cannot be referenced before construction completes") + } + } + } + def referenceCopy( copyContextExpr: Expr[org.apache.fory.context.CopyContext], sourceExpr: Expr[T], @@ -439,9 +486,13 @@ object ForySerializerMacros { val args = constructorOwned.map { field => copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr).asTerm } + val cycleCheck = failIfCopiedDuringConstructorArgCopy( + valueExpr, + copyContextExpr).asTerm val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args) Block( - ValDef(obj, Some(construct)) :: + cycleCheck :: + ValDef(obj, Some(construct)) :: referenceCopy(copyContextExpr, valueExpr, Ref(obj).asExprOf[T]) :: Nil, Ref(obj)).asExprOf[T] @@ -462,31 +513,23 @@ object ForySerializerMacros { val args = constructorOwned.map { field => copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr).asTerm } + val cycleCheck = failIfCopiedDuringConstructorArgCopy( + valueExpr, + copyContextExpr).asTerm val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args) val assignments = postConstruction.map { field => val copied = copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr) Assign(Select.unique(Ref(obj), field.name), copied.asTerm) } Block( - ValDef(obj, Some(construct)) :: + cycleCheck :: + ValDef(obj, Some(construct)) :: referenceCopy(copyContextExpr, valueExpr, Ref(obj).asExprOf[T]) :: assignments, Ref(obj)).asExprOf[T] } - if constructorOwned.nonEmpty then { - '{ - if $copyContextExpr.copyTrackingRef() then { - $copyContextExpr.markCopyInProgress($valueExpr) - try ${ copyBody() } - catch { - case throwable: Throwable => - $copyContextExpr.clearCopyInProgress($valueExpr) - throw throwable - } - } else ${ copyBody() } - } - } else copyBody() + copyBody() } def constructRead(valuesExpr: Expr[Array[Any]], readContextExpr: Expr[org.apache.fory.context.ReadContext]): Expr[T] = { @@ -612,21 +655,13 @@ object ForySerializerMacros { '{ Class.forName(${ Expr(ownerClassName) }).asInstanceOf[Class[T]] } '{ - new ForySerializer[T] with StaticGeneratedStructSerializerFactory[T] { + new ForySerializer[T] { private val descriptors: java.util.List[Descriptor] = $descriptorsExpr - override def getGeneratedDescriptors(): java.util.List[Descriptor] = descriptors - override def createSerializer( resolver: TypeResolver, remoteTypeDef: ForyTypeDef): Serializer[T] = { - newSerializer(resolver, $classExpr, remoteTypeDef) - } - - override def newSerializer( - resolver: TypeResolver, - cls: Class[?], - remoteTypeDef: ForyTypeDef): StaticGeneratedStructSerializer[T] = { + val cls = $classExpr new StaticGeneratedStructSerializer[T](resolver, cls, remoteTypeDef, descriptors) { private val generatedSerializer: StaticGeneratedStructSerializer[T] = this private val fieldGroups: FieldGroups = @@ -647,10 +682,17 @@ object ForySerializerMacros { if resolver.checkClassVersion() then computeClassVersionHash(descriptors) else 0 private val sameSchemaCompatible: Boolean = remoteTypeDef != null && + !${ Expr(hasNestedCompatibleStructFields) } && remoteTypeDef.getId == ForyTypeDef.buildTypeDef(resolver, cls).getId override def getGeneratedDescriptors(): java.util.List[Descriptor] = descriptors + override def copySerializer( + typeResolver: TypeResolver, + typeClass: Class[?], + typeDef: ForyTypeDef): StaticGeneratedStructSerializer[T] = + createSerializer(typeResolver, typeDef).asInstanceOf[StaticGeneratedStructSerializer[T]] + override def write( writeContext: org.apache.fory.context.WriteContext, value: T): Unit = { @@ -714,6 +756,8 @@ object ForySerializerMacros { id: Int, payloadType: TypeRepr, option: Boolean, + trackingRef: Boolean, + hasTrackingRefMetadata: Boolean, payloadName: String, unknownIdName: String, unknown: Boolean, @@ -774,20 +818,24 @@ object ForySerializerMacros { } } - val rawCases = owner.children.flatMap { child => - annotationIntArg[ForyCase](child, "id").map { id => - if id < 0 then report.errorAndAbort(s"${child.fullName} @ForyCase id must be >= 0") - val (tpe, payloadName, unknownIdName) = payloadMeta(child, id) - CaseMeta( - child, - id, - tpe, - optionElement(tpe).nonEmpty, - payloadName, - unknownIdName, - id == 0, - -1) + val rawCases = owner.children.filter(_.flags.is(Flags.Case)).map { child => + val id = annotationIntArg[ForyCase](child, "id").getOrElse { + report.errorAndAbort(s"${child.fullName} must be annotated with @ForyCase") } + if id < 0 then report.errorAndAbort(s"${child.fullName} @ForyCase id must be >= 0") + val (tpe, payloadName, unknownIdName) = payloadMeta(child, id) + val refTracking = topLevelTypeRefTracking(tpe) + CaseMeta( + child, + id, + tpe, + optionElement(tpe).nonEmpty, + refTracking.getOrElse(false), + refTracking.nonEmpty, + payloadName, + unknownIdName, + id == 0, + -1) } var nextFieldIndex = 0 val cases = rawCases.map { unionCase => @@ -875,7 +923,8 @@ object ForySerializerMacros { val option = optionElement(outer) val fieldSource = option.map(boxedIfPrimitive).getOrElse(outer) val (base, baseAnnotations) = peelAnnotations(fieldSource) - val annotations = outerAnnotations ++ baseAnnotations + val optionInnerAnnotations = option.toList.flatMap(inner => peelAnnotations(inner)._2) + val annotations = outerAnnotations ++ baseAnnotations ++ optionInnerAnnotations val argumentSource = fieldSource def appliedType(tpe: TypeRepr): Option[(TypeRepr, List[TypeRepr])] = { val directArgs = tpe.typeArgs @@ -920,7 +969,10 @@ object ForySerializerMacros { val componentExpr: Expr[Descriptor.GeneratedType] = component.getOrElse('{ null.asInstanceOf[Descriptor.GeneratedType] }) val typeId = - annotations.flatMap(typeIdForAnnotation).headOption + annotations + .flatMap(typeIdForAnnotation) + .headOption + .orElse(option.map(inner => wireTypeId(peelAnnotations(boxedIfPrimitive(inner))._1))) .orElse { if hasAnnotation[ForyUnion](base.typeSymbol) then Some(Types.UNION) else None } @@ -932,11 +984,35 @@ object ForySerializerMacros { val typeExtMeta = generatedTypeExtMeta( typeId, nullable = option.nonEmpty, - trackingRef = annotations.exists(isRefAnnotation), + trackingRef = refTrackingFromAnnotations(annotations).getOrElse(false), + hasTrackingRefMetadata = refTrackingFromAnnotations(annotations).nonEmpty, + nullableWrapper = option.nonEmpty, rawClass = Some(rawClass)) '{ Descriptor.generatedType($rawClass, $typeExtMeta, $argList, $componentExpr) } } + def wireTypeId(tpe: TypeRepr): Int = { + val normalized = peelAnnotations(tpe.widen)._1.dealias + if normalized =:= TypeRepr.of[Boolean] then Types.BOOL + else if normalized =:= TypeRepr.of[Byte] then Types.INT8 + else if normalized =:= TypeRepr.of[Short] then Types.INT16 + else if normalized =:= TypeRepr.of[Int] then Types.INT32 + else if normalized =:= TypeRepr.of[Long] then Types.INT64 + else if normalized =:= TypeRepr.of[Float] then Types.FLOAT32 + else if normalized =:= TypeRepr.of[Double] then Types.FLOAT64 + else { + val fullName = normalized.typeSymbol.fullName + if normalized =:= TypeRepr.of[String] || + normalized.typeSymbol == TypeRepr.of[String].typeSymbol || + fullName == "scala.Predef.String" || + fullName == "scala.Predef$.String" || + fullName.endsWith("Predef.String") || + fullName.endsWith("Predef$.String") + then Types.STRING + else Types.UNKNOWN + } + } + def caseDescriptor(unionCase: CaseMeta): Expr[Descriptor] = { '{ new Descriptor( @@ -947,8 +1023,9 @@ object ForySerializerMacros { ${ Expr(owner.fullName.replace("$.", "$")) }, true, ${ Expr(unionCase.id) }, - false, - false, + ${ Expr(unionCase.option) }, + ${ Expr(unionCase.trackingRef) }, + ${ Expr(unionCase.hasTrackingRefMetadata) }, ForyField.Dynamic.AUTO, false ) @@ -1068,6 +1145,19 @@ object ForySerializerMacros { valueExpr: Expr[T], copyContextExpr: Expr[org.apache.fory.context.CopyContext], caseFieldInfosExpr: Expr[Array[FieldGroups.SerializationFieldInfo]]): Expr[T] = { + def failIfCopiedDuringPayloadCopy(): Expr[Unit] = { + '{ + if $copyContextExpr.copyTrackingRef() && + $copyContextExpr.getCopyObject($valueExpr) != null + then { + throw new org.apache.fory.exception.CopyException( + "Cannot copy cyclic object graph rooted at constructor-owned immutable value " + + $valueExpr.getClass.getName + + " because its copy cannot be referenced before construction completes") + } + } + } + cases.foldRight( '{ throw new IllegalStateException("Unknown Scala union case " + $valueExpr) @@ -1083,24 +1173,25 @@ object ForySerializerMacros { Select.unique( '{ $valueExpr.asInstanceOf[c] }.asTerm, unionCase.unknownIdName).asExprOf[Int] - val copiedPayload = '{ $copyContextExpr.copyObject($payload) } - val current = construct(unknown, List(originalId.asTerm, copiedPayload.asTerm)) '{ - if $valueExpr.isInstanceOf[c] then $current else $next + if $valueExpr.isInstanceOf[c] then { + val copiedPayload = $copyContextExpr.copyObject($payload) + ${ failIfCopiedDuringPayloadCopy() } + ${ construct(unknown, List(originalId.asTerm, 'copiedPayload.asTerm)) } + } else $next } } else { val payloadValue = wirePayload(payload, unionCase) - val copiedPayload = - '{ - UnionSerializer.copyCaseValue( - $copyContextExpr, - $caseFieldInfosExpr(${ Expr(unionCase.fieldIndex) }), - $payloadValue) - } - val coerced = decodePayload(copiedPayload, unionCase) - val current = construct(unionCase, List(coerced.asTerm)) '{ - if $valueExpr.isInstanceOf[c] then $current else $next + if $valueExpr.isInstanceOf[c] then { + val copiedPayload = + UnionSerializer.copyCaseValue( + $copyContextExpr, + $caseFieldInfosExpr(${ Expr(unionCase.fieldIndex) }), + $payloadValue) + ${ failIfCopiedDuringPayloadCopy() } + ${ construct(unionCase, List(decodePayload('copiedPayload, unionCase).asTerm)) } + } else $next } } } @@ -1164,16 +1255,19 @@ object ForySerializerMacros { val ownerClassName = owner.fullName.replace("$.", "$") val classExpr: Expr[Class[T]] = '{ Class.forName(${ Expr(ownerClassName) }).asInstanceOf[Class[T]] } - val caseClassesExpr: Expr[List[Class[_]]] = - Expr.ofList(cases.map(unionCase => - '{ Class.forName(${ Expr(ownerClassName + "$" + unionCase.symbol.name) }) })) + val caseClassesExpr: Expr[Array[Class[_]]] = { + val caseClassExprs = cases.map { unionCase => + '{ Class.forName(${ Expr(unionCase.symbol.fullName.replace("$.", "$")) }) } + } + '{ Array[Class[_]](${ Varargs(caseClassExprs) }*) } + } '{ new ForySerializer[T] { override def isUnion: Boolean = true - override def registrationClasses(cls: Class[T]): Array[Class[_]] = - (cls :: $caseClassesExpr).toArray + override private[scala] def handledRuntimeClasses(cls: Class[T]): Array[Class[_]] = + $caseClassesExpr override def createSerializer( resolver: TypeResolver, @@ -1203,20 +1297,9 @@ object ForySerializerMacros { } override def copy(copyContext: org.apache.fory.context.CopyContext, value: T): T = { - if copyContext.copyTrackingRef() then { - copyContext.markCopyInProgress(value) - try { - val copied = ${ copyDispatch('value, 'copyContext, 'caseFieldInfos) } - copyContext.reference(value, copied) - copied - } catch { - case throwable: Throwable => - copyContext.clearCopyInProgress(value) - throw throwable - } - } else { - ${ copyDispatch('value, 'copyContext, 'caseFieldInfos) } - } + val copied = ${ copyDispatch('value, 'copyContext, 'caseFieldInfos) } + copyContext.reference(value, copied) + copied } } } @@ -1245,12 +1328,12 @@ object ForySerializerMacros { symbol.annotations.exists(_.tpe <:< TypeRepr.of[A]) } - private def hasRef(using q: Quotes)(symbol: q.reflect.Symbol): Boolean = { - import q.reflect.* - symbol.annotations.exists(_.tpe.typeSymbol.fullName == "org.apache.fory.annotation.Ref") + private def refAnnotation(using q: Quotes)(symbol: q.reflect.Symbol): Option[Boolean] = { + symbol.annotations.find(isRefAnnotation).map(refAnnotationEnabled) } - private def topLevelTypeHasRef(using q: Quotes)(tpe: q.reflect.TypeRepr): Boolean = { + private def topLevelTypeRefTracking(using q: Quotes)( + tpe: q.reflect.TypeRepr): Option[Boolean] = { import q.reflect.* def peelAnnotations(tpe: TypeRepr): (TypeRepr, List[Term]) = { @@ -1268,15 +1351,12 @@ object ForySerializerMacros { } } - def isRef(annotation: Term): Boolean = - annotation.tpe.typeSymbol.fullName == "org.apache.fory.annotation.Ref" - val (base, annotations) = peelAnnotations(tpe) base.dealias match { case AppliedType(optionType, List(inner)) if optionType.typeSymbol.fullName == "scala.Option" => - peelAnnotations(inner)._2.exists(isRef) - case _ => annotations.exists(isRef) + refTrackingFromAnnotations(peelAnnotations(inner)._2) + case _ => refTrackingFromAnnotations(annotations) } } @@ -1284,30 +1364,105 @@ object ForySerializerMacros { typeId: Int, nullable: Boolean, trackingRef: Boolean, + hasTrackingRefMetadata: Boolean, + nullableWrapper: Boolean = false, rawClass: Option[Expr[Class[?]]] = None): Expr[TypeExtMeta] = { if typeId == Types.UNKNOWN && rawClass.nonEmpty then { val raw = rawClass.get '{ val resolvedTypeId = - if classOf[ForyScalaEnum].isAssignableFrom($raw) then Types.ENUM else Types.UNKNOWN - if resolvedTypeId == Types.UNKNOWN && !${ Expr(nullable) } && !${ Expr(trackingRef) } then { + if ScalaTypes.isScalaEnumType($raw) then Types.ENUM else Types.UNKNOWN + if resolvedTypeId == Types.UNKNOWN && + !${ Expr(nullable) } && + !${ Expr(hasTrackingRefMetadata) } && + !${ Expr(nullableWrapper) } then { null.asInstanceOf[TypeExtMeta] } else { - TypeExtMeta.of(resolvedTypeId, ${ Expr(nullable) }, ${ Expr(trackingRef) }) + TypeExtMeta.of( + resolvedTypeId, + ${ Expr(nullable) }, + ${ Expr(trackingRef) }, + ${ Expr(nullableWrapper) }) } } - } else if typeId == Types.UNKNOWN && !nullable && !trackingRef then { + } else if typeId == Types.UNKNOWN && !nullable && !hasTrackingRefMetadata && !nullableWrapper then { '{ null.asInstanceOf[TypeExtMeta] } } else { - '{ TypeExtMeta.of(${ Expr(typeId) }, ${ Expr(nullable) }, ${ Expr(trackingRef) }) } + '{ + TypeExtMeta.of( + ${ Expr(typeId) }, + ${ Expr(nullable) }, + ${ Expr(trackingRef) }, + ${ Expr(nullableWrapper) }) + } } } private def isScalaEnumType(using q: Quotes)(tpe: q.reflect.TypeRepr): Boolean = { import q.reflect.* tpe.typeSymbol.flags.is(Flags.Enum) || - tpe <:< TypeRepr.of[ForyScalaEnum] || - tpe.baseClasses.exists(_.fullName == "org.apache.fory.scala.ForyScalaEnum") + tpe.baseClasses.exists(_.fullName == "scala.reflect.Enum") + } + + private def hasNestedCompatibleStruct(using q: Quotes)(tpe: q.reflect.TypeRepr): Boolean = { + import q.reflect.* + + def peel(tpe: TypeRepr): TypeRepr = { + tpe match { + case AnnotatedType(underlying, _) => peel(underlying) + case other => + other.dealias match { + case AnnotatedType(underlying, _) => peel(underlying) + case dealiased => dealiased + } + } + } + + def evolutionValue(annotation: Term): Option[String] = { + annotation match { + case Apply(_, args) => + args.collectFirst { + case NamedArg("evolution", Select(_, name)) => name + case NamedArg("evolution", term) => + val rendered = term.show + if rendered.endsWith(".ENABLED") then "ENABLED" + else if rendered.endsWith(".DISABLED") then "DISABLED" + else "INHERIT" + } + case _ => None + } + } + + def evolvingValue(annotation: Term): Option[Boolean] = { + annotation match { + case Apply(_, args) => + args.collectFirst { + case NamedArg("evolving", Literal(BooleanConstant(value))) => value + } + case _ => None + } + } + + def compatibleStruct(symbol: Symbol): Boolean = { + symbol.annotations.find(_.tpe <:< TypeRepr.of[ForyStruct]) match { + case Some(annotation) => + val evolution = evolutionValue(annotation).getOrElse("INHERIT") + if evolution == "DISABLED" then false + else evolvingValue(annotation).getOrElse(true) || evolution == "ENABLED" + case None => false + } + } + + def loop(tpe: TypeRepr): Boolean = { + val base = peel(tpe.widen) + compatibleStruct(base.typeSymbol) || + (base match { + case AppliedType(_, args) => args.exists(loop) + case _ => false + }) + } + + loop(tpe) } private def typeIdForAnnotation(using q: Quotes)(annotation: q.reflect.Term): Option[Int] = { @@ -1353,4 +1508,21 @@ object ForySerializerMacros { private def isRefAnnotation(using q: Quotes)(annotation: q.reflect.Term): Boolean = { annotation.tpe.typeSymbol.fullName == "org.apache.fory.annotation.Ref" } + + private def refTrackingFromAnnotations(using q: Quotes)( + annotations: Iterable[q.reflect.Term]): Option[Boolean] = { + annotations.find(isRefAnnotation).map(refAnnotationEnabled) + } + + private def refAnnotationEnabled(using q: Quotes)(annotation: q.reflect.Term): Boolean = { + import q.reflect.* + annotation match { + case Apply(_, args) => + args.collectFirst { + case NamedArg("enable", Literal(BooleanConstant(value))) => value + case Literal(BooleanConstant(value)) => value + }.getOrElse(true) + case _ => true + } + } } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala index bfd51ce613..49edf89d52 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala @@ -35,19 +35,28 @@ abstract class AbstractScalaXlangCollectionSerializer[A, T <: scala.collection.I override def onCollectionWrite(writeContext: WriteContext, value: T): util.Collection[_] = { writeContext.getBuffer.writeVarUInt32Small7(value.size) - new XlangCollectionAdapter[A](value) + if (ScalaXlangCollectionShape.hasOptionElement(writeContext)) { + new XlangOptionCollectionAdapter[A](value) + } else { + new XlangCollectionAdapter[A](value) + } } override def newCollection(readContext: ReadContext): util.Collection[_] = { val numElements = readCollectionSize(readContext.getBuffer) setNumElements(numElements) - new XlangCollectionBuilder[A, T](newBuilder(numElements)) + val builder = newBuilder(numElements) + if (ScalaXlangCollectionShape.hasOptionElement(readContext)) { + new XlangOptionCollectionBuilder[A, T](builder) + } else { + new XlangCollectionBuilder[A, T](builder) + } } protected def newBuilder(numElements: Int): mutable.Builder[A, T] override def onCollectionRead(collection: util.Collection[_]): T = { - collection.asInstanceOf[XlangCollectionBuilder[A, T]].builder.result() + collection.asInstanceOf[XlangBuilderResult[T]].result() } override def copy(copyContext: CopyContext, value: T): T = { @@ -133,9 +142,7 @@ class ScalaXlangSeqSerializer[A, T <: scala.collection.Seq[A]]( cls: Class[T]) extends AbstractScalaXlangCollectionSerializer[A, T](typeResolver, cls) { override protected def newBuilder(numElements: Int): mutable.Builder[A, T] = { - val builder = simmutable.List.newBuilder[A] - builder.sizeHint(numElements) - builder.asInstanceOf[mutable.Builder[A, T]] + ScalaXlangCollectionShape.seqBuilder[A, T](cls, numElements) } } @@ -145,9 +152,7 @@ class ScalaXlangSetSerializer[A, T <: scala.collection.Set[A]]( cls: Class[T]) extends AbstractScalaXlangCollectionSerializer[A, T](typeResolver, cls) { override protected def newBuilder(numElements: Int): mutable.Builder[A, T] = { - val builder = simmutable.Set.newBuilder[A] - builder.sizeHint(numElements) - builder.asInstanceOf[mutable.Builder[A, T]] + ScalaXlangCollectionShape.setBuilder[A, T](cls, numElements) } } @@ -157,11 +162,122 @@ class ScalaXlangCollectionSerializer[A, T <: scala.collection.Iterable[A]]( cls: Class[T]) extends AbstractScalaXlangCollectionSerializer[A, T](typeResolver, cls) { override protected def newBuilder(numElements: Int): mutable.Builder[A, T] = { - val builder = simmutable.List.newBuilder[A] - builder.sizeHint(numElements) + ScalaXlangCollectionShape.iterableBuilder[A, T](cls, numElements) + } + +} + +private object ScalaXlangCollectionShape { + def hasOptionElement(writeContext: WriteContext): Boolean = { + val genericType = writeContext.getGenerics.nextGenericType(writeContext.getDepth) + genericType != null && isExplicitNullable(genericType.getTypeParameter0) + } + + def hasOptionElement(readContext: ReadContext): Boolean = { + val genericType = readContext.getGenerics.nextGenericType(readContext.getDepth) + genericType != null && isExplicitNullable(genericType.getTypeParameter0) + } + + def hasOptionKey(writeContext: WriteContext): Boolean = { + val genericType = writeContext.getGenerics.nextGenericType(writeContext.getDepth) + genericType != null && + genericType.getTypeParametersCount >= 2 && + isExplicitNullable(genericType.getTypeParameter0) + } + + def hasOptionValue(writeContext: WriteContext): Boolean = { + val genericType = writeContext.getGenerics.nextGenericType(writeContext.getDepth) + genericType != null && + genericType.getTypeParametersCount >= 2 && + isExplicitNullable(genericType.getTypeParameter1) + } + + def hasOptionKey(readContext: ReadContext): Boolean = { + val genericType = readContext.getGenerics.nextGenericType(readContext.getDepth) + genericType != null && + genericType.getTypeParametersCount >= 2 && + isExplicitNullable(genericType.getTypeParameter0) + } + + def hasOptionValue(readContext: ReadContext): Boolean = { + val genericType = readContext.getGenerics.nextGenericType(readContext.getDepth) + genericType != null && + genericType.getTypeParametersCount >= 2 && + isExplicitNullable(genericType.getTypeParameter1) + } + + def seqBuilder[A, T](declared: Class[_], size: Int): mutable.Builder[A, T] = { + val builder = + if (accepts(declared, classOf[mutable.ArrayBuffer[_]])) { + mutable.ArrayBuffer.newBuilder[A] + } else if (accepts(declared, classOf[simmutable.Vector[_]])) { + simmutable.Vector.newBuilder[A] + } else if (accepts(declared, classOf[simmutable.List[_]])) { + simmutable.List.newBuilder[A] + } else { + unsupported("sequence", declared) + } + builder.sizeHint(size) + builder.asInstanceOf[mutable.Builder[A, T]] + } + + def iterableBuilder[A, T](declared: Class[_], size: Int): mutable.Builder[A, T] = { + val builder = + if (accepts(declared, classOf[mutable.ArrayBuffer[_]])) { + mutable.ArrayBuffer.newBuilder[A] + } else if (accepts(declared, classOf[simmutable.List[_]])) { + simmutable.List.newBuilder[A] + } else { + unsupported("iterable", declared) + } + builder.sizeHint(size) builder.asInstanceOf[mutable.Builder[A, T]] } + def setBuilder[A, T](declared: Class[_], size: Int): mutable.Builder[A, T] = { + val builder = + if (accepts(declared, classOf[mutable.HashSet[_]])) { + mutable.HashSet.newBuilder[A] + } else if (accepts(declared, classOf[simmutable.HashSet[_]])) { + simmutable.HashSet.newBuilder[A] + } else if (accepts(declared, classOf[simmutable.Set[_]])) { + simmutable.Set.newBuilder[A] + } else { + unsupported("set", declared) + } + builder.sizeHint(size) + builder.asInstanceOf[mutable.Builder[A, T]] + } + + def mapBuilder[K, V, T](declared: Class[_], size: Int): mutable.Builder[(K, V), T] = { + val builder = + if (accepts(declared, classOf[mutable.HashMap[_, _]])) { + mutable.HashMap.newBuilder[K, V] + } else if (accepts(declared, classOf[simmutable.HashMap[_, _]])) { + simmutable.HashMap.newBuilder[K, V] + } else if (accepts(declared, classOf[simmutable.Map[_, _]])) { + simmutable.Map.newBuilder[K, V] + } else { + unsupported("map", declared) + } + builder.sizeHint(size) + builder.asInstanceOf[mutable.Builder[(K, V), T]] + } + + private def isExplicitNullable(genericType: org.apache.fory.`type`.GenericType): Boolean = + genericType != null && + genericType.getTypeRef.getTypeExtMeta != null && + genericType.getTypeRef.getTypeExtMeta.nullableWrapper() + + private def accepts(declared: Class[_], result: Class[_]): Boolean = + declared.isAssignableFrom(result) + + private def unsupported(kind: String, declared: Class[_]): Nothing = { + throw new IllegalArgumentException( + "Scala xlang " + kind + " serializer cannot rebuild declared type " + + declared.getName + + ". Use a supported immutable collection type or a mutable collection interface.") + } } private final class XlangCollectionAdapter[A](coll: scala.collection.Iterable[A]) @@ -177,8 +293,29 @@ private final class XlangCollectionAdapter[A](coll: scala.collection.Iterable[A] override def size(): Int = coll.size } +private final class XlangOptionCollectionAdapter[A](coll: scala.collection.Iterable[A]) + extends util.AbstractCollection[Any] { + override def iterator(): util.Iterator[Any] = new util.Iterator[Any] { + private val it = coll.iterator + + override def hasNext: Boolean = it.hasNext + + override def next(): Any = { + val value = it.next() + if (value == null) null else value.asInstanceOf[Option[_]].getOrElse(null) + } + } + + override def size(): Int = coll.size +} + +private trait XlangBuilderResult[T] { + def result(): T +} + private final class XlangCollectionBuilder[A, T](val builder: mutable.Builder[A, T]) - extends util.AbstractCollection[A] { + extends util.AbstractCollection[A] + with XlangBuilderResult[T] { override def add(e: A): Boolean = { builder.addOne(e) true @@ -189,6 +326,25 @@ private final class XlangCollectionBuilder[A, T](val builder: mutable.Builder[A, override def size(): Int = throw new UnsupportedOperationException("Scala xlang collection builder is write-only") + + override def result(): T = builder.result() +} + +private final class XlangOptionCollectionBuilder[A, T](val builder: mutable.Builder[A, T]) + extends util.AbstractCollection[Any] + with XlangBuilderResult[T] { + override def add(e: Any): Boolean = { + builder.addOne(Option(e).asInstanceOf[A]) + true + } + + override def iterator(): util.Iterator[Any] = + throw new UnsupportedOperationException("Scala xlang collection builder is write-only") + + override def size(): Int = + throw new UnsupportedOperationException("Scala xlang collection builder is write-only") + + override def result(): T = builder.result() } abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K, V]]( @@ -198,19 +354,30 @@ abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K override def onMapWrite(writeContext: WriteContext, value: T): util.Map[_, _] = { writeContext.getBuffer.writeVarUInt32Small7(value.size) - new XlangMapAdapter[K, V](value) + val optionKey = ScalaXlangCollectionShape.hasOptionKey(writeContext) + val optionValue = ScalaXlangCollectionShape.hasOptionValue(writeContext) + if (optionKey || optionValue) { + new XlangOptionMapAdapter[K, V](value, optionKey, optionValue) + } else { + new XlangMapAdapter[K, V](value) + } } override def newMap(readContext: ReadContext): util.Map[_, _] = { val numElements = readMapSize(readContext.getBuffer) setNumElements(numElements) - val builder = simmutable.Map.newBuilder[K, V] - builder.sizeHint(numElements) - new XlangMapBuilder[K, V, T](builder.asInstanceOf[mutable.Builder[(K, V), T]]) + val builder = ScalaXlangCollectionShape.mapBuilder[K, V, T](cls, numElements) + val optionKey = ScalaXlangCollectionShape.hasOptionKey(readContext) + val optionValue = ScalaXlangCollectionShape.hasOptionValue(readContext) + if (optionKey || optionValue) { + new XlangOptionMapBuilder[K, V, T](builder, optionKey, optionValue) + } else { + new XlangMapBuilder[K, V, T](builder) + } } override def onMapRead(map: util.Map[_, _]): T = { - map.asInstanceOf[XlangMapBuilder[K, V, T]].builder.result() + map.asInstanceOf[XlangBuilderResult[T]].result() } override def onMapCopy(map: util.Map[_, _]): T = onMapRead(map) @@ -229,9 +396,7 @@ abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K copyWithBuilder(copyContext, value, value.mapFactory.newBuilder[K, V]) } } else { - val builder = simmutable.Map.newBuilder[K, V] - builder.sizeHint(value.size) - copyWithBuilder(copyContext, value, builder) + copyWithBuilder(copyContext, value, ScalaXlangCollectionShape.mapBuilder[K, V, T](cls, value.size)) } } @@ -299,8 +464,38 @@ private final class XlangMapAdapter[K, V](map: scala.collection.Map[K, V]) } } +private final class XlangOptionMapAdapter[K, V]( + map: scala.collection.Map[K, V], + optionKey: Boolean, + optionValue: Boolean) + extends util.AbstractMap[Any, Any] { + override def entrySet(): util.Set[util.Map.Entry[Any, Any]] = + new util.AbstractSet[util.Map.Entry[Any, Any]] { + override def size(): Int = map.size + + override def iterator(): util.Iterator[util.Map.Entry[Any, Any]] = + new util.Iterator[util.Map.Entry[Any, Any]] { + private val it = map.iterator + + override def hasNext: Boolean = it.hasNext + + override def next(): util.Map.Entry[Any, Any] = { + val entry = it.next() + new org.apache.fory.collection.MapEntry[Any, Any]( + unwrap(entry._1, optionKey), + unwrap(entry._2, optionValue)) + } + } + } + + private def unwrap(value: Any, option: Boolean): Any = { + if (option && value != null) value.asInstanceOf[Option[_]].getOrElse(null) else value + } +} + private final class XlangMapBuilder[K, V, T](val builder: mutable.Builder[(K, V), T]) - extends util.AbstractMap[K, V] { + extends util.AbstractMap[K, V] + with XlangBuilderResult[T] { override def entrySet(): util.Set[util.Map.Entry[K, V]] = throw new UnsupportedOperationException("Scala xlang map builder is write-only") @@ -308,6 +503,28 @@ private final class XlangMapBuilder[K, V, T](val builder: mutable.Builder[(K, V) builder.addOne((key, value)) value } + + override def result(): T = builder.result() +} + +private final class XlangOptionMapBuilder[K, V, T]( + val builder: mutable.Builder[(K, V), T], + optionKey: Boolean, + optionValue: Boolean) + extends util.AbstractMap[Any, Any] + with XlangBuilderResult[T] { + override def entrySet(): util.Set[util.Map.Entry[Any, Any]] = + throw new UnsupportedOperationException("Scala xlang map builder is write-only") + + override def put(key: Any, value: Any): Any = { + builder.addOne((wrap(key, optionKey).asInstanceOf[K], wrap(value, optionValue).asInstanceOf[V])) + value + } + + private def wrap(value: Any, option: Boolean): Any = + if (option) Option(value) else value + + override def result(): T = builder.result() } final class ScalaOptionSerializer(typeResolver: TypeResolver, cls: Class[_]) diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala index 1cfcbe6cbb..ad7ef8b821 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala @@ -20,13 +20,24 @@ package org.apache.fory.serializer.scala import org.apache.fory.Fory -import org.apache.fory.annotation.{ForyCase, ForyField, ForyStruct, ForyUnion, Ref} +import org.apache.fory.annotation.{ + ForyCase, + ForyField, + ForyStruct, + ForyUnion, + Ref, + UInt64Type, + UInt8Type +} +import org.apache.fory.config.Int64Encoding +import org.apache.fory.meta.TypeDef import org.apache.fory.scala.ForySerializer -import org.apache.fory.serializer.StaticGeneratedStructSerializerFactory -import org.apache.fory.`type`.TypeUtils +import org.apache.fory.serializer.StaticGeneratedStructSerializer +import org.apache.fory.`type`.{Types, TypeUtils} import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec +import scala.compiletime.testing.typeCheckErrors import scala.jdk.CollectionConverters._ object ForySerializerDerivationTest { @@ -47,6 +58,25 @@ object ForySerializerDerivationTest { @ForyField(id = 3) scores: Map[String, Int]) derives ForySerializer + @ForyStruct + final case class OptionalCollectionBox( + @ForyField(id = 1) names: List[Option[String]], + @ForyField(id = 2) unsigned: List[Option[Int @UInt8Type]], + @ForyField(id = 3) + scores: Map[String, Option[Long @UInt64Type(encoding = Int64Encoding.TAGGED)]], + @ForyField(id = 4) keyed: Map[Option[String], Int]) + derives ForySerializer + + @ForyStruct + final case class OptionalCollectionBoxWriter( + @ForyField(id = 1) names: List[Option[String]], + @ForyField(id = 2) unsigned: List[Option[Int @UInt8Type]], + @ForyField(id = 3) + scores: Map[String, Option[Long @UInt64Type(encoding = Int64Encoding.TAGGED)]], + @ForyField(id = 4) keyed: Map[Option[String], Int], + @ForyField(id = 5) extra: String) + derives ForySerializer + @ForyStruct final case class CopyBox( @ForyField(id = 1) user: SearchUser, @@ -54,6 +84,13 @@ object ForySerializerDerivationTest { @ForyField(id = 3) values: Array[Int]) derives ForySerializer + @ForyStruct + final case class RefMetadataBox( + @ForyField(id = 1) peer: SearchUser, + @ForyField(id = 2) localOnly: SearchUser @Ref(enable = false), + @ForyField(id = 3) shared: SearchUser @Ref) + derives ForySerializer + @ForyStruct final class RefNode() derives ForySerializer { @ForyField(id = 1) @@ -93,6 +130,9 @@ object ForySerializerDerivationTest { @ForyCase(id = 3) case OptionalUserCase(value: Option[SearchUser]) + + @ForyCase(id = 4) + case OptionalTaggedCase(value: Option[Long @UInt64Type(encoding = Int64Encoding.TAGGED)]) } @ForyUnion @@ -117,7 +157,13 @@ object ForySerializerDerivationTest { ForySerializer.register(fory, classOf[Person], "scala_test", "Person") ForySerializer.register(fory, classOf[SearchUser], "scala_test", "SearchUser") ForySerializer.register(fory, classOf[CollectionBox], "scala_test", "CollectionBox") + ForySerializer.register( + fory, + classOf[OptionalCollectionBox], + "scala_test", + "OptionalCollectionBox") ForySerializer.register(fory, classOf[CopyBox], "scala_test", "CopyBox") + ForySerializer.register(fory, classOf[RefMetadataBox], "scala_test", "RefMetadataBox") ForySerializer.register(fory, classOf[RefNode], "scala_test", "RefNode") ForySerializer.register(fory, classOf[UnionRefNode], "scala_test", "UnionRefNode") ForySerializer.register(fory, classOf[MixedRecord], "scala_test", "MixedRecord") @@ -125,6 +171,20 @@ object ForySerializerDerivationTest { ForySerializer.register(fory, classOf[UnionCycle], "scala_test", "UnionCycle") fory } + + def compatibleXlangFory(): Fory = { + val fory = Fory.builder() + .withXlang(true) + .withCompatible(true) + .withRefTracking(true) + .withRefCopy(true) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .suppressClassRegistrationWarnings(false) + .build() + ScalaSerializers.registerSerializers(fory) + fory + } } class ForySerializerDerivationTest extends AnyWordSpec with Matchers { @@ -151,6 +211,94 @@ class ForySerializerDerivationTest extends AnyWordSpec with Matchers { val fory = xlangFory() val box = CollectionBox(List("a", "b"), Set("x", "y"), Map("a" -> 1, "b" -> 2)) fory.deserialize(fory.serialize(box)) shouldEqual box + + val serializer = + summon[ForySerializer[CollectionBox]] + .createSerializer(fory.getTypeResolver) + .asInstanceOf[StaticGeneratedStructSerializer[CollectionBox]] + val tags = serializer.getGeneratedDescriptors.asScala.find(_.getName == "tags").get + val tagMeta = TypeUtils.getElementType(tags.getTypeRef).getTypeExtMeta + if tagMeta != null then { + tagMeta.nullableWrapper() shouldBe false + } + } + + "serialize derived Scala collection fields with Option elements" in { + val fory = xlangFory() + val box = OptionalCollectionBox( + List(Some("a"), None), + List(Some(1), None), + Map("a" -> Some(9L), "b" -> None), + Map(Some("a") -> 1, None -> 2)) + + fory.deserialize(fory.serialize(box)) shouldEqual box + } + + "preserve Option collection wrappers on compatible remote reads" in { + val writerFory = ForySerializerDerivationTest.compatibleXlangFory() + ForySerializer.register( + writerFory, + classOf[OptionalCollectionBoxWriter], + "scala_test", + "OptionalCollectionBox") + val readerFory = ForySerializerDerivationTest.compatibleXlangFory() + ForySerializer.register( + readerFory, + classOf[OptionalCollectionBox], + "scala_test", + "OptionalCollectionBox") + + val writerValue = OptionalCollectionBoxWriter( + List(Some("a"), None), + List(Some(1), None), + Map("a" -> Some(9L), "b" -> None), + Map(Some("a") -> 1, None -> 2), + "ignored") + val readerValue = + readerFory.deserialize(writerFory.serialize(writerValue)).asInstanceOf[OptionalCollectionBox] + + readerValue shouldEqual OptionalCollectionBox( + List(Some("a"), None), + List(Some(1), None), + Map("a" -> Some(9L), "b" -> None), + Map(Some("a") -> 1, None -> 2)) + } + + "emit inner nullable metadata for Option collection elements" in { + val fory = xlangFory() + val serializer = + summon[ForySerializer[OptionalCollectionBox]] + .createSerializer(fory.getTypeResolver) + .asInstanceOf[StaticGeneratedStructSerializer[OptionalCollectionBox]] + val descriptors = serializer.getGeneratedDescriptors.asScala + val names = descriptors.find(_.getName == "names").get + val unsigned = descriptors.find(_.getName == "unsigned").get + val scores = descriptors.find(_.getName == "scores").get + val keyed = descriptors.find(_.getName == "keyed").get + + val nameElement = TypeUtils.getElementType(names.getTypeRef) + nameElement.getRawType shouldBe classOf[String] + nameElement.getTypeExtMeta.nullable() shouldBe true + nameElement.getTypeExtMeta.nullableWrapper() shouldBe true + nameElement.getTypeExtMeta.typeId() shouldBe Types.STRING + + val unsignedElement = TypeUtils.getElementType(unsigned.getTypeRef) + unsignedElement.getRawType shouldBe classOf[java.lang.Integer] + unsignedElement.getTypeExtMeta.nullable() shouldBe true + unsignedElement.getTypeExtMeta.nullableWrapper() shouldBe true + unsignedElement.getTypeExtMeta.typeId() shouldBe Types.UINT8 + + val mapValue = TypeUtils.getMapKeyValueType(scores.getTypeRef).f1 + mapValue.getRawType shouldBe classOf[java.lang.Long] + mapValue.getTypeExtMeta.nullable() shouldBe true + mapValue.getTypeExtMeta.nullableWrapper() shouldBe true + mapValue.getTypeExtMeta.typeId() shouldBe Types.TAGGED_UINT64 + + val mapKey = TypeUtils.getMapKeyValueType(keyed.getTypeRef).f0 + mapKey.getRawType shouldBe classOf[String] + mapKey.getTypeExtMeta.nullable() shouldBe true + mapKey.getTypeExtMeta.nullableWrapper() shouldBe true + mapKey.getTypeExtMeta.typeId() shouldBe Types.STRING } "serialize mixed constructor and mutable field classes" in { @@ -163,10 +311,12 @@ class ForySerializerDerivationTest extends AnyWordSpec with Matchers { } "preserve nested reference metadata in generated descriptors" in { - val factory = + val fory = xlangFory() + val serializer = summon[ForySerializer[RefNode]] - .asInstanceOf[StaticGeneratedStructSerializerFactory[RefNode]] - val descriptors = factory.getGeneratedDescriptors.asScala + .createSerializer(fory.getTypeResolver) + .asInstanceOf[StaticGeneratedStructSerializer[RefNode]] + val descriptors = serializer.getGeneratedDescriptors.asScala val children = descriptors.find(_.getName == "children").get val parent = descriptors.find(_.getName == "parent").get @@ -176,6 +326,66 @@ class ForySerializerDerivationTest extends AnyWordSpec with Matchers { parent.isTrackingRef shouldBe true } + "preserve explicit top-level ref metadata presence in generated descriptors" in { + val fory = xlangFory() + val serializer = + summon[ForySerializer[RefMetadataBox]] + .createSerializer(fory.getTypeResolver) + .asInstanceOf[StaticGeneratedStructSerializer[RefMetadataBox]] + val descriptors = serializer.getGeneratedDescriptors.asScala + val peer = descriptors.find(_.getName == "peer").get + val localOnly = descriptors.find(_.getName == "localOnly").get + val shared = descriptors.find(_.getName == "shared").get + + peer.hasTrackingRefMetadata shouldBe false + peer.isTrackingRef shouldBe false + localOnly.hasTrackingRefMetadata shouldBe true + localOnly.isTrackingRef shouldBe false + shared.hasTrackingRefMetadata shouldBe true + shared.isTrackingRef shouldBe true + } + + "disable same-schema compatible fast path for nested compatible structs" in { + def sameSchemaCompatible(serializer: AnyRef): Boolean = { + val field = serializer.getClass.getDeclaredField("sameSchemaCompatible") + field.setAccessible(true) + field.getBoolean(serializer) + } + + val fory = xlangFory() + val resolver = fory.getTypeResolver + val personFactory = summon[ForySerializer[Person]] + val copyBoxFactory = summon[ForySerializer[CopyBox]] + + val personSerializer = + personFactory.createSerializer(resolver, TypeDef.buildTypeDef(resolver, classOf[Person])) + val copyBoxSerializer = + copyBoxFactory.createSerializer(resolver, TypeDef.buildTypeDef(resolver, classOf[CopyBox])) + + sameSchemaCompatible(personSerializer.asInstanceOf[AnyRef]) shouldBe true + sameSchemaCompatible(copyBoxSerializer.asInstanceOf[AnyRef]) shouldBe false + } + + "reject union enum cases without ForyCase metadata" in { + val errors = typeCheckErrors(""" + import org.apache.fory.annotation.{ForyCase, ForyStruct, ForyUnion} + import org.apache.fory.scala.ForySerializer + + @ForyStruct + final case class MissingCaseUser(name: String) derives ForySerializer + + @ForyUnion + enum MissingCaseUnion derives ForySerializer { + @ForyCase(id = 0) + case UnknownCase(caseId: Int, value: Any) + + case UserCase(value: MissingCaseUser) + } + """) + + errors.exists(_.message.contains("must be annotated with @ForyCase")) shouldBe true + } + "serialize derived union unknown cases with original ids" in { val fory = xlangFory() val unknown = SearchTarget.UnknownCase(99, SearchUser("Future")) @@ -187,18 +397,25 @@ class ForySerializerDerivationTest extends AnyWordSpec with Matchers { val some: SearchTarget.OptionalUserCase = SearchTarget.OptionalUserCase(Some(SearchUser("Ada"))) val none: SearchTarget.OptionalUserCase = SearchTarget.OptionalUserCase(None) + val taggedSome = SearchTarget.OptionalTaggedCase(Some(99L)) + val taggedNone = SearchTarget.OptionalTaggedCase(None) fory.deserialize(fory.serialize(some)) shouldEqual some fory.deserialize(fory.serialize(none)) shouldEqual none + fory.deserialize(fory.serialize(taggedSome)) shouldEqual taggedSome + fory.deserialize(fory.serialize(taggedNone)) shouldEqual taggedNone val copiedSome = fory.copy(some).asInstanceOf[SearchTarget.OptionalUserCase] val copiedNone = fory.copy(none).asInstanceOf[SearchTarget.OptionalUserCase] + val copiedTagged = fory.copy(taggedSome).asInstanceOf[SearchTarget.OptionalTaggedCase] copiedSome shouldEqual some copiedSome should not be theSameInstanceAs(some) copiedSome.value.get should not be theSameInstanceAs(some.value.get) copiedNone shouldEqual none copiedNone should not be theSameInstanceAs(none) + copiedTagged shouldEqual taggedSome + copiedTagged should not be theSameInstanceAs(taggedSome) } "copy derived case classes through field serializers" in { diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaEnumTest.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaEnumTest.scala index 4ff57e12d0..74421d3970 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaEnumTest.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaEnumTest.scala @@ -20,12 +20,27 @@ package org.apache.fory.serializer.scala import org.apache.fory.Fory +import org.apache.fory.annotation.ForyEnumId import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec object ScalaEnumTest { enum ColorEnum { case Red, Green, Blue } + enum StableColorV1 { + @ForyEnumId(7) + case Red + @ForyEnumId(3) + case Green + } + + enum StableColorV2 { + @ForyEnumId(3) + case Green + @ForyEnumId(7) + case Red + } + case class Colors(set: Set[ColorEnum]) } @@ -48,5 +63,23 @@ class ScalaEnumTest extends AnyWordSpec with Matchers { val bytes = fory.serialize(colors) fory.deserialize(bytes) shouldEqual colors } + "use case-level ForyEnumId metadata for stable xlang enum tags" in { + val writer = Fory.builder() + .withXlang(true) + .withRefTracking(false) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + val reader = Fory.builder() + .withXlang(true) + .withRefTracking(false) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(true) + .build() + ScalaSerializers.registerEnum(writer, classOf[StableColorV1], "scala_test", "StableColor") + ScalaSerializers.registerEnum(reader, classOf[StableColorV2], "scala_test", "StableColor") + + reader.deserialize(writer.serialize(StableColorV1.Green)) shouldBe StableColorV2.Green + } } } diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala index 60689af816..4e22e11957 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala @@ -23,6 +23,7 @@ import org.apache.fory.Fory import org.apache.fory.annotation.{ ArrayType, ForyCase, + ForyEnumId, ForyField, ForyStruct, ForyUnion, @@ -41,7 +42,7 @@ import org.apache.fory.context.{ReadContext, WriteContext} import org.apache.fory.memory.{MemoryBuffer, MemoryUtils} import org.apache.fory.meta.MetaCompressor import org.apache.fory.resolver.TypeResolver -import org.apache.fory.scala.{ForyScalaEnum, ForySerializer} +import org.apache.fory.scala.ForySerializer import org.apache.fory.serializer.Serializer import org.apache.fory.`type`.{BFloat16, Float16} import org.apache.fory.`type`.union.Union2 @@ -54,21 +55,24 @@ import java.time.{Instant, LocalDate} import java.util import scala.jdk.CollectionConverters.* -enum Color(val foryId: Int) extends ForyScalaEnum { - case Green extends Color(0) - case Red extends Color(1) - case Blue extends Color(2) - case White extends Color(3) - - override def getForyId(): Int = foryId +enum Color { + @ForyEnumId(0) + case Green + @ForyEnumId(1) + case Red + @ForyEnumId(2) + case Blue + @ForyEnumId(3) + case White } -enum TestEnum(val foryId: Int) extends ForyScalaEnum { - case VALUE_A extends TestEnum(0) - case VALUE_B extends TestEnum(1) - case VALUE_C extends TestEnum(2) - - override def getForyId(): Int = foryId +enum TestEnum { + @ForyEnumId(0) + case VALUE_A + @ForyEnumId(1) + case VALUE_B + @ForyEnumId(2) + case VALUE_C } @ForyStruct @@ -254,9 +258,9 @@ final case class RefOverrideElement(id: Int, name: String) derives ForySerialize @ForyStruct final case class RefOverrideContainer( - listField: util.List[RefOverrideElement @Ref(enable = false)], - setField: util.Set[RefOverrideElement @Ref(enable = false)], - mapField: util.Map[String, RefOverrideElement @Ref(enable = false)]) + listField: util.List[RefOverrideElement @Ref], + setField: util.Set[RefOverrideElement @Ref], + mapField: util.Map[String, RefOverrideElement @Ref]) derives ForySerializer @ForyStruct From acfe279b08e0c5e60c3ed9f2709612a3c97b87a6 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 15 May 2026 17:19:19 +0800 Subject: [PATCH 8/9] style: format scala xlang idl changes --- compiler/fory_compiler/generators/scala.py | 8 ++++++-- .../fory_compiler/tests/test_scala_generator.py | 13 ++++++++++--- .../main/java/org/apache/fory/meta/FieldTypes.java | 14 +++++--------- .../org/apache/fory/resolver/TypeResolver.java | 10 ++++------ .../main/java/org/apache/fory/type/Descriptor.java | 12 ++++-------- 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/compiler/fory_compiler/generators/scala.py b/compiler/fory_compiler/generators/scala.py index 39e3858919..59a859be98 100644 --- a/compiler/fory_compiler/generators/scala.py +++ b/compiler/fory_compiler/generators/scala.py @@ -793,7 +793,9 @@ def generate_registration_file(self) -> GeneratedFile: registrations = self.registration_order() for type_def, owner_path in registrations: if isinstance(type_def, Message): - self.generate_type_registration(lines, type_def, owner_path, type_only=True) + self.generate_type_registration( + lines, type_def, owner_path, type_only=True + ) for type_def, owner_path in registrations: if isinstance(type_def, Message): self.generate_serializer_registration(lines, type_def, owner_path) @@ -948,7 +950,9 @@ def generate_serializer_registration( self, lines: List[str], type_def, owner_path: Optional[str] = None ) -> None: class_ref = f"{owner_path}.{type_def.name}" if owner_path else type_def.name - lines.append(f" ForySerializer.registerSerializer(fory, classOf[{class_ref}])") + lines.append( + f" ForySerializer.registerSerializer(fory, classOf[{class_ref}])" + ) def safe_identifier(self, name: str) -> str: return f"`{name}`" if name in self.RESERVED else name diff --git a/compiler/fory_compiler/tests/test_scala_generator.py b/compiler/fory_compiler/tests/test_scala_generator.py index 7266e65038..2831824121 100644 --- a/compiler/fory_compiler/tests/test_scala_generator.py +++ b/compiler/fory_compiler/tests/test_scala_generator.py @@ -368,7 +368,10 @@ def test_scala_generator_uses_jvm_nested_names_for_name_registration(): 'ForySerializer.registerType(fory, classOf[Envelope.Payload], "demo.Envelope", "Payload")' in registration ) - assert "ForySerializer.registerSerializer(fory, classOf[Envelope.Payload])" in registration + assert ( + "ForySerializer.registerSerializer(fory, classOf[Envelope.Payload])" + in registration + ) assert ( 'ScalaSerializers.registerEnum(fory, classOf[Envelope.Kind], "demo.Envelope", "Kind")' in registration @@ -398,8 +401,12 @@ def test_scala_generator_pre_registers_message_type_graph_before_serializers(): registration = files["graph/GraphForyRegistration.scala"] node_type = registration.index("ForySerializer.registerType(fory, classOf[Node]") edge_type = registration.index("ForySerializer.registerType(fory, classOf[Edge]") - node_serializer = registration.index("ForySerializer.registerSerializer(fory, classOf[Node])") - edge_serializer = registration.index("ForySerializer.registerSerializer(fory, classOf[Edge])") + node_serializer = registration.index( + "ForySerializer.registerSerializer(fory, classOf[Node])" + ) + edge_serializer = registration.index( + "ForySerializer.registerSerializer(fory, classOf[Edge])" + ) assert node_type < node_serializer assert edge_type < node_serializer assert node_type < edge_serializer diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java index bcadd42b77..c4d1779911 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java @@ -674,13 +674,11 @@ public TypeRef toTypeToken(TypeResolver resolver, TypeRef declared) { if (declared != null) { Class declaredRaw = declared.getRawType(); if (declaredRaw.isArray()) { - return TypeRef.of( - declaredRaw, typeExtMeta(typeId, nullable, trackingRef, declared)); + return TypeRef.of(declaredRaw, typeExtMeta(typeId, nullable, trackingRef, declared)); } Class listClass = getPrimitiveListClass(internalTypeId); if (listClass != null && listClass.isAssignableFrom(declaredRaw)) { - return TypeRef.of( - declaredRaw, typeExtMeta(typeId, nullable, trackingRef, declared)); + return TypeRef.of(declaredRaw, typeExtMeta(typeId, nullable, trackingRef, declared)); } } cls = getPrimitiveArrayClass(internalTypeId); @@ -1079,7 +1077,8 @@ public EnumFieldType(boolean nullable, int typeId, int userTypeId) { @Override public TypeRef toTypeToken(TypeResolver classResolver, TypeRef declared) { if (declared != null) { - return TypeRef.of(declared.getRawType(), typeExtMeta(Types.ENUM, nullable, false, declared)); + return TypeRef.of( + declared.getRawType(), typeExtMeta(Types.ENUM, nullable, false, declared)); } return TypeRef.of(UnknownClass.UnknownEnum.class); } @@ -1224,10 +1223,7 @@ private static TypeExtMeta typeExtMeta( int typeId, boolean nullable, boolean trackingRef, TypeRef declared) { TypeExtMeta declaredMeta = declared == null ? null : declared.getTypeExtMeta(); return TypeExtMeta.of( - typeId, - nullable, - trackingRef, - declaredMeta != null && declaredMeta.nullableWrapper()); + typeId, nullable, trackingRef, declaredMeta != null && declaredMeta.nullableWrapper()); } /** Class for Union field type. Union types use declared type. */ diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 85a3a6543d..d204cd6cbb 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -82,7 +82,6 @@ import org.apache.fory.serializer.CodegenSerializer; import org.apache.fory.serializer.CodegenSerializer.LazyInitBeanSerializer; import org.apache.fory.serializer.CompatibleSerializer; -import org.apache.fory.serializer.DeferedLazySerializer; import org.apache.fory.serializer.ObjectSerializer; import org.apache.fory.serializer.PrimitiveSerializers; import org.apache.fory.serializer.Serializer; @@ -300,13 +299,11 @@ public final void registerRuntimeTypeAlias(Class runtimeType, Class canoni Class registeredType = extRegistry.registeredClasses.get(runtimeTypeName); Preconditions.checkArgument( registeredType == null || registeredType == runtimeType, - "Runtime type alias name is already registered with different class: " - + runtimeTypeName); + "Runtime type alias name is already registered with different class: " + runtimeTypeName); String registeredName = extRegistry.registeredClasses.inverse().get(runtimeType); Preconditions.checkArgument( registeredName == null || registeredName.equals(runtimeTypeName), - "Runtime type alias is already registered with different name: " - + registeredName); + "Runtime type alias is already registered with different name: " + registeredName); classInfoMap.put(runtimeType, canonicalInfo); extRegistry.registeredClasses.put(runtimeTypeName, runtimeType); registerGraalvmClass(runtimeType); @@ -1767,7 +1764,8 @@ private List getRegisteredStaticGeneratedStructDescriptors(Class || !(typeInfo.getSerializer() instanceof StaticGeneratedStructSerializer)) { return null; } - return ((StaticGeneratedStructSerializer) typeInfo.getSerializer()).getGeneratedDescriptors(); + return ((StaticGeneratedStructSerializer) typeInfo.getSerializer()) + .getGeneratedDescriptors(); } private StaticGeneratedStructSerializer copyRegisteredStaticGeneratedStructSerializer( diff --git a/java/fory-core/src/main/java/org/apache/fory/type/Descriptor.java b/java/fory-core/src/main/java/org/apache/fory/type/Descriptor.java index f28c9bbdef..419e8b4065 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/Descriptor.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/Descriptor.java @@ -208,7 +208,7 @@ private Descriptor( nullableOverride == null ? resolveNullable(typeRef, !hasForyField, field, null, readMethod) : nullableOverride; - this.hasTrackingRefMetadata = hasTrackingRefMetadata(typeRef, field, readMethod); + this.hasTrackingRefMetadata = resolveHasTrackingRefMetadata(typeRef, field, readMethod); this.trackingRef = resolveTrackingRef(typeRef, field, readMethod); } @@ -410,7 +410,7 @@ private Descriptor( typeAnnotation = getAnnotation(field); arrayType = field.isAnnotationPresent(ArrayType.class); this.nullable = resolveNullable(typeRef, !hasForyField, field, recordComponent, readMethod); - this.hasTrackingRefMetadata = hasTrackingRefMetadata(typeRef, field, readMethod); + this.hasTrackingRefMetadata = resolveHasTrackingRefMetadata(typeRef, field, readMethod); this.trackingRef = resolveTrackingRef(typeRef, field, readMethod); } @@ -432,7 +432,7 @@ private Descriptor(Method readMethod) { typeAnnotation = TypeUtils.getMethodReturnTypeUseAnnotation(readMethod, readMethod.getName()); arrayType = readMethod.isAnnotationPresent(ArrayType.class); this.nullable = resolveNullable(typeRef, !hasForyField, null, null, readMethod); - this.hasTrackingRefMetadata = hasTrackingRefMetadata(typeRef, null, readMethod); + this.hasTrackingRefMetadata = resolveHasTrackingRefMetadata(typeRef, null, readMethod); this.trackingRef = resolveTrackingRef(typeRef, null, readMethod); } @@ -493,11 +493,7 @@ private static int resolveForyFieldId(ForyField foryField, String fieldName) { return id; } - private static boolean resolveNullable(TypeRef typeRef, boolean defaultNullable) { - return TypeUtils.isNullable(typeRef, defaultNullable); - } - - private static boolean hasTrackingRefMetadata( + private static boolean resolveHasTrackingRefMetadata( TypeRef typeRef, Field field, Method readMethod) { if (field != null && field.getAnnotation(Ref.class) != null) { return true; From 2c800ab569877fe8ace1c260fb8feaa3776aff16 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 15 May 2026 19:07:58 +0800 Subject: [PATCH 9/9] fix(java): preserve static meta-share field shapes --- AGENTS.md | 1 + .../StaticGeneratedStructSerializer.java | 51 +++++++++++- .../ExampleStaticGeneratedSerializerTest.java | 80 ++++++++++++++++--- 3 files changed, 118 insertions(+), 14 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 801df22a97..bce23d1ae4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -67,6 +67,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Only add tests that verify internal behaviors or fix specific bugs; do not create unnecessary tests unless requested. - Do not add cleanup-sentinel tests that only pin deleted APIs or removed fields. - Tests must exercise the actual code you wrote or changed. Do not write tests that pass by exercising a pre-existing code path that produces similar-looking results. Before writing a test, identify the exact new code path (annotation, codegen output, new API) and verify the test would fail if that code path were removed. When the change involves codegen or annotations, the test must use those annotations on real structs, run through the codegen pipeline, and verify the generated output drives the expected runtime behavior. +- Keep test method names concise. Name the behavior under test without encoding the whole scenario or expected result in the method name. - When reading code, skip files not tracked by git by default unless you generated them yourself or the task explicitly requires them. - Maintain cross-language consistency while respecting language-specific idioms. - Keep one active ownership path per concept. Do not leave duplicate serializers, resolvers, helpers, or registration paths for the same type family unless the split is deliberate and documented. diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java index 6ce05338c9..e3990ac7c7 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java @@ -23,8 +23,10 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.fory.annotation.Internal; import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; @@ -34,6 +36,8 @@ import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.meta.FieldInfo; import org.apache.fory.meta.TypeDef; +import org.apache.fory.meta.TypeExtMeta; +import org.apache.fory.reflect.TypeRef; import org.apache.fory.resolver.TypeInfo; import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.FieldGroups.SerializationFieldInfo; @@ -154,16 +158,55 @@ public final FieldGroups buildFieldGroups(List descriptors) { } public final FieldGroups buildLocalFieldGroups(List descriptors) { - // Generated descriptors carry source-only field metadata such as Scala Option wrappers. A - // schema TypeDef descriptor is the canonical remote contract, but it cannot replace the local - // generated descriptor view used to choose allocation-free field readers and writers. - return buildFieldGroups(descriptors); + if (!typeResolver.isShareMeta() || hasSourceOnlyMetadata(descriptors)) { + return buildFieldGroups(descriptors); + } + // In meta-share mode, Java static-generated writers must use the same local TypeDef-reified + // descriptor view as ObjectSerializer; otherwise readers can disagree on built-in and + // collection wire shapes during compatible skips. Scala macro descriptors are the exception: + // Option wrappers are source-only metadata used for generated accessor adaptation and are not + // represented in TypeDef, so those descriptors must stay on the generated path. + DescriptorGrouper grouper = + typeResolver.createDescriptorGrouper(typeResolver.getTypeDef(type, true), type); + return FieldGroups.buildFieldInfos(typeResolver, grouper); } protected final List runtimeDescriptors(List descriptors) { return typeResolver.normalizeFieldDescriptors(type, true, descriptors); } + private static boolean hasSourceOnlyMetadata(List descriptors) { + Set> visitedTypes = Collections.newSetFromMap(new IdentityHashMap<>()); + for (Descriptor descriptor : descriptors) { + if (hasSourceOnlyMetadata(descriptor.getTypeRef(), visitedTypes)) { + return true; + } + } + return false; + } + + private static boolean hasSourceOnlyMetadata(TypeRef typeRef, Set> visitedTypes) { + if (!visitedTypes.add(typeRef)) { + return false; + } + TypeExtMeta meta = typeRef.getTypeExtMeta(); + if (meta != null && meta.nullableWrapper()) { + return true; + } + for (TypeRef argument : typeRef.getTypeArguments()) { + if (hasSourceOnlyMetadata(argument, visitedTypes)) { + return true; + } + } + // TypeRef.getComponentType() is only meaningful for arrays here; walking it for arbitrary + // TypeRefs can loop through self-like component views while checking a cold descriptor path. + if (!typeRef.isArray()) { + return false; + } + TypeRef componentType = typeRef.getComponentType(); + return componentType != null && hasSourceOnlyMetadata(componentType, visitedTypes); + } + public final int[] localFieldIds( SerializationFieldInfo[] fieldInfos, List descriptors) { Map localIds = new HashMap<>(); diff --git a/java/fory-latest-jdk-tests/src/test/java/org/apache/fory/integration_tests/ExampleStaticGeneratedSerializerTest.java b/java/fory-latest-jdk-tests/src/test/java/org/apache/fory/integration_tests/ExampleStaticGeneratedSerializerTest.java index ba28669a1f..ed7b002e20 100644 --- a/java/fory-latest-jdk-tests/src/test/java/org/apache/fory/integration_tests/ExampleStaticGeneratedSerializerTest.java +++ b/java/fory-latest-jdk-tests/src/test/java/org/apache/fory/integration_tests/ExampleStaticGeneratedSerializerTest.java @@ -34,6 +34,7 @@ import java.time.Instant; import java.time.LocalDate; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -215,11 +216,52 @@ public void testStaticCompatibleBuilder(boolean xlang) throws Exception { } } + @Test + public void testMetaShareStaticSkipShapes() throws Exception { + StaticGeneratedMetaShareWriter value = new StaticGeneratedMetaShareWriter(); + value.date = LocalDate.of(2026, 5, 15); + value.flags = Arrays.asList(true, false, true); + value.after = "after"; + Fory writer = foryWithNativeId(StaticGeneratedMetaShareWriter.class, 1700); + Fory reader = foryWithNativeId(StaticGeneratedMetaShareReader.class, 1700); + assertStaticSerializer(writer, StaticGeneratedMetaShareWriter.class); + assertStaticSerializer(reader, StaticGeneratedMetaShareReader.class); + + writer.setMetaWriteContext(new MetaWriteContext()); + byte[] bytes = writer.serialize(value); + reader.setMetaReadContext(new MetaReadContext()); + StaticGeneratedMetaShareReader result = + (StaticGeneratedMetaShareReader) reader.deserialize(bytes); + Assert.assertEquals(result.after, "after"); + } + @ForyStruct public static class EmptyMessage { public EmptyMessage() {} } + @ForyStruct + public static class StaticGeneratedMetaShareWriter { + @ForyField(id = 1) + public LocalDate date; + + @ForyField(id = 2) + public List flags; + + @ForyField(id = 3) + public String after; + + public StaticGeneratedMetaShareWriter() {} + } + + @ForyStruct + public static class StaticGeneratedMetaShareReader { + @ForyField(id = 3) + public String after; + + public StaticGeneratedMetaShareReader() {} + } + public static class RuntimeEmptyMessage { public RuntimeEmptyMessage() {} } @@ -269,20 +311,29 @@ private static ThreadSafeFory threadSafeFory( } private static Fory fory(Class type, boolean xlang, boolean compatible, boolean codegen) { - Fory fory = - Fory.builder() - .withName("latest-static-" + FORY_ID.incrementAndGet()) - .withXlang(xlang) - .withCodegen(codegen) - .withMetaShare(compatible) - .withScopedMetaShare(false) - .withCompatible(compatible) - .requireClassRegistration(false) - .build(); + Fory fory = newFory(xlang, compatible, codegen); registerType(fory, type, xlang); return fory; } + private static Fory foryWithNativeId(Class type, int nativeId) { + Fory fory = newFory(false, true, false); + register(fory, type, false, nativeId, type.getSimpleName()); + return fory; + } + + private static Fory newFory(boolean xlang, boolean compatible, boolean codegen) { + return Fory.builder() + .withName("latest-static-" + FORY_ID.incrementAndGet()) + .withXlang(xlang) + .withCodegen(codegen) + .withMetaShare(compatible) + .withScopedMetaShare(false) + .withCompatible(compatible) + .requireClassRegistration(false) + .build(); + } + private static void registerType(org.apache.fory.BaseFory fory, Class type, boolean xlang) { if (type == ExampleRecordMessage.class || type == EmptyRecordMessage.class @@ -315,6 +366,15 @@ private static void assertStaticSerializer(ThreadSafeFory fory, Class type) { serializer instanceof StaticGeneratedStructSerializer, serializer.getClass().getName()); } + private static void assertStaticSerializer(Fory fory, Class type) { + Serializer serializer = fory.getTypeResolver().getTypeInfo(type).getSerializer(); + if (serializer instanceof DeferedLazySerializer) { + serializer = ((DeferedLazySerializer) serializer).resolveSerializer(); + } + Assert.assertTrue( + serializer instanceof StaticGeneratedStructSerializer, serializer.getClass().getName()); + } + private static void assertNotStaticSerializer(Fory fory, Class type) { Serializer serializer = fory.getTypeResolver().getTypeInfo(type).getSerializer(); if (serializer instanceof DeferedLazySerializer) {