From 9072db7c28c6c0aae2ba319a08368a54cf7667ee Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Thu, 24 Oct 2019 20:32:31 +0800 Subject: [PATCH] merge with eliminate sorts --- .../expressions/aggregate/Average.scala | 3 +- .../aggregate/CentralMomentAgg.scala | 2 +- .../expressions/aggregate/Count.scala | 2 +- .../catalyst/expressions/aggregate/Max.scala | 2 +- .../catalyst/expressions/aggregate/Min.scala | 2 +- .../aggregate/OrderIrrelevantAggs.scala | 26 ----------- .../catalyst/expressions/aggregate/Sum.scala | 3 +- .../sql/catalyst/optimizer/Optimizer.scala | 44 +++++++++---------- .../optimizer/RemoveRedundantSortsSuite.scala | 4 +- 9 files changed, 30 insertions(+), 58 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/OrderIrrelevantAggs.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 5fbec716050ec..66ac73087b4d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -33,8 +33,7 @@ import org.apache.spark.sql.types._ 1.5 """, since = "1.0.0") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes - with OrderIrrelevantAggs { +case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def prettyName: String = "avg" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 9f101c324d0fa..8ce8dfa19c017 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ * @param child to compute central moments of. */ abstract class CentralMomentAgg(child: Expression) - extends DeclarativeAggregate with ImplicitCastInputTypes with OrderIrrelevantAggs { + extends DeclarativeAggregate with ImplicitCastInputTypes { /** * The central moment order to be computed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 7bf732f726194..2a8edac502c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.types._ """, since = "1.0.0") // scalastyle:on line.size.limit -case class Count(children: Seq[Expression]) extends DeclarativeAggregate with OrderIrrelevantAggs { +case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false // Return data type. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 700a08968bdf5..7520db146ba6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ 50 """, since = "1.0.0") -case class Max(child: Expression) extends DeclarativeAggregate with OrderIrrelevantAggs { +case class Max(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 1e4493e2a95c1..106eb968e3917 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ -1 """, since = "1.0.0") -case class Min(child: Expression) extends DeclarativeAggregate with OrderIrrelevantAggs { +case class Min(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/OrderIrrelevantAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/OrderIrrelevantAggs.scala deleted file mode 100644 index e932eba2c59d3..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/OrderIrrelevantAggs.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -/** - * An [[OrderIrrelevantAggs]] trait denotes those aggregate functions that its result - * has nothing to do with the order of input data. - * For example, [[Sum]] is [[OrderIrrelevantAggs]] while [[First]] is not. - */ -trait OrderIrrelevantAggs extends AggregateFunction { -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index c47fc63ea6848..c2ab8adfaef67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -36,8 +36,7 @@ import org.apache.spark.sql.types._ NULL """, since = "1.0.0") -case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes - with OrderIrrelevantAggs { +case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 17fc38261a1af..fd0f0d19f3038 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -97,7 +97,6 @@ abstract class Optimizer(catalogManager: CatalogManager) SimplifyBinaryComparison, ReplaceNullWithFalseInPredicate, PruneFilters, - EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, RewriteCorrelatedScalarSubquery, @@ -174,8 +173,8 @@ abstract class Optimizer(catalogManager: CatalogManager) // idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once. Batch("Join Reorder", FixedPoint(1), CostBasedJoinReorder) :+ - Batch("Remove Redundant Sorts", Once, - RemoveRedundantSorts) :+ + Batch("Eliminate Sorts", Once, + EliminateSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, @@ -953,29 +952,22 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { } /** - * Removes no-op SortOrder from Sort + * Removes Sort operation. This can happen: + * 1) if the sort is noop + * 2) if the child is already sorted + * 3) if there is another Sort operator separated by 0...n Project/Filter operators + * 4) if the Sort operator is within Join and without Limit + * 5) if the Sort operator is within GroupBy and the aggregate function is order irrelevant */ object EliminateSorts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => val newOrders = orders.filterNot(_.child.foldable) if (newOrders.isEmpty) child else s.copy(order = newOrders) - } -} - -/** - * Removes redundant Sort operation. This can happen: - * 1) if the child is already sorted - * 2) if there is another Sort operator separated by 0...n Project/Filter operators - * 3) if the Sort operator is within Join and without Limit - * 4) if the Sort operator is within GroupBy and the aggregate function is order irrelevant - */ -object RemoveRedundantSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => child case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) - case j @ Join(originLeft, originRight, _, _, _) => + case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) => j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight)) case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) => g.copy(child = recursiveRemoveSort(originChild)) @@ -995,13 +987,21 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] { } def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = { - val aggExpressions = aggs.flatMap { e => - e.collect { - case ae: AggregateExpression => ae - } + def isOrderIrrelevantAggFunction(func: AggregateFunction): Boolean = func match { + case _: Sum => true + case _: Min => true + case _: Max => true + case _: Count => true + case _: Average => true + case _: CentralMomentAgg => true + case _ => false } - aggExpressions.forall(_.aggregateFunction.isInstanceOf[OrderIrrelevantAggs]) + aggs.flatMap { e => + e.collect { + case ae: AggregateExpression => ae.aggregateFunction + } + }.forall(isOrderIrrelevantAggFunction) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala index 54f0193de9182..626931203f785 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -30,8 +30,8 @@ class RemoveRedundantSortsSuite extends PlanTest { val batches = Batch("Limit PushDown", Once, LimitPushDown) :: - Batch("Remove Redundant Sorts", Once, - RemoveRedundantSorts) :: + Batch("Eliminate Sorts", Once, + EliminateSorts) :: Batch("Collapse Project", Once, CollapseProject) :: Nil }