New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-36646][SQL] Push down group by partition column for aggregate #34445
Changes from 6 commits
9ea1c2d
03c2bd4
0e655a8
4fb313b
4eeae6d
b561d09
1f45a04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
package org.apache.spark.sql.execution.datasources | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.Expression | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} | ||
import org.apache.spark.sql.connector.expressions.NamedReference | ||
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} | ||
import org.apache.spark.sql.execution.RowToColumnConverter | ||
|
@@ -81,19 +81,37 @@ object AggregatePushDownUtils { | |
} | ||
} | ||
|
||
if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) { | ||
if (dataFilters.nonEmpty) { | ||
// Parquet/ORC footer has max/min/count for columns | ||
// e.g. SELECT COUNT(col1) FROM t | ||
// but footer doesn't have max/min/count for a column if max/min/count | ||
// are combined with filter or group by | ||
// e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 | ||
// SELECT COUNT(col1) FROM t GROUP BY col2 | ||
// However, if the filter is on partition column, max/min/count can still be pushed down | ||
// Todo: add support if groupby column is partition col | ||
// (https://issues.apache.org/jira/browse/SPARK-36646) | ||
return None | ||
} | ||
|
||
if (aggregation.groupByColumns.nonEmpty && | ||
partitionNames.size != aggregation.groupByColumns.length) { | ||
// If there are group by columns, we only push down if the group by columns are the same as | ||
// the partition columns. In theory, if group by columns are a subset of partition columns, | ||
// we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3, | ||
// SELECT MAX(c) FROM t GROUP BY p1, p2 should still be able to push down. However, the | ||
// partial aggregation pushed down to data source needs to be | ||
// SELECT p1, p2, p3, MAX(c) FROM t GROUP BY p1, p2, p3, and Spark layer | ||
// needs to have a final aggregation such as SELECT MAX(c) FROM t GROUP BY p1, p2, then the | ||
// pushed down query schema is different from the query schema at Spark. We will keep | ||
// aggregate push down simple and don't handle this complicate case for now. | ||
return None | ||
} | ||
aggregation.groupByColumns.foreach { col => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe also add some comments here - it's not that easy to understand and can help the maintenance of this code. |
||
// don't push down if the group by columns are not the same as the partition columns (orders | ||
// doesn't matter because reorder can be done at data source layer) | ||
if (col.fieldNames.length != 1 || !isPartitionCol(col)) return None | ||
finalSchema = finalSchema.add(getStructFieldForCol(col)) | ||
} | ||
|
||
aggregation.aggregateExpressions.foreach { | ||
case max: Max => | ||
if (!processMinOrMax(max)) return None | ||
|
@@ -138,4 +156,44 @@ object AggregatePushDownUtils { | |
converter.convert(aggregatesAsRow, columnVectors.toArray) | ||
new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) | ||
} | ||
|
||
/** | ||
* Return the schema for aggregates only (exclude group by columns) | ||
*/ | ||
def getSchemaWithoutGroupingExpression( | ||
sunchao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
aggSchema: StructType, | ||
aggregation: Aggregation): StructType = { | ||
val numOfGroupByColumns = aggregation.groupByColumns.length | ||
if (numOfGroupByColumns > 0) { | ||
new StructType(aggSchema.fields.drop(numOfGroupByColumns)) | ||
} else { | ||
aggSchema | ||
} | ||
} | ||
|
||
/** | ||
* Reorder partition cols if they are not in the same order as group by columns | ||
*/ | ||
def reOrderPartitionCol( | ||
partitionSchema: StructType, | ||
aggregation: Aggregation, | ||
partitionValues: InternalRow): InternalRow = { | ||
val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head) | ||
huaxingao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert(groupByColNames.length == partitionSchema.length && | ||
groupByColNames.length == partitionValues.numFields, "The number of group by columns " + | ||
s"${groupByColNames.length} should be the same as partition schema length " + | ||
s"${partitionSchema.length} and the number of fields ${partitionValues.numFields} " + | ||
s"in partitionValues") | ||
var reorderedPartColValues = Array.empty[Any] | ||
if (!partitionSchema.names.sameElements(groupByColNames)) { | ||
groupByColNames.foreach { col => | ||
val index = partitionSchema.names.indexOf(col) | ||
val v = partitionValues.asInstanceOf[GenericInternalRow].values(index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just curious: is this always guaranteed to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems to me that the |
||
reorderedPartColValues = reorderedPartColValues :+ v | ||
} | ||
new GenericInternalRow(reorderedPartColValues) | ||
} else { | ||
partitionValues | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,8 +31,9 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName | |
import org.apache.spark.SparkException | ||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.JoinedRow | ||
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} | ||
import org.apache.spark.sql.execution.datasources.PartitioningUtils | ||
import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils | ||
import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} | ||
import org.apache.spark.sql.types.StructType | ||
|
||
|
@@ -157,17 +158,22 @@ object ParquetUtils { | |
partitionSchema: StructType, | ||
aggregation: Aggregation, | ||
aggSchema: StructType, | ||
datetimeRebaseMode: LegacyBehaviorPolicy.Value, | ||
isCaseSensitive: Boolean): InternalRow = { | ||
partitionValues: InternalRow, | ||
datetimeRebaseMode: LegacyBehaviorPolicy.Value): InternalRow = { | ||
val (primitiveTypes, values) = getPushedDownAggResult( | ||
footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive) | ||
footer, filePath, dataSchema, partitionSchema, aggregation) | ||
|
||
val builder = Types.buildMessage | ||
primitiveTypes.foreach(t => builder.addField(t)) | ||
val parquetSchema = builder.named("root") | ||
|
||
// if there are group by columns, we will build result row first, | ||
// and then append group by columns values (partition columns values) to the result row. | ||
val schemaWithoutGroupBy = | ||
AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggSchema, aggregation) | ||
|
||
val schemaConverter = new ParquetToSparkSchemaConverter | ||
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, | ||
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, schemaWithoutGroupBy, | ||
None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) | ||
val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName) | ||
primitiveTypeNames.zipWithIndex.foreach { | ||
|
@@ -195,7 +201,14 @@ object ParquetUtils { | |
case (_, i) => | ||
throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i)) | ||
} | ||
converter.currentRecord | ||
|
||
if (aggregation.groupByColumns.nonEmpty) { | ||
val reorderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol( | ||
partitionSchema, aggregation, partitionValues) | ||
new JoinedRow(reorderedPartitionValues, converter.currentRecord) | ||
} else { | ||
converter.currentRecord | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -211,16 +224,14 @@ object ParquetUtils { | |
filePath: String, | ||
dataSchema: StructType, | ||
partitionSchema: StructType, | ||
aggregation: Aggregation, | ||
isCaseSensitive: Boolean) | ||
aggregation: Aggregation) | ||
: (Array[PrimitiveType], Array[Any]) = { | ||
val footerFileMetaData = footer.getFileMetaData | ||
val fields = footerFileMetaData.getSchema.getFields | ||
val blocks = footer.getBlocks | ||
val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType] | ||
val valuesBuilder = mutable.ArrayBuilder.make[Any] | ||
|
||
assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down") | ||
aggregation.aggregateExpressions.foreach { agg => | ||
var value: Any = None | ||
var rowCount = 0L | ||
|
@@ -250,8 +261,7 @@ object ParquetUtils { | |
schemaName = "count(" + count.column.fieldNames.head + ")" | ||
rowCount += block.getRowCount | ||
var isPartitionCol = false | ||
if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) | ||
.toSet.contains(count.column.fieldNames.head)) { | ||
if (partitionSchema.fields.map(_.name).toSet.contains(count.column.fieldNames.head)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't need check case sensitivity now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems to me no need to check case sensitivity because I have normalized aggregates and group by columns in |
||
isPartitionCol = true | ||
} | ||
isCount = true | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,11 +83,10 @@ case class OrcPartitionReaderFactory( | |
|
||
override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { | ||
val conf = broadcastedConf.value.value | ||
val filePath = new Path(new URI(file.filePath)) | ||
|
||
if (aggregation.nonEmpty) { | ||
return buildReaderWithAggregates(filePath, conf) | ||
return buildReaderWithAggregates(file, conf) | ||
} | ||
val filePath = new Path(new URI(file.filePath)) | ||
|
||
val resultedColPruneInfo = | ||
Utils.tryWithResource(createORCReader(filePath, conf)) { reader => | ||
|
@@ -127,11 +126,10 @@ case class OrcPartitionReaderFactory( | |
|
||
override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { | ||
val conf = broadcastedConf.value.value | ||
val filePath = new Path(new URI(file.filePath)) | ||
|
||
if (aggregation.nonEmpty) { | ||
return buildColumnarReaderWithAggregates(filePath, conf) | ||
return buildColumnarReaderWithAggregates(file, conf) | ||
} | ||
Comment on lines
129
to
131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
val filePath = new Path(new URI(file.filePath)) | ||
|
||
val resultedColPruneInfo = | ||
Utils.tryWithResource(createORCReader(filePath, conf)) { reader => | ||
|
@@ -183,14 +181,16 @@ case class OrcPartitionReaderFactory( | |
* Build reader with aggregate push down. | ||
*/ | ||
private def buildReaderWithAggregates( | ||
filePath: Path, | ||
file: PartitionedFile, | ||
conf: Configuration): PartitionReader[InternalRow] = { | ||
val filePath = new Path(new URI(file.filePath)) | ||
new PartitionReader[InternalRow] { | ||
private var hasNext = true | ||
private lazy val row: InternalRow = { | ||
Utils.tryWithResource(createORCReader(filePath, conf)) { reader => | ||
OrcUtils.createAggInternalRowFromFooter( | ||
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema) | ||
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, | ||
readDataSchema, file.partitionValues) | ||
} | ||
} | ||
|
||
|
@@ -209,15 +209,16 @@ case class OrcPartitionReaderFactory( | |
* Build columnar reader with aggregate push down. | ||
*/ | ||
private def buildColumnarReaderWithAggregates( | ||
filePath: Path, | ||
file: PartitionedFile, | ||
conf: Configuration): PartitionReader[ColumnarBatch] = { | ||
val filePath = new Path(new URI(file.filePath)) | ||
new PartitionReader[ColumnarBatch] { | ||
private var hasNext = true | ||
private lazy val batch: ColumnarBatch = { | ||
Utils.tryWithResource(createORCReader(filePath, conf)) { reader => | ||
val row = OrcUtils.createAggInternalRowFromFooter( | ||
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, | ||
readDataSchema) | ||
readDataSchema, file.partitionValues) | ||
AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false) | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe add some comments explaining the reasoning why we have this check and only support the case when group by columns is the same as partition columns. What if the number of group by columns is smaller than that of partition columns?