Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@
import java.util.List;
import java.util.Set;

import org.apache.spark.memory.MemoryMode;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.catalyst.types.DataTypeUtils;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructType;
Expand Down Expand Up @@ -69,16 +65,9 @@ final class ParquetColumnVector {
ParquetColumn column,
WritableColumnVector vector,
int capacity,
MemoryMode memoryMode,
Set<ParquetColumn> missingColumns,
boolean isTopLevel,
Object defaultValue) {
DataType sparkType = column.sparkType();
if (!DataTypeUtils.sameType(sparkType, vector.dataType())) {
throw new IllegalArgumentException("Spark type: " + sparkType +
" doesn't match the type: " + vector.dataType() + " in column vector");
}

this.column = column;
this.vector = vector;
this.children = new ArrayList<>();
Expand Down Expand Up @@ -111,35 +100,41 @@ final class ParquetColumnVector {

if (column.variantFileType().isDefined()) {
ParquetColumn fileContentCol = column.variantFileType().get();
WritableColumnVector fileContent = memoryMode == MemoryMode.OFF_HEAP
? new OffHeapColumnVector(capacity, fileContentCol.sparkType())
: new OnHeapColumnVector(capacity, fileContentCol.sparkType());
ParquetColumnVector contentVector = new ParquetColumnVector(fileContentCol,
fileContent, capacity, memoryMode, missingColumns, false, null);
WritableColumnVector fileContent = vector.reserveNewColumn(
capacity, fileContentCol.sparkType());
ParquetColumnVector contentVector = new ParquetColumnVector(fileContentCol, fileContent,
capacity, missingColumns, /* isTopLevel= */ false, /* defaultValue= */ null);
children.add(contentVector);
variantSchema = SparkShreddingUtils.buildVariantSchema(fileContentCol.sparkType());
fieldsToExtract = SparkShreddingUtils.getFieldsToExtract(column.sparkType(), variantSchema);
repetitionLevels = contentVector.repetitionLevels;
definitionLevels = contentVector.definitionLevels;
} else if (isPrimitive) {
if (column.repetitionLevel() > 0) {
repetitionLevels = allocateLevelsVector(capacity, memoryMode);
repetitionLevels = vector.reserveNewColumn(capacity, DataTypes.IntegerType);
}
// We don't need to create and store definition levels if the column is top-level.
if (!isTopLevel) {
definitionLevels = allocateLevelsVector(capacity, memoryMode);
definitionLevels = vector.reserveNewColumn(capacity, DataTypes.IntegerType);
}
} else {
JavaUtils.checkArgument(column.children().size() == vector.getNumChildren(),
"The number of column children is different from the number of vector children");
// If a child is not present in the allocated vectors, it means we don't care about this
// child's data, we just want to read its levels to help assemble some parent struct. So we
// create a dummy vector below to hold the child's data. There can only be one such child.
JavaUtils.checkArgument(column.children().size() == vector.getNumChildren() ||
column.children().size() == vector.getNumChildren() + 1,
"The number of column children is not equal to the number of vector children or that + 1");
boolean allChildrenAreMissing = true;

for (int i = 0; i < column.children().size(); i++) {
ParquetColumnVector childCv = new ParquetColumnVector(column.children().apply(i),
vector.getChild(i), capacity, memoryMode, missingColumns, false, null);
ParquetColumn childColumn = column.children().apply(i);
WritableColumnVector childVector = i < vector.getNumChildren()
? vector.getChild(i)
: vector.reserveNewColumn(capacity, childColumn.sparkType());
ParquetColumnVector childCv = new ParquetColumnVector(childColumn, childVector, capacity,
missingColumns, /* isTopLevel= */ false, /* defaultValue= */ null);
children.add(childCv);


// Only use levels from non-missing child, this can happen if only some but not all
// fields of a struct are missing.
if (!childCv.vector.isAllNull()) {
Expand Down Expand Up @@ -375,13 +370,6 @@ private void assembleStruct() {
vector.addElementsAppended(rowId);
}

private static WritableColumnVector allocateLevelsVector(int capacity, MemoryMode memoryMode) {
return switch (memoryMode) {
case ON_HEAP -> new OnHeapColumnVector(capacity, DataTypes.IntegerType);
case OFF_HEAP -> new OffHeapColumnVector(capacity, DataTypes.IntegerType);
};
}

/**
* For a collection (i.e., array or map) element at index 'idx', returns the starting index of
* the next collection after it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ protected void initialize(String path, List<String> columns) throws IOException
this.parquetColumn = new ParquetToSparkSchemaConverter(config)
.convertParquetColumn(requestedSchema, Option.empty());
this.sparkSchema = (StructType) parquetColumn.sparkType();
this.sparkRequestedSchema = this.sparkSchema;
this.totalRowCount = fileReader.getFilteredRecordCount();
}

Expand All @@ -225,6 +226,7 @@ protected void initialize(
this.parquetColumn = new ParquetToSparkSchemaConverter(config)
.convertParquetColumn(requestedSchema, Option.empty());
this.sparkSchema = (StructType) parquetColumn.sparkType();
this.sparkRequestedSchema = this.sparkSchema;
this.totalRowCount = totalRowCount;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

/**
* A specialized RecordReader that reads into InternalRows or ColumnarBatches directly using the
Expand Down Expand Up @@ -265,7 +264,7 @@ private void initBatch(
MemoryMode memMode,
StructType partitionColumns,
InternalRow partitionValues) {
StructType batchSchema = new StructType(sparkSchema.fields());
StructType batchSchema = (StructType) truncateType(sparkSchema, sparkRequestedSchema);

int constantColumnLength = 0;
if (partitionColumns != null) {
Expand All @@ -287,7 +286,8 @@ private void initBatch(
defaultValue = ResolveDefaultColumns.existenceDefaultValues(sparkRequestedSchema)[i];
}
columnVectors[i] = new ParquetColumnVector(parquetColumn.children().apply(i),
(WritableColumnVector) vectors[i], capacity, memMode, missingColumns, true, defaultValue);
(WritableColumnVector) vectors[i], capacity, missingColumns, /* isTopLevel= */ true,
defaultValue);
}

if (partitionColumns != null) {
Expand All @@ -309,6 +309,58 @@ public void initBatch(StructType partitionColumns, InternalRow partitionValues)
initBatch(MEMORY_MODE, partitionColumns, partitionValues);
}

/**
* Keeps the hierarchy and fields of readType, recursively truncating struct fields from the end
* of the fields list to match the same number of fields in requestedType. This is used to get rid
* of the extra fields that are added to the structs when the fields we wanted to read initially
* were missing in the file schema. So this returns a type that we would be reading if everything
* was present in the file, matching Spark's expected schema.
*
* <p> Example: <pre>{@code
* readType: array<struct<a:int,b:long,c:int>>
* requestedType: array<struct<a:int,b:long>>
* returns: array<struct<a:int,b:long>>
* }</pre>
* We cannot return requestedType here because there might be slight differences, like nullability
* of fields or the type precision (smallint/int)
*/
@VisibleForTesting
static DataType truncateType(DataType readType, DataType requestedType) {
if (requestedType instanceof UserDefinedType<?> requestedUDT) {
requestedType = requestedUDT.sqlType();
}

if (readType instanceof StructType readStruct &&
requestedType instanceof StructType requestedStruct) {
StructType result = new StructType();
for (int i = 0; i < requestedStruct.fields().length; i++) {
StructField readField = readStruct.fields()[i];
StructField requestedField = requestedStruct.fields()[i];
DataType truncatedType = truncateType(readField.dataType(), requestedField.dataType());
result = result.add(readField.copy(
readField.name(), truncatedType, readField.nullable(), readField.metadata()));
}
return result;
}

if (readType instanceof ArrayType readArray &&
requestedType instanceof ArrayType requestedArray) {
DataType truncatedElementType = truncateType(
readArray.elementType(), requestedArray.elementType());
return readArray.copy(truncatedElementType, readArray.containsNull());
}

if (readType instanceof MapType readMap && requestedType instanceof MapType requestedMap) {
DataType truncatedKeyType = truncateType(readMap.keyType(), requestedMap.keyType());
DataType truncatedValueType = truncateType(readMap.valueType(), requestedMap.valueType());
return readMap.copy(truncatedKeyType, truncatedValueType, readMap.valueContainsNull());
}

assert !ParquetSchemaConverter.isComplexType(readType);
assert !ParquetSchemaConverter.isComplexType(requestedType);
return readType;
}

/**
* Returns the ColumnarBatch object that will be used for all rows returned by this reader.
* This object is reused. Calling this enables the vectorized reader. This should be called
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ protected void reserveInternal(int newCapacity) {
}

@Override
protected OffHeapColumnVector reserveNewColumn(int capacity, DataType type) {
public OffHeapColumnVector reserveNewColumn(int capacity, DataType type) {
return new OffHeapColumnVector(capacity, type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ protected void reserveInternal(int newCapacity) {
}

@Override
protected OnHeapColumnVector reserveNewColumn(int capacity, DataType type) {
public OnHeapColumnVector reserveNewColumn(int capacity, DataType type) {
return new OnHeapColumnVector(capacity, type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ public final boolean isAllNull() {
/**
* Reserve a new column.
*/
protected abstract WritableColumnVector reserveNewColumn(int capacity, DataType type);
public abstract WritableColumnVector reserveNewColumn(int capacity, DataType type);

protected boolean isArray() {
return type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,13 @@ object ParquetReadSupport extends Logging {
caseSensitive: Boolean,
useFieldId: Boolean): Type = {
val newParquetType = catalystType match {
case t: ArrayType if !isPrimitiveCatalystType(t.elementType) =>
case t: ArrayType if ParquetSchemaConverter.isComplexType(t.elementType) =>
// Only clips array types with nested type as element type.
clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive, useFieldId)

case t: MapType
if !isPrimitiveCatalystType(t.keyType) ||
!isPrimitiveCatalystType(t.valueType) =>
if ParquetSchemaConverter.isComplexType(t.keyType) ||
ParquetSchemaConverter.isComplexType(t.valueType) =>
// Only clips map types with nested key type or value type
clipParquetMapType(
parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive, useFieldId)
Expand All @@ -241,18 +241,6 @@ object ParquetReadSupport extends Logging {
}
}

/**
* Whether a Catalyst [[DataType]] is primitive. Primitive [[DataType]] is not equivalent to
* [[AtomicType]]. For example, [[CalendarIntervalType]] is primitive, but it's not an
* [[AtomicType]].
*/
private def isPrimitiveCatalystType(dataType: DataType): Boolean = {
dataType match {
case _: ArrayType | _: MapType | _: StructType => false
case _ => true
}
}

/**
* Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[ArrayType]]. The element type
* of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a
Expand All @@ -264,7 +252,7 @@ object ParquetReadSupport extends Logging {
caseSensitive: Boolean,
useFieldId: Boolean): Type = {
// Precondition of this method, should only be called for lists with nested element types.
assert(!isPrimitiveCatalystType(elementType))
assert(ParquetSchemaConverter.isComplexType(elementType))

// Unannotated repeated group should be interpreted as required list of required element, so
// list element type is just the group itself. Clip it.
Expand Down Expand Up @@ -343,7 +331,8 @@ object ParquetReadSupport extends Logging {
caseSensitive: Boolean,
useFieldId: Boolean): GroupType = {
// Precondition of this method, only handles maps with nested key types or value types.
assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType))
assert(ParquetSchemaConverter.isComplexType(keyType) ||
ParquetSchemaConverter.isComplexType(valueType))

val repeatedGroup = parquetMap.getType(0).asGroupType()
val parquetKeyType = repeatedGroup.getType(0)
Expand Down Expand Up @@ -418,11 +407,15 @@ object ParquetReadSupport extends Logging {
parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT))
lazy val idToParquetFieldMap =
parquetRecord.getFields.asScala.filter(_.getId != null).groupBy(f => f.getId.intValue())
var isStructWithMissingFields = true

def matchCaseSensitiveField(f: StructField): Type = {
caseSensitiveParquetFieldMap
.get(f.name)
.map(clipParquetType(_, f.dataType, caseSensitive, useFieldId))
.map { parquetType =>
isStructWithMissingFields = false
clipParquetType(parquetType, f.dataType, caseSensitive, useFieldId)
}
.getOrElse(toParquet.convertField(f, inShredded = false))
}

Expand All @@ -437,6 +430,7 @@ object ParquetReadSupport extends Logging {
throw QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError(
f.name, parquetTypesString)
} else {
isStructWithMissingFields = false
clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId)
}
}.getOrElse(toParquet.convertField(f, inShredded = false))
Expand All @@ -453,6 +447,7 @@ object ParquetReadSupport extends Logging {
throw QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError(
fieldId, parquetTypesString)
} else {
isStructWithMissingFields = false
clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId)
}
}.getOrElse {
Expand All @@ -463,7 +458,7 @@ object ParquetReadSupport extends Logging {
}

val shouldMatchById = useFieldId && ParquetUtils.hasFieldIds(structType)
structType.map { f =>
val clippedType = structType.map { f =>
if (shouldMatchById && ParquetUtils.hasFieldId(f)) {
matchIdField(f)
} else if (caseSensitive) {
Expand All @@ -472,6 +467,63 @@ object ParquetReadSupport extends Logging {
matchCaseInsensitiveField(f)
}
}
// Ignore MessageType, because it is the root of the schema, not a struct.
if (!isStructWithMissingFields || parquetRecord.isInstanceOf[MessageType]) {
clippedType
} else {
// Read one arbitrary field to understand when the struct value is null or not null.
clippedType :+ findCheapestGroupField(parquetRecord)
}
}

/**
* Finds the leaf node under a given file schema node that is likely to be cheapest to fetch.
* Keeps this leaf node inside the same parent hierarchy. This is used when all struct fields in
* the requested schema are missing. Uses a very simple heuristic based on the parquet type.
*/
private def findCheapestGroupField(parentGroupType: GroupType): Type = {
def findCheapestGroupFieldRecurse(curType: Type, repLevel: Int = 0): (Type, Int, Int) = {
curType match {
case groupType: GroupType =>
var (bestType, bestRepLevel, bestCost) = (Option.empty[Type], 0, 0)
for (field <- groupType.getFields.asScala) {
val newRepLevel = repLevel + (if (field.isRepetition(Repetition.REPEATED)) 1 else 0)
// Never take a field at a deeper repetition level, since it's likely to have more data.
// Don't do safety checks because we should already have done them when traversing the
// schema for the first time.
if (bestType.isEmpty || newRepLevel <= bestRepLevel) {
val (childType, childRepLevel, childCost) =
findCheapestGroupFieldRecurse(field, newRepLevel)
// Always prefer elements with a lower repetition level, since more nesting of arrays
// is likely to result in more data. At the same repetition level, prefer the smaller
// type.
if (bestType.isEmpty || childRepLevel < bestRepLevel ||
(childRepLevel == bestRepLevel && childCost < bestCost)) {
// This is the new best path.
bestType = Some(childType)
bestRepLevel = childRepLevel
bestCost = childCost
}
}
}
(groupType.withNewFields(bestType.get), bestRepLevel, bestCost)
case primitiveType: PrimitiveType =>
val cost = primitiveType.getPrimitiveTypeName match {
case PrimitiveType.PrimitiveTypeName.BOOLEAN => 1
case PrimitiveType.PrimitiveTypeName.INT32 => 4
case PrimitiveType.PrimitiveTypeName.INT64 => 8
case PrimitiveType.PrimitiveTypeName.FLOAT => 4
case PrimitiveType.PrimitiveTypeName.DOUBLE => 8
// Strings seem undesirable, since they don't have a fixed size. Give them a high cost.
case PrimitiveType.PrimitiveTypeName.BINARY |
PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => 32
case PrimitiveType.PrimitiveTypeName.INT96 => 12
}
(primitiveType, repLevel, cost)
}
}
// Ignore the highest level of the hierarchy since we are interested only in the subfield.
findCheapestGroupFieldRecurse(parentGroupType)._1.asGroupType().getType(0)
}

/**
Expand Down
Loading