diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 928335ad414f5..98663a8d807be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3114,56 +3114,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, - * put them into an inner Project and finally project them away at the outer Project. - */ - object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case p if !p.resolved => p // Skip unresolved nodes. - case p: Project => p - case f: Filter => f - - case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) => - val nondeterToAttr = getNondeterToAttr(a.groupingExpressions) - val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child) - a.transformExpressions { case e => - nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) - }.copy(child = newChild) - - // Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail) - // and we want to retain them inside the aggregate functions. - case m: CollectMetrics => m - - // todo: It's hard to write a general rule to pull out nondeterministic expressions - // from LogicalPlan, currently we only do it for UnaryNode which has same output - // schema with its child. - case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => - val nondeterToAttr = getNondeterToAttr(p.expressions) - val newPlan = p.transformExpressions { case e => - nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) - } - val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child) - Project(p.output, newPlan.withNewChildren(newChild :: Nil)) - } - - private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = { - exprs.filterNot(_.deterministic).flatMap { expr => - val leafNondeterministic = expr.collect { - case n: Nondeterministic => n - case udf: UserDefinedExpression if !udf.deterministic => udf - } - leafNondeterministic.distinct.map { e => - val ne = e match { - case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")() - } - e -> ne - } - }.toMap - } - } - /** * Set the seed for random number generation. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala new file mode 100644 index 0000000000000..3431c9327f1d5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, + * put them into an inner Project and finally project them away at the outer Project. + */ +object PullOutNondeterministic extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp applyLocally + + val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { + case p if !p.resolved => p // Skip unresolved nodes. + case p: Project => p + case f: Filter => f + + case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) => + val nondeterToAttr = getNondeterToAttr(a.groupingExpressions) + val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child) + a.transformExpressions { case e => + nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) + }.copy(child = newChild) + + // Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail) + // and we want to retain them inside the aggregate functions. + case m: CollectMetrics => m + + // todo: It's hard to write a general rule to pull out nondeterministic expressions + // from LogicalPlan, currently we only do it for UnaryNode which has same output + // schema with its child. + case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => + val nondeterToAttr = getNondeterToAttr(p.expressions) + val newPlan = p.transformExpressions { case e => + nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) + } + val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child) + Project(p.output, newPlan.withNewChildren(newChild :: Nil)) + } + + private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = { + exprs.filterNot(_.deterministic).flatMap { expr => + val leafNondeterministic = expr.collect { + case n: Nondeterministic => n + case udf: UserDefinedExpression if !udf.deterministic => udf + } + leafNondeterministic.distinct.map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")() + } + e -> ne + } + }.toMap + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1a0dd2d100522..3e3550d5da89b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -112,6 +112,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery, EliminateSerialization, RemoveRedundantAliases, + RemoveRedundantAggregates, UnwrapCastInBinaryComparison, RemoveNoopOperators, OptimizeUpdateFields, @@ -496,6 +497,50 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = removeRedundantAliases(plan, AttributeSet.empty) } +/** + * Remove redundant aggregates from a query plan. A redundant aggregate is an aggregate whose + * only goal is to keep distinct values, while its parent aggregate would ignore duplicate values. + */ +object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) => + val aliasMap = getAliasMap(lower) + + val newAggregate = upper.copy( + child = lower.child, + groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)), + aggregateExpressions = upper.aggregateExpressions.map( + replaceAliasButKeepName(_, aliasMap)) + ) + + // We might have introduces non-deterministic grouping expression + if (newAggregate.groupingExpressions.exists(!_.deterministic)) { + PullOutNondeterministic.applyLocally.applyOrElse(newAggregate, identity[LogicalPlan]) + } else { + newAggregate + } + } + + private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = { + val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate) + + lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet( + lower + .aggregateExpressions + .filter(_.deterministic) + .filter(!isAggregate(_)) + .map(_.toAttribute) + )) + + upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg + } + + private def isAggregate(expr: Expression): Boolean = { + expr.find(e => e.isInstanceOf[AggregateExpression] || + PythonUDF.isGroupedAggPandasUDF(e)).isDefined + } +} + /** * Remove no-op operators from the query plan that do not make any modifications. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 0946773f3ace6..662360ce1b21f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -624,7 +624,7 @@ case class Range( * * @param groupingExpressions expressions for grouping keys * @param aggregateExpressions expressions for a project list, which could contain - * [[AggregateFunction]]s. + * [[AggregateExpression]]s. * * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before * separating projection from grouping and aggregate, we should avoid expression-level optimization diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala new file mode 100644 index 0000000000000..d376c31ef965f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.IntegerType + +class RemoveRedundantAggregatesSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("RemoveRedundantAggregates", FixedPoint(10), + RemoveRedundantAggregates) :: Nil + } + + private def aggregates(e: Expression): Seq[Expression] = { + Seq( + count(e), + PythonUDF("pyUDF", null, IntegerType, Seq(e), + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, udfDeterministic = true) + ) + } + + test("Remove redundant aggregate") { + val relation = LocalRelation('a.int, 'b.int) + for (agg <- aggregates('b)) { + val query = relation + .groupBy('a)('a, agg) + .groupBy('a)('a) + .analyze + val expected = relation + .groupBy('a)('a) + .analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, expected) + } + } + + test("Remove 2 redundant aggregates") { + val relation = LocalRelation('a.int, 'b.int) + for (agg <- aggregates('b)) { + val query = relation + .groupBy('a)('a, agg) + .groupBy('a)('a) + .groupBy('a)('a) + .analyze + val expected = relation + .groupBy('a)('a) + .analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, expected) + } + } + + test("Remove redundant aggregate with different grouping") { + val relation = LocalRelation('a.int, 'b.int) + val query = relation + .groupBy('a, 'b)('a) + .groupBy('a)('a) + .analyze + val expected = relation + .groupBy('a)('a) + .analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, expected) + } + + test("Remove redundant aggregate with aliases") { + val relation = LocalRelation('a.int, 'b.int) + for (agg <- aggregates('b)) { + val query = relation + .groupBy('a + 'b)(('a + 'b) as 'c, agg) + .groupBy('c)('c) + .analyze + val expected = relation + .groupBy('a + 'b)(('a + 'b) as 'c) + .analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, expected) + } + } + + test("Remove redundant aggregate with non-deterministic upper") { + val relation = LocalRelation('a.int, 'b.int) + val query = relation + .groupBy('a)('a) + .groupBy('a)('a, rand(0) as 'c) + .analyze + val expected = relation + .groupBy('a)('a, rand(0) as 'c) + .analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, expected) + } + + test("Remove redundant aggregate with non-deterministic lower") { + val relation = LocalRelation('a.int, 'b.int) + val query = relation + .groupBy('a, 'c)('a, rand(0) as 'c) + .groupBy('a, 'c)('a, 'c) + .analyze + val expected = relation + .groupBy('a, 'c)('a, rand(0) as 'c) + .analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, expected) + } + + test("Keep non-redundant aggregate - upper has agg expression") { + val relation = LocalRelation('a.int, 'b.int) + for (agg <- aggregates('b)) { + val query = relation + .groupBy('a, 'b)('a, 'b) + // The count would change if we remove the first aggregate + .groupBy('a)('a, agg) + .analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, query) + } + } + + test("Keep non-redundant aggregate - upper references agg expression") { + val relation = LocalRelation('a.int, 'b.int) + for (agg <- aggregates('b)) { + val query = relation + .groupBy('a)('a, agg as 'c) + .groupBy('c)('c) + .analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, query) + } + } + + test("Keep non-redundant aggregate - upper references non-deterministic non-grouping") { + val relation = LocalRelation('a.int, 'b.int) + val query = relation + .groupBy('a)('a, ('a + rand(0)) as 'c) + .groupBy('a, 'c)('a, 'c) + .analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, query) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala index d0b657887e0bc..944aa963cc4be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala @@ -195,7 +195,7 @@ abstract class RemoveRedundantProjectsSuiteBase |) |""".stripMargin - Seq(("UNION", 2, 2), ("UNION ALL", 1, 2)).foreach { case (setOperation, enabled, disabled) => + Seq(("UNION", 1, 2), ("UNION ALL", 1, 2)).foreach { case (setOperation, enabled, disabled) => val query = queryTemplate.format(setOperation) assertProjectExec(query, enabled = enabled, disabled = disabled) }