diff --git a/core/src/main/java/kafka/automq/table/binder/RecordBinder.java b/core/src/main/java/kafka/automq/table/binder/RecordBinder.java index 1b47374264..47835c0355 100644 --- a/core/src/main/java/kafka/automq/table/binder/RecordBinder.java +++ b/core/src/main/java/kafka/automq/table/binder/RecordBinder.java @@ -22,6 +22,7 @@ import kafka.automq.table.metric.FieldMetric; import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; import org.apache.avro.generic.GenericRecord; import org.apache.iceberg.avro.AvroSchemaUtil; import org.apache.iceberg.data.Record; @@ -29,12 +30,14 @@ import org.apache.iceberg.types.Types; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.HashMap; import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; +import static org.apache.avro.Schema.Type.ARRAY; import static org.apache.avro.Schema.Type.NULL; /** @@ -124,12 +127,8 @@ private FieldMapping[] buildFieldMappings(Schema avroSchema, org.apache.iceberg. Schema recordSchema = avroSchema; FieldMapping[] mappings = new FieldMapping[icebergSchema.columns().size()]; - if (recordSchema.getType() == Schema.Type.UNION) { - recordSchema = recordSchema.getTypes().stream() - .filter(s -> s.getType() == Schema.Type.RECORD) - .findFirst() - .orElseThrow(() -> new IllegalArgumentException("UNION schema does not contain a RECORD type: " + avroSchema)); - } + // Unwrap UNION if it contains only one non-NULL type + recordSchema = resolveUnionElement(recordSchema); for (int icebergPos = 0; icebergPos < icebergSchema.columns().size(); icebergPos++) { Types.NestedField icebergField = icebergSchema.columns().get(icebergPos); @@ -162,17 +161,29 @@ private FieldMapping buildFieldMapping(String avroFieldName, int avroPosition, T } private Schema resolveUnionElement(Schema schema) { - Schema resolved = schema; - if (schema.getType() == Schema.Type.UNION) { - resolved = null; - for (Schema unionMember : schema.getTypes()) { - if (unionMember.getType() != NULL) { - resolved = unionMember; - break; - } + if (schema.getType() != Schema.Type.UNION) { + return schema; + } + + // Collect all non-NULL types + List nonNullTypes = new ArrayList<>(); + for (Schema s : schema.getTypes()) { + if (s.getType() != NULL) { + nonNullTypes.add(s); } } - return resolved; + + if (nonNullTypes.isEmpty()) { + throw new IllegalArgumentException("UNION schema contains only NULL type: " + schema); + } else if (nonNullTypes.size() == 1) { + // Only unwrap UNION if it contains exactly one non-NULL type (optional union) + return nonNullTypes.get(0); + } else { + // Multiple non-NULL types: non-optional union not supported + throw new UnsupportedOperationException( + "Non-optional UNION with multiple non-NULL types is not supported. " + + "Found " + nonNullTypes.size() + " non-NULL types in UNION: " + schema); + } } @@ -184,53 +195,135 @@ private Map precomputeBindersMap(TypeAdapter typeA for (FieldMapping mapping : fieldMappings) { if (mapping != null) { - Type type = mapping.icebergType(); - if (type.isPrimitiveType()) { - } else if (type.isStructType()) { - org.apache.iceberg.Schema schema = type.asStructType().asSchema(); - RecordBinder structBinder = new RecordBinder( - schema, - mapping.avroSchema(), - typeAdapter, - batchFieldCount - ); - binders.put(mapping.avroSchema(), structBinder); - } else if (type.isListType()) { - Types.ListType listType = type.asListType(); - Type elementType = listType.elementType(); - if (elementType.isStructType()) { - org.apache.iceberg.Schema schema = elementType.asStructType().asSchema(); - RecordBinder elementBinder = new RecordBinder( - schema, - mapping.avroSchema().getElementType(), - typeAdapter, - batchFieldCount - ); - binders.put(mapping.avroSchema().getElementType(), elementBinder); - } - } else if (type.isMapType()) { - Types.MapType mapType = type.asMapType(); - Type keyType = mapType.keyType(); - Type valueType = mapType.valueType(); - if (keyType.isStructType()) { - throw new UnsupportedOperationException("Struct keys in MAP types are not supported"); - } - if (valueType.isStructType()) { - org.apache.iceberg.Schema schema = valueType.asStructType().asSchema(); - RecordBinder valueBinder = new RecordBinder( - schema, - mapping.avroSchema().getValueType(), - typeAdapter, - batchFieldCount - ); - binders.put(mapping.avroSchema().getValueType(), valueBinder); - } - } + precomputeBindersForType(mapping.icebergType(), mapping.avroSchema(), binders, typeAdapter); } } return binders; } + /** + * Recursively precomputes binders for a given Iceberg type and its corresponding Avro schema. + */ + private void precomputeBindersForType(Type icebergType, Schema avroSchema, + Map binders, + TypeAdapter typeAdapter) { + if (icebergType.isPrimitiveType()) { + return; // No binders needed for primitive types + } + + if (icebergType.isStructType() && !avroSchema.isUnion()) { + createStructBinder(icebergType.asStructType(), avroSchema, binders, typeAdapter); + } else if (icebergType.isStructType() && avroSchema.isUnion()) { + createUnionStructBinders(icebergType.asStructType(), avroSchema, binders, typeAdapter); + } else if (icebergType.isListType()) { + createListBinder(icebergType.asListType(), avroSchema, binders, typeAdapter); + } else if (icebergType.isMapType()) { + createMapBinder(icebergType.asMapType(), avroSchema, binders, typeAdapter); + } + } + + /** + * Creates binders for STRUCT types represented as Avro UNIONs. + */ + private void createUnionStructBinders(Types.StructType structType, Schema avroSchema, + Map binders, + TypeAdapter typeAdapter) { + org.apache.iceberg.Schema schema = structType.asSchema(); + SchemaBuilder.FieldAssembler schemaBuilder = SchemaBuilder.record(avroSchema.getName()).fields() + .name("tag").type().intType().noDefault(); + int tag = 0; + for (Schema unionMember : avroSchema.getTypes()) { + if (unionMember.getType() != NULL) { + schemaBuilder.name("field" + tag).type(unionMember).noDefault(); + tag++; + } + } + RecordBinder structBinder = new RecordBinder(schema, schemaBuilder.endRecord(), typeAdapter, batchFieldCount); + binders.put(avroSchema, structBinder); + } + + /** + * Creates a binder for a STRUCT type field. + */ + private void createStructBinder(Types.StructType structType, Schema avroSchema, + Map binders, + TypeAdapter typeAdapter) { + org.apache.iceberg.Schema schema = structType.asSchema(); + RecordBinder structBinder = new RecordBinder(schema, avroSchema, typeAdapter, batchFieldCount); + binders.put(avroSchema, structBinder); + } + + /** + * Creates binders for LIST type elements (if they are STRUCT types). + */ + private void createListBinder(Types.ListType listType, Schema avroSchema, + Map binders, + TypeAdapter typeAdapter) { + Type elementType = listType.elementType(); + if (elementType.isStructType()) { + Schema elementAvroSchema = avroSchema.getElementType(); + createStructBinder(elementType.asStructType(), elementAvroSchema, binders, typeAdapter); + } + } + + /** + * Creates binders for MAP type keys and values (if they are STRUCT types). + * Handles two Avro representations: ARRAY of key-value records, or native MAP. + */ + private void createMapBinder(Types.MapType mapType, Schema avroSchema, + Map binders, + TypeAdapter typeAdapter) { + Type keyType = mapType.keyType(); + Type valueType = mapType.valueType(); + + if (ARRAY.equals(avroSchema.getType())) { + // Avro represents MAP as ARRAY of records with "key" and "value" fields + createMapAsArrayBinder(keyType, valueType, avroSchema, binders, typeAdapter); + } else { + // Avro represents MAP as native MAP type + createMapAsMapBinder(keyType, valueType, avroSchema, binders, typeAdapter); + } + } + + /** + * Handles MAP represented as Avro ARRAY of {key, value} records. + */ + private void createMapAsArrayBinder(Type keyType, Type valueType, Schema avroSchema, + Map binders, + TypeAdapter typeAdapter) { + Schema elementSchema = avroSchema.getElementType(); + + // Process key if it's a STRUCT + if (keyType.isStructType()) { + Schema keyAvroSchema = elementSchema.getField("key").schema(); + createStructBinder(keyType.asStructType(), keyAvroSchema, binders, typeAdapter); + } + + // Process value if it's a STRUCT + if (valueType.isStructType()) { + Schema valueAvroSchema = elementSchema.getField("value").schema(); + createStructBinder(valueType.asStructType(), valueAvroSchema, binders, typeAdapter); + } + } + + /** + * Handles MAP represented as Avro native MAP type. + */ + private void createMapAsMapBinder(Type keyType, Type valueType, Schema avroSchema, + Map binders, + TypeAdapter typeAdapter) { + // Struct keys in native MAP are not supported by Avro + if (keyType.isStructType()) { + throw new UnsupportedOperationException("Struct keys in MAP types are not supported"); + } + + // Process value if it's a STRUCT + if (valueType.isStructType()) { + Schema valueAvroSchema = avroSchema.getValueType(); + createStructBinder(valueType.asStructType(), valueAvroSchema, binders, typeAdapter); + } + } + private static class AvroRecordView implements Record { private final GenericRecord avroRecord; private final org.apache.iceberg.Schema icebergSchema; diff --git a/core/src/main/java/kafka/automq/table/process/convert/LogicalMapProtobufData.java b/core/src/main/java/kafka/automq/table/process/convert/LogicalMapProtobufData.java new file mode 100644 index 0000000000..5206a3d2cd --- /dev/null +++ b/core/src/main/java/kafka/automq/table/process/convert/LogicalMapProtobufData.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025, AutoMQ HK Limited. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.automq.table.process.convert; + +import com.google.protobuf.Descriptors; + +import org.apache.avro.Schema; +import org.apache.avro.protobuf.ProtobufData; +import org.apache.iceberg.avro.CodecSetup; + +import java.util.Arrays; + +/** + * ProtobufData extension that annotates protobuf map fields with Iceberg's LogicalMap logical type so that + * downstream Avro{@literal >}Iceberg conversion keeps them as MAP instead of generic {@literal ARRAY>}. + */ +public class LogicalMapProtobufData extends ProtobufData { + private static final LogicalMapProtobufData INSTANCE = new LogicalMapProtobufData(); + private static final Schema NULL = Schema.create(Schema.Type.NULL); + + public static LogicalMapProtobufData get() { + return INSTANCE; + } + + @Override + public Schema getSchema(Descriptors.FieldDescriptor f) { + Schema schema = super.getSchema(f); + if (f.isMapField()) { + Schema nonNull = resolveNonNull(schema); + // protobuf maps are materialized as ARRAY in Avro + if (nonNull != null && nonNull.getType() == Schema.Type.ARRAY) { + // set logicalType property; LogicalTypes is registered in CodecSetup + CodecSetup.getLogicalMap().addToSchema(nonNull); + } + } else if (f.isOptional() && !f.isRepeated() && f.getContainingOneof() == null + && schema.getType() != Schema.Type.UNION) { + // Proto3 optional scalars/messages: wrap as union(type, null) so the protobuf default (typically non-null) + // remains valid (Avro default must match the first branch). + schema = Schema.createUnion(Arrays.asList(schema, NULL)); + } else if (f.getContainingOneof() != null && !f.isRepeated() && schema.getType() != Schema.Type.UNION) { + // oneof fields: wrap as union(type, null) so that non-set fields can be represented as null + schema = Schema.createUnion(Arrays.asList(schema, NULL)); + } + return schema; + } + + private Schema resolveNonNull(Schema schema) { + if (schema == null) { + return null; + } + if (schema.getType() == Schema.Type.UNION) { + for (Schema member : schema.getTypes()) { + if (member.getType() != Schema.Type.NULL) { + return member; + } + } + return null; + } + return schema; + } +} diff --git a/core/src/main/java/kafka/automq/table/process/convert/ProtoToAvroConverter.java b/core/src/main/java/kafka/automq/table/process/convert/ProtoToAvroConverter.java index b8f6727f5d..e6f0ec2130 100644 --- a/core/src/main/java/kafka/automq/table/process/convert/ProtoToAvroConverter.java +++ b/core/src/main/java/kafka/automq/table/process/convert/ProtoToAvroConverter.java @@ -34,16 +34,22 @@ import org.apache.avro.protobuf.ProtobufData; import java.nio.ByteBuffer; -import java.util.ArrayList; import java.util.List; public class ProtoToAvroConverter { + private static final ProtobufData DATA = initProtobufData(); + + private static ProtobufData initProtobufData() { + ProtobufData protobufData = LogicalMapProtobufData.get(); + protobufData.addLogicalTypeConversion(new ProtoConversions.TimestampMicrosConversion()); + return protobufData; + } + public static GenericRecord convert(Message protoMessage, Schema schema) { try { - ProtobufData protobufData = ProtobufData.get(); - protobufData.addLogicalTypeConversion(new ProtoConversions.TimestampMicrosConversion()); - return convertRecord(protoMessage, schema, protobufData); + Schema nonNull = resolveNonNullSchema(schema); + return convertRecord(protoMessage, nonNull, DATA); } catch (Exception e) { throw new ConverterException("Proto to Avro conversion failed", e); } @@ -51,56 +57,52 @@ public static GenericRecord convert(Message protoMessage, Schema schema) { private static Object convert(Message protoMessage, Schema schema, ProtobufData protobufData) { Conversion conversion = getConversion(protoMessage.getDescriptorForType(), protobufData); - if (conversion != null) { - if (conversion instanceof ProtoConversions.TimestampMicrosConversion) { - ProtoConversions.TimestampMicrosConversion timestampConversion = (ProtoConversions.TimestampMicrosConversion) conversion; - Timestamp.Builder builder = Timestamp.newBuilder(); - Timestamp.getDescriptor().getFields().forEach(field -> { - String fieldName = field.getName(); - Descriptors.FieldDescriptor protoField = protoMessage.getDescriptorForType() - .findFieldByName(fieldName); - if (protoField != null) { - Object value = protoMessage.getField(protoField); - if (value != null) { - builder.setField(field, value); - } - } - }); - Timestamp timestamp = builder.build(); - return timestampConversion.toLong(timestamp, schema, null); - } + if (conversion instanceof ProtoConversions.TimestampMicrosConversion) { + ProtoConversions.TimestampMicrosConversion timestampConversion = (ProtoConversions.TimestampMicrosConversion) conversion; + Timestamp.Builder builder = Timestamp.newBuilder(); + Timestamp.getDescriptor().getFields().forEach(field -> { + Descriptors.FieldDescriptor protoField = protoMessage.getDescriptorForType().findFieldByName(field.getName()); + if (protoField != null && protoMessage.hasField(protoField)) { + builder.setField(field, protoMessage.getField(protoField)); + } + }); + return timestampConversion.toLong(builder.build(), schema, null); } - if (schema.getType() == Schema.Type.RECORD) { - return convertRecord(protoMessage, schema, protobufData); - } else if (schema.getType() == Schema.Type.UNION) { - Schema dataSchema = protobufData.getSchema(protoMessage.getDescriptorForType()); - return convertRecord(protoMessage, dataSchema, protobufData); - } else { - return null; + + Schema nonNull = resolveNonNullSchema(schema); + if (nonNull.getType() == Schema.Type.RECORD) { + return convertRecord(protoMessage, nonNull, protobufData); } + return null; } private static Conversion getConversion(Descriptors.Descriptor descriptor, ProtobufData protobufData) { String namespace = protobufData.getNamespace(descriptor.getFile(), descriptor.getContainingType()); String name = descriptor.getName(); - - if (namespace.equals("com.google.protobuf")) { - if (name.equals("Timestamp")) { - return new ProtoConversions.TimestampMicrosConversion(); - } + if ("com.google.protobuf".equals(namespace) && "Timestamp".equals(name)) { + return new ProtoConversions.TimestampMicrosConversion(); } return null; } - private static GenericRecord convertRecord(Message protoMessage, Schema schema, ProtobufData protobufData) { - GenericRecord record = new GenericData.Record(schema); - for (Schema.Field field : schema.getFields()) { + private static GenericRecord convertRecord(Message protoMessage, Schema recordSchema, ProtobufData protobufData) { + GenericRecord record = new GenericData.Record(recordSchema); + Descriptors.Descriptor descriptor = protoMessage.getDescriptorForType(); + + for (Schema.Field field : recordSchema.getFields()) { String fieldName = field.name(); - Descriptors.FieldDescriptor protoField = protoMessage.getDescriptorForType() - .findFieldByName(fieldName); + Descriptors.FieldDescriptor protoField = descriptor.findFieldByName(fieldName); + if (protoField == null) { + continue; + } - if (protoField == null) + boolean hasPresence = protoField.hasPresence() || protoField.getContainingOneof() != null; + if (!protoField.isRepeated() && hasPresence && !protoMessage.hasField(protoField)) { + if (allowsNull(field.schema())) { + record.put(fieldName, null); + } continue; + } Object value = protoMessage.getField(protoField); Object convertedValue = convertValue(value, protoField, field.schema(), protobufData); @@ -111,22 +113,23 @@ private static GenericRecord convertRecord(Message protoMessage, Schema schema, private static Object convertValue(Object value, Descriptors.FieldDescriptor fieldDesc, Schema avroSchema, ProtobufData protobufData) { - if (value == null) + if (value == null) { return null; + } + + Schema nonNullSchema = resolveNonNullSchema(avroSchema); - // process repeated fields if (fieldDesc.isRepeated() && value instanceof List) { List protoList = (List) value; - List avroList = new ArrayList<>(); - Schema elementSchema = avroSchema.getElementType(); - + GenericData.Array avroArray = new GenericData.Array<>(protoList.size(), nonNullSchema); + Schema elementSchema = nonNullSchema.getElementType(); for (Object item : protoList) { - avroList.add(convertSingleValue(item, elementSchema, protobufData)); + avroArray.add(convertSingleValue(item, elementSchema, protobufData)); } - return avroList; + return avroArray; } - return convertSingleValue(value, avroSchema, protobufData); + return convertSingleValue(value, nonNullSchema, protobufData); } private static Object convertSingleValue(Object value, Schema avroSchema, ProtobufData protobufData) { @@ -135,41 +138,59 @@ private static Object convertSingleValue(Object value, Schema avroSchema, Protob } else if (value instanceof ByteString) { return ((ByteString) value).asReadOnlyByteBuffer(); } else if (value instanceof Enum) { - return value.toString(); // protobuf Enum is represented as string + return value.toString(); } else if (value instanceof List) { throw new ConverterException("Unexpected list type found; repeated fields should have been handled in convertValue"); } - // primitive types return convertPrimitive(value, avroSchema); } private static Object convertPrimitive(Object value, Schema schema) { - switch (schema.getType()) { - case INT: { + Schema.Type type = schema.getType(); + switch (type) { + case INT: return ((Number) value).intValue(); - } - case LONG: { + case LONG: return ((Number) value).longValue(); - } - case FLOAT: { + case FLOAT: return ((Number) value).floatValue(); - } - case DOUBLE: { + case DOUBLE: return ((Number) value).doubleValue(); - } - case BOOLEAN: { + case BOOLEAN: return (Boolean) value; - } - case BYTES: { + case BYTES: if (value instanceof byte[]) { return ByteBuffer.wrap((byte[]) value); } return value; - } - default: { + default: return value; + } + } + + private static Schema resolveNonNullSchema(Schema schema) { + if (schema.getType() == Schema.Type.UNION) { + for (Schema type : schema.getTypes()) { + if (type.getType() != Schema.Type.NULL) { + return type; + } + } + } + return schema; + } + + private static boolean allowsNull(Schema schema) { + if (schema.getType() == Schema.Type.NULL) { + return true; + } + if (schema.getType() == Schema.Type.UNION) { + for (Schema type : schema.getTypes()) { + if (type.getType() == Schema.Type.NULL) { + return true; + } } } + return false; } } diff --git a/core/src/main/java/kafka/automq/table/process/convert/ProtobufRegistryConverter.java b/core/src/main/java/kafka/automq/table/process/convert/ProtobufRegistryConverter.java index 3608fb2130..ffa01b881f 100644 --- a/core/src/main/java/kafka/automq/table/process/convert/ProtobufRegistryConverter.java +++ b/core/src/main/java/kafka/automq/table/process/convert/ProtobufRegistryConverter.java @@ -78,7 +78,7 @@ public ConversionResult convert(String topic, ByteBuffer buffer) throws Converte Message protoMessage = deserializer.deserialize(topic, null, buffer); Schema schema = avroSchemaCache.getIfPresent(schemaId); if (schema == null) { - ProtobufData protobufData = ProtobufData.get(); + ProtobufData protobufData = LogicalMapProtobufData.get(); protobufData.addLogicalTypeConversion(new ProtoConversions.TimestampMicrosConversion()); schema = protobufData.getSchema(protoMessage.getDescriptorForType()); avroSchemaCache.put(schemaId, schema); diff --git a/core/src/main/java/org/apache/iceberg/avro/CodecSetup.java b/core/src/main/java/org/apache/iceberg/avro/CodecSetup.java index 2a68b9f8ce..83ec87134c 100644 --- a/core/src/main/java/org/apache/iceberg/avro/CodecSetup.java +++ b/core/src/main/java/org/apache/iceberg/avro/CodecSetup.java @@ -23,6 +23,10 @@ public class CodecSetup { + public static LogicalMap getLogicalMap() { + return LogicalMap.get(); + } + static { LogicalTypes.register(LogicalMap.NAME, schema -> LogicalMap.get()); } diff --git a/core/src/test/java/kafka/automq/table/binder/AvroRecordBinderTest.java b/core/src/test/java/kafka/automq/table/binder/AvroRecordBinderTest.java index 30492c7f22..cfffed5485 100644 --- a/core/src/test/java/kafka/automq/table/binder/AvroRecordBinderTest.java +++ b/core/src/test/java/kafka/automq/table/binder/AvroRecordBinderTest.java @@ -19,651 +19,91 @@ package kafka.automq.table.binder; -import com.google.common.collect.ImmutableMap; - -import org.apache.avro.Conversions; -import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; -import org.apache.avro.SchemaBuilder; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; -import org.apache.avro.io.DatumReader; -import org.apache.avro.io.DatumWriter; -import org.apache.avro.io.Decoder; -import org.apache.avro.io.DecoderFactory; -import org.apache.avro.io.Encoder; -import org.apache.avro.io.EncoderFactory; -import org.apache.avro.specific.SpecificDatumReader; -import org.apache.avro.specific.SpecificDatumWriter; import org.apache.avro.util.Utf8; -import org.apache.commons.lang3.RandomStringUtils; -import org.apache.iceberg.FileFormat; -import org.apache.iceberg.Table; import org.apache.iceberg.avro.AvroSchemaUtil; import org.apache.iceberg.avro.CodecSetup; -import org.apache.iceberg.catalog.Namespace; -import org.apache.iceberg.catalog.TableIdentifier; -import org.apache.iceberg.data.GenericAppenderFactory; import org.apache.iceberg.data.Record; -import org.apache.iceberg.inmemory.InMemoryCatalog; -import org.apache.iceberg.io.FileAppenderFactory; -import org.apache.iceberg.io.OutputFileFactory; -import org.apache.iceberg.io.TaskWriter; -import org.apache.iceberg.io.UnpartitionedWriter; import org.apache.iceberg.types.Types; -import org.apache.iceberg.util.DateTimeUtil; -import org.apache.iceberg.util.UUIDUtil; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; -import org.mockito.MockitoAnnotations; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.math.BigDecimal; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.time.Instant; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.time.OffsetDateTime; -import java.time.temporal.ChronoUnit; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.UUID; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; -import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; -import static org.apache.iceberg.TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @Tag("S3Unit") -public class AvroRecordBinderTest { +class AvroRecordBinderTest { private static final String TEST_NAMESPACE = "kafka.automq.table.binder"; - private static Schema avroSchema; - private InMemoryCatalog catalog; - private Table table; - private TaskWriter writer; - private int tableCounter; - static { CodecSetup.setup(); } - @BeforeEach - void setUp() { - MockitoAnnotations.openMocks(this); - catalog = new InMemoryCatalog(); - catalog.initialize("test", ImmutableMap.of()); - catalog.createNamespace(Namespace.of("default")); - tableCounter = 0; - } - - private void testSendRecord(org.apache.iceberg.Schema schema, Record record) { - String tableName = "test_" + tableCounter++; - table = catalog.createTable(TableIdentifier.of(Namespace.of("default"), tableName), schema); - writer = createTableWriter(table); - try { - writer.write(record); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public static TaskWriter createTableWriter(Table table) { - FileAppenderFactory appenderFactory = new GenericAppenderFactory( - table.schema(), - table.spec(), - null, null, null) - .setAll(new HashMap<>(table.properties())) - .set(PARQUET_ROW_GROUP_SIZE_BYTES, "1"); - - OutputFileFactory fileFactory = - OutputFileFactory.builderFor(table, 1, System.currentTimeMillis()) - .defaultSpec(table.spec()) - .operationId(UUID.randomUUID().toString()) - .format(FileFormat.PARQUET) - .build(); - - return new UnpartitionedWriter<>( - table.spec(), - FileFormat.PARQUET, - appenderFactory, - fileFactory, - table.io(), - WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT - ); - } - - private static GenericRecord serializeAndDeserialize(GenericRecord record, Schema schema) { - try { - // Serialize the avro record to a byte array - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - DatumWriter datumWriter = new SpecificDatumWriter<>(schema); - Encoder encoder = EncoderFactory.get().binaryEncoder(outputStream, null); - datumWriter.write(record, encoder); - encoder.flush(); - outputStream.close(); - - byte[] serializedBytes = outputStream.toByteArray(); - - // Deserialize the byte array back to an avro record - DatumReader datumReader = new SpecificDatumReader<>(schema); - ByteArrayInputStream inputStream = new ByteArrayInputStream(serializedBytes); - Decoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null); - return datumReader.read(null, decoder); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private static Map toStringKeyMap(Object value) { - if (value == null) { - return null; - } - Map map = (Map) value; - Map result = new HashMap<>(map.size()); - for (Map.Entry entry : map.entrySet()) { - String key = entry.getKey() == null ? null : entry.getKey().toString(); - result.put(key, normalizeValue(entry.getValue())); - } - return result; - } - - private static Object normalizeValue(Object value) { - if (value == null) { - return null; - } - if (value instanceof CharSequence) { - return value.toString(); - } - if (value instanceof List) { - List list = (List) value; - List normalized = new ArrayList<>(list.size()); - for (Object element : list) { - normalized.add(normalizeValue(element)); - } - return normalized; - } - if (value instanceof Map) { - return toStringKeyMap(value); - } - return value; - } - - private static Map normalizeMapValues(Object value) { - if (value == null) { - return null; - } - Map map = (Map) value; - Map result = new HashMap<>(map.size()); - for (Map.Entry entry : map.entrySet()) { - @SuppressWarnings("unchecked") - K key = (K) entry.getKey(); - result.put(key, normalizeValue(entry.getValue())); - } - return result; - } - - private static Schema createOptionalSchema(Schema nonNullSchema) { - return Schema.createUnion(Arrays.asList(Schema.create(Schema.Type.NULL), nonNullSchema)); - } - - private static Schema ensureNonNullBranch(Schema schema) { - if (schema.getType() != Schema.Type.UNION) { - return schema; - } - return schema.getTypes().stream() - .filter(type -> type.getType() != Schema.Type.NULL) - .findFirst() - .orElseThrow(() -> new IllegalArgumentException("Union schema lacks non-null branch: " + schema)); - } - - private void runRoundTrip(Schema recordSchema, Consumer avroPopulator, Consumer assertions) { - GenericRecord avroRecord = new GenericData.Record(recordSchema); - avroPopulator.accept(avroRecord); - GenericRecord roundTripRecord = serializeAndDeserialize(avroRecord, recordSchema); - - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(recordSchema); - Record icebergRecord = new RecordBinder(icebergSchema, recordSchema).bind(roundTripRecord); - - assertions.accept(icebergRecord); - testSendRecord(icebergSchema, icebergRecord); - } - - // Helper method to test round-trip conversion for a single field - private void assertFieldRoundTrips(String recordPrefix, - String fieldName, - Supplier fieldSchemaSupplier, - Function avroValueSupplier, - Consumer valueAssertion) { - Schema baseFieldSchema = fieldSchemaSupplier.get(); - Schema baseRecordSchema = SchemaBuilder.builder() - .record(recordPrefix + "Base") - .namespace(TEST_NAMESPACE) - .fields() - .name(fieldName).type(baseFieldSchema).noDefault() - .endRecord(); - - // Direct field - runRoundTrip(baseRecordSchema, - record -> record.put(fieldName, avroValueSupplier.apply(baseFieldSchema)), - icebergRecord -> valueAssertion.accept(icebergRecord.getField(fieldName)) - ); - - Schema optionalFieldSchema = createOptionalSchema(fieldSchemaSupplier.get()); - Schema unionRecordSchema = SchemaBuilder.builder() - .record(recordPrefix + "Union") - .namespace(TEST_NAMESPACE) - .fields() - .name(fieldName).type(optionalFieldSchema).withDefault(null) - .endRecord(); - Schema nonNullBranch = ensureNonNullBranch(optionalFieldSchema); - - // Optional field with non-null value - runRoundTrip(unionRecordSchema, - record -> record.put(fieldName, avroValueSupplier.apply(nonNullBranch)), - icebergRecord -> valueAssertion.accept(icebergRecord.getField(fieldName)) - ); - - // Optional field with null value - runRoundTrip(unionRecordSchema, - record -> record.put(fieldName, null), - icebergRecord -> assertNull(icebergRecord.getField(fieldName)) - ); - } - - @Test - public void testSchemaEvolution() { - // Original Avro schema with 3 fields - String originalAvroSchemaJson = "{" - + "\"type\": \"record\"," - + "\"name\": \"User\"," - + "\"fields\": [" - + " {\"name\": \"id\", \"type\": \"long\"}," - + " {\"name\": \"name\", \"type\": \"string\"}," - + " {\"name\": \"email\", \"type\": \"string\"}" - + "]}"; - - // Evolved Iceberg schema: added age field, removed email field - org.apache.iceberg.Schema evolvedIcebergSchema = new org.apache.iceberg.Schema( - Types.NestedField.required(1, "id", Types.LongType.get()), - Types.NestedField.required(2, "name", Types.StringType.get()), - Types.NestedField.optional(4, "age", Types.IntegerType.get()) // New field - // email field removed - ); - - Schema avroSchema = new Schema.Parser().parse(originalAvroSchemaJson); - GenericRecord avroRecord = new GenericData.Record(avroSchema); - avroRecord.put("id", 12345L); - avroRecord.put("name", new Utf8("John Doe")); - avroRecord.put("email", new Utf8("john@example.com")); - - // Test wrapper with evolved schema - RecordBinder recordBinder = new RecordBinder(evolvedIcebergSchema, avroSchema); - Record bind = recordBinder.bind(avroRecord); - - assertEquals(12345L, bind.get(0)); // id - assertEquals("John Doe", bind.get(1).toString()); // name - assertNull(bind.get(2)); // age - doesn't exist in Avro record - } - - - @Test - public void testWrapperReusability() { - // Test that the same wrapper can be reused for multiple records - String avroSchemaJson = "{" - + "\"type\": \"record\"," - + "\"name\": \"User\"," - + "\"fields\": [" - + " {\"name\": \"id\", \"type\": \"long\"}," - + " {\"name\": \"name\", \"type\": \"string\"}" - + "]}"; - Schema avroSchema = new Schema.Parser().parse(avroSchemaJson); - - org.apache.iceberg.Schema icebergSchema = new org.apache.iceberg.Schema( - Types.NestedField.required(1, "id", Types.LongType.get()), - Types.NestedField.required(2, "name", Types.StringType.get()) - ); - - RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); - - - // First record - GenericRecord record1 = new GenericData.Record(avroSchema); - record1.put("id", 1L); - record1.put("name", new Utf8("Alice")); - - Record bind1 = recordBinder.bind(record1); - assertEquals(1L, bind1.get(0)); - assertEquals("Alice", bind1.get(1).toString()); - - // Reuse wrapper for second record - GenericRecord record2 = new GenericData.Record(avroSchema); - record2.put("id", 2L); - record2.put("name", new Utf8("Bob")); - - Record bind2 = recordBinder.bind(record2); - assertEquals(2L, bind2.get(0)); - assertEquals("Bob", bind2.get(1).toString()); - } - - - // Test method for converting a single string field - @Test - public void testStringConversion() { - assertFieldRoundTrips("String", "stringField", - () -> Schema.create(Schema.Type.STRING), - schema -> "test_string", - value -> assertEquals("test_string", value.toString()) - ); - } - - // Test method for converting a single integer field - @Test - public void testIntegerConversion() { - assertFieldRoundTrips("Int", "intField", - () -> Schema.create(Schema.Type.INT), - schema -> 42, - value -> assertEquals(42, value) - ); - } - - // Test method for converting a single long field - @Test - public void testLongConversion() { - assertFieldRoundTrips("Long", "longField", - () -> Schema.create(Schema.Type.LONG), - schema -> 123456789L, - value -> assertEquals(123456789L, value) - ); - } - - // Test method for converting a single float field - @Test - public void testFloatConversion() { - assertFieldRoundTrips("Float", "floatField", - () -> Schema.create(Schema.Type.FLOAT), - schema -> 3.14f, - value -> assertEquals(3.14f, (Float) value) - ); - } - - // Test method for converting a single double field - @Test - public void testDoubleConversion() { - assertFieldRoundTrips("Double", "doubleField", - () -> Schema.create(Schema.Type.DOUBLE), - schema -> 6.28, - value -> assertEquals(6.28, value) - ); - } - - // Test method for converting a single boolean field - @Test - public void testBooleanConversion() { - assertFieldRoundTrips("Boolean", "booleanField", - () -> Schema.create(Schema.Type.BOOLEAN), - schema -> true, - value -> assertEquals(true, value) - ); - } - - // Test method for converting a single date field (number of days from epoch) - @Test - public void testDateConversion() { - LocalDate localDate = LocalDate.of(2020, 1, 1); - int epochDays = (int) ChronoUnit.DAYS.between(LocalDate.ofEpochDay(0), localDate); - assertFieldRoundTrips("Date", "dateField", - () -> LogicalTypes.date().addToSchema(Schema.create(Schema.Type.INT)), - schema -> epochDays, - value -> assertEquals(localDate, value) - ); - } - - // Test method for converting a single time field (number of milliseconds from midnight) - @Test - public void testTimeConversion() { - LocalTime localTime = LocalTime.of(10, 0); - long epochMicros = localTime.toNanoOfDay() / 1000; - int epochMillis = (int) (localTime.toNanoOfDay() / 1_000_000); - assertFieldRoundTrips("TimeMicros", "timeField", - () -> LogicalTypes.timeMicros().addToSchema(Schema.create(Schema.Type.LONG)), - schema -> epochMicros, - value -> assertEquals(localTime, value) - ); - - assertFieldRoundTrips("TimeMillis", "timeField2", - () -> LogicalTypes.timeMillis().addToSchema(Schema.create(Schema.Type.INT)), - schema -> epochMillis, - value -> assertEquals(localTime, value) - ); - } - - // Test method for converting a single timestamp field (number of milliseconds from epoch) - // timestamp: Stores microseconds from 1970-01-01 00:00:00.000000. [1] - // timestamptz: Stores microseconds from 1970-01-01 00:00:00.000000 UTC. [1] + /** + * Tests that when the same Schema instance is used in multiple places (direct field and list element), + * the RecordBinder correctly shares the same binder for that schema instance. + * This verifies the IdentityHashMap optimization. + */ @Test - public void testTimestampConversion() { - Instant instant = Instant.parse("2020-01-01T12:34:56.123456Z"); - long timestampMicros = instant.getEpochSecond() * 1_000_000 + instant.getNano() / 1_000; - long timestampMillis = instant.toEpochMilli(); - - Supplier timestampMicrosTzSchema = () -> { - Schema schema = LogicalTypes.timestampMicros().addToSchema(Schema.create(Schema.Type.LONG)); - schema.addProp("adjust-to-utc", true); - return schema; - }; - - Supplier timestampMicrosSchema = () -> { - Schema schema = LogicalTypes.timestampMicros().addToSchema(Schema.create(Schema.Type.LONG)); - schema.addProp("adjust-to-utc", false); - return schema; - }; - - Supplier timestampMillisTzSchema = () -> { - Schema schema = LogicalTypes.timestampMillis().addToSchema(Schema.create(Schema.Type.LONG)); - schema.addProp("adjust-to-utc", true); - return schema; - }; - - Supplier timestampMillisSchema = () -> { - Schema schema = LogicalTypes.timestampMillis().addToSchema(Schema.create(Schema.Type.LONG)); - schema.addProp("adjust-to-utc", false); - return schema; - }; - - OffsetDateTime expectedMicrosTz = DateTimeUtil.timestamptzFromMicros(timestampMicros); - LocalDateTime expectedMicros = DateTimeUtil.timestampFromMicros(timestampMicros); - OffsetDateTime expectedMillisTz = DateTimeUtil.timestamptzFromMicros(timestampMillis * 1000); - LocalDateTime expectedMillis = DateTimeUtil.timestampFromMicros(timestampMillis * 1000); - - assertFieldRoundTrips("TimestampMicrosTz", "timestampField1", - timestampMicrosTzSchema, - schema -> timestampMicros, - value -> assertEquals(expectedMicrosTz, value) - ); - - assertFieldRoundTrips("TimestampMicros", "timestampField2", - timestampMicrosSchema, - schema -> timestampMicros, - value -> assertEquals(expectedMicros, value) - ); - - assertFieldRoundTrips("TimestampMillisTz", "timestampField3", - timestampMillisTzSchema, - schema -> timestampMillis, - value -> assertEquals(expectedMillisTz, value) - ); - - assertFieldRoundTrips("TimestampMillis", "timestampField4", - timestampMillisSchema, - schema -> timestampMillis, - value -> assertEquals(expectedMillis, value) - ); - } - - // Test method for converting a single binary field - @Test - public void testBinaryConversion() { - String randomAlphabetic = RandomStringUtils.randomAlphabetic(64); - assertFieldRoundTrips("Binary", "binaryField", - () -> Schema.create(Schema.Type.BYTES), - schema -> ByteBuffer.wrap(randomAlphabetic.getBytes(StandardCharsets.UTF_8)), - value -> { - ByteBuffer binaryField = (ByteBuffer) value; - assertEquals(randomAlphabetic, new String(binaryField.array(), StandardCharsets.UTF_8)); - } - ); - } - - // Test method for converting a single fixed field - @Test - public void testFixedConversion() { - assertFieldRoundTrips("Fixed", "fixedField", - () -> Schema.createFixed("FixedField", null, null, 3), - schema -> new GenericData.Fixed(schema, "bar".getBytes(StandardCharsets.UTF_8)), - value -> assertEquals("bar", new String((byte[]) value, StandardCharsets.UTF_8)) - ); - } - - // Test method for converting a single enum field - @Test - public void testEnumConversion() { - assertFieldRoundTrips("Enum", "enumField", - () -> Schema.createEnum("EnumField", null, null, Arrays.asList("A", "B", "C")), - schema -> new GenericData.EnumSymbol(schema, "B"), - value -> assertEquals("B", value.toString()) - ); - } - - // Test method for converting a single UUID field - @Test - public void testUUIDConversion() { - UUID uuid = UUID.randomUUID(); - assertFieldRoundTrips("UUID", "uuidField", - () -> LogicalTypes.uuid().addToSchema(Schema.create(Schema.Type.STRING)), - schema -> new Conversions.UUIDConversion().toCharSequence(uuid, schema, LogicalTypes.uuid()), - value -> assertEquals(uuid, UUIDUtil.convert((byte[]) value)) - ); - } - - // Test method for converting a single decimal field - @Test - public void testDecimalConversion() { - BigDecimal bigDecimal = BigDecimal.valueOf(1000.00).setScale(2); - assertFieldRoundTrips("Decimal", "decimalField", - () -> LogicalTypes.decimal(9, 2).addToSchema(Schema.create(Schema.Type.BYTES)), - schema -> { - LogicalTypes.Decimal decimalType = (LogicalTypes.Decimal) schema.getLogicalType(); - return new Conversions.DecimalConversion().toBytes(bigDecimal, schema, decimalType); - }, - value -> assertEquals(bigDecimal, value) - ); - } - - // Test method for converting a list field - @Test - public void testListConversion() { - List expected = Arrays.asList("a", "b", "c"); - assertFieldRoundTrips("List", "listField", - () -> Schema.createArray(Schema.create(Schema.Type.STRING)), - schema -> new ArrayList<>(expected), - value -> assertEquals(expected, normalizeValue(value)) - ); - } + public void testStructSchemaInstanceReuseSharesBinder() { + Schema sharedStruct = Schema.createRecord("SharedStruct", null, TEST_NAMESPACE, false); + sharedStruct.setFields(Arrays.asList( + new Schema.Field("value", Schema.create(Schema.Type.LONG), null, null) + )); - @Test - public void testListWithNullableElementsConversion() { - assertFieldRoundTrips("ListNullableElements", "listField", - () -> Schema.createArray(Schema.createUnion(Arrays.asList( - Schema.create(Schema.Type.NULL), - Schema.create(Schema.Type.STRING) - ))), - schema -> { - @SuppressWarnings("unchecked") - GenericData.Array listValue = new GenericData.Array<>(3, schema); - listValue.add(new Utf8("a")); - listValue.add(null); - listValue.add(new Utf8("c")); - return listValue; - }, - value -> assertEquals(Arrays.asList("a", null, "c"), normalizeValue(value)) - ); - } + Schema listSchema = Schema.createArray(sharedStruct); - @Test - public void testListOfRecordsConversion() { - String avroSchemaJson = "{\n" - + " \"type\": \"record\",\n" - + " \"name\": \"ListRecordContainer\",\n" - + " \"namespace\": \"" + TEST_NAMESPACE + "\",\n" - + " \"fields\": [\n" - + " {\n" - + " \"name\": \"listField\",\n" - + " \"type\": {\n" - + " \"type\": \"array\",\n" - + " \"items\": {\n" - + " \"type\": \"record\",\n" - + " \"name\": \"ListRecordEntry\",\n" - + " \"fields\": [\n" - + " {\"name\": \"innerString\", \"type\": \"string\"},\n" - + " {\"name\": \"innerInt\", \"type\": \"int\"}\n" - + " ]\n" - + " }\n" - + " }\n" - + " }\n" - + " ]\n" - + "}\n"; - - Schema avroSchema = new Schema.Parser().parse(avroSchemaJson); - GenericRecord avroRecord = new GenericData.Record(avroSchema); + Schema parent = Schema.createRecord("SharedStructReuseRoot", null, TEST_NAMESPACE, false); + parent.setFields(Arrays.asList( + new Schema.Field("directField", sharedStruct, null, null), + new Schema.Field("listField", listSchema, null, null) + )); - Schema listFieldSchema = avroSchema.getField("listField").schema(); - Schema listEntrySchema = listFieldSchema.getElementType(); + GenericRecord directValue = new GenericData.Record(sharedStruct); + directValue.put("value", 1L); @SuppressWarnings("unchecked") - GenericData.Array listValue = new GenericData.Array<>(2, listFieldSchema); - - GenericRecord firstEntry = new GenericData.Record(listEntrySchema); - firstEntry.put("innerString", new Utf8("first")); - firstEntry.put("innerInt", 1); - listValue.add(firstEntry); + GenericData.Array listValue = new GenericData.Array<>(2, listSchema); + GenericRecord listEntry1 = new GenericData.Record(sharedStruct); + listEntry1.put("value", 2L); + listValue.add(listEntry1); + GenericRecord listEntry2 = new GenericData.Record(sharedStruct); + listEntry2.put("value", 3L); + listValue.add(listEntry2); - GenericRecord secondEntry = new GenericData.Record(listEntrySchema); - secondEntry.put("innerString", new Utf8("second")); - secondEntry.put("innerInt", 2); - listValue.add(secondEntry); + GenericRecord parentRecord = new GenericData.Record(parent); + parentRecord.put("directField", directValue); + parentRecord.put("listField", listValue); - avroRecord.put("listField", listValue); + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(parent); + Record icebergRecord = new RecordBinder(icebergSchema, parent).bind(parentRecord); - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - Record icebergRecord = new RecordBinder(icebergSchema, avroSchema) - .bind(serializeAndDeserialize(avroRecord, avroSchema)); + Record directRecord = (Record) icebergRecord.getField("directField"); + assertEquals(1L, directRecord.getField("value")); @SuppressWarnings("unchecked") List boundList = (List) icebergRecord.getField("listField"); assertEquals(2, boundList.size()); - assertEquals("first", boundList.get(0).getField("innerString").toString()); - assertEquals(1, boundList.get(0).getField("innerInt")); - assertEquals("second", boundList.get(1).getField("innerString").toString()); - assertEquals(2, boundList.get(1).getField("innerInt")); - - testSendRecord(icebergSchema, icebergRecord); + assertEquals(2L, boundList.get(0).getField("value")); + assertEquals(3L, boundList.get(1).getField("value")); } + /** + * Tests that structs with the same full name but different schemas in different contexts + * (direct field vs list element) are handled correctly using IdentityHashMap. + * This ensures schema identity, not name equality, is used for binder lookup. + */ @Test public void testStructBindersHandleDuplicateFullNames() { Schema directStruct = Schema.createRecord("DuplicatedStruct", null, TEST_NAMESPACE, false); @@ -697,8 +137,7 @@ public void testStructBindersHandleDuplicateFullNames() { parentRecord.put("listField", listValue); org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(parent); - Record icebergRecord = new RecordBinder(icebergSchema, parent) - .bind(serializeAndDeserialize(parentRecord, parent)); + Record icebergRecord = new RecordBinder(icebergSchema, parent).bind(parentRecord); Record directField = (Record) icebergRecord.getField("directField"); assertEquals("direct", directField.getField("directOnly").toString()); @@ -709,6 +148,10 @@ public void testStructBindersHandleDuplicateFullNames() { assertEquals(42, boundList.get(0).getField("listOnly")); } + /** + * Tests duplicate struct names in map values context. + * Verifies IdentityHashMap correctly distinguishes between schemas with same name. + */ @Test public void testStructBindersHandleDuplicateFullNamesInMapValues() { Schema directStruct = Schema.createRecord("DuplicatedStruct", null, TEST_NAMESPACE, false); @@ -741,8 +184,7 @@ public void testStructBindersHandleDuplicateFullNamesInMapValues() { parentRecord.put("mapField", mapValue); org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(parent); - Record icebergRecord = new RecordBinder(icebergSchema, parent) - .bind(serializeAndDeserialize(parentRecord, parent)); + Record icebergRecord = new RecordBinder(icebergSchema, parent).bind(parentRecord); Record directField = (Record) icebergRecord.getField("directField"); assertEquals("direct", directField.getField("directOnly").toString()); @@ -750,9 +192,13 @@ public void testStructBindersHandleDuplicateFullNamesInMapValues() { @SuppressWarnings("unchecked") Map boundMap = (Map) icebergRecord.getField("mapField"); assertEquals(1, boundMap.size()); - assertEquals(123L, boundMap.get(new Utf8("key")).getField("mapOnly")); + assertEquals(123L, boundMap.get("key").getField("mapOnly")); } + /** + * Tests that AvroValueAdapter throws IllegalStateException when trying to convert + * a struct with missing fields in the source Avro schema. + */ @Test public void testConvertStructThrowsWhenSourceFieldMissing() { Schema nestedSchema = Schema.createRecord("NestedRecord", null, TEST_NAMESPACE, false); @@ -775,531 +221,12 @@ public void testConvertStructThrowsWhenSourceFieldMissing() { assertTrue(exception.getMessage().contains("NestedRecord")); } - @Test - public void testNestedStructsBindRecursively() { - Schema innerStruct = Schema.createRecord("InnerStruct", null, TEST_NAMESPACE, false); - innerStruct.setFields(Arrays.asList( - new Schema.Field("innerField", Schema.create(Schema.Type.INT), null, null) - )); - - Schema middleStruct = Schema.createRecord("MiddleStruct", null, TEST_NAMESPACE, false); - middleStruct.setFields(Arrays.asList( - new Schema.Field("middleField", Schema.create(Schema.Type.STRING), null, null), - new Schema.Field("inner", innerStruct, null, null) - )); - - Schema outerStruct = Schema.createRecord("OuterStruct", null, TEST_NAMESPACE, false); - outerStruct.setFields(Arrays.asList( - new Schema.Field("outerField", Schema.create(Schema.Type.STRING), null, null), - new Schema.Field("middle", middleStruct, null, null) - )); - - GenericRecord innerRecord = new GenericData.Record(innerStruct); - innerRecord.put("innerField", 7); - - GenericRecord middleRecord = new GenericData.Record(middleStruct); - middleRecord.put("middleField", new Utf8("mid")); - middleRecord.put("inner", innerRecord); - - GenericRecord outerRecord = new GenericData.Record(outerStruct); - outerRecord.put("outerField", new Utf8("out")); - outerRecord.put("middle", middleRecord); - - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(outerStruct); - Record icebergRecord = new RecordBinder(icebergSchema, outerStruct) - .bind(serializeAndDeserialize(outerRecord, outerStruct)); - - Record middleResult = (Record) icebergRecord.getField("middle"); - assertEquals("mid", middleResult.getField("middleField").toString()); - Record innerResult = (Record) middleResult.getField("inner"); - assertEquals(7, innerResult.getField("innerField")); - } - - @Test - public void testStructSchemaInstanceReuseSharesBinder() { - Schema sharedStruct = Schema.createRecord("SharedStruct", null, TEST_NAMESPACE, false); - sharedStruct.setFields(Arrays.asList( - new Schema.Field("value", Schema.create(Schema.Type.LONG), null, null) - )); - - Schema listSchema = Schema.createArray(sharedStruct); - - Schema parent = Schema.createRecord("SharedStructReuseRoot", null, TEST_NAMESPACE, false); - parent.setFields(Arrays.asList( - new Schema.Field("directField", sharedStruct, null, null), - new Schema.Field("listField", listSchema, null, null) - )); - - GenericRecord directValue = new GenericData.Record(sharedStruct); - directValue.put("value", 1L); - - @SuppressWarnings("unchecked") - GenericData.Array listValue = new GenericData.Array<>(2, listSchema); - GenericRecord listEntry1 = new GenericData.Record(sharedStruct); - listEntry1.put("value", 2L); - listValue.add(listEntry1); - GenericRecord listEntry2 = new GenericData.Record(sharedStruct); - listEntry2.put("value", 3L); - listValue.add(listEntry2); - - GenericRecord parentRecord = new GenericData.Record(parent); - parentRecord.put("directField", directValue); - parentRecord.put("listField", listValue); - - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(parent); - Record icebergRecord = new RecordBinder(icebergSchema, parent) - .bind(serializeAndDeserialize(parentRecord, parent)); - - Record directRecord = (Record) icebergRecord.getField("directField"); - assertEquals(1L, directRecord.getField("value")); - - @SuppressWarnings("unchecked") - List boundList = (List) icebergRecord.getField("listField"); - assertEquals(2, boundList.size()); - assertEquals(2L, boundList.get(0).getField("value")); - assertEquals(3L, boundList.get(1).getField("value")); - } - - // Test method for converting a map field - @Test - public void testStringMapConversion() { - Map map = new HashMap<>(); - map.put("key1", "value1"); - map.put("key2", "value2"); - assertFieldRoundTrips("StringMap", "mapField", - () -> Schema.createMap(Schema.create(Schema.Type.STRING)), - schema -> new HashMap<>(map), - value -> assertEquals(map, normalizeValue(value)) - ); - } - - @Test - public void testMapWithRecordValuesConversion() { - String avroSchemaJson = "{\n" - + " \"type\": \"record\",\n" - + " \"name\": \"MapRecordContainer\",\n" - + " \"namespace\": \"" + TEST_NAMESPACE + "\",\n" - + " \"fields\": [\n" - + " {\n" - + " \"name\": \"mapField\",\n" - + " \"type\": {\n" - + " \"type\": \"map\",\n" - + " \"values\": {\n" - + " \"type\": \"record\",\n" - + " \"name\": \"MapValueRecord\",\n" - + " \"fields\": [\n" - + " {\"name\": \"innerString\", \"type\": \"string\"},\n" - + " {\"name\": \"innerLong\", \"type\": \"long\"}\n" - + " ]\n" - + " }\n" - + " }\n" - + " }\n" - + " ]\n" - + "}\n"; - - Schema avroSchema = new Schema.Parser().parse(avroSchemaJson); - GenericRecord avroRecord = new GenericData.Record(avroSchema); - - Schema mapFieldSchema = avroSchema.getField("mapField").schema(); - Schema mapValueSchema = mapFieldSchema.getValueType(); - - Map mapValue = new HashMap<>(); - GenericRecord firstValue = new GenericData.Record(mapValueSchema); - firstValue.put("innerString", new Utf8("first")); - firstValue.put("innerLong", 10L); - mapValue.put("key1", firstValue); - - GenericRecord secondValue = new GenericData.Record(mapValueSchema); - secondValue.put("innerString", new Utf8("second")); - secondValue.put("innerLong", 20L); - mapValue.put("key2", secondValue); - - avroRecord.put("mapField", mapValue); - - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - Record icebergRecord = new RecordBinder(icebergSchema, avroSchema) - .bind(serializeAndDeserialize(avroRecord, avroSchema)); - - Map boundMap = normalizeMapValues(icebergRecord.getField("mapField")); - assertEquals(2, boundMap.size()); - - Record key1Record = (Record) boundMap.get(new Utf8("key1")); - assertEquals("first", key1Record.getField("innerString").toString()); - assertEquals(10L, key1Record.getField("innerLong")); - - Record key2Record = (Record) boundMap.get(new Utf8("key2")); - assertEquals("second", key2Record.getField("innerString").toString()); - assertEquals(20L, key2Record.getField("innerLong")); - - testSendRecord(icebergSchema, icebergRecord); - } - - // Test method for converting a map field - @Test - public void testIntMapConversion() { - Map map = new HashMap<>(); - map.put("key1", 1); - map.put("key2", 2); - assertFieldRoundTrips("IntMap", "mapField", - () -> Schema.createMap(Schema.create(Schema.Type.INT)), - schema -> new HashMap<>(map), - value -> assertEquals(map, normalizeValue(value)) - ); - } - - // Test method for converting a map field with non-string keys - // Maps with non-string keys must use an array representation with the map logical type. - // The array representation or Avro’s map type may be used for maps with string keys. - @Test - public void testMapWithNonStringKeysConversion() { - // Define Avro schema - String avroSchemaStr = " {\n" + - " \"type\": \"record\",\n" + - " \"name\": \"TestRecord\",\n" + - " \"fields\": [\n" + - " {\n" + - " \"name\": \"mapField\",\n" + - " \"type\": {\n" + - " \"type\": \"array\",\n" + - " \"logicalType\": \"map\",\n" + - " \"items\": {\n" + - " \"type\": \"record\",\n" + - " \"name\": \"MapEntry\",\n" + - " \"fields\": [\n" + - " {\"name\": \"key\", \"type\": \"int\"},\n" + - " {\"name\": \"value\", \"type\": \"string\"}\n" + - " ]\n" + - " }\n" + - " }\n" + - " }\n" + - " ]\n" + - " }\n"; - avroSchema = new Schema.Parser().parse(avroSchemaStr); - // Create Avro record - Map expectedMap = new HashMap<>(); - expectedMap.put(1, "value1"); - expectedMap.put(2, "value2"); - expectedMap.put(3, "value3"); - - GenericRecord avroRecord = new GenericData.Record(avroSchema); - List mapEntries = new ArrayList<>(); - for (Map.Entry entry : expectedMap.entrySet()) { - GenericRecord mapEntry = new GenericData.Record(avroSchema.getField("mapField").schema().getElementType()); - mapEntry.put("key", entry.getKey()); - mapEntry.put("value", entry.getValue()); - mapEntries.add(mapEntry); - } - avroRecord.put("mapField", mapEntries); - - // Convert Avro record to Iceberg record using the wrapper - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - Record icebergRecord = new RecordBinder(icebergSchema, avroSchema).bind(serializeAndDeserialize(avroRecord, avroSchema)); - - // Convert the list of records back to a map - Map mapField = normalizeMapValues(icebergRecord.getField("mapField")); - // Verify the field value - assertEquals(expectedMap, mapField); - - // Send the record to the table - testSendRecord(icebergSchema, icebergRecord); - } - - @Test - public void testMapWithNullableValuesConversion() { - Map expectedMap = new HashMap<>(); - expectedMap.put("key1", "value1"); - expectedMap.put("key2", null); - - assertFieldRoundTrips("NullableValueMap", "mapField", - () -> Schema.createMap(Schema.createUnion(Arrays.asList( - Schema.create(Schema.Type.NULL), - Schema.create(Schema.Type.STRING) - ))), - schema -> new HashMap<>(expectedMap), - value -> assertEquals(expectedMap, normalizeValue(value)) - ); - } - - // Test method for converting a record with nested fields - @Test - public void testNestedRecordConversion() { - // Define Avro schema - String avroSchemaStr = " {\n" + - " \"type\": \"record\",\n" + - " \"name\": \"TestRecord\",\n" + - " \"fields\": [\n" + - " {\n" + - " \"name\": \"nestedField\",\n" + - " \"type\": {\n" + - " \"type\": \"record\",\n" + - " \"name\": \"NestedRecord\",\n" + - " \"fields\": [\n" + - " {\"name\": \"nestedStringField\", \"type\": \"string\"},\n" + - " {\"name\": \"nestedIntField\", \"type\": \"int\"}\n" + - " ]\n" + - " }\n" + - " }\n" + - " ]\n" + - " }\n"; - avroSchema = new Schema.Parser().parse(avroSchemaStr); - // Create Avro record - GenericRecord nestedRecord = new GenericData.Record(avroSchema.getField("nestedField").schema()); - nestedRecord.put("nestedStringField", "nested_string"); - nestedRecord.put("nestedIntField", 42); - GenericRecord avroRecord = new GenericData.Record(avroSchema); - avroRecord.put("nestedField", nestedRecord); - - // Convert Avro record to Iceberg record using the wrapper - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - Record icebergRecord = new RecordBinder(icebergSchema, avroSchema).bind(serializeAndDeserialize(avroRecord, avroSchema)); - - // Verify the field values - Record nestedIcebergRecord = (Record) icebergRecord.getField("nestedField"); - assertEquals("nested_string", nestedIcebergRecord.getField("nestedStringField").toString()); - assertEquals(42, nestedIcebergRecord.getField("nestedIntField")); - - // Send the record to the table - testSendRecord(icebergSchema, icebergRecord); - } - - // Test method for converting a record with optional fields - // Optional fields must always set the Avro field default value to null. - @Test - public void testOptionalFieldConversion() { - // Define Avro schema - String avroSchemaStr = " {\n" + - " \"type\": \"record\",\n" + - " \"name\": \"TestRecord\",\n" + - " \"fields\": [\n" + - " {\"name\": \"optionalStringField\", \"type\": [\"null\", \"string\"], \"default\": null},\n" + - " {\"name\": \"optionalIntField\", \"type\": [\"null\", \"int\"], \"default\": null},\n" + - " {\"name\": \"optionalStringNullField\", \"type\": [\"null\", \"string\"], \"default\": null},\n" + - " {\"name\": \"optionalIntNullField\", \"type\": [\"null\", \"int\"], \"default\": null}\n" + - " ]\n" + - " }\n"; - avroSchema = new Schema.Parser().parse(avroSchemaStr); - // Create Avro record - GenericRecord avroRecord = new GenericData.Record(avroSchema); - avroRecord.put("optionalStringField", "optional_string"); - avroRecord.put("optionalIntField", 42); - avroRecord.put("optionalStringNullField", null); - avroRecord.put("optionalIntNullField", null); - - // Convert Avro record to Iceberg record using the wrapper - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - Record icebergRecord = new RecordBinder(icebergSchema, avroSchema).bind(serializeAndDeserialize(avroRecord, avroSchema)); - - // Verify the field values - assertEquals("optional_string", icebergRecord.getField("optionalStringField").toString()); - assertEquals(42, icebergRecord.getField("optionalIntField")); - assertNull(icebergRecord.getField("optionalStringNullField")); - assertNull(icebergRecord.getField("optionalIntNullField")); - - // Send the record to the table - testSendRecord(icebergSchema, icebergRecord); - } - - // Test method for converting a record with default values - @Test - public void testDefaultFieldConversion() { - // Define Avro schema - String avroSchemaStr = " {\n" + - " \"type\": \"record\",\n" + - " \"name\": \"TestRecord\",\n" + - " \"fields\": [\n" + - " {\"name\": \"defaultStringField\", \"type\": \"string\", \"default\": \"default_string\"},\n" + - " {\"name\": \"defaultIntField\", \"type\": \"int\", \"default\": 42}\n" + - " ]\n" + - " }\n"; - avroSchema = new Schema.Parser().parse(avroSchemaStr); - // Create Avro record - GenericRecord avroRecord = new GenericData.Record(avroSchema); - Schema.Field defaultStringField = avroSchema.getField("defaultStringField"); - Schema.Field defaultIntField = avroSchema.getField("defaultIntField"); - avroRecord.put("defaultStringField", defaultStringField.defaultVal()); - avroRecord.put("defaultIntField", defaultIntField.defaultVal()); - - // Convert Avro record to Iceberg record using the wrapper - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - Record icebergRecord = new RecordBinder(icebergSchema, avroSchema).bind(serializeAndDeserialize(avroRecord, avroSchema)); - - // Verify the field values - assertEquals("default_string", icebergRecord.getField("defaultStringField").toString()); - assertEquals(42, icebergRecord.getField("defaultIntField")); - - // Send the record to the table - testSendRecord(icebergSchema, icebergRecord); - } - - // Test method for converting a record with union fields - // Optional fields, array elements, and map values must be wrapped in an Avro union with null. - // This is the only union type allowed in Iceberg data files. - @Test - public void testUnionFieldConversion() { - // Define Avro schema - String avroSchemaStr = " {\n" + - " \"type\": \"record\",\n" + - " \"name\": \"TestRecord\",\n" + - " \"fields\": [\n" + - " {\n" + - " \"name\": \"unionField1\",\n" + - " \"type\": [\"null\", \"string\"]\n" + - " },\n" + - " {\n" + - " \"name\": \"unionField2\",\n" + - " \"type\": [\"null\", \"int\"]\n" + - " },\n" + - " {\n" + - " \"name\": \"unionField3\",\n" + - " \"type\": [\"null\", \"boolean\"]\n" + - " },\n" + - " {\n" + - " \"name\": \"unionField4\",\n" + - " \"type\": [\"null\", \"string\"]\n" + - " },\n" + - " {\n" + - " \"name\": \"unionListField\",\n" + - " \"type\": [\n" + - " \"null\",\n" + - " {\n" + - " \"type\": \"array\",\n" + - " \"items\": \"string\"\n" + - " }\n" + - " ]\n" + - " },\n" + - " {\n" + - " \"name\": \"unionMapField\",\n" + - " \"type\": [\n" + - " \"null\",\n" + - " {\n" + - " \"type\": \"map\",\n" + - " \"values\": \"int\"\n" + - " }\n" + - " ]\n" + - " },\n" + - " {\n" + - " \"name\": \"unionStructField\",\n" + - " \"type\": [\n" + - " \"null\",\n" + - " {\n" + - " \"type\": \"record\",\n" + - " \"name\": \"UnionStruct\",\n" + - " \"fields\": [\n" + - " {\"name\": \"innerString\", \"type\": \"string\"},\n" + - " {\"name\": \"innerInt\", \"type\": \"int\"}\n" + - " ]\n" + - " }\n" + - " ]\n" + - " }\n" + - " ]\n" + - " }\n"; - avroSchema = new Schema.Parser().parse(avroSchemaStr); - // Create Avro record - GenericRecord avroRecord = new GenericData.Record(avroSchema); - avroRecord.put("unionField1", "union_string"); - avroRecord.put("unionField2", 42); - avroRecord.put("unionField3", true); - List unionList = Arrays.asList("item1", "item2"); - avroRecord.put("unionListField", unionList); - Map unionMap = new HashMap<>(); - unionMap.put("one", 1); - unionMap.put("two", 2); - avroRecord.put("unionMapField", unionMap); - Schema unionStructSchema = avroSchema.getField("unionStructField").schema().getTypes().get(1); - GenericRecord unionStruct = new GenericData.Record(unionStructSchema); - unionStruct.put("innerString", "nested"); - unionStruct.put("innerInt", 99); - avroRecord.put("unionStructField", unionStruct); - - // Convert Avro record to Iceberg record using the wrapper - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - Record icebergRecord = new RecordBinder(icebergSchema, avroSchema).bind(serializeAndDeserialize(avroRecord, avroSchema)); - - // Verify the field value - Object unionField1 = icebergRecord.getField("unionField1"); - assertEquals("union_string", unionField1.toString()); - - Object unionField2 = icebergRecord.getField("unionField2"); - assertEquals(42, unionField2); - - Object unionField3 = icebergRecord.getField("unionField3"); - assertEquals(true, unionField3); - - assertNull(icebergRecord.getField("unionField4")); - - assertEquals(unionList, normalizeValue(icebergRecord.getField("unionListField"))); - assertEquals(unionMap, normalizeValue(icebergRecord.getField("unionMapField"))); - - Record unionStructRecord = (Record) icebergRecord.getField("unionStructField"); - assertEquals("nested", unionStructRecord.getField("innerString").toString()); - assertEquals(99, unionStructRecord.getField("innerInt")); - - // Send the record to the table - testSendRecord(icebergSchema, icebergRecord); - } - - @Test - public void testBindWithNestedOptionalRecord() { - // Schema representing a record with an optional nested record field, similar to Debezium envelopes. - String avroSchemaJson = "{\n" + - " \"type\": \"record\",\n" + - " \"name\": \"Envelope\",\n" + - " \"namespace\": \"inventory.inventory.customers\",\n" + - " \"fields\": [\n" + - " {\n" + - " \"name\": \"before\",\n" + - " \"type\": [\n" + - " \"null\",\n" + - " {\n" + - " \"type\": \"record\",\n" + - " \"name\": \"Value\",\n" + - " \"fields\": [\n" + - " { \"name\": \"id\", \"type\": \"int\" },\n" + - " { \"name\": \"first_name\", \"type\": \"string\" }\n" + - " ]\n" + - " }\n" + - " ],\n" + - " \"default\": null\n" + - " }\n" + - " ]\n" + - "}"; - - Schema avroSchema = new Schema.Parser().parse(avroSchemaJson); - - // Corresponding Iceberg Schema - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - - // This binder will recursively create a nested binder for the 'before' field. - // The nested binder will receive a UNION schema, which is what our fix addresses. - RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); - - // --- Test Case 1: Nested record is present --- - Schema valueSchema = avroSchema.getField("before").schema().getTypes().get(1); - GenericRecord valueRecord = new GenericData.Record(valueSchema); - valueRecord.put("id", 101); - valueRecord.put("first_name", "John"); - - GenericRecord envelopeRecord = new GenericData.Record(avroSchema); - envelopeRecord.put("before", valueRecord); - - Record boundRecord = recordBinder.bind(envelopeRecord); - Record nestedBoundRecord = (Record) boundRecord.getField("before"); - - assertEquals(101, nestedBoundRecord.getField("id")); - assertEquals("John", nestedBoundRecord.getField("first_name")); - - // --- Test Case 2: Nested record is null --- - GenericRecord envelopeRecordWithNull = new GenericData.Record(avroSchema); - envelopeRecordWithNull.put("before", null); - - Record boundRecordWithNull = recordBinder.bind(envelopeRecordWithNull); - assertNull(boundRecordWithNull.getField("before")); - } - - // Test method for field count statistics + /** + * Tests field count statistics for various field types and sizes. + * Verifies that small/large strings, binary fields, and primitives are counted correctly. + */ @Test public void testFieldCountStatistics() { - // Test different field types and their count calculations String avroSchemaStr = "{\n" + " \"type\": \"record\",\n" + " \"name\": \"TestRecord\",\n" + @@ -1316,7 +243,6 @@ public void testFieldCountStatistics() { org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); - // Create test record with different field sizes GenericRecord avroRecord = new GenericData.Record(avroSchema); avroRecord.put("smallString", "small"); // 5 chars = 3 field avroRecord.put("largeString", "a".repeat(50)); // 50 chars = 3 + 50/32 = 4 @@ -1324,7 +250,6 @@ public void testFieldCountStatistics() { avroRecord.put("binaryField", ByteBuffer.wrap("test".repeat(10).getBytes())); // 5 avroRecord.put("optionalStringField", "optional"); - // Bind record - this should trigger field counting Record icebergRecord = recordBinder.bind(avroRecord); // Access all fields to trigger counting @@ -1339,14 +264,14 @@ public void testFieldCountStatistics() { // Second call should return 0 (reset) assertEquals(0, recordBinder.getAndResetFieldCount()); - - testSendRecord(icebergSchema.asStruct().asSchema(), icebergRecord); - assertEquals(16, recordBinder.getAndResetFieldCount()); } + /** + * Tests field counting for complex types (LIST and MAP). + * Verifies that list and map elements are counted correctly. + */ @Test public void testFieldCountWithComplexTypes() { - // Test field counting for LIST and MAP types String avroSchemaStr = "{\n" + " \"type\": \"record\",\n" + " \"name\": \"ComplexRecord\",\n" + @@ -1361,10 +286,8 @@ public void testFieldCountWithComplexTypes() { RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); GenericRecord avroRecord = new GenericData.Record(avroSchema); - // List with 3 small strings: 1 (list itself) + 3 * 3 * 1 = 10 fields avroRecord.put("stringList", Arrays.asList("a", "b", "c")); - // Map with 2 entries: 1 (map itself) + 2 * (3 key + 3 value) = 13 fields Map map = new HashMap<>(); map.put("key1", "val1"); map.put("key2", "val2"); @@ -1373,20 +296,20 @@ public void testFieldCountWithComplexTypes() { Record icebergRecord = recordBinder.bind(avroRecord); // Access fields to trigger counting - assertEquals(Arrays.asList("a", "b", "c"), normalizeValue(icebergRecord.getField("stringList"))); - assertEquals(map, normalizeValue(icebergRecord.getField("stringMap"))); + icebergRecord.getField("stringList"); + icebergRecord.getField("stringMap"); // Total: 10 (list) + 13 (map) = 23 fields long fieldCount = recordBinder.getAndResetFieldCount(); assertEquals(23, fieldCount); - - testSendRecord(icebergSchema.asStruct().asSchema(), icebergRecord); - assertEquals(23, recordBinder.getAndResetFieldCount()); } + /** + * Tests field counting for nested struct fields. + * Verifies that nested struct fields contribute to the count correctly. + */ @Test public void testFieldCountWithNestedStructure() { - // Test field counting for nested records String avroSchemaStr = "{\n" + " \"type\": \"record\",\n" + " \"name\": \"NestedRecord\",\n" + @@ -1410,7 +333,6 @@ public void testFieldCountWithNestedStructure() { org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); - // Create nested record GenericRecord nestedRecord = new GenericData.Record(avroSchema.getField("nestedField").schema()); nestedRecord.put("nestedString", "nested"); nestedRecord.put("nestedInt", 123); @@ -1428,17 +350,16 @@ public void testFieldCountWithNestedStructure() { assertEquals(123, nested.getField("nestedInt")); // Total: 3 (simple) + 1(struct) + 3 (nested string) + 1 (nested int) = 8 fields - // Note: STRUCT type itself doesn't add to count, only its leaf fields long fieldCount = recordBinder.getAndResetFieldCount(); assertEquals(8, fieldCount); - - testSendRecord(icebergSchema.asStruct().asSchema(), icebergRecord); - assertEquals(8, recordBinder.getAndResetFieldCount()); } + /** + * Tests that field counts accumulate across multiple record bindings. + * Verifies batch processing statistics. + */ @Test public void testFieldCountBatchAccumulation() { - // Test that field counts accumulate across multiple record bindings String avroSchemaStr = "{\n" + " \"type\": \"record\",\n" + " \"name\": \"SimpleRecord\",\n" + @@ -1455,8 +376,8 @@ public void testFieldCountBatchAccumulation() { // Process multiple records for (int i = 0; i < 3; i++) { GenericRecord avroRecord = new GenericData.Record(avroSchema); - avroRecord.put("stringField", "test" + i); // 1 field each - avroRecord.put("intField", i); // 1 field each + avroRecord.put("stringField", "test" + i); + avroRecord.put("intField", i); Record icebergRecord = recordBinder.bind(avroRecord); // Access fields to trigger counting @@ -1469,9 +390,11 @@ public void testFieldCountBatchAccumulation() { assertEquals(12, totalFieldCount); } + /** + * Tests that null values don't contribute to field count. + */ @Test public void testFieldCountWithNullValues() { - // Test that null values don't contribute to field count String avroSchemaStr = "{\n" + " \"type\": \"record\",\n" + " \"name\": \"NullableRecord\",\n" + @@ -1486,8 +409,8 @@ public void testFieldCountWithNullValues() { RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); GenericRecord avroRecord = new GenericData.Record(avroSchema); - avroRecord.put("nonNullField", "value"); // 1 field - avroRecord.put("nullField", null); // 0 fields + avroRecord.put("nonNullField", "value"); + avroRecord.put("nullField", null); Record icebergRecord = recordBinder.bind(avroRecord); @@ -1498,11 +421,11 @@ public void testFieldCountWithNullValues() { // Only the non-null field should count long fieldCount = recordBinder.getAndResetFieldCount(); assertEquals(3, fieldCount); - - testSendRecord(icebergSchema.asStruct().asSchema(), icebergRecord); - assertEquals(3, recordBinder.getAndResetFieldCount()); } + /** + * Tests field counting for optional union fields with both null and non-null values. + */ @Test public void testFieldCountWithUnionFields() { String avroSchemaStr = "{\n" + @@ -1517,14 +440,15 @@ public void testFieldCountWithUnionFields() { org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); + // Test with non-null value GenericRecord nonNullRecord = new GenericData.Record(avroSchema); nonNullRecord.put("optionalString", "value"); Record icebergRecord = recordBinder.bind(nonNullRecord); assertEquals("value", icebergRecord.getField("optionalString").toString()); - assertEquals(3, recordBinder.getAndResetFieldCount()); + // Test with null value GenericRecord nullRecord = new GenericData.Record(avroSchema); nullRecord.put("optionalString", null); @@ -1532,4 +456,168 @@ public void testFieldCountWithUnionFields() { assertNull(nullIcebergRecord.getField("optionalString")); assertEquals(0, recordBinder.getAndResetFieldCount()); } + + /** + * Tests that binding a null GenericRecord returns null. + */ + @Test + public void testBindNullRecordReturnsNull() { + Schema avroSchema = Schema.createRecord("TestRecord", null, TEST_NAMESPACE, false); + avroSchema.setFields(Arrays.asList( + new Schema.Field("field", Schema.create(Schema.Type.STRING), null, null) + )); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); + + Record result = recordBinder.bind(null); + assertNull(result); + } + + /** + * Tests that accessing a field with negative position throws IndexOutOfBoundsException. + */ + @Test + public void testGetFieldWithNegativePositionThrowsException() { + Schema avroSchema = Schema.createRecord("TestRecord", null, TEST_NAMESPACE, false); + avroSchema.setFields(Arrays.asList( + new Schema.Field("field", Schema.create(Schema.Type.STRING), null, null) + )); + + GenericRecord avroRecord = new GenericData.Record(avroSchema); + avroRecord.put("field", new Utf8("value")); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); + Record icebergRecord = recordBinder.bind(avroRecord); + + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, + () -> icebergRecord.get(-1)); + assertTrue(exception.getMessage().contains("out of bounds")); + } + + /** + * Tests that accessing a field with position >= size throws IndexOutOfBoundsException. + */ + @Test + public void testGetFieldWithExcessivePositionThrowsException() { + Schema avroSchema = Schema.createRecord("TestRecord", null, TEST_NAMESPACE, false); + avroSchema.setFields(Arrays.asList( + new Schema.Field("field", Schema.create(Schema.Type.STRING), null, null) + )); + + GenericRecord avroRecord = new GenericData.Record(avroSchema); + avroRecord.put("field", new Utf8("value")); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); + Record icebergRecord = recordBinder.bind(avroRecord); + + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class, + () -> icebergRecord.get(999)); + assertTrue(exception.getMessage().contains("out of bounds")); + } + + /** + * Tests that accessing a field by an unknown name returns null. + */ + @Test + public void testGetFieldByUnknownNameReturnsNull() { + Schema avroSchema = Schema.createRecord("TestRecord", null, TEST_NAMESPACE, false); + avroSchema.setFields(Arrays.asList( + new Schema.Field("existingField", Schema.create(Schema.Type.STRING), null, null) + )); + + GenericRecord avroRecord = new GenericData.Record(avroSchema); + avroRecord.put("existingField", new Utf8("value")); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); + Record icebergRecord = recordBinder.bind(avroRecord); + + assertNull(icebergRecord.getField("nonExistentField")); + } + + /** + * Tests that a UNION containing only NULL type throws IllegalArgumentException. + */ + @Test + public void testUnionWithOnlyNullThrowsException() { + Schema nullOnlyUnion = Schema.createUnion(Arrays.asList(Schema.create(Schema.Type.NULL))); + + Schema avroSchema = Schema.createRecord("TestRecord", null, TEST_NAMESPACE, false); + avroSchema.setFields(Arrays.asList( + new Schema.Field("nullField", nullOnlyUnion, null, null) + )); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> new RecordBinder(icebergSchema, avroSchema)); + assertTrue(exception.getMessage().contains("UNION schema contains only NULL type")); + } + + /** + * Tests that null elements in Map-as-Array representation are skipped. + */ + @Test + public void testMapAsArrayWithNullElementsSkipped() { + String avroSchemaStr = "{\n" + + " \"type\": \"record\",\n" + + " \"name\": \"MapAsArrayRecord\",\n" + + " \"fields\": [\n" + + " {\n" + + " \"name\": \"mapField\",\n" + + " \"type\": {\n" + + " \"type\": \"array\",\n" + + " \"logicalType\": \"map\",\n" + + " \"items\": {\n" + + " \"type\": \"record\",\n" + + " \"name\": \"MapEntry\",\n" + + " \"fields\": [\n" + + " {\"name\": \"key\", \"type\": \"string\"},\n" + + " {\"name\": \"value\", \"type\": \"int\"}\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + Schema avroSchema = new Schema.Parser().parse(avroSchemaStr); + Schema entrySchema = avroSchema.getField("mapField").schema().getElementType(); + + @SuppressWarnings("unchecked") + GenericData.Array arrayValue = new GenericData.Array<>(3, avroSchema.getField("mapField").schema()); + + // Add valid entry + GenericRecord entry1 = new GenericData.Record(entrySchema); + entry1.put("key", new Utf8("key1")); + entry1.put("value", 100); + arrayValue.add(entry1); + + // Add null entry (should be skipped) + arrayValue.add(null); + + // Add another valid entry + GenericRecord entry2 = new GenericData.Record(entrySchema); + entry2.put("key", new Utf8("key2")); + entry2.put("value", 200); + arrayValue.add(entry2); + + GenericRecord avroRecord = new GenericData.Record(avroSchema); + avroRecord.put("mapField", arrayValue); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + RecordBinder recordBinder = new RecordBinder(icebergSchema, avroSchema); + Record icebergRecord = recordBinder.bind(avroRecord); + + @SuppressWarnings("unchecked") + Map mapField = (Map) icebergRecord.getField("mapField"); + + // Should only contain 2 entries (null entry skipped) + assertEquals(2, mapField.size()); + assertEquals(100, mapField.get(new Utf8("key1"))); + assertEquals(200, mapField.get(new Utf8("key2"))); + } } diff --git a/core/src/test/java/kafka/automq/table/binder/AvroRecordBinderTypeTest.java b/core/src/test/java/kafka/automq/table/binder/AvroRecordBinderTypeTest.java new file mode 100644 index 0000000000..a8f6d34571 --- /dev/null +++ b/core/src/test/java/kafka/automq/table/binder/AvroRecordBinderTypeTest.java @@ -0,0 +1,1019 @@ +/* + * Copyright 2025, AutoMQ HK Limited. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.automq.table.binder; + +import com.google.common.collect.ImmutableMap; + +import org.apache.avro.Conversions; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.Encoder; +import org.apache.avro.io.EncoderFactory; +import org.apache.avro.specific.SpecificDatumReader; +import org.apache.avro.specific.SpecificDatumWriter; +import org.apache.avro.util.Utf8; +import org.apache.commons.lang3.RandomStringUtils; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.avro.CodecSetup; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.GenericAppenderFactory; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.inmemory.InMemoryCatalog; +import org.apache.iceberg.io.FileAppenderFactory; +import org.apache.iceberg.io.OutputFileFactory; +import org.apache.iceberg.io.TaskWriter; +import org.apache.iceberg.io.UnpartitionedWriter; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.DateTimeUtil; +import org.apache.iceberg.util.UUIDUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.MockitoAnnotations; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.iceberg.TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class AvroRecordBinderTypeTest { + + private static final String TEST_NAMESPACE = "kafka.automq.table.binder"; + + private InMemoryCatalog catalog; + private Table table; + private TaskWriter writer; + private int tableCounter; + + static { + CodecSetup.setup(); + } + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + catalog = new InMemoryCatalog(); + catalog.initialize("test", ImmutableMap.of()); + catalog.createNamespace(Namespace.of("default")); + tableCounter = 0; + } + + // Test method for converting a single string field + @Test + public void testStringConversion() { + assertFieldRoundTrips("String", "stringField", + () -> Schema.create(Schema.Type.STRING), + schema -> "test_string", + value -> assertEquals("test_string", value.toString()) + ); + } + + // Test method for converting a single integer field + @Test + public void testIntegerConversion() { + assertFieldRoundTrips("Int", "intField", + () -> Schema.create(Schema.Type.INT), + schema -> 42, + value -> assertEquals(42, value) + ); + } + + // Test method for converting a single long field + @Test + public void testLongConversion() { + assertFieldRoundTrips("Long", "longField", + () -> Schema.create(Schema.Type.LONG), + schema -> 123456789L, + value -> assertEquals(123456789L, value) + ); + } + + // Test method for converting a single float field + @Test + public void testFloatConversion() { + assertFieldRoundTrips("Float", "floatField", + () -> Schema.create(Schema.Type.FLOAT), + schema -> 3.14f, + value -> assertEquals(3.14f, (Float) value) + ); + } + + // Test method for converting a single double field + @Test + public void testDoubleConversion() { + assertFieldRoundTrips("Double", "doubleField", + () -> Schema.create(Schema.Type.DOUBLE), + schema -> 6.28, + value -> assertEquals(6.28, value) + ); + } + + // Test method for converting a single boolean field + @Test + public void testBooleanConversion() { + assertFieldRoundTrips("Boolean", "booleanField", + () -> Schema.create(Schema.Type.BOOLEAN), + schema -> true, + value -> assertEquals(true, value) + ); + } + + // Test method for converting a single date field (number of days from epoch) + @Test + public void testDateConversion() { + LocalDate localDate = LocalDate.of(2020, 1, 1); + int epochDays = (int) ChronoUnit.DAYS.between(LocalDate.ofEpochDay(0), localDate); + assertFieldRoundTrips("Date", "dateField", + () -> LogicalTypes.date().addToSchema(Schema.create(Schema.Type.INT)), + schema -> epochDays, + value -> assertEquals(localDate, value) + ); + } + + // Test method for converting a single time field (number of milliseconds from midnight) + @Test + public void testTimeConversion() { + LocalTime localTime = LocalTime.of(10, 0); + long epochMicros = localTime.toNanoOfDay() / 1000; + int epochMillis = (int) (localTime.toNanoOfDay() / 1_000_000); + assertFieldRoundTrips("TimeMicros", "timeField", + () -> LogicalTypes.timeMicros().addToSchema(Schema.create(Schema.Type.LONG)), + schema -> epochMicros, + value -> assertEquals(localTime, value) + ); + + assertFieldRoundTrips("TimeMillis", "timeField2", + () -> LogicalTypes.timeMillis().addToSchema(Schema.create(Schema.Type.INT)), + schema -> epochMillis, + value -> assertEquals(localTime, value) + ); + } + + // Test method for converting a single timestamp field (number of milliseconds from epoch) + // timestamp: Stores microseconds from 1970-01-01 00:00:00.000000. [1] + // timestamptz: Stores microseconds from 1970-01-01 00:00:00.000000 UTC. [1] + @Test + public void testTimestampConversion() { + Instant instant = Instant.parse("2020-01-01T12:34:56.123456Z"); + long timestampMicros = instant.getEpochSecond() * 1_000_000 + instant.getNano() / 1_000; + long timestampMillis = instant.toEpochMilli(); + + Supplier timestampMicrosTzSchema = () -> { + Schema schema = LogicalTypes.timestampMicros().addToSchema(Schema.create(Schema.Type.LONG)); + schema.addProp("adjust-to-utc", true); + return schema; + }; + + Supplier timestampMicrosSchema = () -> { + Schema schema = LogicalTypes.timestampMicros().addToSchema(Schema.create(Schema.Type.LONG)); + schema.addProp("adjust-to-utc", false); + return schema; + }; + + Supplier timestampMillisTzSchema = () -> { + Schema schema = LogicalTypes.timestampMillis().addToSchema(Schema.create(Schema.Type.LONG)); + schema.addProp("adjust-to-utc", true); + return schema; + }; + + Supplier timestampMillisSchema = () -> { + Schema schema = LogicalTypes.timestampMillis().addToSchema(Schema.create(Schema.Type.LONG)); + schema.addProp("adjust-to-utc", false); + return schema; + }; + + OffsetDateTime expectedMicrosTz = DateTimeUtil.timestamptzFromMicros(timestampMicros); + LocalDateTime expectedMicros = DateTimeUtil.timestampFromMicros(timestampMicros); + OffsetDateTime expectedMillisTz = DateTimeUtil.timestamptzFromMicros(timestampMillis * 1000); + LocalDateTime expectedMillis = DateTimeUtil.timestampFromMicros(timestampMillis * 1000); + + assertFieldRoundTrips("TimestampMicrosTz", "timestampField1", + timestampMicrosTzSchema, + schema -> timestampMicros, + value -> assertEquals(expectedMicrosTz, value) + ); + + assertFieldRoundTrips("TimestampMicros", "timestampField2", + timestampMicrosSchema, + schema -> timestampMicros, + value -> assertEquals(expectedMicros, value) + ); + + assertFieldRoundTrips("TimestampMillisTz", "timestampField3", + timestampMillisTzSchema, + schema -> timestampMillis, + value -> assertEquals(expectedMillisTz, value) + ); + + assertFieldRoundTrips("TimestampMillis", "timestampField4", + timestampMillisSchema, + schema -> timestampMillis, + value -> assertEquals(expectedMillis, value) + ); + } + + @Test + public void testLocalTimestampConversion() { + LocalDateTime localDateTime = LocalDateTime.of(2023, 6, 1, 8, 15, 30, 123456000); + long micros = DateTimeUtil.microsFromTimestamp(localDateTime); + long millis = DateTimeUtil.microsToMillis(micros); + + // For millis precision, we need to truncate to milliseconds + LocalDateTime localDateTimeMillis = DateTimeUtil.timestampFromMicros(millis * 1000); + + Supplier localTimestampMillisSchema = () -> + LogicalTypes.localTimestampMillis().addToSchema(Schema.create(Schema.Type.LONG)); + Supplier localTimestampMicrosSchema = () -> + LogicalTypes.localTimestampMicros().addToSchema(Schema.create(Schema.Type.LONG)); + + assertFieldRoundTrips("LocalTimestampMillis", "localTsMillis", + localTimestampMillisSchema, + schema -> millis, + value -> { + assertEquals(millis, value); + assertEquals(localDateTimeMillis, DateTimeUtil.timestampFromMicros(((Long) value) * 1000)); + } + ); + + assertFieldRoundTrips("LocalTimestampMicros", "localTsMicros", + localTimestampMicrosSchema, + schema -> micros, + value -> { + assertEquals(micros, value); + assertEquals(localDateTime, DateTimeUtil.timestampFromMicros((Long) value)); + } + ); + } + + // Test method for converting a single binary field + @Test + public void testBinaryConversion() { + String randomAlphabetic = RandomStringUtils.randomAlphabetic(64); + assertFieldRoundTrips("Binary", "binaryField", + () -> Schema.create(Schema.Type.BYTES), + schema -> ByteBuffer.wrap(randomAlphabetic.getBytes(StandardCharsets.UTF_8)), + value -> { + ByteBuffer binaryField = (ByteBuffer) value; + assertEquals(randomAlphabetic, new String(binaryField.array(), StandardCharsets.UTF_8)); + } + ); + } + + // Test method for converting a single fixed field + @Test + public void testFixedConversion() { + assertFieldRoundTrips("Fixed", "fixedField", + () -> Schema.createFixed("FixedField", null, null, 3), + schema -> new GenericData.Fixed(schema, "bar".getBytes(StandardCharsets.UTF_8)), + value -> assertEquals("bar", new String((byte[]) value, StandardCharsets.UTF_8)) + ); + } + + // Test method for converting a single enum field + @Test + public void testEnumConversion() { + assertFieldRoundTrips("Enum", "enumField", + () -> Schema.createEnum("EnumField", null, null, Arrays.asList("A", "B", "C")), + schema -> new GenericData.EnumSymbol(schema, "B"), + value -> assertEquals("B", value.toString()) + ); + } + + // Test method for converting a single UUID field + @Test + public void testUUIDConversion() { + UUID uuid = UUID.randomUUID(); + assertFieldRoundTrips("UUID", "uuidField", + () -> LogicalTypes.uuid().addToSchema(Schema.create(Schema.Type.STRING)), + schema -> new Conversions.UUIDConversion().toCharSequence(uuid, schema, LogicalTypes.uuid()), + value -> assertEquals(uuid, UUIDUtil.convert((byte[]) value)) + ); + } + + // Test method for converting a single decimal field + @Test + public void testDecimalConversion() { + BigDecimal bigDecimal = BigDecimal.valueOf(1000.00).setScale(2); + assertFieldRoundTrips("Decimal", "decimalField", + () -> LogicalTypes.decimal(9, 2).addToSchema(Schema.create(Schema.Type.BYTES)), + schema -> { + LogicalTypes.Decimal decimalType = (LogicalTypes.Decimal) schema.getLogicalType(); + return new Conversions.DecimalConversion().toBytes(bigDecimal, schema, decimalType); + }, + value -> assertEquals(bigDecimal, value) + ); + } + + @Test + public void testStructFieldConversion() { + Schema structSchema = SchemaBuilder.record("NestedStruct") + .fields() + .name("field1").type().stringType().noDefault() + .name("field2").type().intType().noDefault() + .endRecord(); + + GenericRecord expected = new GenericData.Record(structSchema); + expected.put("field1", "nested_value"); + expected.put("field2", 99); + + assertFieldRoundTrips("StructField", "structField", + () -> structSchema, + schema -> cloneStruct(expected, schema), + value -> assertStructEquals(expected, (Record) value) + ); + } + + // Test method for converting a list field + @Test + public void testListConversion() { + List expected = Arrays.asList("a", "b", "c"); + assertFieldRoundTrips("List", "listField", + () -> Schema.createArray(Schema.create(Schema.Type.STRING)), + schema -> new ArrayList<>(expected), + value -> assertEquals(expected, normalizeValue(value)) + ); + } + + // Test method for converting a list of structs + @Test + public void testListStructConversion() { + Schema structSchema = SchemaBuilder.record("Struct") + .fields() + .name("field1").type().stringType().noDefault() + .name("field2").type().intType().noDefault() + .endRecord(); + + List expectedList = new ArrayList<>(); + + GenericRecord struct1 = new GenericData.Record(structSchema); + struct1.put("field1", "value1"); + struct1.put("field2", 1); + expectedList.add(struct1); + + GenericRecord struct2 = new GenericData.Record(structSchema); + struct2.put("field1", "value2"); + struct2.put("field2", 2); + expectedList.add(struct2); + + assertFieldRoundTrips("StructList", "listField", + () -> Schema.createArray(structSchema), + schema -> new ArrayList<>(expectedList), + value -> assertStructListEquals(expectedList, value) + ); + } + + // Test method for converting a list with nullable elements + @Test + public void testListWithNullableElementsConversion() { + assertFieldRoundTrips("ListNullableElements", "listField", + () -> Schema.createArray(Schema.createUnion(Arrays.asList( + Schema.create(Schema.Type.NULL), + Schema.create(Schema.Type.STRING) + ))), + schema -> { + @SuppressWarnings("unchecked") + GenericData.Array listValue = new GenericData.Array<>(3, schema); + listValue.add(new Utf8("a")); + listValue.add(null); + listValue.add(new Utf8("c")); + return listValue; + }, + value -> assertEquals(Arrays.asList("a", null, "c"), normalizeValue(value)) + ); + } + + @Test + public void testMapWithNonStringKeysConversion() { + Map expected = new LinkedHashMap<>(); + expected.put(1, "one"); + expected.put(2, "two"); + + Schema logicalMapSchema = createLogicalMapSchema("IntStringEntry", + Schema.create(Schema.Type.INT), Schema.create(Schema.Type.STRING)); + + assertFieldRoundTrips("IntKeyLogicalMap", "mapField", + () -> logicalMapSchema, + schema -> createLogicalMapArrayValue(schema, expected), + value -> { + Map actual = (Map) value; + Map normalized = new LinkedHashMap<>(); + actual.forEach((k, v) -> normalized.put((Integer) k, v == null ? null : v.toString())); + assertEquals(expected, normalized); + } + ); + } + + // Test method for converting a map with string values + @Test + public void testStringMapConversion() { + Map map = new HashMap<>(); + map.put("key1", "value1"); + map.put("key2", "value2"); + assertFieldRoundTrips("StringMap", "mapField", + () -> Schema.createMap(Schema.create(Schema.Type.STRING)), + schema -> new HashMap<>(map), + value -> assertEquals(map, normalizeValue(value)) + ); + } + + // Test method for converting a map with integer values + @Test + public void testIntMapConversion() { + Map map = new HashMap<>(); + map.put("key1", 1); + map.put("key2", 2); + assertFieldRoundTrips("IntMap", "mapField", + () -> Schema.createMap(Schema.create(Schema.Type.INT)), + schema -> new HashMap<>(map), + value -> assertEquals(map, normalizeValue(value)) + ); + } + + // Test method for converting a map with struct values + @Test + public void testStructMapConversion() { + Schema structSchema = SchemaBuilder.record("Struct") + .fields() + .name("field1").type().stringType().noDefault() + .name("field2").type().intType().noDefault() + .endRecord(); + + Map map = new HashMap<>(); + GenericRecord struct1 = new GenericData.Record(structSchema); + struct1.put("field1", "value1"); + struct1.put("field2", 1); + map.put("key1", struct1); + + GenericRecord struct2 = new GenericData.Record(structSchema); + struct2.put("field1", "value2"); + struct2.put("field2", 2); + map.put("key2", struct2); + + assertFieldRoundTrips("StructMap", "mapField", + () -> Schema.createMap(structSchema), + schema -> new HashMap<>(map), + value -> assertStructMapEquals(map, value) + ); + } + + // Test method for converting a map with nullable values + @Test + public void testMapWithNullableValuesConversion() { + Map expectedMap = new HashMap<>(); + expectedMap.put("key1", "value1"); + expectedMap.put("key2", null); + + assertFieldRoundTrips("NullableValueMap", "mapField", + () -> Schema.createMap(Schema.createUnion(Arrays.asList( + Schema.create(Schema.Type.NULL), + Schema.create(Schema.Type.STRING) + ))), + schema -> new HashMap<>(expectedMap), + value -> assertEquals(expectedMap, normalizeValue(value)) + ); + } + + + @Test + public void testBinaryFieldBackedByFixedConversion() { + Schema fixedSchema = Schema.createFixed("FixedBinary", null, null, 4); + Schema recordSchema = SchemaBuilder.builder() + .record("FixedBinaryRecord") + .namespace(TEST_NAMESPACE) + .fields() + .name("binaryField").type(fixedSchema).noDefault() + .endRecord(); + + Types.StructType structType = Types.StructType.of( + Types.NestedField.required(1, "binaryField", Types.BinaryType.get()) + ); + org.apache.iceberg.Schema icebergSchema = new org.apache.iceberg.Schema(structType.fields()); + + runRoundTrip(recordSchema, icebergSchema, + record -> record.put("binaryField", new GenericData.Fixed(fixedSchema, new byte[]{1, 2, 3, 4})), + icebergRecord -> { + ByteBuffer buffer = (ByteBuffer) icebergRecord.getField("binaryField"); + byte[] actual = new byte[buffer.remaining()]; + buffer.get(actual); + assertArrayEquals(new byte[]{1, 2, 3, 4}, actual); + } + ); + } + + // Test method for deeply nested struct (3+ levels) + @Test + public void testDeeplyNestedStructConversion() { + Schema innerMostStruct = SchemaBuilder.record("InnerMostStruct") + .namespace(TEST_NAMESPACE) + .fields() + .name("deepValue").type().intType().noDefault() + .endRecord(); + + Schema middleStruct = SchemaBuilder.record("MiddleStruct") + .namespace(TEST_NAMESPACE) + .fields() + .name("middleField").type().stringType().noDefault() + .name("innerMost").type(innerMostStruct).noDefault() + .endRecord(); + + Schema outerStruct = SchemaBuilder.record("OuterStruct") + .namespace(TEST_NAMESPACE) + .fields() + .name("outerField").type().stringType().noDefault() + .name("middle").type(middleStruct).noDefault() + .endRecord(); + + Schema recordSchema = SchemaBuilder.builder() + .record("DeeplyNestedRecord") + .namespace(TEST_NAMESPACE) + .fields() + .name("topLevel").type().stringType().noDefault() + .name("nested").type(outerStruct).noDefault() + .endRecord(); + + GenericRecord innerMostRecord = new GenericData.Record(innerMostStruct); + innerMostRecord.put("deepValue", 42); + + GenericRecord middleRecord = new GenericData.Record(middleStruct); + middleRecord.put("middleField", "middle"); + middleRecord.put("innerMost", innerMostRecord); + + GenericRecord outerRecord = new GenericData.Record(outerStruct); + outerRecord.put("outerField", "outer"); + outerRecord.put("middle", middleRecord); + + runRoundTrip(recordSchema, + record -> { + record.put("topLevel", "top"); + record.put("nested", outerRecord); + }, + icebergRecord -> { + assertEquals("top", icebergRecord.getField("topLevel").toString()); + Record nestedRecord = (Record) icebergRecord.getField("nested"); + assertNotNull(nestedRecord); + assertEquals("outer", nestedRecord.getField("outerField").toString()); + + Record middleResult = (Record) nestedRecord.getField("middle"); + assertNotNull(middleResult); + assertEquals("middle", middleResult.getField("middleField").toString()); + + Record innerMostResult = (Record) middleResult.getField("innerMost"); + assertNotNull(innerMostResult); + assertEquals(42, innerMostResult.getField("deepValue")); + } + ); + } + + // Test method for converting a record with default values + @Test + public void testDefaultFieldConversion() { + Schema recordSchema = SchemaBuilder.builder() + .record("DefaultValueRecord") + .namespace(TEST_NAMESPACE) + .fields() + .name("defaultStringField").type().stringType().stringDefault("default_string") + .name("defaultIntField").type().intType().intDefault(42) + .name("defaultBoolField").type().booleanType().booleanDefault(true) + .endRecord(); + + // Test with default values + runRoundTrip(recordSchema, + record -> { + Schema.Field defaultStringField = recordSchema.getField("defaultStringField"); + Schema.Field defaultIntField = recordSchema.getField("defaultIntField"); + Schema.Field defaultBoolField = recordSchema.getField("defaultBoolField"); + record.put("defaultStringField", defaultStringField.defaultVal()); + record.put("defaultIntField", defaultIntField.defaultVal()); + record.put("defaultBoolField", defaultBoolField.defaultVal()); + }, + icebergRecord -> { + assertEquals("default_string", icebergRecord.getField("defaultStringField").toString()); + assertEquals(42, icebergRecord.getField("defaultIntField")); + assertEquals(true, icebergRecord.getField("defaultBoolField")); + } + ); + + // Test with non-default values + runRoundTrip(recordSchema, + record -> { + record.put("defaultStringField", "custom_value"); + record.put("defaultIntField", 100); + record.put("defaultBoolField", false); + }, + icebergRecord -> { + assertEquals("custom_value", icebergRecord.getField("defaultStringField").toString()); + assertEquals(100, icebergRecord.getField("defaultIntField")); + assertEquals(false, icebergRecord.getField("defaultBoolField")); + } + ); + } + + // Test that non-optional unions with multiple non-NULL types throw UnsupportedOperationException + @Test + public void testNonOptionalUnionThrowsException() { + // Test case 1: {null, string, int} at record level + Schema unionSchema1 = Schema.createUnion(Arrays.asList( + Schema.create(Schema.Type.NULL), + Schema.create(Schema.Type.STRING), + Schema.create(Schema.Type.INT) + )); + + try { + RecordBinder binder = new RecordBinder(AvroSchemaUtil.toIceberg(unionSchema1), unionSchema1); + org.junit.jupiter.api.Assertions.fail("Expected UnsupportedOperationException for non-optional union {null, string, int}"); + } catch (UnsupportedOperationException e) { + assertEquals(true, e.getMessage().contains("Non-optional UNION with multiple non-NULL types is not supported")); + assertEquals(true, e.getMessage().contains("Found 2 non-NULL types")); + } + + // Test case 2: {null, struct1, struct2} at record level + Schema struct1Schema = SchemaBuilder.record("Struct1") + .namespace(TEST_NAMESPACE) + .fields() + .name("field1").type().stringType().noDefault() + .endRecord(); + + Schema struct2Schema = SchemaBuilder.record("Struct2") + .namespace(TEST_NAMESPACE) + .fields() + .name("field2").type().intType().noDefault() + .endRecord(); + + Schema unionSchema2 = Schema.createUnion(Arrays.asList( + Schema.create(Schema.Type.NULL), + struct1Schema, + struct2Schema + )); + + try { + RecordBinder binder = new RecordBinder(AvroSchemaUtil.toIceberg(unionSchema2), unionSchema2); + org.junit.jupiter.api.Assertions.fail("Expected UnsupportedOperationException for non-optional union {null, struct1, struct2}"); + } catch (UnsupportedOperationException e) { + assertEquals(true, e.getMessage().contains("Non-optional UNION with multiple non-NULL types is not supported")); + assertEquals(true, e.getMessage().contains("Found 2 non-NULL types")); + } + + // Test case 3: Union in field with multiple non-NULL types + Schema unionFieldSchema = Schema.createUnion(Arrays.asList( + Schema.create(Schema.Type.NULL), + Schema.create(Schema.Type.STRING), + Schema.create(Schema.Type.INT) + )); + + Schema recordSchema = SchemaBuilder.builder() + .record("RecordWithUnionField") + .namespace(TEST_NAMESPACE) + .fields() + .name("id").type().intType().noDefault() + .name("unionField").type(unionFieldSchema).withDefault(null) + .endRecord(); + + try { + RecordBinder binder = new RecordBinder(AvroSchemaUtil.toIceberg(recordSchema), recordSchema); + org.junit.jupiter.api.Assertions.fail("Expected UnsupportedOperationException for field with non-optional union"); + } catch (UnsupportedOperationException e) { + assertEquals(true, e.getMessage().contains("Non-optional UNION with multiple non-NULL types is not supported")); + assertEquals(true, e.getMessage().contains("Found 2 non-NULL types")); + } + } + + + private void testSendRecord(org.apache.iceberg.Schema schema, org.apache.iceberg.data.Record record) { + String tableName = "test_" + tableCounter++; + table = catalog.createTable(TableIdentifier.of(Namespace.of("default"), tableName), schema); + writer = createTableWriter(table); + try { + writer.write(record); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static TaskWriter createTableWriter(Table table) { + FileAppenderFactory appenderFactory = new GenericAppenderFactory( + table.schema(), + table.spec(), + null, null, null) + .setAll(new HashMap<>(table.properties())) + .set(PARQUET_ROW_GROUP_SIZE_BYTES, "1"); + + OutputFileFactory fileFactory = + OutputFileFactory.builderFor(table, 1, System.currentTimeMillis()) + .defaultSpec(table.spec()) + .operationId(UUID.randomUUID().toString()) + .format(FileFormat.PARQUET) + .build(); + + return new UnpartitionedWriter<>( + table.spec(), + FileFormat.PARQUET, + appenderFactory, + fileFactory, + table.io(), + WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT + ); + } + + private static GenericRecord serializeAndDeserialize(GenericRecord record, Schema schema) { + try { + // Serialize the avro record to a byte array + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + DatumWriter datumWriter = new SpecificDatumWriter<>(schema); + Encoder encoder = EncoderFactory.get().binaryEncoder(outputStream, null); + datumWriter.write(record, encoder); + encoder.flush(); + outputStream.close(); + + byte[] serializedBytes = outputStream.toByteArray(); + + // Deserialize the byte array back to an avro record + DatumReader datumReader = new SpecificDatumReader<>(schema); + ByteArrayInputStream inputStream = new ByteArrayInputStream(serializedBytes); + Decoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null); + return datumReader.read(null, decoder); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + + private static Schema createOptionalSchema(Schema schema) { + if (schema.getType() == Schema.Type.UNION) { + boolean hasNull = schema.getTypes().stream() + .anyMatch(type -> type.getType() == Schema.Type.NULL); + if (hasNull) { + return schema; + } + List updatedTypes = new ArrayList<>(); + updatedTypes.add(Schema.create(Schema.Type.NULL)); + updatedTypes.addAll(schema.getTypes()); + return Schema.createUnion(updatedTypes); + } + return Schema.createUnion(Arrays.asList(Schema.create(Schema.Type.NULL), schema)); + } + + private static Schema ensureNonNullBranch(Schema schema) { + if (schema.getType() != Schema.Type.UNION) { + return schema; + } + return schema.getTypes().stream() + .filter(type -> type.getType() != Schema.Type.NULL) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Union schema lacks non-null branch: " + schema)); + } + + private void runRoundTrip(Schema recordSchema, Consumer avroPopulator, Consumer assertions) { + runRoundTrip(recordSchema, AvroSchemaUtil.toIceberg(recordSchema), avroPopulator, assertions); + } + + private void runRoundTrip(Schema recordSchema, + org.apache.iceberg.Schema icebergSchema, + Consumer avroPopulator, + Consumer assertions) { + GenericRecord avroRecord = new GenericData.Record(recordSchema); + avroPopulator.accept(avroRecord); + GenericRecord roundTripRecord = serializeAndDeserialize(avroRecord, recordSchema); + + Record icebergRecord = new RecordBinder(icebergSchema, recordSchema).bind(roundTripRecord); + + assertions.accept(icebergRecord); + testSendRecord(icebergSchema, icebergRecord); + } + + // Helper method to test round-trip conversion for a single field + private void assertFieldRoundTrips(String recordPrefix, + String fieldName, + Supplier fieldSchemaSupplier, + Function avroValueSupplier, + Consumer valueAssertion) { + Schema baseFieldSchema = fieldSchemaSupplier.get(); + Schema baseRecordSchema = SchemaBuilder.builder() + .record(recordPrefix + "Base") + .namespace(TEST_NAMESPACE) + .fields() + .name(fieldName).type(baseFieldSchema).noDefault() + .endRecord(); + + // Direct field + runRoundTrip(baseRecordSchema, + record -> record.put(fieldName, avroValueSupplier.apply(baseFieldSchema)), + icebergRecord -> valueAssertion.accept(icebergRecord.getField(fieldName)) + ); + + Schema optionalFieldSchema = createOptionalSchema(fieldSchemaSupplier.get()); + Schema unionRecordSchema = SchemaBuilder.builder() + .record(recordPrefix + "Union") + .namespace(TEST_NAMESPACE) + .fields() + .name(fieldName).type(optionalFieldSchema).withDefault(null) + .endRecord(); + Schema nonNullBranch = ensureNonNullBranch(optionalFieldSchema); + + // Optional field with non-null value + runRoundTrip(unionRecordSchema, + record -> record.put(fieldName, avroValueSupplier.apply(nonNullBranch)), + icebergRecord -> valueAssertion.accept(icebergRecord.getField(fieldName)) + ); + + // Optional field with null value + runRoundTrip(unionRecordSchema, + record -> record.put(fieldName, null), + icebergRecord -> assertNull(icebergRecord.getField(fieldName)) + ); + } + + + private static Map toStringKeyMap(Object value) { + if (value == null) { + return null; + } + Map map = (Map) value; + Map result = new HashMap<>(map.size()); + for (Map.Entry entry : map.entrySet()) { + String key = entry.getKey() == null ? null : entry.getKey().toString(); + result.put(key, normalizeValue(entry.getValue())); + } + return result; + } + + private static GenericRecord cloneStruct(GenericRecord source, Schema schema) { + GenericRecord target = new GenericData.Record(schema); + for (Schema.Field field : schema.getFields()) { + target.put(field.name(), source.get(field.name())); + } + return target; + } + + private static Schema createLogicalMapSchema(String entryName, Schema keySchema, Schema valueSchema) { + Schema.Field keyField = new Schema.Field("key", keySchema, null, null); + Schema.Field valueField = new Schema.Field("value", valueSchema, null, null); + Schema entrySchema = Schema.createRecord(entryName, null, null, false); + entrySchema.setFields(Arrays.asList(keyField, valueField)); + Schema arraySchema = Schema.createArray(entrySchema); + return CodecSetup.getLogicalMap().addToSchema(arraySchema); + } + + private static GenericData.Array createLogicalMapArrayValue(Schema schema, Map values) { + Schema nonNullSchema = ensureNonNullBranch(schema); + if (nonNullSchema.getType() != Schema.Type.ARRAY) { + throw new IllegalArgumentException("Expected array schema for logical map but got: " + nonNullSchema); + } + Schema entrySchema = nonNullSchema.getElementType(); + Schema.Field keyField = entrySchema.getField("key"); + Schema.Field valueField = entrySchema.getField("value"); + GenericData.Array entries = new GenericData.Array<>(values.size(), nonNullSchema); + for (Map.Entry entry : values.entrySet()) { + GenericRecord kv = new GenericData.Record(entrySchema); + kv.put(keyField.name(), toAvroValue(entry.getKey(), keyField.schema())); + kv.put(valueField.name(), toAvroValue(entry.getValue(), valueField.schema())); + entries.add(kv); + } + return entries; + } + + private static Object toAvroValue(Object value, Schema schema) { + if (value == null) { + return null; + } + Schema actualSchema = ensureNonNullBranch(schema); + switch (actualSchema.getType()) { + case STRING: + return value instanceof CharSequence ? value : new Utf8(value.toString()); + case INT: + case LONG: + case FLOAT: + case DOUBLE: + case BOOLEAN: + return value; + case RECORD: + return value; + default: + return value; + } + } + + private static List toRecordList(Object value) { + if (value == null) { + return null; + } + List list = (List) value; + List normalized = new ArrayList<>(list.size()); + for (Object element : list) { + normalized.add((Record) element); + } + return normalized; + } + + private static Map toRecordMap(Object value) { + if (value == null) { + return null; + } + Map map = (Map) value; + Map normalized = new HashMap<>(map.size()); + for (Map.Entry entry : map.entrySet()) { + String key = entry.getKey() == null ? null : entry.getKey().toString(); + normalized.put(key, (Record) entry.getValue()); + } + return normalized; + } + + private static void assertStructListEquals(List expectedList, Object actualValue) { + List actualList = toRecordList(actualValue); + assertNotNull(actualList, "Actual list is null"); + assertEquals(expectedList.size(), actualList.size()); + for (int i = 0; i < expectedList.size(); i++) { + assertStructEquals(expectedList.get(i), actualList.get(i)); + } + } + + private static void assertStructMapEquals(Map expectedMap, Object actualValue) { + Map actualMap = toRecordMap(actualValue); + assertNotNull(actualMap, "Actual map is null"); + assertEquals(expectedMap.keySet(), actualMap.keySet()); + for (Map.Entry entry : expectedMap.entrySet()) { + assertStructEquals(entry.getValue(), actualMap.get(entry.getKey())); + } + } + + private static void assertStructEquals(GenericRecord expected, Record actual) { + assertNotNull(actual, "Actual struct record is null"); + for (Schema.Field field : expected.getSchema().getFields()) { + Object expectedValue = normalizeValue(expected.get(field.name())); + Object actualValue = normalizeValue(actual.getField(field.name())); + assertEquals(expectedValue, actualValue, "Mismatch on field " + field.name()); + } + } + + private static Object normalizeValue(Object value) { + if (value == null) { + return null; + } + if (value instanceof CharSequence) { + return value.toString(); + } + if (value instanceof List) { + List list = (List) value; + List normalized = new ArrayList<>(list.size()); + for (Object element : list) { + normalized.add(normalizeValue(element)); + } + return normalized; + } + if (value instanceof Map) { + return toStringKeyMap(value); + } + return value; + } + +} diff --git a/core/src/test/java/kafka/automq/table/process/convert/ProtoToAvroConverterTest.java b/core/src/test/java/kafka/automq/table/process/convert/ProtoToAvroConverterTest.java new file mode 100644 index 0000000000..ece4c043c3 --- /dev/null +++ b/core/src/test/java/kafka/automq/table/process/convert/ProtoToAvroConverterTest.java @@ -0,0 +1,168 @@ +/* + * Copyright 2025, AutoMQ HK Limited. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.automq.table.process.convert; + +import kafka.automq.table.deserializer.proto.parse.ProtobufSchemaParser; +import kafka.automq.table.deserializer.proto.parse.converter.ProtoConstants; +import kafka.automq.table.deserializer.proto.schema.DynamicSchema; +import kafka.automq.table.process.exception.ConverterException; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; + +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.protobuf.ProtobufData; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Focused unit tests for {@link ProtoToAvroConverter} exercise converter paths that are + * hard to reach through the higher-level registry converter integration tests. + */ +class ProtoToAvroConverterTest { + + private static final String SIMPLE_PROTO = """ + syntax = \"proto3\"; + + package kafka.automq.table.process.proto; + + message SimpleRecord { + bool flag = 1; + Nested nested = 2; + optional int32 opt_scalar = 3; + } + + message Nested { + string note = 1; + } + """; + + @Test + void skipsUnknownAvroFieldsWhenSchemaHasExtraColumns() throws Exception { + DynamicMessage message = buildSimpleRecord(b -> + b.setField(b.getDescriptorForType().findFieldByName("flag"), true) + ); + + Schema schema = SchemaBuilder.record("SimpleRecord") + .fields() + .name("flag").type().booleanType().noDefault() + .name("ghost_field").type().stringType().noDefault() + .endRecord(); + + GenericRecord record = ProtoToAvroConverter.convert(message, schema); + assertEquals(true, record.get("flag")); + assertNull(record.get("ghost_field")); + } + + @Test + void leavesMissingPresenceFieldUnsetWhenAvroSchemaDisallowsNull() throws Exception { + DynamicMessage message = buildSimpleRecord(b -> + b.setField(b.getDescriptorForType().findFieldByName("flag"), false) + ); + + Schema nestedSchema = SchemaBuilder.record("Nested") + .fields() + .name("note").type().stringType().noDefault() + .endRecord(); + + Schema schema = SchemaBuilder.record("SimpleRecord") + .fields() + .name("nested").type(nestedSchema).noDefault() + .name("opt_scalar").type().intType().noDefault() + .endRecord(); + + GenericRecord record = ProtoToAvroConverter.convert(message, schema); + assertNull(record.get("nested")); + assertEquals(0, record.get("opt_scalar")); + } + + @Test + void messageSchemaMismatchYieldsNullWhenNonRecordTypeProvided() throws Exception { + DynamicMessage message = buildSimpleRecord(b -> { + Descriptors.Descriptor nestedDesc = b.getDescriptorForType().findFieldByName("nested").getMessageType(); + b.setField(b.getDescriptorForType().findFieldByName("nested"), + DynamicMessage.newBuilder(nestedDesc) + .setField(nestedDesc.findFieldByName("note"), "note-value") + .build() + ); + }); + + Schema schema = SchemaBuilder.record("SimpleRecord") + .fields() + .name("nested").type().longType().noDefault() + .endRecord(); + + GenericRecord record = ProtoToAvroConverter.convert(message, schema); + assertNull(record.get("nested")); + } + + @Test + void convertPrimitiveWrapsByteArrayValues() throws Exception { + Method method = ProtoToAvroConverter.class.getDeclaredMethod("convertPrimitive", Object.class, Schema.class); + method.setAccessible(true); + byte[] source = new byte[]{1, 2, 3}; + ByteBuffer buffer = (ByteBuffer) invoke(method, null, source, Schema.create(Schema.Type.BYTES)); + ByteBuffer copy = buffer.duplicate(); + byte[] actual = new byte[copy.remaining()]; + copy.get(actual); + assertEquals(List.of((byte) 1, (byte) 2, (byte) 3), List.of(actual[0], actual[1], actual[2])); + } + + @Test + void convertSingleValueRejectsRawListsWhenFieldIsNotRepeated() throws Exception { + Method method = ProtoToAvroConverter.class.getDeclaredMethod("convertSingleValue", Object.class, Schema.class, ProtobufData.class); + method.setAccessible(true); + Schema schema = Schema.create(Schema.Type.STRING); + assertThrows(ConverterException.class, () -> invoke(method, null, List.of("unexpected"), schema, LogicalMapProtobufData.get())); + } + + private static T invoke(Method method, Object target, Object... args) throws Exception { + try { + return (T) method.invoke(target, args); + } catch (InvocationTargetException e) { + throw (Exception) e.getCause(); + } + } + + private static DynamicMessage buildSimpleRecord(Consumer configurer) throws Exception { + Descriptors.Descriptor descriptor = getDescriptor(SIMPLE_PROTO, "SimpleRecord"); + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + configurer.accept(builder); + return builder.build(); + } + + private static Descriptors.Descriptor getDescriptor(String proto, String messageName) throws Exception { + com.squareup.wire.schema.internal.parser.ProtoFileElement fileElement = + com.squareup.wire.schema.internal.parser.ProtoParser.Companion.parse(ProtoConstants.DEFAULT_LOCATION, proto); + DynamicSchema dynamicSchema = ProtobufSchemaParser.toDynamicSchema(messageName, fileElement, Collections.emptyMap()); + return dynamicSchema.getMessageDescriptor(messageName); + } +} diff --git a/core/src/test/java/kafka/automq/table/process/convert/ProtobufRegistryConverterTest.java b/core/src/test/java/kafka/automq/table/process/convert/ProtobufRegistryConverterTest.java index 635592db18..0cf75b1cb5 100644 --- a/core/src/test/java/kafka/automq/table/process/convert/ProtobufRegistryConverterTest.java +++ b/core/src/test/java/kafka/automq/table/process/convert/ProtobufRegistryConverterTest.java @@ -1,3 +1,21 @@ +/* + * Copyright 2025, AutoMQ HK Limited. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package kafka.automq.table.process.convert; import kafka.automq.table.binder.RecordBinder; @@ -18,6 +36,8 @@ import com.squareup.wire.schema.internal.parser.ProtoFileElement; import com.squareup.wire.schema.internal.parser.ProtoParser; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; import org.apache.iceberg.Table; import org.apache.iceberg.avro.AvroSchemaUtil; @@ -26,6 +46,7 @@ import org.apache.iceberg.data.Record; import org.apache.iceberg.inmemory.InMemoryCatalog; import org.apache.iceberg.io.TaskWriter; +import org.apache.iceberg.types.Type; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @@ -39,8 +60,9 @@ import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient; -import static kafka.automq.table.binder.AvroRecordBinderTest.createTableWriter; +import static kafka.automq.table.binder.AvroRecordBinderTypeTest.createTableWriter; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertSame; @Tag("S3Unit") @@ -106,6 +128,16 @@ enum SampleEnum { } """; + private static final String MAP_ONLY_PROTO = """ + syntax = \"proto3\"; + + package kafka.automq.table.process.proto; + + message MapOnly { + map attributes = 1; + } + """; + private void testSendRecord(org.apache.iceberg.Schema schema, org.apache.iceberg.data.Record record) { InMemoryCatalog catalog = new InMemoryCatalog(); catalog.initialize("test", ImmutableMap.of()); @@ -233,6 +265,63 @@ private static DynamicMessage buildAllTypesMessage(Descriptors.Descriptor descri return builder.build(); } + @Test + void testConvertStandaloneMapField() throws Exception { + String topic = "pb-map-only"; + String subject = topic + "-value"; + + MockSchemaRegistryClient registryClient = new MockSchemaRegistryClient(List.of(new ProtobufSchemaProvider())); + CustomProtobufSchema schema = new CustomProtobufSchema( + "MapOnly", + -1, + null, + null, + MAP_ONLY_PROTO, + List.of(), + Map.of() + ); + int schemaId = registryClient.register(subject, schema); + + ProtoFileElement fileElement = ProtoParser.Companion.parse(ProtoConstants.DEFAULT_LOCATION, MAP_ONLY_PROTO); + DynamicSchema dynamicSchema = ProtobufSchemaParser.toDynamicSchema("MapOnly", fileElement, Collections.emptyMap()); + Descriptors.Descriptor descriptor = dynamicSchema.getMessageDescriptor("MapOnly"); + + DynamicMessage message = buildMapOnlyMessage(descriptor); + ByteBuffer payload = buildConfluentPayload(schemaId, message.toByteArray(), 0); + + ProtobufRegistryConverter converter = new ProtobufRegistryConverter(registryClient, "http://mock:8081", false); + ConversionResult result = converter.convert(topic, payload.asReadOnlyBuffer()); + + GenericRecord record = (GenericRecord) result.getValue(); + List attributeEntries = (List) record.get("attributes"); + Map attributes = attributeEntries.stream() + .map(GenericRecord.class::cast) + .collect(Collectors.toMap( + entry -> entry.get("key").toString(), + entry -> (Integer) entry.get("value") + )); + + assertEquals(Map.of("env", 1, "tier", 2), attributes); + + Schema.Field attributesField = record.getSchema().getField("attributes"); + Schema mapSchema = attributesField.schema(); + assertNotNull(mapSchema.getLogicalType(), "Map field should have logical type"); + assertEquals("map", mapSchema.getLogicalType().getName()); + assertEquals(GenericData.Array.class, record.get("attributes").getClass()); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(record.getSchema()); + assertEquals(Type.TypeID.MAP, icebergSchema.findField("attributes").type().typeId()); + } + + private static DynamicMessage buildMapOnlyMessage(Descriptors.Descriptor descriptor) { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + Descriptors.FieldDescriptor mapField = descriptor.findFieldByName("attributes"); + Descriptors.Descriptor entryDescriptor = mapField.getMessageType(); + builder.addRepeatedField(mapField, mapEntry(entryDescriptor, "env", 1)); + builder.addRepeatedField(mapField, mapEntry(entryDescriptor, "tier", 2)); + return builder.build(); + } + private static DynamicMessage mapEntry(Descriptors.Descriptor entryDescriptor, Object key, Object value) { return DynamicMessage.newBuilder(entryDescriptor) .setField(entryDescriptor.findFieldByName("key"), key) diff --git a/core/src/test/java/kafka/automq/table/process/convert/ProtobufRegistryConverterUnitTest.java b/core/src/test/java/kafka/automq/table/process/convert/ProtobufRegistryConverterUnitTest.java new file mode 100644 index 0000000000..a6db377ea3 --- /dev/null +++ b/core/src/test/java/kafka/automq/table/process/convert/ProtobufRegistryConverterUnitTest.java @@ -0,0 +1,719 @@ +/* + * Copyright 2025, AutoMQ HK Limited. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.automq.table.process.convert; + +import kafka.automq.table.binder.RecordBinder; +import kafka.automq.table.deserializer.proto.CustomProtobufSchema; +import kafka.automq.table.deserializer.proto.ProtobufSchemaProvider; +import kafka.automq.table.deserializer.proto.parse.ProtobufSchemaParser; +import kafka.automq.table.deserializer.proto.parse.converter.ProtoConstants; +import kafka.automq.table.deserializer.proto.schema.DynamicSchema; +import kafka.automq.table.process.ConversionResult; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Timestamp; +import com.squareup.wire.schema.internal.parser.ProtoFileElement; +import com.squareup.wire.schema.internal.parser.ProtoParser; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericRecord; +import org.apache.iceberg.Table; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.inmemory.InMemoryCatalog; +import org.apache.iceberg.io.TaskWriter; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient; + +import static kafka.automq.table.binder.AvroRecordBinderTypeTest.createTableWriter; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@Tag("S3Unit") +public class ProtobufRegistryConverterUnitTest { + + private static final String BASIC_PROTO = """ + syntax = \"proto3\"; + + package kafka.automq.table.process.proto; + + message BasicRecord { + bool active = 1; + int32 score = 2; + uint32 quota = 3; + int64 total = 4; + uint64 big_total = 5; + float ratio = 6; + double precise = 7; + string name = 8; + bytes payload = 9; + Status status = 10; + Nested meta = 11; + } + + message Nested { + string note = 1; + } + + enum Status { + STATUS_UNSPECIFIED = 0; + STATUS_READY = 1; + } + """; + + private static final String COLLECTION_PROTO = """ + syntax = \"proto3\"; + + package kafka.automq.table.process.proto; + + message CollectionRecord { + repeated string tags = 1; + repeated Item notes = 2; + repeated Wrap wrappers = 3; + map counters = 4; + map keyed_items = 5; + map wrap_map = 6; + } + + message Item { + string value = 1; + } + + message Wrap { + Item item = 1; + repeated Item items = 2; + } + """; + + private static final String OPTIONAL_COLLECTION_PROTO = """ + syntax = \"proto3\"; + + package kafka.automq.table.process.proto; + + import \"google/protobuf/timestamp.proto\"; + + message OptionalCollectionRecord { + optional Wrapper opt_wrapper = 1; + optional IntStringMap opt_int_map = 2; + optional Item opt_item = 3; + optional WrapMapHolder opt_wrap_map = 4; + optional google.protobuf.Timestamp opt_ts = 5; + } + + message Item { + string value = 1; + } + + message Wrapper { + repeated Item items = 1; + } + + message WrapMapHolder { + map entries = 1; + } + + message IntStringMap { + map entries = 1; + } + """; + + private static final String ADVANCED_PROTO = """ + syntax = \"proto3\"; + + package kafka.automq.table.process.proto; + + import \"google/protobuf/timestamp.proto\"; + + message AdvancedRecord { + optional string opt_str = 1; + optional int32 opt_int = 2; + optional Ref opt_ref = 3; + + oneof selection { + string selection_str = 4; + int32 selection_int = 5; + Ref selection_ref = 6; + Bag selection_bag = 7; + MapHolder selection_map = 8; + IntMapHolder selection_int_map = 9; + } + + google.protobuf.Timestamp event_time = 10; + Ref direct = 11; + repeated Ref refs = 12; + } + + message Ref { + string name = 1; + } + + message Bag { + repeated Ref refs = 1; + } + + message MapHolder { + map entries = 1; + } + + message IntMapHolder { + map entries = 1; + } + """; + + private static final String RECURSIVE_PROTO = """ + syntax = \"proto3\"; + + package kafka.automq.table.process.proto; + + message Node { + string id = 1; + Child child = 2; + } + + message Child { + Node leaf = 1; + } + """; + + @Test + void convertBasicTypesRecord() throws Exception { + String topic = "proto-basic"; + ConversionResult result = convert(topic, BASIC_PROTO, "BasicRecord", builder -> { + builder.setField(builder.getDescriptorForType().findFieldByName("active"), true); + builder.setField(builder.getDescriptorForType().findFieldByName("score"), -10); + builder.setField(builder.getDescriptorForType().findFieldByName("quota"), -1); // uint32 max + builder.setField(builder.getDescriptorForType().findFieldByName("total"), -123456789L); + builder.setField(builder.getDescriptorForType().findFieldByName("big_total"), -1L); + builder.setField(builder.getDescriptorForType().findFieldByName("ratio"), 1.5f); + builder.setField(builder.getDescriptorForType().findFieldByName("precise"), 3.14159d); + builder.setField(builder.getDescriptorForType().findFieldByName("name"), "basic-name"); + builder.setField( + builder.getDescriptorForType().findFieldByName("payload"), + ByteString.copyFromUtf8("payload-bytes") + ); + builder.setField( + builder.getDescriptorForType().findFieldByName("status"), + builder.getDescriptorForType().getFile().findEnumTypeByName("Status").findValueByName("STATUS_READY") + ); + Descriptors.FieldDescriptor nestedField = builder.getDescriptorForType().findFieldByName("meta"); + builder.setField(nestedField, nestedMessage(nestedField.getMessageType(), "note-value")); + }); + + GenericRecord record = (GenericRecord) result.getValue(); + assertEquals(true, record.get("active")); + assertEquals(-10, record.get("score")); + int quotaSigned = (Integer) record.get("quota"); + assertEquals("4294967295", Long.toUnsignedString(Integer.toUnsignedLong(quotaSigned))); + assertEquals(-123456789L, record.get("total")); + long bigTotal = (Long) record.get("big_total"); + assertEquals("18446744073709551615", Long.toUnsignedString(bigTotal)); + assertEquals(1.5f, (Float) record.get("ratio"), 1e-6); + assertEquals(3.14159d, (Double) record.get("precise"), 1e-9); + assertEquals("basic-name", record.get("name").toString()); + assertEquals("payload-bytes", utf8(record.get("payload"))); + assertEquals("STATUS_READY", record.get("status").toString()); + assertEquals("note-value", ((GenericRecord) record.get("meta")).get("note").toString()); + + bindAndWrite(record); + } + + @Test + void convertCollectionsRecord() throws Exception { + String topic = "proto-collections"; + ConversionResult result = convert(topic, COLLECTION_PROTO, "CollectionRecord", builder -> { + Descriptors.FieldDescriptor tagsFd = builder.getDescriptorForType().findFieldByName("tags"); + builder.addRepeatedField(tagsFd, "alpha"); + builder.addRepeatedField(tagsFd, "beta"); + + Descriptors.FieldDescriptor notesFd = builder.getDescriptorForType().findFieldByName("notes"); + Descriptors.Descriptor itemDesc = notesFd.getMessageType(); + builder.addRepeatedField(notesFd, nestedMessage(itemDesc, "note-1")); + builder.addRepeatedField(notesFd, nestedMessage(itemDesc, "note-2")); + + Descriptors.FieldDescriptor wrappersFd = builder.getDescriptorForType().findFieldByName("wrappers"); + Descriptors.Descriptor wrapDesc = wrappersFd.getMessageType(); + Descriptors.Descriptor wrapItemDesc = wrapDesc.findFieldByName("item").getMessageType(); + builder.addRepeatedField(wrappersFd, wrapMessage(wrapDesc, wrapItemDesc, "w1-a", List.of("w1-b"))); + builder.addRepeatedField(wrappersFd, wrapMessage(wrapDesc, wrapItemDesc, "w2-a", List.of("w2-b", "w2-c"))); + + Descriptors.FieldDescriptor countersFd = builder.getDescriptorForType().findFieldByName("counters"); + Descriptors.Descriptor countersEntry = countersFd.getMessageType(); + builder.addRepeatedField(countersFd, mapEntry(countersEntry, "k1", 10)); + builder.addRepeatedField(countersFd, mapEntry(countersEntry, "k2", 20)); + + Descriptors.FieldDescriptor keyedFd = builder.getDescriptorForType().findFieldByName("keyed_items"); + Descriptors.Descriptor keyedEntry = keyedFd.getMessageType(); + builder.addRepeatedField(keyedFd, mapEntry(keyedEntry, 1, nestedMessage(keyedEntry.findFieldByName("value").getMessageType(), "v1"))); + builder.addRepeatedField(keyedFd, mapEntry(keyedEntry, 2, nestedMessage(keyedEntry.findFieldByName("value").getMessageType(), "v2"))); + + Descriptors.FieldDescriptor wrapMapFd = builder.getDescriptorForType().findFieldByName("wrap_map"); + Descriptors.Descriptor wrapMapEntry = wrapMapFd.getMessageType(); + Descriptors.Descriptor wrapMapValueDesc = wrapMapEntry.findFieldByName("value").getMessageType(); + Descriptors.Descriptor wrapMapItemDesc = wrapMapValueDesc.findFieldByName("item").getMessageType(); + builder.addRepeatedField(wrapMapFd, mapEntry( + wrapMapEntry, + "wm1", + wrapMessage(wrapMapValueDesc, wrapMapItemDesc, "wm1-a", List.of("wm1-b")) + )); + builder.addRepeatedField(wrapMapFd, mapEntry( + wrapMapEntry, + "wm2", + wrapMessage(wrapMapValueDesc, wrapMapItemDesc, "wm2-a", List.of("wm2-b", "wm2-c")) + )); + }); + + GenericRecord record = (GenericRecord) result.getValue(); + List tags = (List) record.get("tags"); + assertEquals(List.of("alpha", "beta"), tags.stream().map(Object::toString).collect(Collectors.toList())); + + List notes = (List) record.get("notes"); + assertEquals(List.of("note-1", "note-2"), notes.stream() + .map(GenericRecord.class::cast) + .map(r -> r.get("value").toString()) + .collect(Collectors.toList())); + + Map counters = logicalMapToMap(record.get("counters")); + assertEquals(Map.of("k1", 10, "k2", 20), counters); + + Map keyed = logicalMapToMap(record.get("keyed_items")); + assertEquals(Map.of(1, "v1", 2, "v2"), keyed); + + List wrappers = (List) record.get("wrappers"); + assertEquals(List.of("w1-a", "w2-a"), wrappers.stream() + .map(GenericRecord.class::cast) + .map(r -> ((GenericRecord) r.get("item")).get("value").toString()) + .collect(Collectors.toList())); + assertEquals(List.of( + List.of("w1-b"), + List.of("w2-b", "w2-c") + ), wrappers.stream() + .map(GenericRecord.class::cast) + .map(r -> (List) r.get("items")) + .map(lst -> lst.stream() + .map(GenericRecord.class::cast) + .map(it -> it.get("value").toString()) + .collect(Collectors.toList())) + .collect(Collectors.toList())); + + Map wrapMap = logicalMapToMap(record.get("wrap_map")); + assertEquals("wm1-a", ((GenericRecord) ((GenericRecord) wrapMap.get("wm1")).get("item")).get("value").toString()); + assertEquals("wm2-a", ((GenericRecord) ((GenericRecord) wrapMap.get("wm2")).get("item")).get("value").toString()); + + bindAndWrite(record); + } + + @Test + void convertAdvancedRecord() throws Exception { + String topic = "proto-advanced"; + ConversionResult result = convert(topic, ADVANCED_PROTO, "AdvancedRecord", builder -> { + builder.setField(builder.getDescriptorForType().findFieldByName("opt_str"), "optional-value"); + builder.setField(builder.getDescriptorForType().findFieldByName("opt_int"), 99); + Descriptors.FieldDescriptor optRefFd = builder.getDescriptorForType().findFieldByName("opt_ref"); + builder.setField(optRefFd, nestedMessage(optRefFd.getMessageType(), "opt-ref")); + + // choose oneof map branch via MapHolder; other branches should remain null + Descriptors.FieldDescriptor selMapFd = builder.getDescriptorForType().findFieldByName("selection_map"); + Descriptors.Descriptor mapHolderDesc = selMapFd.getMessageType(); + Descriptors.FieldDescriptor entriesFd = mapHolderDesc.findFieldByName("entries"); + Descriptors.Descriptor entryDesc = entriesFd.getMessageType(); + DynamicMessage.Builder holderBuilder = DynamicMessage.newBuilder(mapHolderDesc); + holderBuilder.addRepeatedField(entriesFd, mapEntry(entryDesc, "a", 1)); + holderBuilder.addRepeatedField(entriesFd, mapEntry(entryDesc, "b", 2)); + builder.setField(selMapFd, holderBuilder.build()); + Timestamp timestamp = Timestamp.newBuilder().setSeconds(1234L).setNanos(567000000).build(); + builder.setField(builder.getDescriptorForType().findFieldByName("event_time"), timestamp); + + Descriptors.FieldDescriptor refField = builder.getDescriptorForType().findFieldByName("direct"); + builder.setField(refField, nestedMessage(refField.getMessageType(), "parent")); + + Descriptors.FieldDescriptor refsField = builder.getDescriptorForType().findFieldByName("refs"); + Descriptors.Descriptor refDescriptor = refsField.getMessageType(); + builder.addRepeatedField(refsField, nestedMessage(refDescriptor, "child-1")); + builder.addRepeatedField(refsField, nestedMessage(refDescriptor, "child-2")); + }); + + GenericRecord record = (GenericRecord) result.getValue(); + Schema optionalSchema = record.getSchema().getField("opt_str").schema(); + assertEquals(Schema.Type.UNION, optionalSchema.getType()); + assertEquals(Schema.Type.STRING, optionalSchema.getTypes().get(0).getType()); + assertEquals("optional-value", record.get("opt_str").toString()); + assertEquals(99, record.get("opt_int")); + assertEquals("opt-ref", ((GenericRecord) record.get("opt_ref")).get("name").toString()); + + GenericRecord selMapRecord = (GenericRecord) record.get("selection_map"); + Map selMap = logicalMapToMap(selMapRecord.get("entries")); + assertEquals(Map.of("a", 1, "b", 2), selMap); + assertEquals(null, record.get("selection_ref")); + assertEquals(null, record.get("selection_str")); + assertEquals(null, record.get("selection_int")); + assertEquals(null, record.get("selection_bag")); + + long expectedMicros = 1234_000_000L + 567_000; + assertEquals(expectedMicros, record.get("event_time")); + assertEquals("parent", ((GenericRecord) record.get("direct")).get("name").toString()); + + List refs = (List) record.get("refs"); + assertEquals(List.of("child-1", "child-2"), refs.stream() + .map(GenericRecord.class::cast) + .map(r -> r.get("name").toString()) + .collect(Collectors.toList())); + + bindAndWrite(record); + } + + @Test + void convertAdvancedOneofStringIntRefBag() throws Exception { + // string branch + ConversionResult stringResult = convert("proto-adv-oneof-str", ADVANCED_PROTO, "AdvancedRecord", b -> + b.setField(b.getDescriptorForType().findFieldByName("selection_str"), "sel-str")); + GenericRecord stringRec = (GenericRecord) stringResult.getValue(); + assertEquals("sel-str", stringRec.get("selection_str")); + assertEquals(null, stringRec.get("selection_int")); + assertEquals(null, stringRec.get("selection_ref")); + assertEquals(null, stringRec.get("selection_bag")); + bindAndWrite((GenericRecord) stringResult.getValue()); + + // int branch + ConversionResult intResult = convert("proto-adv-oneof-int", ADVANCED_PROTO, "AdvancedRecord", b -> + b.setField(b.getDescriptorForType().findFieldByName("selection_int"), 123)); + GenericRecord intRec = (GenericRecord) intResult.getValue(); + assertEquals(123, intRec.get("selection_int")); + assertEquals(null, intRec.get("selection_str")); + assertEquals(null, intRec.get("selection_ref")); + assertEquals(null, intRec.get("selection_bag")); + bindAndWrite((GenericRecord) intResult.getValue()); + + // ref branch + ConversionResult refResult = convert("proto-adv-oneof-ref", ADVANCED_PROTO, "AdvancedRecord", b -> { + Descriptors.FieldDescriptor fd = b.getDescriptorForType().findFieldByName("selection_ref"); + b.setField(fd, nestedMessage(fd.getMessageType(), "sel-ref")); + }); + GenericRecord refRec = (GenericRecord) refResult.getValue(); + assertEquals("sel-ref", ((GenericRecord) refRec.get("selection_ref")).get("name").toString()); + assertEquals(null, refRec.get("selection_str")); + assertEquals(null, refRec.get("selection_int")); + assertEquals(null, refRec.get("selection_bag")); + bindAndWrite((GenericRecord) refResult.getValue()); + + // bag branch (contains repeated refs) + ConversionResult bagResult = convert("proto-adv-oneof-bag", ADVANCED_PROTO, "AdvancedRecord", b -> { + Descriptors.FieldDescriptor fd = b.getDescriptorForType().findFieldByName("selection_bag"); + Descriptors.Descriptor bagDesc = fd.getMessageType(); + Descriptors.FieldDescriptor refsFd = bagDesc.findFieldByName("refs"); + DynamicMessage.Builder bagBuilder = DynamicMessage.newBuilder(bagDesc); + bagBuilder.addRepeatedField(refsFd, nestedMessage(refsFd.getMessageType(), "b1")); + bagBuilder.addRepeatedField(refsFd, nestedMessage(refsFd.getMessageType(), "b2")); + b.setField(fd, bagBuilder.build()); + }); + GenericRecord bagRec = (GenericRecord) bagResult.getValue(); + List bagRefs = (List) ((GenericRecord) bagRec.get("selection_bag")).get("refs"); + assertEquals(List.of("b1", "b2"), bagRefs.stream() + .map(GenericRecord.class::cast) + .map(r -> r.get("name").toString()) + .collect(Collectors.toList())); + assertEquals(null, bagRec.get("selection_str")); + assertEquals(null, bagRec.get("selection_int")); + assertEquals(null, bagRec.get("selection_ref")); + bindAndWrite((GenericRecord) bagResult.getValue()); + + // int map branch (map) + ConversionResult intMapResult = convert("proto-adv-oneof-intmap", ADVANCED_PROTO, "AdvancedRecord", b -> { + Descriptors.FieldDescriptor fd = b.getDescriptorForType().findFieldByName("selection_int_map"); + Descriptors.Descriptor holderDesc = fd.getMessageType(); + Descriptors.FieldDescriptor entriesFd = holderDesc.findFieldByName("entries"); + Descriptors.Descriptor entryDesc = entriesFd.getMessageType(); + DynamicMessage.Builder holder = DynamicMessage.newBuilder(holderDesc); + holder.addRepeatedField(entriesFd, mapEntry(entryDesc, 10, "x")); + holder.addRepeatedField(entriesFd, mapEntry(entryDesc, 20, "y")); + b.setField(fd, holder.build()); + }); + GenericRecord intMapRec = (GenericRecord) intMapResult.getValue(); + Map intMaps = logicalMapToMap(((GenericRecord) intMapRec.get("selection_int_map")).get("entries")); + assertEquals(Map.of(10, "x", 20, "y"), intMaps); + assertEquals(null, intMapRec.get("selection_str")); + assertEquals(null, intMapRec.get("selection_int")); + assertEquals(null, intMapRec.get("selection_ref")); + assertEquals(null, intMapRec.get("selection_bag")); + bindAndWrite((GenericRecord) intMapResult.getValue()); + } + + @Test + void convertOptionalCollectionsRecord() throws Exception { + String topic = "proto-optional-collections"; + ConversionResult result = convert(topic, OPTIONAL_COLLECTION_PROTO, "OptionalCollectionRecord", builder -> { + Descriptors.FieldDescriptor wrapperFd = builder.getDescriptorForType().findFieldByName("opt_wrapper"); + Descriptors.Descriptor wrapperDesc = wrapperFd.getMessageType(); + Descriptors.Descriptor itemDesc = wrapperDesc.findFieldByName("items").getMessageType(); + DynamicMessage.Builder wrapperBuilder = DynamicMessage.newBuilder(wrapperDesc); + wrapperBuilder.addRepeatedField(wrapperDesc.findFieldByName("items"), nestedMessage(itemDesc, "i1")); + wrapperBuilder.addRepeatedField(wrapperDesc.findFieldByName("items"), nestedMessage(itemDesc, "i2")); + builder.setField(wrapperFd, wrapperBuilder.build()); + + // leave opt_int_map unset to validate optional-map -> null + + Descriptors.FieldDescriptor optItemFd = builder.getDescriptorForType().findFieldByName("opt_item"); + builder.setField(optItemFd, nestedMessage(optItemFd.getMessageType(), "single")); + }); + + GenericRecord record = (GenericRecord) result.getValue(); + // opt_wrapper union present + assertEquals(Schema.Type.UNION, record.getSchema().getField("opt_wrapper").schema().getType()); + List items = (List) ((GenericRecord) record.get("opt_wrapper")).get("items"); + assertEquals(List.of("i1", "i2"), items.stream().map(GenericRecord.class::cast).map(r -> r.get("value").toString()).collect(Collectors.toList())); + + assertEquals(null, record.get("opt_int_map")); + + GenericRecord optItem = (GenericRecord) record.get("opt_item"); + assertEquals("single", optItem.get("value").toString()); + } + + @Test + void convertOptionalCollectionsRecordWithMap() throws Exception { + String topic = "proto-optional-collections-map"; + ConversionResult result = convert(topic, OPTIONAL_COLLECTION_PROTO, "OptionalCollectionRecord", builder -> { + Descriptors.FieldDescriptor optMapFd = builder.getDescriptorForType().findFieldByName("opt_int_map"); + Descriptors.Descriptor holderDesc = optMapFd.getMessageType(); + Descriptors.FieldDescriptor entriesFd = holderDesc.findFieldByName("entries"); + Descriptors.Descriptor entryDesc = entriesFd.getMessageType(); + DynamicMessage.Builder holder = DynamicMessage.newBuilder(holderDesc); + holder.addRepeatedField(entriesFd, mapEntry(entryDesc, 7, "v7")); + holder.addRepeatedField(entriesFd, mapEntry(entryDesc, 8, "v8")); + builder.setField(optMapFd, holder.build()); + + Descriptors.FieldDescriptor wrapMapFd = builder.getDescriptorForType().findFieldByName("opt_wrap_map"); + Descriptors.Descriptor wrapMapDesc = wrapMapFd.getMessageType(); + Descriptors.FieldDescriptor wrapEntriesFd = wrapMapDesc.findFieldByName("entries"); + Descriptors.Descriptor wrapEntryDesc = wrapEntriesFd.getMessageType(); + Descriptors.Descriptor wrapValueDesc = wrapEntryDesc.findFieldByName("value").getMessageType(); + Descriptors.Descriptor wrapItemDesc = wrapValueDesc.findFieldByName("items").getMessageType(); + + DynamicMessage.Builder wrapHolder = DynamicMessage.newBuilder(wrapMapDesc); + wrapHolder.addRepeatedField(wrapEntriesFd, mapEntry( + wrapEntryDesc, + "wkey1", + wrapMessage(wrapValueDesc, wrapItemDesc, "wm1-a", List.of("wm1-b")) + )); + builder.setField(wrapMapFd, wrapHolder.build()); + + // optional timestamp + Timestamp ts = Timestamp.newBuilder().setSeconds(10L).setNanos(500_000_000).build(); + builder.setField(builder.getDescriptorForType().findFieldByName("opt_ts"), ts); + }); + + GenericRecord record = (GenericRecord) result.getValue(); + GenericRecord optIntMap = (GenericRecord) record.get("opt_int_map"); + Map map = logicalMapToMap(optIntMap.get("entries")); + assertEquals(Map.of(7, "v7", 8, "v8"), map); + + Schema.Field optWrapField = record.getSchema().getField("opt_wrap_map"); + assertEquals(Schema.Type.UNION, optWrapField.schema().getType()); + GenericRecord optWrapMap = (GenericRecord) record.get("opt_wrap_map"); + Map wrapEntries = logicalMapToMap(optWrapMap.get("entries")); + GenericRecord wrapper = wrapEntries.get("wkey1"); + List wrapItems = (List) wrapper.get("items"); + assertEquals(List.of("wm1-b"), wrapItems.stream() + .map(GenericRecord.class::cast) + .map(item -> item.get("value").toString()) + .collect(Collectors.toList())); + + assertEquals(10_500_000L, record.get("opt_ts")); + + + } + + @Test + void convertRecursiveRecord() throws Exception { + String topic = "proto-recursive"; + ConversionResult result = convert(topic, RECURSIVE_PROTO, "Node", builder -> { + builder.setField(builder.getDescriptorForType().findFieldByName("id"), "root"); + Descriptors.FieldDescriptor childFd = builder.getDescriptorForType().findFieldByName("child"); + Descriptors.Descriptor childDesc = childFd.getMessageType(); + Descriptors.FieldDescriptor leafFd = childDesc.findFieldByName("leaf"); + Descriptors.Descriptor nodeDesc = leafFd.getMessageType(); + + DynamicMessage leaf = DynamicMessage.newBuilder(nodeDesc) + .setField(nodeDesc.findFieldByName("id"), "leaf") + .build(); + DynamicMessage child = DynamicMessage.newBuilder(childDesc) + .setField(leafFd, leaf) + .build(); + builder.setField(childFd, child); + }); + + GenericRecord record = (GenericRecord) result.getValue(); + assertEquals("root", record.get("id").toString()); + GenericRecord child = (GenericRecord) record.get("child"); + GenericRecord leaf = (GenericRecord) child.get("leaf"); + assertEquals("leaf", leaf.get("id").toString()); + + assertThrows(IllegalStateException.class, () -> bindAndWrite(record)); + } + + private ConversionResult convert(String topic, String proto, String messageName, Consumer messageConfigurer) throws Exception { + MockSchemaRegistryClient registryClient = new MockSchemaRegistryClient(List.of(new ProtobufSchemaProvider())); + CustomProtobufSchema schema = new CustomProtobufSchema( + messageName, + -1, + null, + null, + proto, + List.of(), + Map.of() + ); + int schemaId = registryClient.register(topic + "-value", schema); + + ProtoFileElement fileElement = ProtoParser.Companion.parse(ProtoConstants.DEFAULT_LOCATION, proto); + DynamicSchema dynamicSchema = ProtobufSchemaParser.toDynamicSchema(messageName, fileElement, Collections.emptyMap()); + Descriptors.Descriptor descriptor = dynamicSchema.getMessageDescriptor(messageName); + + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + messageConfigurer.accept(builder); + DynamicMessage message = builder.build(); + + ByteBuffer payload = buildConfluentPayload(schemaId, message.toByteArray(), 0); + ProtobufRegistryConverter converter = new ProtobufRegistryConverter(registryClient, "http://mock:8081", false); + ConversionResult result = converter.convert(topic, payload.asReadOnlyBuffer()); + + ConversionResult cached = converter.convert(topic, payload.asReadOnlyBuffer()); + assertSame(result.getSchema(), cached.getSchema()); + return result; + } + + private void bindAndWrite(GenericRecord record) { + org.apache.iceberg.Schema iceberg = AvroSchemaUtil.toIceberg(record.getSchema()); + RecordBinder binder = new RecordBinder(iceberg, record.getSchema()); + Record icebergRecord = binder.bind(record); + assertDoesNotThrow(() -> testSendRecord(iceberg, icebergRecord)); + } + + private static String utf8(Object value) { + ByteBuffer buffer = ((ByteBuffer) value).duplicate(); + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } + + private static DynamicMessage nestedMessage(Descriptors.Descriptor descriptor, String value) { + return DynamicMessage.newBuilder(descriptor) + .setField(descriptor.findFieldByName("note") != null ? descriptor.findFieldByName("note") : descriptor.findFieldByName("value") != null + ? descriptor.findFieldByName("value") : descriptor.findFieldByName("name"), value) + .build(); + } + + private static DynamicMessage wrapMessage(Descriptors.Descriptor wrapDesc, Descriptors.Descriptor itemDesc, String itemValue, List itemListValues) { + DynamicMessage.Builder wrapBuilder = DynamicMessage.newBuilder(wrapDesc); + Descriptors.FieldDescriptor itemField = wrapDesc.findFieldByName("item"); + if (itemField != null) { + wrapBuilder.setField(itemField, nestedMessage(itemDesc, itemValue)); + } + Descriptors.FieldDescriptor itemsFd = wrapDesc.findFieldByName("items"); + if (itemsFd != null) { + for (String v : itemListValues) { + wrapBuilder.addRepeatedField(itemsFd, nestedMessage(itemDesc, v)); + } + } + return wrapBuilder.build(); + } + + private static DynamicMessage mapEntry(Descriptors.Descriptor descriptor, Object key, Object value) { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + builder.setField(descriptor.findFieldByName("key"), key); + builder.setField(descriptor.findFieldByName("value"), value); + return builder.build(); + } + + private static Map logicalMapToMap(Object logicalMap) { + List entries = (List) logicalMap; + return entries.stream() + .map(GenericRecord.class::cast) + .collect(Collectors.toMap( + entry -> (K) entry.get("key"), + entry -> { + Object value = entry.get("value"); + if (value instanceof GenericRecord) { + GenericRecord record = (GenericRecord) value; + if (record.getSchema().getField("value") != null) { + return (V) record.get("value").toString(); + } + if (record.getSchema().getField("name") != null) { + return (V) record.get("name").toString(); + } + } + return (V) value; + } + )); + } + + private static ByteBuffer buildConfluentPayload(int schemaId, byte[] messageBytes, int... messageIndexes) { + byte[] indexBytes = encodeMessageIndexes(messageIndexes); + ByteBuffer buffer = ByteBuffer.allocate(1 + Integer.BYTES + indexBytes.length + messageBytes.length); + buffer.put((byte) 0); + buffer.putInt(schemaId); + buffer.put(indexBytes); + buffer.put(messageBytes); + buffer.flip(); + return buffer; + } + + private static byte[] encodeMessageIndexes(int... indexes) { + if (indexes == null || indexes.length == 0) { + return new byte[]{0}; + } + ByteBuffer buffer = ByteBuffer.allocate(5 * (indexes.length + 1)); + org.apache.kafka.common.utils.ByteUtils.writeVarint(indexes.length, buffer); + for (int index : indexes) { + org.apache.kafka.common.utils.ByteUtils.writeVarint(index, buffer); + } + buffer.flip(); + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + + private void testSendRecord(org.apache.iceberg.Schema schema, Record record) { + InMemoryCatalog catalog = new InMemoryCatalog(); + catalog.initialize("test", ImmutableMap.of()); + catalog.createNamespace(Namespace.of("default")); + Table table = catalog.createTable(TableIdentifier.of(Namespace.of("default"), "scenario"), schema); + TaskWriter writer = createTableWriter(table); + try { + writer.write(record); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +}