Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ public static <P, T> T visit(
P partner, Schema schema, AvroWithPartnerByStructureVisitor<P, T> visitor) {
switch (schema.getType()) {
case RECORD:
return visitRecord(partner, schema, visitor);
if (schema.getLogicalType() instanceof VariantLogicalType
|| visitor.isVariantType(partner)) {
return visitVariant(partner, schema, visitor);
} else {
return visitRecord(partner, schema, visitor);
}

case UNION:
return visitUnion(partner, schema, visitor);
Expand All @@ -61,6 +66,23 @@ public static <P, T> T visit(

// ---------------------------------- Static helpers ---------------------------------------------

private static <P, R> R visitVariant(
P partner, Schema variant, AvroWithPartnerByStructureVisitor<P, R> visitor) {
// check to make sure this hasn't been visited before
String name = variant.getFullName();
Preconditions.checkState(
!visitor.recordLevels.contains(name), "Cannot process recursive Avro record %s", name);

visitor.recordLevels.push(name);

R metadataResult = visit(null, variant.getField("metadata").schema(), visitor);
R valueResult = visit(null, variant.getField("value").schema(), visitor);

visitor.recordLevels.pop();

return visitor.variant(partner, metadataResult, valueResult);
}

private static <P, T> T visitRecord(
P struct, Schema record, AvroWithPartnerByStructureVisitor<P, T> visitor) {
// check to make sure this hasn't been visited before
Expand Down Expand Up @@ -155,6 +177,10 @@ private static <P, T> T visitArray(
// ---------------------------------- Partner type methods
// ---------------------------------------------

protected boolean isVariantType(P type) {
return false;
}

protected abstract boolean isMapType(P type);

protected abstract boolean isStringType(P type);
Expand Down Expand Up @@ -191,6 +217,10 @@ public T map(P sMap, Schema map, T value) {
return null;
}

public T variant(P partner, T metadata, T value) {
throw new UnsupportedOperationException("Visitor does not support variant");
}

public T primitive(P type, Schema primitive) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
import org.apache.iceberg.util.Pair;

public class AvroWithTypeByStructureVisitor<T> extends AvroWithPartnerByStructureVisitor<Type, T> {
@Override
protected boolean isVariantType(Type type) {
return type.isVariantType();
}

@Override
protected boolean isMapType(Type type) {
return type.isMapType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.avro.specific.SpecificData;
import org.apache.iceberg.avro.AvroSchemaVisitor;
import org.apache.iceberg.avro.UUIDConversion;
import org.apache.iceberg.avro.VariantConversion;
import org.apache.iceberg.relocated.com.google.common.base.Objects;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
Expand Down Expand Up @@ -192,10 +193,12 @@ public GenericFixed toFixed(BigDecimal value, Schema schema, LogicalType type) {
private final Conversion<?> intDecimalConversion = new IntDecimalConversion();
private final Conversion<?> longDecimalConversion = new LongDecimalConversion();
private final Conversion<?> uuidConversion = new UUIDConversion();
private final Conversion<?> variantConversion = new VariantConversion();

{
addLogicalTypeConversion(fixedDecimalConversion);
addLogicalTypeConversion(uuidConversion);
addLogicalTypeConversion(variantConversion);
}

@Override
Expand Down Expand Up @@ -298,6 +301,11 @@ public Schema map(Schema map, Schema value) {
return map;
}

@Override
public Schema variant(Schema variant, Schema metadataResult, Schema valueResult) {
return variant;
}

@Override
public Schema primitive(Schema primitive) {
LogicalType logicalType = primitive.getLogicalType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ public static VariantValueReader asVariant(PhysicalType type, ParquetValueReader
return new ValueAsVariantReader<>(type, reader);
}

private abstract static class DelegatingValueReader<S, T> implements ParquetValueReader<T> {
public abstract static class DelegatingValueReader<S, T> implements ParquetValueReader<T> {
private final ParquetValueReader<S> reader;

private DelegatingValueReader(ParquetValueReader<S> reader) {
protected DelegatingValueReader(ParquetValueReader<S> reader) {
this.reader = reader;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ public Type map(Types.MapType map, Supplier<Type> keyResult, Supplier<Type> valu
}
}

@Override
public Type variant(Types.VariantType variant) {
return Types.VariantType.get();
}

@Override
public Type primitive(Type.PrimitiveType primitive) {
Set<Class<? extends DataType>> expectedType = TYPES.get(primitive.typeId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.spark.sql.types.TimestampNTZType;
import org.apache.spark.sql.types.TimestampType;
import org.apache.spark.sql.types.VarcharType;
import org.apache.spark.sql.types.VariantType;

class SparkTypeToType extends SparkTypeVisitor<Type> {
private final StructType root;
Expand Down Expand Up @@ -116,6 +117,11 @@ public Type map(MapType map, Type keyType, Type valueType) {
}
}

@Override
public Type variant(VariantType variant) {
return Types.VariantType.get();
}

@SuppressWarnings("checkstyle:CyclomaticComplexity")
@Override
public Type atomic(DataType atomic) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.UserDefinedType;
import org.apache.spark.sql.types.VariantType;

class SparkTypeVisitor<T> {
static <T> T visit(DataType type, SparkTypeVisitor<T> visitor) {
Expand All @@ -48,6 +49,9 @@ static <T> T visit(DataType type, SparkTypeVisitor<T> visitor) {
} else if (type instanceof ArrayType) {
return visitor.array((ArrayType) type, visit(((ArrayType) type).elementType(), visitor));

} else if (type instanceof VariantType) {
return visitor.variant((VariantType) type);

} else if (type instanceof UserDefinedType) {
throw new UnsupportedOperationException("User-defined types are not supported");

Expand All @@ -56,6 +60,10 @@ static <T> T visit(DataType type, SparkTypeVisitor<T> visitor) {
}
}

public T variant(VariantType variant) {
throw new UnsupportedOperationException("Not implemented for variant");
}

public T struct(StructType struct, List<T> fieldResults) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.spark.sql.types.StructType$;
import org.apache.spark.sql.types.TimestampNTZType$;
import org.apache.spark.sql.types.TimestampType$;
import org.apache.spark.sql.types.VariantType$;

class TypeToSparkType extends TypeUtil.SchemaVisitor<DataType> {
TypeToSparkType() {}
Expand Down Expand Up @@ -88,6 +89,11 @@ public DataType map(Types.MapType map, DataType keyResult, DataType valueResult)
return MapType$.MODULE$.apply(keyResult, valueResult, map.isValueOptional());
}

@Override
public DataType variant(Types.VariantType variant) {
return VariantType$.MODULE$;
}

@Override
public DataType primitive(Type.PrimitiveType primitive) {
switch (primitive.typeId()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.VariantType;

public abstract class AvroWithSparkSchemaVisitor<T>
extends AvroWithPartnerByStructureVisitor<DataType, T> {
@Override
protected boolean isVariantType(DataType type) {
return type instanceof VariantType;
}

@Override
protected boolean isStringType(DataType dataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.VariantType;

/**
* Visitor for traversing a Parquet type with a companion Spark type.
Expand Down Expand Up @@ -152,6 +153,14 @@ public static <T> T visit(DataType sType, Type type, ParquetWithSparkSchemaVisit
} finally {
visitor.fieldNames.pop();
}
} else if (sType instanceof VariantType) {
// TODO: Use LogicalTypeAnnotation.variantType().equals(annotation) when VARIANT type is
// added to Parquet
// Preconditions.checkArgument(
// sType instanceof VariantType, "Invalid variant: %s is not a VariantType", sType);
VariantType variant = (VariantType) sType;

return visitor.variant(variant, group);
}

Preconditions.checkArgument(
Expand Down Expand Up @@ -211,6 +220,10 @@ public T primitive(DataType sPrimitive, PrimitiveType primitive) {
return null;
}

public T variant(VariantType sVariant, GroupType variant) {
throw new UnsupportedOperationException("Not implemented for variant");
}

protected String[] currentPath() {
return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public ValueWriter<?> map(
keyWriter, mapKeyType(sMap), valueWriter, mapValueType(sMap));
}

@Override
public ValueWriter<?> variant(DataType partner, ValueWriter<?> metadata, ValueWriter<?> value) {
return SparkValueWriters.variants();
}

@Override
public ValueWriter<?> primitive(DataType type, Schema primitive) {
LogicalType logicalType = primitive.getLogicalType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -36,7 +37,10 @@
import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry;
import org.apache.iceberg.parquet.ParquetValueReaders.StructReader;
import org.apache.iceberg.parquet.ParquetValueReaders.UnboxedReader;
import org.apache.iceberg.parquet.ParquetVariantReaders.DelegatingValueReader;
import org.apache.iceberg.parquet.ParquetVariantVisitor;
import org.apache.iceberg.parquet.TypeWithSchemaVisitor;
import org.apache.iceberg.parquet.VariantReaderBuilder;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
Expand All @@ -45,6 +49,7 @@
import org.apache.iceberg.types.Type.TypeID;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.UUIDUtil;
import org.apache.iceberg.variants.Variant;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.GroupType;
Expand Down Expand Up @@ -220,6 +225,17 @@ public ParquetValueReader<?> map(
ParquetValueReaders.option(valueType, valueD, valueReader));
}

@Override
public ParquetVariantVisitor<ParquetValueReader<?>> variantVisitor() {
return new VariantReaderBuilder(type, Arrays.asList(currentPath()));
}

@Override
public ParquetValueReader<?> variant(
Types.VariantType iVariant, GroupType variant, ParquetValueReader<?> variantReader) {
return new VariantReader(variantReader);
}

@Override
@SuppressWarnings("checkstyle:CyclomaticComplexity")
public ParquetValueReader<?> primitive(
Expand Down Expand Up @@ -497,6 +513,28 @@ protected MapData buildMap(ReusableMapData map) {
}
}

/** Variant reader to convert from Variant to Spark VariantVal */
private static class VariantReader extends DelegatingValueReader<Variant, VariantVal> {
@SuppressWarnings("unchecked")
private VariantReader(ParquetValueReader<?> reader) {
super((ParquetValueReader<Variant>) reader);
}

@Override
public VariantVal read(VariantVal reuse) {
Variant variant = super.readFromDelegate(null);
byte[] metadataBytes = new byte[variant.metadata().sizeInBytes()];
ByteBuffer metadataBuffer = ByteBuffer.wrap(metadataBytes).order(ByteOrder.LITTLE_ENDIAN);
variant.metadata().writeTo(metadataBuffer, 0);

byte[] valueBytes = new byte[variant.value().sizeInBytes()];
ByteBuffer valueBuffer = ByteBuffer.wrap(valueBytes).order(ByteOrder.LITTLE_ENDIAN);
variant.value().writeTo(valueBuffer, 0);

return new VariantVal(valueBytes, metadataBytes);
}
}

private static class InternalRowReader extends StructReader<InternalRow, GenericInternalRow> {
private final int numFields;

Expand Down
Loading