Skip to content

Commit

Permalink
Pick up commit from @andredasilvapinto
Browse files Browse the repository at this point in the history
  • Loading branch information
Constantin Muraru committed Feb 14, 2018
1 parent 5cf9248 commit a8bd704
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ public ListConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor
if (parquetType.asGroupType().containsField("list")) {
parquetSchema = parquetType.asGroupType().getType("list");
if (parquetSchema.asGroupType().containsField("element")) {
parquetSchema.asGroupType().getType("element");
parquetSchema = parquetSchema.asGroupType().getType("element");
}
} else {
throw new ParquetDecodingException("Expected list but got: " + parquetType);
Expand All @@ -403,10 +403,6 @@ public Converter getConverter(int fieldIndex) {
throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper");
}

if (listOfMessage) {
return converter;
}

return new GroupConverter() {
@Override
public Converter getConverter(int fieldIndex) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
Expand All @@ -19,6 +19,7 @@
package org.apache.parquet.proto;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType;
import com.google.protobuf.Message;
import com.twitter.elephantbird.util.Protobufs;
Expand Down Expand Up @@ -59,8 +60,8 @@ public MessageType convert(Class<? extends Message> protobufClass) {
}

/* Iterates over list of fields. **/
private <T> GroupBuilder<T> convertFields(GroupBuilder<T> groupBuilder, List<Descriptors.FieldDescriptor> fieldDescriptors) {
for (Descriptors.FieldDescriptor fieldDescriptor : fieldDescriptors) {
private <T> GroupBuilder<T> convertFields(GroupBuilder<T> groupBuilder, List<FieldDescriptor> fieldDescriptors) {
for (FieldDescriptor fieldDescriptor : fieldDescriptors) {
groupBuilder =
addField(fieldDescriptor, groupBuilder)
.id(fieldDescriptor.getNumber())
Expand All @@ -69,7 +70,7 @@ private <T> GroupBuilder<T> convertFields(GroupBuilder<T> groupBuilder, List<Des
return groupBuilder;
}

private Type.Repetition getRepetition(Descriptors.FieldDescriptor descriptor) {
private Type.Repetition getRepetition(FieldDescriptor descriptor) {
if (descriptor.isRequired()) {
return Type.Repetition.REQUIRED;
} else if (descriptor.isRepeated()) {
Expand All @@ -79,7 +80,7 @@ private Type.Repetition getRepetition(Descriptors.FieldDescriptor descriptor) {
}
}

private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addField(Descriptors.FieldDescriptor descriptor, final GroupBuilder<T> builder) {
private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addField(FieldDescriptor descriptor, final GroupBuilder<T> builder) {
if (descriptor.getJavaType() == JavaType.MESSAGE) {
return addMessageField(descriptor, builder);
}
Expand All @@ -92,7 +93,7 @@ private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addF
return builder.primitive(parquetType.primitiveType, getRepetition(descriptor)).as(parquetType.originalType);
}

private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addRepeatedPrimitive(Descriptors.FieldDescriptor descriptor,
private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addRepeatedPrimitive(FieldDescriptor descriptor,
PrimitiveTypeName primitiveType,
OriginalType originalType,
final GroupBuilder<T> builder) {
Expand All @@ -104,18 +105,19 @@ private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addR
.named("list");
}

private <T> GroupBuilder<GroupBuilder<T>> addRepeatedMessage(Descriptors.FieldDescriptor descriptor, GroupBuilder<T> builder) {
GroupBuilder<GroupBuilder<GroupBuilder<T>>> result =
private <T> GroupBuilder<GroupBuilder<T>> addRepeatedMessage(FieldDescriptor descriptor, GroupBuilder<T> builder) {
GroupBuilder<GroupBuilder<GroupBuilder<GroupBuilder<T>>>> result =
builder
.group(Type.Repetition.REQUIRED).as(OriginalType.LIST)
.group(Type.Repetition.REPEATED);
.group(Type.Repetition.REPEATED)
.group(Type.Repetition.OPTIONAL);

convertFields(result, descriptor.getMessageType().getFields());

return result.named("list");
return result.named("element").named("list");
}

private <T> GroupBuilder<GroupBuilder<T>> addMessageField(Descriptors.FieldDescriptor descriptor, final GroupBuilder<T> builder) {
private <T> GroupBuilder<GroupBuilder<T>> addMessageField(FieldDescriptor descriptor, final GroupBuilder<T> builder) {
if (descriptor.isMapField()) {
return addMapField(descriptor, builder);
} else if (descriptor.isRepeated()) {
Expand All @@ -128,24 +130,24 @@ private <T> GroupBuilder<GroupBuilder<T>> addMessageField(Descriptors.FieldDescr
return group;
}

private <T> GroupBuilder<GroupBuilder<T>> addMapField(Descriptors.FieldDescriptor descriptor, final GroupBuilder<T> builder) {
List<Descriptors.FieldDescriptor> fields = descriptor.getMessageType().getFields();
private <T> GroupBuilder<GroupBuilder<T>> addMapField(FieldDescriptor descriptor, final GroupBuilder<T> builder) {
List<FieldDescriptor> fields = descriptor.getMessageType().getFields();
if (fields.size() != 2) {
throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields);
}

ParquetType mapKeyParquetType = getParquetType(fields.get(0));

GroupBuilder<GroupBuilder<GroupBuilder<T>>> group = builder
.group(Type.Repetition.REQUIRED).as(OriginalType.MAP)
.group(Type.Repetition.OPTIONAL).as(OriginalType.MAP) // only optional maps are allowed in Proto3
.group(Type.Repetition.REPEATED) // key_value wrapper
.primitive(mapKeyParquetType.primitiveType, Type.Repetition.REQUIRED).as(mapKeyParquetType.originalType).named("key");

return addField(fields.get(1), group).named("value")
.named("key_value");
}

private ParquetType getParquetType(Descriptors.FieldDescriptor fieldDescriptor) {
private ParquetType getParquetType(FieldDescriptor fieldDescriptor) {

JavaType javaType = fieldDescriptor.getJavaType();
switch (javaType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,17 @@
*/
package org.apache.parquet.proto;

import com.google.protobuf.ByteString;
import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.Descriptors;
import com.google.protobuf.MapEntry;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.TextFormat;
import com.google.protobuf.*;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.twitter.elephantbird.util.Protobufs;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.hadoop.BadConfigurationException;
import org.apache.parquet.hadoop.api.WriteSupport;
import org.apache.parquet.io.InvalidRecordException;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.IncompatibleSchemaModificationException;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.OriginalType;
import org.apache.parquet.schema.*;
import org.apache.parquet.schema.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -113,7 +106,7 @@ public WriteContext init(Configuration configuration) {
}

MessageType rootSchema = new ProtoSchemaConverter().convert(protoMessage);
Descriptors.Descriptor messageDescriptor = Protobufs.getMessageDescriptor(protoMessage);
Descriptor messageDescriptor = Protobufs.getMessageDescriptor(protoMessage);
validatedMapping(messageDescriptor, rootSchema);

this.messageWriter = new MessageWriter(messageDescriptor, rootSchema);
Expand Down Expand Up @@ -156,11 +149,11 @@ class MessageWriter extends FieldWriter {
final FieldWriter[] fieldWriters;

@SuppressWarnings("unchecked")
MessageWriter(Descriptors.Descriptor descriptor, GroupType schema) {
List<Descriptors.FieldDescriptor> fields = descriptor.getFields();
MessageWriter(Descriptor descriptor, GroupType schema) {
List<FieldDescriptor> fields = descriptor.getFields();
fieldWriters = (FieldWriter[]) Array.newInstance(FieldWriter.class, fields.size());

for (Descriptors.FieldDescriptor fieldDescriptor: fields) {
for (FieldDescriptor fieldDescriptor: fields) {
String name = fieldDescriptor.getName();
Type type = schema.getType(name);
FieldWriter writer = createWriter(fieldDescriptor, type);
Expand All @@ -176,7 +169,7 @@ class MessageWriter extends FieldWriter {
}
}

private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) {
private FieldWriter createWriter(FieldDescriptor fieldDescriptor, Type type) {

switch (fieldDescriptor.getJavaType()) {
case STRING: return new StringWriter() ;
Expand All @@ -193,7 +186,7 @@ private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Ty
return unknownType(fieldDescriptor);//should not be executed, always throws exception.
}

private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) {
private FieldWriter createMessageWriter(FieldDescriptor fieldDescriptor, Type type) {
if (fieldDescriptor.isMapField()) {
return createMapWriter(fieldDescriptor, type);
}
Expand All @@ -203,7 +196,7 @@ private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescrip

private GroupType getGroupType(Type type) {
if (type.getOriginalType() == OriginalType.LIST) {
return type.asGroupType().getType("list").asGroupType();
return type.asGroupType().getType("list").asGroupType().getType("element").asGroupType();
}

if (type.getOriginalType() == OriginalType.MAP) {
Expand All @@ -213,20 +206,20 @@ private GroupType getGroupType(Type type) {
return type.asGroupType();
}

private MapWriter createMapWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) {
List<Descriptors.FieldDescriptor> fields = fieldDescriptor.getMessageType().getFields();
private MapWriter createMapWriter(FieldDescriptor fieldDescriptor, Type type) {
List<FieldDescriptor> fields = fieldDescriptor.getMessageType().getFields();
if (fields.size() != 2) {
throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields);
}

// KeyFieldWriter
Descriptors.FieldDescriptor keyProtoField = fields.get(0);
FieldDescriptor keyProtoField = fields.get(0);
FieldWriter keyWriter = createWriter(keyProtoField, type);
keyWriter.setFieldName(keyProtoField.getName());
keyWriter.setIndex(0);

// ValueFieldWriter
Descriptors.FieldDescriptor valueProtoField = fields.get(1);
FieldDescriptor valueProtoField = fields.get(1);
FieldWriter valueWriter = createWriter(valueProtoField, type);
valueWriter.setFieldName(valueProtoField.getName());
valueWriter.setIndex(1);
Expand Down Expand Up @@ -257,10 +250,10 @@ final void writeField(Object value) {

private void writeAllFields(MessageOrBuilder pb) {
//returns changed fields with values. Map is ordered by id.
Map<Descriptors.FieldDescriptor, Object> changedPbFields = pb.getAllFields();
Map<FieldDescriptor, Object> changedPbFields = pb.getAllFields();

for (Map.Entry<Descriptors.FieldDescriptor, Object> entry : changedPbFields.entrySet()) {
Descriptors.FieldDescriptor fieldDescriptor = entry.getKey();
for (Map.Entry<FieldDescriptor, Object> entry : changedPbFields.entrySet()) {
FieldDescriptor fieldDescriptor = entry.getKey();

if(fieldDescriptor.isExtension()) {
// Field index of an extension field might overlap with a base field.
Expand Down Expand Up @@ -295,13 +288,21 @@ final void writeField(Object value) {
recordConsumer.startField("list", 0); // This is the wrapper group for the array field
for (Object listEntry: list) {
recordConsumer.startGroup();
if (isPrimitive(listEntry)) {
recordConsumer.startField("element", 0);

recordConsumer.startField("element", 0); // This is the mandatory inner field

if (!isPrimitive(listEntry)) {
recordConsumer.startGroup();
}

fieldWriter.writeRawValue(listEntry);
if (isPrimitive(listEntry)) {
recordConsumer.endField("element", 0);

if (!isPrimitive(listEntry)) {
recordConsumer.endGroup();
}

recordConsumer.endField("element", 0);

recordConsumer.endGroup();
}
recordConsumer.endField("list", 0);
Expand All @@ -316,10 +317,10 @@ private boolean isPrimitive(Object listEntry) {
}

/** validates mapping between protobuffer fields and parquet fields.*/
private void validatedMapping(Descriptors.Descriptor descriptor, GroupType parquetSchema) {
List<Descriptors.FieldDescriptor> allFields = descriptor.getFields();
private void validatedMapping(Descriptor descriptor, GroupType parquetSchema) {
List<FieldDescriptor> allFields = descriptor.getFields();

for (Descriptors.FieldDescriptor fieldDescriptor: allFields) {
for (FieldDescriptor fieldDescriptor: allFields) {
String fieldName = fieldDescriptor.getName();
int fieldIndex = fieldDescriptor.getIndex();
int parquetIndex = parquetSchema.getFieldIndex(fieldName);
Expand Down Expand Up @@ -370,10 +371,16 @@ final void writeRawValue(Object value) {
recordConsumer.startGroup();

recordConsumer.startField("key_value", 0); // This is the wrapper group for the map field
for(MapEntry<?, ?> entry : (Collection<MapEntry<?, ?>>) value) {
for (Message msg : (Collection<Message>) value) {
recordConsumer.startGroup();
keyWriter.writeField(entry.getKey());
valueWriter.writeField(entry.getValue());

final Descriptor descriptorForType = msg.getDescriptorForType();
final FieldDescriptor keyDesc = descriptorForType.findFieldByName("key");
final FieldDescriptor valueDesc = descriptorForType.findFieldByName("value");

keyWriter.writeField(msg.getField(keyDesc));
valueWriter.writeField(msg.getField(valueDesc));

recordConsumer.endGroup();
}

Expand Down Expand Up @@ -421,15 +428,15 @@ final void writeRawValue(Object value) {
}
}

private FieldWriter unknownType(Descriptors.FieldDescriptor fieldDescriptor) {
private FieldWriter unknownType(FieldDescriptor fieldDescriptor) {
String exceptionMsg = "Unknown type with descriptor \"" + fieldDescriptor
+ "\" and type \"" + fieldDescriptor.getJavaType() + "\".";
throw new InvalidRecordException(exceptionMsg);
}

/** Returns message descriptor as JSON String*/
private String serializeDescriptor(Class<? extends Message> protoClass) {
Descriptors.Descriptor descriptor = Protobufs.getMessageDescriptor(protoClass);
Descriptor descriptor = Protobufs.getMessageDescriptor(protoClass);
DescriptorProtos.DescriptorProto asProto = descriptor.toProto();
return TextFormat.printToString(asProto);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public void testProto3ConvertAllDatatypes() throws Exception {
" optional binary optionalEnum (ENUM) = 18;" +
" optional int32 someInt32 = 19;" +
" optional binary someString (UTF8) = 20;" +
" required group optionalMap (MAP) = 21 {\n" +
" optional group optionalMap (MAP) = 21 {\n" +
" repeated group key_value {\n" +
" required int64 key;\n" +
" optional group value {\n" +
Expand Down Expand Up @@ -135,7 +135,9 @@ public void testConvertRepetition() throws Exception {
" }\n" +
" required group repeatedMessage (LIST) = 9 {\n" +
" repeated group list {\n" +
" optional int32 someId = 3;\n" +
" optional group element {\n" +
" optional int32 someId = 3;\n" +
" }\n" +
" }\n" +
" }" +
"}";
Expand All @@ -158,7 +160,9 @@ public void testProto3ConvertRepetition() throws Exception {
" }\n" +
" required group repeatedMessage (LIST) = 9 {\n" +
" repeated group list {\n" +
" optional int32 someId = 3;\n" +
" optional group element {\n" +
" optional int32 someId = 3;\n" +
" }\n" +
" }\n" +
" }\n" +
"}";
Expand Down

0 comments on commit a8bd704

Please sign in to comment.