New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-36646][SQL] Push down group by partition column for aggregate #34445
Changes from 2 commits
9ea1c2d
03c2bd4
0e655a8
4fb313b
4eeae6d
b561d09
1f45a04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,19 +81,22 @@ object AggregatePushDownUtils { | |
} | ||
} | ||
|
||
if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) { | ||
if (dataFilters.nonEmpty) { | ||
// Parquet/ORC footer has max/min/count for columns | ||
// e.g. SELECT COUNT(col1) FROM t | ||
// but footer doesn't have max/min/count for a column if max/min/count | ||
// are combined with filter or group by | ||
// e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 | ||
// SELECT COUNT(col1) FROM t GROUP BY col2 | ||
// However, if the filter is on partition column, max/min/count can still be pushed down | ||
// Todo: add support if groupby column is partition col | ||
// (https://issues.apache.org/jira/browse/SPARK-36646) | ||
return None | ||
} | ||
|
||
aggregation.groupByColumns.foreach { col => | ||
if (col.fieldNames.length != 1 || !isPartitionCol(col)) return None | ||
finalSchema = finalSchema.add(getStructFieldForCol(col)) | ||
} | ||
|
||
aggregation.aggregateExpressions.foreach { | ||
case max: Max => | ||
if (!processMinOrMax(max)) return None | ||
|
@@ -138,4 +141,19 @@ object AggregatePushDownUtils { | |
converter.convert(aggregatesAsRow, columnVectors.toArray) | ||
new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) | ||
} | ||
|
||
/** | ||
* Return the schema for aggregates only (exclude group by columns) | ||
*/ | ||
def getSchemaWithoutGroupingExpression( | ||
sunchao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
aggregation: Aggregation, | ||
aggSchema: StructType): StructType = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe swap the order of |
||
val groupByColNums = aggregation.groupByColumns.length | ||
huaxingao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (groupByColNums > 0) { | ||
new StructType(aggSchema.fields.drop(groupByColNums)) | ||
} else { | ||
aggSchema | ||
} | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,8 +31,9 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName | |
import org.apache.spark.SparkException | ||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.JoinedRow | ||
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} | ||
import org.apache.spark.sql.execution.datasources.PartitioningUtils | ||
import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils | ||
import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} | ||
import org.apache.spark.sql.types.StructType | ||
|
||
|
@@ -157,17 +158,22 @@ object ParquetUtils { | |
partitionSchema: StructType, | ||
aggregation: Aggregation, | ||
aggSchema: StructType, | ||
datetimeRebaseMode: LegacyBehaviorPolicy.Value, | ||
isCaseSensitive: Boolean): InternalRow = { | ||
partitionValues: InternalRow, | ||
datetimeRebaseMode: LegacyBehaviorPolicy.Value): InternalRow = { | ||
val (primitiveTypes, values) = getPushedDownAggResult( | ||
footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive) | ||
footer, filePath, dataSchema, partitionSchema, aggregation) | ||
|
||
val builder = Types.buildMessage | ||
primitiveTypes.foreach(t => builder.addField(t)) | ||
val parquetSchema = builder.named("root") | ||
|
||
// if there are group by columns, we will build result row first, | ||
// and then append group by columns values (partition columns values) to the result row. | ||
val schemaWithoutGroupby = | ||
AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggregation, aggSchema) | ||
|
||
val schemaConverter = new ParquetToSparkSchemaConverter | ||
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, | ||
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, schemaWithoutGroupby, | ||
None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) | ||
val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName) | ||
primitiveTypeNames.zipWithIndex.foreach { | ||
|
@@ -195,7 +201,12 @@ object ParquetUtils { | |
case (_, i) => | ||
throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i)) | ||
} | ||
converter.currentRecord | ||
|
||
if (aggregation.groupByColumns.length > 0) { | ||
huaxingao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
new JoinedRow(partitionValues, converter.currentRecord) | ||
huaxingao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} else { | ||
converter.currentRecord | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -211,16 +222,14 @@ object ParquetUtils { | |
filePath: String, | ||
dataSchema: StructType, | ||
partitionSchema: StructType, | ||
aggregation: Aggregation, | ||
isCaseSensitive: Boolean) | ||
aggregation: Aggregation) | ||
: (Array[PrimitiveType], Array[Any]) = { | ||
val footerFileMetaData = footer.getFileMetaData | ||
val fields = footerFileMetaData.getSchema.getFields | ||
val blocks = footer.getBlocks | ||
val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType] | ||
val valuesBuilder = mutable.ArrayBuilder.make[Any] | ||
|
||
assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down") | ||
aggregation.aggregateExpressions.foreach { agg => | ||
var value: Any = None | ||
var rowCount = 0L | ||
|
@@ -250,8 +259,7 @@ object ParquetUtils { | |
schemaName = "count(" + count.column.fieldNames.head + ")" | ||
rowCount += block.getRowCount | ||
var isPartitionCol = false | ||
if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) | ||
.toSet.contains(count.column.fieldNames.head)) { | ||
if (partitionSchema.fields.map(_.name).toSet.contains(count.column.fieldNames.head)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't need check case sensitivity now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems to me no need to check case sensitivity because I have normalized aggregates and group by columns in |
||
isPartitionCol = true | ||
} | ||
isCount = true | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,7 +86,7 @@ case class OrcPartitionReaderFactory( | |
val filePath = new Path(new URI(file.filePath)) | ||
|
||
if (aggregation.nonEmpty) { | ||
return buildReaderWithAggregates(filePath, conf) | ||
return buildReaderWithAggregates(file, conf) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems if (aggregation.nonEmpty) {
return buildReaderWithAggregates(file, conf)
}
val filePath = new Path(new URI(file.filePath)) |
||
|
||
val resultedColPruneInfo = | ||
|
@@ -130,7 +130,7 @@ case class OrcPartitionReaderFactory( | |
val filePath = new Path(new URI(file.filePath)) | ||
|
||
if (aggregation.nonEmpty) { | ||
return buildColumnarReaderWithAggregates(filePath, conf) | ||
return buildColumnarReaderWithAggregates(file, conf) | ||
} | ||
Comment on lines
129
to
131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
|
||
val resultedColPruneInfo = | ||
|
@@ -183,14 +183,16 @@ case class OrcPartitionReaderFactory( | |
* Build reader with aggregate push down. | ||
*/ | ||
private def buildReaderWithAggregates( | ||
filePath: Path, | ||
file: PartitionedFile, | ||
conf: Configuration): PartitionReader[InternalRow] = { | ||
val filePath = new Path(new URI(file.filePath)) | ||
new PartitionReader[InternalRow] { | ||
private var hasNext = true | ||
private lazy val row: InternalRow = { | ||
Utils.tryWithResource(createORCReader(filePath, conf)) { reader => | ||
OrcUtils.createAggInternalRowFromFooter( | ||
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema) | ||
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, | ||
readDataSchema, file.partitionValues) | ||
} | ||
} | ||
|
||
|
@@ -209,15 +211,16 @@ case class OrcPartitionReaderFactory( | |
* Build columnar reader with aggregate push down. | ||
*/ | ||
private def buildColumnarReaderWithAggregates( | ||
filePath: Path, | ||
file: PartitionedFile, | ||
conf: Configuration): PartitionReader[ColumnarBatch] = { | ||
val filePath = new Path(new URI(file.filePath)) | ||
new PartitionReader[ColumnarBatch] { | ||
private var hasNext = true | ||
private lazy val batch: ColumnarBatch = { | ||
Utils.tryWithResource(createORCReader(filePath, conf)) { reader => | ||
val row = OrcUtils.createAggInternalRowFromFooter( | ||
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, | ||
readDataSchema) | ||
readDataSchema, file.partitionValues) | ||
AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false) | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -261,6 +261,63 @@ trait FileSourceAggregatePushDownSuite | |
} | ||
} | ||
|
||
test("aggregate with partition group by can be pushed down") { | ||
sunchao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
withTempPath { dir => | ||
spark.range(10).selectExpr("id", "id % 3 as p") | ||
.write.partitionBy("p").format(format).save(dir.getCanonicalPath) | ||
withTempView("tmp") { | ||
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp"); | ||
Seq("false", "true").foreach { enableVectorizedReader => | ||
withSQLConf(aggPushDownEnabledKey -> "true", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, can you test both |
||
vectorizedReaderEnabledKey -> enableVectorizedReader) { | ||
val df = sql("SELECT count(*), count(id), p, max(id), p, count(p), max(id)," + | ||
" min(id), p FROM tmp group by p") | ||
df.queryExecution.optimizedPlan.collect { | ||
case _: DataSourceV2ScanRelation => | ||
val expected_plan_fragment = | ||
"PushedAggregation: [COUNT(*), COUNT(id), MAX(id), COUNT(p), MIN(id)], " + | ||
"PushedFilters: [], PushedGroupBy: [p]" | ||
checkKeywordsExistsInExplain(df, expected_plan_fragment) | ||
} | ||
checkAnswer(df, Seq(Row(3, 3, 1, 7, 1, 3, 7, 1, 1), Row(3, 3, 2, 8, 2, 3, 8, 2, 2), | ||
Row(4, 4, 0, 9, 0, 4, 9, 0, 0))) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
test("aggregate with multi partition group by columns can be pushed down") { | ||
withTempPath { dir => | ||
Seq((10, 1, 2), (2, 1, 2), (3, 2, 1), (4, 2, 1), (5, 2, 1), (6, 2, 1), | ||
(1, 1, 2), (4, 1, 2), (3, 2, 2), (-4, 2, 2), (6, 2, 2)) | ||
.toDF("value", "p1", "p2") | ||
.write | ||
.partitionBy("p1", "p2") | ||
.format(format) | ||
.save(dir.getCanonicalPath) | ||
withTempView("tmp") { | ||
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp"); | ||
Seq("false", "true").foreach { enableVectorizedReader => | ||
withSQLConf(aggPushDownEnabledKey -> "true", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here too. We should make sure |
||
vectorizedReaderEnabledKey -> enableVectorizedReader) { | ||
val df = sql("SELECT count(*), count(value), max(value), min(value), p1, p2 FROM tmp" + | ||
" GROUP BY p1, p2") | ||
df.queryExecution.optimizedPlan.collect { | ||
case _: DataSourceV2ScanRelation => | ||
val expected_plan_fragment = | ||
"PushedAggregation: [COUNT(*), COUNT(value), MAX(value), MIN(value)]," + | ||
" PushedFilters: [], PushedGroupBy: [p1, p2]" | ||
checkKeywordsExistsInExplain(df, expected_plan_fragment) | ||
} | ||
checkAnswer(df, Seq(Row(4, 4, 10, 1, 1, 2), Row(4, 4, 6, 3, 2, 1), | ||
Row(3, 3, 6, -4, 2, 2))) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
test("push down only if all the aggregates can be pushed down") { | ||
val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), | ||
(9, "mno", 7), (2, null, 7)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe also add some comments here - it's not that easy to understand and can help the maintenance of this code.