From 25d2d9e4a720f1be026ba2a5a59e504d006f8322 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 1 Aug 2024 09:28:05 +0200 Subject: [PATCH 1/3] Initial commit --- .../optimizer/RewriteDistinctAggregates.scala | 16 ++- .../spark/sql/DataFrameAggregateSuite.scala | 114 ++++++++++++++++++ 2 files changed, 127 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index da3cf782f6682..24947e5fe40e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -197,6 +197,17 @@ import org.apache.spark.util.collection.Utils * techniques. */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { + private def mustRewrite( + distinctAggs: Seq[AggregateExpression], + groupingExpressions: Seq[Expression]): Boolean = { + // If there are any distinct AggregateExpressions with filter, we need to rewrite the query. + // Also, if there are no grouping expressions and all aggregate expressions are foldable, + // we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). Without this condition, + // non-grouping aggregation queries with distinct aggregate expressions will be incorrectly + // handled by the aggregation strategy, causing wrong results when working with empty tables. + distinctAggs.exists(_.filter.isDefined) || (groupingExpressions.isEmpty && + distinctAggs.exists(_.aggregateFunction.children.forall(_.foldable))) + } private def mayNeedtoRewrite(a: Aggregate): Boolean = { val aggExpressions = collectAggregateExprs(a) @@ -204,8 +215,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // We need at least two distinct aggregates or the single distinct aggregate group exists filter // clause for this rule because aggregation strategy can handle a single distinct aggregate // group without filter clause. - // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) + distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( @@ -236,7 +246,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { + if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 1ba3f6c84d0ad..d8e3a046655f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -2150,6 +2151,119 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(df, Row(1, 2, 2) :: Row(3, 1, 1) :: Nil) } } + + test("aggregating with various distinct expressions") { + abstract class AggregateTestCaseBase( + val query: String, + val resultSeq: Seq[Seq[Row]], + val hasExpandNodeInPlan: Boolean) + case class AggregateTestCase( + override val query: String, + override val resultSeq: Seq[Seq[Row]], + override val hasExpandNodeInPlan: Boolean) + extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan) + case class AggregateTestCaseDefault( + override val query: String) + extends AggregateTestCaseBase( + query, + Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = true) + + val t = "t" + val testCases: Seq[AggregateTestCaseBase] = Seq( + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1 + 2) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t" + ), + AggregateTestCase( + s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT 1, "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT collation("abc")) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT current_date()) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t""" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col", + Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0", + Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))), + hasExpandNodeInPlan = true + ) + ) + withTable(t) { + sql(s"create table $t(col int) using parquet") + Seq(0, 1, 2).foreach(columnValue => { + if (columnValue != 0) { + sql(s"insert into $t(col) values($columnValue)") + } + testCases.foreach(testCase => { + val query = sql(testCase.query) + checkAnswer(query, testCase.resultSeq(columnValue)) + val hasExpandNodeInPlan = query.queryExecution.optimizedPlan.collectFirst { + case _: Expand => true + }.nonEmpty + assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan) + }) + }) + } + } } case class B(c: Option[Double]) From 3f68e9c109c68cc1fc5bce42b1fdcfe0249ed256 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 1 Aug 2024 09:48:36 +0200 Subject: [PATCH 2/3] Remove collation --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d8e3a046655f6..764b7a9719d29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2191,9 +2191,6 @@ class DataFrameAggregateSuite extends QueryTest AggregateTestCaseDefault( s"""SELECT COUNT(DISTINCT 1, "col") FROM $t""" ), - AggregateTestCaseDefault( - s"""SELECT COUNT(DISTINCT collation("abc")) FROM $t""" - ), AggregateTestCaseDefault( s"""SELECT COUNT(DISTINCT current_date()) FROM $t""" ), From 08d738874f0cf7d0140889535349e17b37dc25a1 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 1 Aug 2024 22:32:03 +0200 Subject: [PATCH 3/3] Update comment --- .../sql/catalyst/optimizer/RewriteDistinctAggregates.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 24947e5fe40e4..801bd2693af42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -201,8 +201,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { distinctAggs: Seq[AggregateExpression], groupingExpressions: Seq[Expression]): Boolean = { // If there are any distinct AggregateExpressions with filter, we need to rewrite the query. - // Also, if there are no grouping expressions and all aggregate expressions are foldable, - // we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). Without this condition, + // Also, if there are no grouping expressions and all distinct aggregate expressions are + // foldable, we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). Without this case, // non-grouping aggregation queries with distinct aggregate expressions will be incorrectly // handled by the aggregation strategy, causing wrong results when working with empty tables. distinctAggs.exists(_.filter.isDefined) || (groupingExpressions.isEmpty &&