Skip to content

Commit

Permalink
[HUDI-4904] Add support for unraveling proto schemas in ProtoClassBas…
Browse files Browse the repository at this point in the history
…edSchemaProvider (#6761)

If a user provides a recursive proto schema, it will fail when we write to parquet. We need to allow the user to specify how many levels of recursion they want before truncating the remaining data.

Main changes to existing code:

ProtoClassBasedSchemaProvider tracks number of times a message descriptor is seen within a branch of the schema traversal
once the number of times that descriptor is seen exceeds the user provided limit, set the field to preset record that will contain two fields: 1) the remaining data serialized as a proto byte array, 2) the descriptors full name for context about what is in that byte array
Converting from a proto to an avro now accounts for this truncation of the input
  • Loading branch information
the-other-tim-brown committed Sep 27, 2022
1 parent 76b9354 commit db03e1f
Show file tree
Hide file tree
Showing 14 changed files with 1,551 additions and 272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
package org.apache.hudi.utilities.schema;

import org.apache.hudi.DataSourceUtils;
import org.apache.hudi.common.config.ConfigProperty;
import org.apache.hudi.common.config.TypedProperties;
import org.apache.hudi.common.util.ReflectionUtils;
import org.apache.hudi.exception.HoodieException;
import org.apache.hudi.internal.schema.HoodieSchemaException;
import org.apache.hudi.utilities.sources.helpers.ProtoConversionUtil;

