Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29343][SQL] Eliminate sorts without limit in the subquery of Join/Aggregation #26011

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
SimplifyBinaryComparison,
ReplaceNullWithFalseInPredicate,
PruneFilters,
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
RewriteCorrelatedScalarSubquery,
Expand Down Expand Up @@ -174,8 +173,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
// idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once.
Batch("Join Reorder", FixedPoint(1),
CostBasedJoinReorder) :+
Batch("Remove Redundant Sorts", Once,
RemoveRedundantSorts) :+
Batch("Eliminate Sorts", Once,
EliminateSorts) :+
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) :+
Batch("Object Expressions Optimization", fixedPoint,
Expand Down Expand Up @@ -953,40 +952,60 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
}

/**
* Removes no-op SortOrder from Sort
* Removes Sort operation. This can happen:
* 1) if the sort order is empty or the sort order does not have any reference
* 2) if the child is already sorted
* 3) if there is another Sort operator separated by 0...n Project/Filter operators
* 4) if the Sort operator is within Join without Limit and having deterministic conditions only
* 5) if the Sort operator is within GroupBy and the aggregate function is order irrelevant
*/
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
if (newOrders.isEmpty) child else s.copy(order = newOrders)
}
}

/**
* Removes redundant Sort operation. This can happen:
* 1) if the child is already sorted
* 2) if there is another Sort operator separated by 0...n Project/Filter operators
*/
object RemoveRedundantSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
child
case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child))
case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) =>
j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight))
case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
g.copy(child = recursiveRemoveSort(originChild))
}

def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match {
private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match {
case Sort(_, _, child) => recursiveRemoveSort(child)
case other if canEliminateSort(other) =>
other.withNewChildren(other.children.map(recursiveRemoveSort))
case _ => plan
}

def canEliminateSort(plan: LogicalPlan): Boolean = plan match {
private def canEliminateSort(plan: LogicalPlan): Boolean = plan match {
case p: Project => p.projectList.forall(_.deterministic)
case f: Filter => f.condition.deterministic
case _ => false
}

private def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = {
def isOrderIrrelevantAggFunction(func: AggregateFunction): Boolean = func match {
case _: Sum => true
case _: Min => true
case _: Max => true
case _: Count => true
case _: Average => true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could still have a precision difference after eliminating the sort for the floating point data type. I am afraid some end users might prefer to adding a sort in these cases to ensure the results are consistent?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's a good point. AVG over floating values is order sensitive. Not sure if this can really affect queries in practice, but better to be conservative here. @WangGuangxin can you fix it in a followup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll fix it in a followup

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WangGuangxin do you have no time for the follow-up now? Could I take this over?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same will be true of all central moments, BTW

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revisit this and have a small question. In fact Avg is transformed to Sum and Count, so I think there should be no precision problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sum is the issue itself, actually; see the followup PR

case _: CentralMomentAgg => true
case _ => false
}

def checkValidAggregateExpression(expr: Expression): Boolean = expr match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: private?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot make it private because this is a nested function within private isOrderIrrelevantAggs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ur, this is an inner func. I see.

case _: AttributeReference => true
case ae: AggregateExpression => isOrderIrrelevantAggFunction(ae.aggregateFunction)
case _: UserDefinedExpression => false
case e => e.children.forall(checkValidAggregateExpression)
}

aggs.forall(checkValidAggregateExpression)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand All @@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL}
import org.apache.spark.sql.types.IntegerType

class EliminateSortsSuite extends PlanTest {
override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false)
Expand All @@ -35,12 +37,22 @@ class EliminateSortsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Eliminate Sorts", FixedPoint(10),
Batch("Default", FixedPoint(10),
FoldablePropagation,
EliminateSorts) :: Nil
LimitPushDown) ::
Batch("Eliminate Sorts", Once,
EliminateSorts) ::
Batch("Collapse Project", Once,
CollapseProject) :: Nil
}

object PushDownOptimizer extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Limit PushDown", FixedPoint(10), LimitPushDown) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelationB = LocalRelation('d.int)

test("Empty order by clause") {
val x = testRelation
Expand Down Expand Up @@ -83,4 +95,217 @@ class EliminateSortsSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("remove redundant order by") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
val optimized = Optimize.execute(unnecessaryReordered.analyze)
val correctAnswer = orderedPlan.limit(2).select('a).analyze
comparePlans(Optimize.execute(optimized), correctAnswer)
}

test("do not remove sort if the order is different") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(reorderedDifferently.analyze)
val correctAnswer = reorderedDifferently.analyze
comparePlans(optimized, correctAnswer)
}

test("filters don't affect order") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(filteredAndReordered.analyze)
val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
comparePlans(optimized, correctAnswer)
}

test("limits don't affect order") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(filteredAndReordered.analyze)
val correctAnswer = orderedPlan.limit(Literal(10)).analyze
comparePlans(optimized, correctAnswer)
}

