Skip to content

Commit

Permalink
merge with eliminate sorts
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Oct 26, 2019
1 parent 425f76d commit 9072db7
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ import org.apache.spark.sql.types._
1.5
""",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with OrderIrrelevantAggs {
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {

override def prettyName: String = "avg"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._
* @param child to compute central moments of.
*/
abstract class CentralMomentAgg(child: Expression)
extends DeclarativeAggregate with ImplicitCastInputTypes with OrderIrrelevantAggs {
extends DeclarativeAggregate with ImplicitCastInputTypes {

/**
* The central moment order to be computed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.types._
""",
since = "1.0.0")
// scalastyle:on line.size.limit
case class Count(children: Seq[Expression]) extends DeclarativeAggregate with OrderIrrelevantAggs {
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false

// Return data type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
50
""",
since = "1.0.0")
case class Max(child: Expression) extends DeclarativeAggregate with OrderIrrelevantAggs {
case class Max(child: Expression) extends DeclarativeAggregate {

override def children: Seq[Expression] = child :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
-1
""",
since = "1.0.0")
case class Min(child: Expression) extends DeclarativeAggregate with OrderIrrelevantAggs {
case class Min(child: Expression) extends DeclarativeAggregate {

override def children: Seq[Expression] = child :: Nil

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ import org.apache.spark.sql.types._
NULL
""",
since = "1.0.0")
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with OrderIrrelevantAggs {
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {

override def children: Seq[Expression] = child :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
SimplifyBinaryComparison,
ReplaceNullWithFalseInPredicate,
PruneFilters,
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
RewriteCorrelatedScalarSubquery,
Expand Down Expand Up @@ -174,8 +173,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
// idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once.
Batch("Join Reorder", FixedPoint(1),
CostBasedJoinReorder) :+
Batch("Remove Redundant Sorts", Once,
RemoveRedundantSorts) :+
Batch("Eliminate Sorts", Once,
EliminateSorts) :+
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) :+
Batch("Object Expressions Optimization", fixedPoint,
Expand Down Expand Up @@ -953,29 +952,22 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
}

/**
* Removes no-op SortOrder from Sort
* Removes Sort operation. This can happen:
* 1) if the sort is noop
* 2) if the child is already sorted
* 3) if there is another Sort operator separated by 0...n Project/Filter operators
* 4) if the Sort operator is within Join and without Limit
* 5) if the Sort operator is within GroupBy and the aggregate function is order irrelevant
*/
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
if (newOrders.isEmpty) child else s.copy(order = newOrders)
}
}

/**
* Removes redundant Sort operation. This can happen:
* 1) if the child is already sorted
* 2) if there is another Sort operator separated by 0...n Project/Filter operators
* 3) if the Sort operator is within Join and without Limit
* 4) if the Sort operator is within GroupBy and the aggregate function is order irrelevant
*/
object RemoveRedundantSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
child
case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child))
case j @ Join(originLeft, originRight, _, _, _) =>
case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) =>
j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight))
case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
g.copy(child = recursiveRemoveSort(originChild))
Expand All @@ -995,13 +987,21 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] {
}

def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = {
val aggExpressions = aggs.flatMap { e =>
e.collect {
case ae: AggregateExpression => ae
}
def isOrderIrrelevantAggFunction(func: AggregateFunction): Boolean = func match {
case _: Sum => true
case _: Min => true
case _: Max => true
case _: Count => true
case _: Average => true
case _: CentralMomentAgg => true
case _ => false
}

aggExpressions.forall(_.aggregateFunction.isInstanceOf[OrderIrrelevantAggs])
aggs.flatMap { e =>
e.collect {
case ae: AggregateExpression => ae.aggregateFunction
}
}.forall(isOrderIrrelevantAggFunction)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class RemoveRedundantSortsSuite extends PlanTest {
val batches =
Batch("Limit PushDown", Once,
LimitPushDown) ::
Batch("Remove Redundant Sorts", Once,
RemoveRedundantSorts) ::
Batch("Eliminate Sorts", Once,
EliminateSorts) ::
Batch("Collapse Project", Once,
CollapseProject) :: Nil
}
Expand Down

0 comments on commit 9072db7

Please sign in to comment.