Skip to content

Commit

Permalink
[SPARK-25352][SQL] Perform ordered global limit when limit number is …
Browse files Browse the repository at this point in the history
…bigger than topKSortFallbackThreshold

## What changes were proposed in this pull request?

We have optimization on global limit to evenly distribute limit rows across all partitions. This optimization doesn't work for ordered results.

For a query ending with sort + limit, in most cases it is performed by `TakeOrderedAndProjectExec`.

But if limit number is bigger than `SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD`, global limit will be used. At this moment, we need to do ordered global limit.

## How was this patch tested?

Unit tests.

Closes #22344 from viirya/SPARK-25352.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
viirya authored and cloud-fan committed Sep 12, 2018
1 parent 79cc597 commit 2f42239
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 56 deletions.
Expand Up @@ -68,22 +68,42 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object SpecialLimits extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ReturnAnswer(rootPlan) => rootPlan match {
case Limit(IntegerLiteral(limit), Sort(order, true, child))
if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) =>
if (limit < conf.topKSortFallbackThreshold) {
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
} else {
GlobalLimitExec(limit,
LocalLimitExec(limit, planLater(s)),
orderedLimit = true) :: Nil
}
case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) =>
if (limit < conf.topKSortFallbackThreshold) {
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
} else {
GlobalLimitExec(limit,
LocalLimitExec(limit, planLater(p)),
orderedLimit = true) :: Nil
}
case Limit(IntegerLiteral(limit), child) =>
CollectLimitExec(limit, planLater(child)) :: Nil
case other => planLater(other) :: Nil
}
case Limit(IntegerLiteral(limit), Sort(order, true, child))
if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) =>
if (limit < conf.topKSortFallbackThreshold) {
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
} else {
GlobalLimitExec(limit,
LocalLimitExec(limit, planLater(s)),
orderedLimit = true) :: Nil
}
case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) =>
if (limit < conf.topKSortFallbackThreshold) {
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
} else {
GlobalLimitExec(limit,
LocalLimitExec(limit, planLater(p)),
orderedLimit = true) :: Nil
}
case _ => Nil
}
}
Expand Down
Expand Up @@ -98,7 +98,8 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode wi
/**
* Take the `limit` elements of the child output.
*/
case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode {
case class GlobalLimitExec(limit: Int, child: SparkPlan,
orderedLimit: Boolean = false) extends UnaryExecNode {

override def output: Seq[Attribute] = child.output

Expand Down Expand Up @@ -126,7 +127,9 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode {
// When enabled, Spark goes to take rows at each partition repeatedly until reaching
// limit number. When disabled, Spark takes all rows at first partition, then rows
// at second partition ..., until reaching limit number.
val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit
// The optimization is disabled when it is needed to keep the original order of rows
// before global sort, e.g., select * from table order by col limit 10.
val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && !orderedLimit

val shuffled = new ShuffledRowRDD(shuffleDependency)

Expand Down
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, TakeOrderedAndProjectExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -2552,6 +2552,26 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}

test("SPARK-25352: Ordered global limit when more than topKSortFallbackThreshold ") {
withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
val baseDf = spark.range(1000).toDF.repartition(3).sort("id")

withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") {
val expected = baseDf.limit(99)
val takeOrderedNode1 = expected.queryExecution.executedPlan
.find(_.isInstanceOf[TakeOrderedAndProjectExec])
assert(takeOrderedNode1.isDefined)

val result = baseDf.limit(100)
val takeOrderedNode2 = result.queryExecution.executedPlan
.find(_.isInstanceOf[TakeOrderedAndProjectExec])
assert(takeOrderedNode2.isEmpty)

checkAnswer(expected, result.collect().take(99))
}
}
}