test("different sorts are not simplified if limit is in between") {
val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10))
.orderBy('a.asc)
val optimized = Optimize.execute(orderedPlan.analyze)
val correctAnswer = orderedPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("range is already sorted") {
val inputPlan = Range(1L, 1000L, 1, 10)
val orderedPlan = inputPlan.orderBy('id.asc)
val optimized = Optimize.execute(orderedPlan.analyze)
val correctAnswer = inputPlan.analyze
comparePlans(optimized, correctAnswer)

val reversedPlan = inputPlan.orderBy('id.desc)
val reversedOptimized = Optimize.execute(reversedPlan.analyze)
val reversedCorrectAnswer = reversedPlan.analyze
comparePlans(reversedOptimized, reversedCorrectAnswer)

val negativeStepInputPlan = Range(10L, 1L, -1, 10)
val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc)
val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze)
val negativeStepCorrectAnswer = negativeStepInputPlan.analyze
comparePlans(negativeStepOptimized, negativeStepCorrectAnswer)
}

test("sort should not be removed when there is a node which doesn't guarantee any order") {
val orderedPlan = testRelation.select('a, 'b)
val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
val optimized = Optimize.execute(groupedAndResorted.analyze)
val correctAnswer = groupedAndResorted.analyze
comparePlans(optimized, correctAnswer)
}

test("remove two consecutive sorts") {
val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc)
val optimized = Optimize.execute(orderedTwice.analyze)
val correctAnswer = testRelation.orderBy('b.desc).analyze
comparePlans(optimized, correctAnswer)
}

test("remove sorts separated by Filter/Project operators") {
val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc)
val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze)
val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze
comparePlans(optimizedWithProject, correctAnswerWithProject)

val orderedTwiceWithFilter =
testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc)
val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze)
val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze
comparePlans(optimizedWithFilter, correctAnswerWithFilter)

val orderedTwiceWithBoth =
testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc)
val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze)
val correctAnswerWithBoth =
testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze
comparePlans(optimizedWithBoth, correctAnswerWithBoth)

val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc)
val optimizedThrice = Optimize.execute(orderedThrice.analyze)
val correctAnswerThrice = testRelation.select('b).where('b > Literal(0))
.select(('b + 1).as('c)).orderBy('c.asc).analyze
comparePlans(optimizedThrice, correctAnswerThrice)
}

test("remove orderBy in groupBy clause with count aggs") {
val projectPlan = testRelation.select('a, 'b)
val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = unnecessaryOrderByPlan.groupBy('a)(count(1))
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = projectPlan.groupBy('a)(count(1)).analyze
comparePlans(optimized, correctAnswer)
}

test("remove orderBy in groupBy clause with sum aggs") {
val projectPlan = testRelation.select('a, 'b)
val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = unnecessaryOrderByPlan.groupBy('a)(sum('a) + 10 as "sum")
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = projectPlan.groupBy('a)(sum('a) + 10 as "sum").analyze
comparePlans(optimized, correctAnswer)
}

test("should not remove orderBy in groupBy clause with first aggs") {
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = orderByPlan.groupBy('a)(first('a))
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = groupByPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("should not remove orderBy in groupBy clause with first and count aggs") {
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = orderByPlan.groupBy('a)(first('a), count(1))
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = groupByPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("should not remove orderBy in groupBy clause with PythonUDF as aggs") {
val pythonUdf = PythonUDF("pyUDF", null,
IntegerType, Seq.empty, PythonEvalType.SQL_BATCHED_UDF, true)
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = orderByPlan.groupBy('a)(pythonUdf)
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = groupByPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("should not remove orderBy in groupBy clause with ScalaUDF as aggs") {
val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil, true :: Nil)
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = orderByPlan.groupBy('a)(scalaUdf)
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = groupByPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("should not remove orderBy with limit in groupBy clause") {
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc).limit(10)
val groupByPlan = orderByPlan.groupBy('a)(count(1))
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = groupByPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("remove orderBy in join clause") {
val projectPlan = testRelation.select('a, 'b)
val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val projectPlanB = testRelationB.select('d)
val joinPlan = unnecessaryOrderByPlan.join(projectPlanB).select('a, 'd)
val optimized = Optimize.execute(joinPlan.analyze)
val correctAnswer = projectPlan.join(projectPlanB).select('a, 'd).analyze
comparePlans(optimized, correctAnswer)
}

test("should not remove orderBy with limit in join clause") {
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc).limit(10)
val projectPlanB = testRelationB.select('d)
val joinPlan = orderByPlan.join(projectPlanB).select('a, 'd)
val optimized = Optimize.execute(joinPlan.analyze)
val correctAnswer = joinPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("should not remove orderBy in left join clause if there is an outer limit") {
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val projectPlanB = testRelationB.select('d)
val joinPlan = orderByPlan
.join(projectPlanB, LeftOuter)
.limit(10)
val optimized = Optimize.execute(joinPlan.analyze)
val correctAnswer = PushDownOptimizer.execute(joinPlan.analyze)
comparePlans(optimized, correctAnswer)
}

test("remove orderBy in right join clause event if there is an outer limit") {
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val projectPlanB = testRelationB.select('d)
val joinPlan = orderByPlan
.join(projectPlanB, RightOuter)
.limit(10)
val optimized = Optimize.execute(joinPlan.analyze)
val noOrderByPlan = projectPlan
.join(projectPlanB, RightOuter)
.limit(10)
val correctAnswer = PushDownOptimizer.execute(noOrderByPlan.analyze)
comparePlans(optimized, correctAnswer)
}
}
Loading