Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23973][SQL] Remove consecutive Sorts #21072

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -736,12 +736,22 @@ object EliminateSorts extends Rule[LogicalPlan] {
}

/**
* Removes Sort operation if the child is already sorted
* Removes redundant Sort operation. This can happen:
* 1) if the child is already sorted
* 2) if there is another Sort operator separated by 0...n Project/Filter operators
*/
object RemoveRedundantSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: now it's more efficient to do transformDown

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it the same?

Copy link
Contributor

@cloud-fan cloud-fan Apr 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assume the plan is

Sort
  Filter
    Sort
      Filter
        Sort
          OtherPlan

If we do transformUp, we hit the rule 3 times, which has some unnecessary transformation(OtherPlan is transformed 3 times). If it's transformDown, it's one-shot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but I saw that transfrom actually does transformDown. Anyway, I see that this might change and here we best have transformDown

case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
child
case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child))
}

def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match {
case Project(fields, child) => Project(fields, recursiveRemoveSort(child))
case Filter(condition, child) => Filter(condition, recursiveRemoveSort(child))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should at least add ResolvedHint. To easily expand the white list in the future, I'd like to change the code style to

def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match {
  case s: Sort => recursiveRemoveSort(s.child)
  case other if canEliminateSort(other) => other.withNewChildren(other.children.map(recursiveRemoveSort))
  case _ => plan
}

def canEliminateSort(plan: LogicalPlan): Boolean = plan match {
  case p: Project => p.projectList.forall(_.deterministic)
  case f: Filter => f.condition.deterministic
  case _: ResolvedHint => true
  ...
  case _ => false
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you think we should check for the filter condition and the projected items to be deterministic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by the definition of deterministic, the entire input is the stats of the expression. It's very likely we will get a different result if we remove sort before filter, e.g. rowId() < 10 will get the first 10 rows, if you sort the input, the first 10 rows changed.

I think we should be conservative about deterministic expressions.

case Sort(_, _, child) => recursiveRemoveSort(child)
case _ => plan
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL}

class RemoveRedundantSortsSuite extends PlanTest {

Expand All @@ -42,15 +38,15 @@ class RemoveRedundantSortsSuite extends PlanTest {

test("remove redundant order by") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst)
val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
val optimized = Optimize.execute(unnecessaryReordered.analyze)
val correctAnswer = orderedPlan.select('a).analyze
val correctAnswer = orderedPlan.limit(2).select('a).analyze
comparePlans(Optimize.execute(optimized), correctAnswer)
}

test("do not remove sort if the order is different") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc)
val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(reorderedDifferently.analyze)
val correctAnswer = reorderedDifferently.analyze
comparePlans(optimized, correctAnswer)
Expand All @@ -72,6 +68,14 @@ class RemoveRedundantSortsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("different sorts are not simplified if limit is in between") {
val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10))
.orderBy('a.asc)
val optimized = Optimize.execute(orderedPlan.analyze)
val correctAnswer = orderedPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("range is already sorted") {
val inputPlan = Range(1L, 1000L, 1, 10)
val orderedPlan = inputPlan.orderBy('id.asc)
Expand All @@ -98,4 +102,37 @@ class RemoveRedundantSortsSuite extends PlanTest {
val correctAnswer = groupedAndResorted.analyze
comparePlans(optimized, correctAnswer)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test which explicitly confirms that sort.limit.sort is not simplified? I know the above two tests cover that case, but it's good to have one dedicated to testing this important property.

test("remove two consecutive sorts") {
val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc)
val optimized = Optimize.execute(orderedTwice.analyze)
val correctAnswer = testRelation.orderBy('b.desc).analyze
comparePlans(optimized, correctAnswer)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for three consecutive sorts? Two is the base case, three will help us show the inductive case :)


test("remove sorts separated by Filter/Project operators") {
val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc)
val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze)
val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze
comparePlans(optimizedWithProject, correctAnswerWithProject)

val orderedTwiceWithFilter =
testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc)
val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze)
val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze
comparePlans(optimizedWithFilter, correctAnswerWithFilter)

val orderedTwiceWithBoth =
testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc)
val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze)
val correctAnswerWithBoth =
testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze
comparePlans(optimizedWithBoth, correctAnswerWithBoth)

val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc)
val optimizedThrice = Optimize.execute(orderedThrice.analyze)
val correctAnswerThrice = testRelation.select('b).where('b > Literal(0))
.select(('b + 1).as('c)).orderBy('c.asc).analyze
comparePlans(optimizedThrice, correctAnswerThrice)
}
}