From 4eeae6d3a2d3ea7c22c579e09a83a88e13f91a70 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 16 Nov 2021 15:07:21 -0800 Subject: [PATCH] address comments --- .../datasources/AggregatePushDownUtils.scala | 15 +++++++++++++-- .../sql/execution/datasources/orc/OrcUtils.scala | 2 +- .../datasources/parquet/ParquetUtils.scala | 2 +- .../FileSourceAggregatePushDownSuite.scala | 2 +- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index f00217e8eb0f4..e7069137f31cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -94,9 +94,20 @@ object AggregatePushDownUtils { if (aggregation.groupByColumns.nonEmpty && partitionNames.size != aggregation.groupByColumns.length) { + // If there are group by columns, we only push down if the group by columns are the same as + // the partition columns. In theory, if group by columns are a subset of partition columns, + // we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3, + // SELECT MAX(c) FROM t GROUP BY p1, p2 should still be able to push down. However, the + // partial aggregation pushed down to data source needs to be + // SELECT p1, p2, p3, MAX(c) FROM t GROUP BY p1, p2, p3, and Spark layer + // needs to have a final aggregation such as SELECT MAX(c) FROM t GROUP BY p1, p2, then the + // pushed down query schema is different from the query schema at Spark. We will keep + // aggregate push down simple and don't handle this complicate case for now. return None } aggregation.groupByColumns.foreach { col => + // don't push down if the group by columns are not the same as the partition columns (orders + // doesn't matter because reorder can be done at data source layer) if (col.fieldNames.length != 1 || !isPartitionCol(col)) return None finalSchema = finalSchema.add(getStructFieldForCol(col)) } @@ -150,8 +161,8 @@ object AggregatePushDownUtils { * Return the schema for aggregates only (exclude group by columns) */ def getSchemaWithoutGroupingExpression( - aggregation: Aggregation, - aggSchema: StructType): StructType = { + aggSchema: StructType, + aggregation: Aggregation): StructType = { val numOfGroupByColumns = aggregation.groupByColumns.length if (numOfGroupByColumns > 0) { new StructType(aggSchema.fields.drop(numOfGroupByColumns)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index fb599d0b21c54..17aab36c6d7d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -460,7 +460,7 @@ object OrcUtils extends Logging { // if there are group by columns, we will build result row first, // and then append group by columns values (partition columns values) to the result row. val schemaWithoutGroupBy = - AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggregation, aggSchema) + AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggSchema, aggregation) val aggORCValues: Seq[WritableComparable[_]] = aggregation.aggregateExpressions.zipWithIndex.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 4dff5d89f3ad2..5bd712e9c583c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -170,7 +170,7 @@ object ParquetUtils { // if there are group by columns, we will build result row first, // and then append group by columns values (partition columns values) to the result row. val schemaWithoutGroupBy = - AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggregation, aggSchema) + AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggSchema, aggregation) val schemaConverter = new ParquetToSparkSchemaConverter val converter = new ParquetRowConverter(schemaConverter, parquetSchema, schemaWithoutGroupBy, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index 9891cfa875940..7015a648f910e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -309,7 +309,7 @@ trait FileSourceAggregatePushDownSuite val expected_plan_fragment = "PushedAggregation: [COUNT(*), COUNT(value), MAX(value), MIN(value)]," + " PushedFilters: [], PushedGroupBy: [p1, p2, p3, p4]" - // checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(1, 1, 5, 5, 8, 1, 5, 2), Row(1, 1, 4, 4, 9, 1, 4, 2), Row(2, 2, 6, 3, 8, 1, 4, 2), Row(4, 4, 10, 1, 6, 2, 5, 1),