import org.apache.avro.Schema;
Expand All @@ -38,27 +39,44 @@ public class ProtoClassBasedSchemaProvider extends SchemaProvider {
* Configs supported.
*/
public static class Config {
public static final String PROTO_SCHEMA_CLASS_NAME = "hoodie.deltastreamer.schemaprovider.proto.className";
public static final String PROTO_SCHEMA_FLATTEN_WRAPPED_PRIMITIVES = "hoodie.deltastreamer.schemaprovider.proto.flattenWrappers";
private static final String PROTO_SCHEMA_PROVIDER_PREFIX = "hoodie.deltastreamer.schemaprovider.proto";
public static final ConfigProperty<String> PROTO_SCHEMA_CLASS_NAME = ConfigProperty.key(PROTO_SCHEMA_PROVIDER_PREFIX + ".class.name")
.noDefaultValue()
.sinceVersion("0.13.0")
.withDocumentation("The Protobuf Message class used as the source for the schema.");

public static final ConfigProperty<Boolean> PROTO_SCHEMA_FLATTEN_WRAPPED_PRIMITIVES = ConfigProperty.key(PROTO_SCHEMA_PROVIDER_PREFIX + ".flatten.wrappers")
.defaultValue(false)
.sinceVersion("0.13.0")
.withDocumentation("When set to false wrapped primitives like Int64Value are translated to a record with a single 'value' field instead of simply a nullable value");

public static final ConfigProperty<Integer> PROTO_SCHEMA_MAX_RECURSION_DEPTH = ConfigProperty.key(PROTO_SCHEMA_PROVIDER_PREFIX + ".max.recursion.depth")
.defaultValue(5)
.sinceVersion("0.13.0")
.withDocumentation("The max depth to unravel the Proto schema when translating into an Avro schema. Setting this depth allows the user to convert a schema that is recursive in proto into "
+ "something that can be represented in their lake format like Parquet. After a given class has been seen N times within a single branch, the schema provider will create a record with a "
+ "byte array to hold the remaining proto data and a string to hold the message descriptor's name for context.");
}

private final String schemaString;

/**
* To be lazily inited on executors.
* To be lazily initiated on executors.
*/
private transient Schema schema;

public ProtoClassBasedSchemaProvider(TypedProperties props, JavaSparkContext jssc) {
super(props, jssc);
DataSourceUtils.checkRequiredProperties(props, Collections.singletonList(
Config.PROTO_SCHEMA_CLASS_NAME));
String className = config.getString(Config.PROTO_SCHEMA_CLASS_NAME);
boolean flattenWrappedPrimitives = props.getBoolean(ProtoClassBasedSchemaProvider.Config.PROTO_SCHEMA_FLATTEN_WRAPPED_PRIMITIVES, false);
Config.PROTO_SCHEMA_CLASS_NAME.key()));
String className = config.getString(Config.PROTO_SCHEMA_CLASS_NAME.key());
boolean flattenWrappedPrimitives = props.getBoolean(ProtoClassBasedSchemaProvider.Config.PROTO_SCHEMA_FLATTEN_WRAPPED_PRIMITIVES.key(),
Config.PROTO_SCHEMA_FLATTEN_WRAPPED_PRIMITIVES.defaultValue());
int maxRecursionDepth = props.getInteger(Config.PROTO_SCHEMA_MAX_RECURSION_DEPTH.key(), Config.PROTO_SCHEMA_MAX_RECURSION_DEPTH.defaultValue());
try {
schemaString = ProtoConversionUtil.getAvroSchemaForMessageClass(ReflectionUtils.getClass(className), flattenWrappedPrimitives).toString();
schemaString = ProtoConversionUtil.getAvroSchemaForMessageClass(ReflectionUtils.getClass(className), flattenWrappedPrimitives, maxRecursionDepth).toString();
} catch (Exception e) {
throw new HoodieException(String.format("Error reading proto source schema for class: %s", className), e);
throw new HoodieSchemaException(String.format("Error reading proto source schema for class: %s", className), e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ public ProtoKafkaSource(TypedProperties props, JavaSparkContext sparkContext,
SparkSession sparkSession, SchemaProvider schemaProvider, HoodieDeltaStreamerMetrics metrics) {
super(props, sparkContext, sparkSession, schemaProvider, SourceType.PROTO, metrics);
DataSourceUtils.checkRequiredProperties(props, Collections.singletonList(
ProtoClassBasedSchemaProvider.Config.PROTO_SCHEMA_CLASS_NAME));
ProtoClassBasedSchemaProvider.Config.PROTO_SCHEMA_CLASS_NAME.key()));
props.put(NATIVE_KAFKA_KEY_DESERIALIZER_PROP, StringDeserializer.class);
props.put(NATIVE_KAFKA_VALUE_DESERIALIZER_PROP, ByteArrayDeserializer.class);
className = props.getString(ProtoClassBasedSchemaProvider.Config.PROTO_SCHEMA_CLASS_NAME);
className = props.getString(ProtoClassBasedSchemaProvider.Config.PROTO_SCHEMA_CLASS_NAME.key());
this.offsetGen = new KafkaOffsetGen(props);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.avro.generic.GenericFixed;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.kafka.common.utils.CopyOnWriteMap;

import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand All @@ -45,6 +46,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

Expand All @@ -57,10 +59,11 @@ public class ProtoConversionUtil {
* Creates an Avro {@link Schema} for the provided class. Assumes that the class is a protobuf {@link Message}.
* @param clazz The protobuf class
* @param flattenWrappedPrimitives set to true to treat wrapped primitives like nullable fields instead of nested messages.
* @param maxRecursionDepth the number of times to unravel a recursive proto schema before spilling the rest to bytes
* @return An Avro schema
*/
public static Schema getAvroSchemaForMessageClass(Class clazz, boolean flattenWrappedPrimitives) {
return AvroSupport.get().getSchema(clazz, flattenWrappedPrimitives);
public static Schema getAvroSchemaForMessageClass(Class clazz, boolean flattenWrappedPrimitives, int maxRecursionDepth) {
return AvroSupport.get().getSchema(clazz, flattenWrappedPrimitives, maxRecursionDepth);
}

/**
Expand All @@ -80,17 +83,19 @@ public static GenericRecord convertToAvro(Schema schema, Message message) {
* 2. Convert directly from a protobuf {@link Message} to a {@link GenericRecord} while properly handling enums and wrapped primitives mentioned above.
*/
private static class AvroSupport {
private static final Schema STRING_SCHEMA = Schema.create(Schema.Type.STRING);
private static final Schema NULL_SCHEMA = Schema.create(Schema.Type.NULL);
private static final String OVERFLOW_DESCRIPTOR_FIELD_NAME = "descriptor_full_name";
private static final String OVERFLOW_BYTES_FIELD_NAME = "proto_bytes";
private static final Schema RECURSION_OVERFLOW_SCHEMA = Schema.createRecord("recursion_overflow", null, "org.apache.hudi.proto", false,
Arrays.asList(new Schema.Field(OVERFLOW_DESCRIPTOR_FIELD_NAME, STRING_SCHEMA, null, ""),
new Schema.Field(OVERFLOW_BYTES_FIELD_NAME, Schema.create(Schema.Type.BYTES), null, "".getBytes())));
private static final AvroSupport INSTANCE = new AvroSupport();
// A cache of the proto class name paired with whether wrapped primitives should be flattened as the key and the generated avro schema as the value
private static final Map<Pair<Class, Boolean>, Schema> SCHEMA_CACHE = new ConcurrentHashMap<>();
private static final Map<SchemaCacheKey, Schema> SCHEMA_CACHE = new ConcurrentHashMap<>();
// A cache with a key as the pair target avro schema and the proto descriptor for the source and the value as an array of proto field descriptors where the order matches the avro ordering.
// When converting from proto to avro, we want to be able to iterate over the fields in the proto in the same order as they appear in the avro schema.
private static final Map<Pair<Schema, Descriptors.Descriptor>, Descriptors.FieldDescriptor[]> FIELD_CACHE = new ConcurrentHashMap<>();


private static final Schema STRINGS = Schema.create(Schema.Type.STRING);

private static final Schema NULL = Schema.create(Schema.Type.NULL);
private static final Map<Descriptors.Descriptor, Schema.Type> WRAPPER_DESCRIPTORS_TO_TYPE = getWrapperDescriptorsToType();

private static Map<Descriptors.Descriptor, Schema.Type> getWrapperDescriptorsToType() {
Expand Down Expand Up @@ -118,14 +123,15 @@ public GenericRecord convert(Schema schema, Message message) {
return (GenericRecord) convertObject(schema, message);
}

public Schema getSchema(Class c, boolean flattenWrappedPrimitives) {
return SCHEMA_CACHE.computeIfAbsent(Pair.of(c, flattenWrappedPrimitives), key -> {
public Schema getSchema(Class c, boolean flattenWrappedPrimitives, int maxRecursionDepth) {
return SCHEMA_CACHE.computeIfAbsent(new SchemaCacheKey(c, flattenWrappedPrimitives, maxRecursionDepth), key -> {
try {
Object descriptor = c.getMethod("getDescriptor").invoke(null);
if (c.isEnum()) {
return getEnumSchema((Descriptors.EnumDescriptor) descriptor);
} else {
return getMessageSchema((Descriptors.Descriptor) descriptor, new HashMap<>(), flattenWrappedPrimitives);
Descriptors.Descriptor castedDescriptor = (Descriptors.Descriptor) descriptor;
return getMessageSchema(castedDescriptor, new CopyOnWriteMap<>(), flattenWrappedPrimitives, getNamespace(castedDescriptor.getFullName()), maxRecursionDepth);
}
} catch (Exception e) {
throw new RuntimeException(e);
Expand All @@ -141,24 +147,40 @@ private Schema getEnumSchema(Descriptors.EnumDescriptor enumDescriptor) {
return Schema.createEnum(enumDescriptor.getName(), null, getNamespace(enumDescriptor.getFullName()), symbols);
}

private Schema getMessageSchema(Descriptors.Descriptor descriptor, Map<Descriptors.Descriptor, Schema> seen, boolean flattenWrappedPrimitives) {
if (seen.containsKey(descriptor)) {
return seen.get(descriptor);
/**
* Translates a Proto Message descriptor into an Avro Schema
* @param descriptor the descriptor for the proto message
* @param recursionDepths a map of the descriptor to the number of times it has been encountered in this depth first traversal of the schema.
* This is used to cap the number of times we recurse on a schema.
* @param flattenWrappedPrimitives if true, treat wrapped primitives as nullable primitives, if false, treat them as proto messages
* @param path a string prefixed with the namespace of the original message being translated to avro and containing the current dot separated path tracking progress through the schema.
* This value is used for a namespace when creating Avro records to avoid an error when reusing the same class name when unraveling a recursive schema.
* @param maxRecursionDepth the number of times to unravel a recursive proto schema before spilling the rest to bytes
* @return an avro schema
*/
private Schema getMessageSchema(Descriptors.Descriptor descriptor, CopyOnWriteMap<Descriptors.Descriptor, Integer> recursionDepths, boolean flattenWrappedPrimitives, String path,
int maxRecursionDepth) {
// Parquet does not handle recursive schemas so we "unravel" the proto N levels
Integer currentRecursionCount = recursionDepths.getOrDefault(descriptor, 0);
if (currentRecursionCount >= maxRecursionDepth) {
return RECURSION_OVERFLOW_SCHEMA;
}
Schema result = Schema.createRecord(descriptor.getName(), null,
getNamespace(descriptor.getFullName()), false);
// The current path is used as a namespace to avoid record name collisions within recursive schemas
Schema result = Schema.createRecord(descriptor.getName(), null, path, false);

seen.put(descriptor, result);
recursionDepths.put(descriptor, ++currentRecursionCount);

List<Schema.Field> fields = new ArrayList<>(descriptor.getFields().size());
for (Descriptors.FieldDescriptor f : descriptor.getFields()) {
fields.add(new Schema.Field(f.getName(), getFieldSchema(f, seen, flattenWrappedPrimitives), null, getDefault(f)));
// each branch of the schema traversal requires its own recursion depth tracking so copy the recursionDepths map
fields.add(new Schema.Field(f.getName(), getFieldSchema(f, new CopyOnWriteMap<>(recursionDepths), flattenWrappedPrimitives, path, maxRecursionDepth), null, getDefault(f)));
}
result.setFields(fields);
return result;
}

private Schema getFieldSchema(Descriptors.FieldDescriptor f, Map<Descriptors.Descriptor, Schema> seen, boolean flattenWrappedPrimitives) {
private Schema getFieldSchema(Descriptors.FieldDescriptor f, CopyOnWriteMap<Descriptors.Descriptor, Integer> recursionDepths, boolean flattenWrappedPrimitives, String path,
int maxRecursionDepth) {
Function<Schema, Schema> schemaFinalizer = f.isRepeated() ? Schema::createArray : Function.identity();
switch (f.getType()) {
case BOOL:
Expand Down Expand Up @@ -188,16 +210,18 @@ private Schema getFieldSchema(Descriptors.FieldDescriptor f, Map<Descriptors.Des
case SFIXED64:
return schemaFinalizer.apply(Schema.create(Schema.Type.LONG));
case MESSAGE:
String updatedPath = appendFieldNameToPath(path, f.getName());
if (flattenWrappedPrimitives && WRAPPER_DESCRIPTORS_TO_TYPE.containsKey(f.getMessageType())) {
// all wrapper types have a single field, so we can get the first field in the message's schema
return schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL, getFieldSchema(f.getMessageType().getFields().get(0), seen, flattenWrappedPrimitives))));
return schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL_SCHEMA, getFieldSchema(f.getMessageType().getFields().get(0), recursionDepths, flattenWrappedPrimitives, updatedPath,
maxRecursionDepth))));
}
// if message field is repeated (like a list), elements are non-null
if (f.isRepeated()) {
return schemaFinalizer.apply(getMessageSchema(f.getMessageType(), seen, flattenWrappedPrimitives));
return schemaFinalizer.apply(getMessageSchema(f.getMessageType(), recursionDepths, flattenWrappedPrimitives, updatedPath, maxRecursionDepth));
}
// otherwise we create a nullable field schema
return schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL, getMessageSchema(f.getMessageType(), seen, flattenWrappedPrimitives))));
return schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL_SCHEMA, getMessageSchema(f.getMessageType(), recursionDepths, flattenWrappedPrimitives, updatedPath, maxRecursionDepth))));
case GROUP: // groups are deprecated
default:
throw new RuntimeException("Unexpected type: " + f.getType());
Expand Down Expand Up @@ -255,6 +279,14 @@ private Object convertObject(Schema schema, Object value) {
if (value == null) {
return null;
}
// if we've reached max recursion depth in the provided schema, write out message to bytes
if (RECURSION_OVERFLOW_SCHEMA.getFullName().equals(schema.getFullName())) {
GenericData.Record overflowRecord = new GenericData.Record(schema);
Message messageValue = (Message) value;
overflowRecord.put(OVERFLOW_DESCRIPTOR_FIELD_NAME, messageValue.getDescriptorForType().getFullName());
overflowRecord.put(OVERFLOW_BYTES_FIELD_NAME, ByteBuffer.wrap(messageValue.toByteArray()));
return overflowRecord;
}

switch (schema.getType()) {
case ARRAY:
Expand Down Expand Up @@ -305,7 +337,7 @@ private Object convertObject(Schema schema, Object value) {
Map<Object, Object> mapValue = (Map) value;
Map<Object, Object> mapCopy = new HashMap<>(mapValue.size());
for (Map.Entry<Object, Object> entry : mapValue.entrySet()) {
mapCopy.put(convertObject(STRINGS, entry.getKey()), convertObject(schema.getValueType(), entry.getValue()));
mapCopy.put(convertObject(STRING_SCHEMA, entry.getKey()), convertObject(schema.getValueType(), entry.getValue()));
}
return mapCopy;
case NULL:
Expand Down Expand Up @@ -355,5 +387,38 @@ private String getNamespace(String descriptorFullName) {
int lastDotIndex = descriptorFullName.lastIndexOf('.');
return descriptorFullName.substring(0, lastDotIndex);
}

private String appendFieldNameToPath(String existingPath, String fieldName) {
return existingPath + "." + fieldName;
}

private static class SchemaCacheKey {
private final String className;
private final boolean flattenWrappedPrimitives;
private final int maxRecursionDepth;

SchemaCacheKey(Class clazz, boolean flattenWrappedPrimitives, int maxRecursionDepth) {
this.className = clazz.getName();
this.flattenWrappedPrimitives = flattenWrappedPrimitives;
this.maxRecursionDepth = maxRecursionDepth;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SchemaCacheKey that = (SchemaCacheKey) o;
return flattenWrappedPrimitives == that.flattenWrappedPrimitives && maxRecursionDepth == that.maxRecursionDepth && className.equals(that.className);
}

@Override
public int hashCode() {
return Objects.hash(className, flattenWrappedPrimitives, maxRecursionDepth);
}
}
}
}

0 comments on commit db03e1f

Please sign in to comment.