test("SPARK-25368 Incorrect predicate pushdown returns wrong result") {
def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = {
val df1 = spark.createDataFrame(Seq(
Expand Down
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext


class LimitSuite extends SparkPlanTest with SharedSQLContext {

private var rand: Random = _
private var seed: Long = 0

protected override def beforeAll(): Unit = {
super.beforeAll()
seed = System.currentTimeMillis()
rand = new Random(seed)
}

test("Produce ordered global limit if more than topKSortFallbackThreshold") {
withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") {
val df = LimitTest.generateRandomInputData(spark, rand).sort("a")

val globalLimit = df.limit(99).queryExecution.executedPlan.collect {
case g: GlobalLimitExec => g
}
assert(globalLimit.size == 0)

val topKSort = df.limit(99).queryExecution.executedPlan.collect {
case t: TakeOrderedAndProjectExec => t
}
assert(topKSort.size == 1)

val orderedGlobalLimit = df.limit(100).queryExecution.executedPlan.collect {
case g: GlobalLimitExec => g
}
assert(orderedGlobalLimit.size == 1 && orderedGlobalLimit(0).orderedLimit == true)
}
}

test("Ordered global limit") {
val baseDf = LimitTest.generateRandomInputData(spark, rand)
.select("a").repartition(3).sort("a")

withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
val orderedGlobalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan,
orderedLimit = true)
val orderedGlobalLimitResult = SparkPlanTest.executePlan(orderedGlobalLimit, spark.sqlContext)
.map(_.getInt(0))

val globalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, orderedLimit = false)
val globalLimitResult = SparkPlanTest.executePlan(globalLimit, spark.sqlContext)
.map(_.getInt(0))

// Global limit without order takes values at each partition sequentially.
// After global sort, the values in second partition must be larger than the values
// in first partition.
assert(orderedGlobalLimitResult(0) == globalLimitResult(0))
assert(orderedGlobalLimitResult(1) < globalLimitResult(1))
assert(orderedGlobalLimitResult(2) < globalLimitResult(2))
}
}
}

Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -32,28 +32,10 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
private var rand: Random = _
private var seed: Long = 0

private val originalLimitFlatGlobalLimit = SQLConf.get.getConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT)

protected override def beforeAll(): Unit = {
super.beforeAll()
seed = System.currentTimeMillis()
rand = new Random(seed)

// Disable the optimization to make Sort-Limit match `TakeOrderedAndProject` semantics.
SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false)
}

protected override def afterAll() = {
SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit)
super.afterAll()
}

private def generateRandomInputData(): DataFrame = {
val schema = new StructType()
.add("a", IntegerType, nullable = false)
.add("b", IntegerType, nullable = false)
val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
}

/**
Expand All @@ -66,32 +48,62 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
val sortOrder = 'a.desc :: 'b.desc :: Nil

test("TakeOrderedAndProject.doExecute without project") {
withClue(s"seed = $seed") {
checkThatPlansAgree(
generateRandomInputData(),
input =>
noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
input =>
GlobalLimitExec(limit,
LocalLimitExec(limit,
SortExec(sortOrder, true, input))),
sortAnswers = false)
withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") {
withClue(s"seed = $seed") {
checkThatPlansAgree(
LimitTest.generateRandomInputData(spark, rand),
input =>
noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
input =>
GlobalLimitExec(limit,
LocalLimitExec(limit,
SortExec(sortOrder, true, input))),
sortAnswers = false)
}
}
}

test("TakeOrderedAndProject.doExecute with project") {
withClue(s"seed = $seed") {
checkThatPlansAgree(
generateRandomInputData(),
input =>
noOpFilter(
TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
input =>
GlobalLimitExec(limit,
LocalLimitExec(limit,
ProjectExec(Seq(input.output.last),
SortExec(sortOrder, true, input)))),
sortAnswers = false)
withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") {
withClue(s"seed = $seed") {
checkThatPlansAgree(
LimitTest.generateRandomInputData(spark, rand),
input =>
noOpFilter(
TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
input =>
GlobalLimitExec(limit,
LocalLimitExec(limit,
ProjectExec(Seq(input.output.last),
SortExec(sortOrder, true, input)))),
sortAnswers = false)
}
}
}

test("TakeOrderedAndProject.doExecute equals to ordered global limit") {
withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
withClue(s"seed = $seed") {
checkThatPlansAgree(
LimitTest.generateRandomInputData(spark, rand),
input =>
noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
input =>
GlobalLimitExec(limit,
LocalLimitExec(limit,
SortExec(sortOrder, true, input)), orderedLimit = true),
sortAnswers = false)
}
}
}
}

object LimitTest {
def generateRandomInputData(spark: SparkSession, rand: Random): DataFrame = {
val schema = new StructType()
.add("a", IntegerType, nullable = false)
.add("b", IntegerType, nullable = false)
val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
spark.createDataFrame(spark.sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
}
}

0 comments on commit 2f42239

Please sign in to comment.