diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 1a58f45c07a29..3a35c08d594a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -220,7 +220,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => - val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet + val unfoldableChildren = ExpressionSet(e.aggregateFunction.children.filter(!_.foldable)) if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children unfoldableChildren @@ -231,7 +231,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // count(distinct 1) will be explained to count(1) after the rewrite function. // Generally, the distinct aggregateFunction should not run // foldable TypeCheck for the first child. - e.aggregateFunction.children.take(1).toSet + ExpressionSet(e.aggregateFunction.children.take(1)) } } @@ -254,7 +254,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrMap = distinctAggChildren.map { e => + e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)() + } val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup all the filters in distinct aggregate. val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggs.collect { @@ -292,7 +294,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { af } else { patchAggregateFunctionChildren(af) { x => - distinctAggChildAttrLookup.get(x) + distinctAggChildAttrLookup.get(x.canonicalized) } } val newCondition = if (condition.isDefined) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 6e66c91b8a89a..cb4771dd92f80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -75,4 +75,37 @@ class RewriteDistinctAggregatesSuite extends PlanTest { .analyze checkRewrite(RewriteDistinctAggregates(input)) } + + test("SPARK-40382: eliminate multiple distinct groups due to superficial differences") { + val input = testRelation + .groupBy($"a")( + countDistinct($"b" + $"c").as("agg1"), + countDistinct($"c" + $"b").as("agg2"), + max($"c").as("agg3")) + .analyze + + val rewrite = RewriteDistinctAggregates(input) + rewrite match { + case Aggregate(_, _, LocalRelation(_, _, _)) => + case _ => fail(s"Plan is not as expected:\n$rewrite") + } + } + + test("SPARK-40382: reduce multiple distinct groups due to superficial differences") { + val input = testRelation + .groupBy($"a")( + countDistinct($"b" + $"c" + $"d").as("agg1"), + countDistinct($"d" + $"c" + $"b").as("agg2"), + countDistinct($"b" + $"c").as("agg3"), + countDistinct($"c" + $"b").as("agg4"), + max($"c").as("agg5")) + .analyze + + val rewrite = RewriteDistinctAggregates(input) + rewrite match { + case Aggregate(_, _, Aggregate(_, _, e: Expand)) => + assert(e.projections.size == 3) + case _ => fail(s"Plan is not rewritten:\n$rewrite") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c64a123e3a78c..03e722a86fb21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -527,8 +527,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map( - _.aggregateFunction.children.filterNot(_.foldable).toSet).distinct.length > 1) { + val distinctAggChildSets = functionsWithDistinct.map { ae => + ExpressionSet(ae.aggregateFunction.children.filterNot(_.foldable)) + }.distinct + if (distinctAggChildSets.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our `RewriteDistinctAggregates` should take care this case. throw new IllegalStateException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 0849ab59f64d0..579a00c7996f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -219,14 +219,17 @@ object AggUtils { } // 3. Create an Aggregate operator for partial aggregation (for distinct) - val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes) + val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions.map(_.canonicalized), + distinctAttributes) val rewrittenDistinctFunctions = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) => - aggregateFunction.transformDown(distinctColumnAttributeLookup) - .asInstanceOf[AggregateFunction] + aggregateFunction.transformDown { + case e: Expression if distinctColumnAttributeLookup.contains(e.canonicalized) => + distinctColumnAttributeLookup(e.canonicalized) + }.asInstanceOf[AggregateFunction] case agg => throw new IllegalArgumentException( "Non-distinct aggregate is found in functionsWithDistinct " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 90e2acfe5d688..54911d2a6fb61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1485,6 +1485,40 @@ class DataFrameAggregateSuite extends QueryTest val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id") checkAnswer(df, Row(2, 3, 1)) } + + test("SPARK-40382: Distinct aggregation expression grouping by semantic equivalence") { + Seq( + (1, 1, 3), + (1, 2, 3), + (1, 2, 3), + (2, 1, 1), + (2, 2, 5) + ).toDF("k", "c1", "c2").createOrReplaceTempView("df") + + // all distinct aggregation children are semantically equivalent + val res1 = sql( + """select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1) + |from df + |group by k + |""".stripMargin) + checkAnswer(res1, Row(1, 5, 2.5, 2) :: Row(2, 5, 2.5, 2) :: Nil) + + // some distinct aggregation children are semantically equivalent + val res2 = sql( + """select k, sum(distinct c1 + 2), avg(distinct 2 + c1), count(distinct c2) + |from df + |group by k + |""".stripMargin) + checkAnswer(res2, Row(1, 7, 3.5, 1) :: Row(2, 7, 3.5, 2) :: Nil) + + // no distinct aggregation children are semantically equivalent + val res3 = sql( + """select k, sum(distinct c1 + 2), avg(distinct 3 + c1), count(distinct c2) + |from df + |group by k + |""".stripMargin) + checkAnswer(res3, Row(1, 7, 4.5, 1) :: Row(2, 7, 4.5, 2) :: Nil) + } } case class B(c: Option[Double]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c7bd12c86a4d1..8c5a09cb1890d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -95,6 +95,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { // 2 distinct columns with different order val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i") assertNoExpand(query3.queryExecution.executedPlan) + + // SPARK-40382: 1 distinct expression with cosmetic differences + val query4 = sql("SELECT sum(DISTINCT j), max(DISTINCT J) FROM v GROUP BY i") + assertNoExpand(query4.queryExecution.executedPlan) } }