From 2f422398b524eacc89ab58e423bb134ae3ca3941 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 12 Sep 2018 22:54:05 +0800 Subject: [PATCH] [SPARK-25352][SQL] Perform ordered global limit when limit number is bigger than topKSortFallbackThreshold ## What changes were proposed in this pull request? We have optimization on global limit to evenly distribute limit rows across all partitions. This optimization doesn't work for ordered results. For a query ending with sort + limit, in most cases it is performed by `TakeOrderedAndProjectExec`. But if limit number is bigger than `SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD`, global limit will be used. At this moment, we need to do ordered global limit. ## How was this patch tested? Unit tests. Closes #22344 from viirya/SPARK-25352. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/execution/SparkStrategies.scala | 44 ++++++--- .../apache/spark/sql/execution/limit.scala | 7 +- .../org/apache/spark/sql/DataFrameSuite.scala | 22 ++++- .../spark/sql/execution/LimitSuite.scala | 81 ++++++++++++++++ .../TakeOrderedAndProjectSuite.scala | 94 +++++++++++-------- 5 files changed, 192 insertions(+), 56 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbc6db62bd820..7c8ce316f9647 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -68,22 +68,42 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) => + if (limit < conf.topKSortFallbackThreshold) { + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(s)), + orderedLimit = true) :: Nil + } + case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) => + if (limit < conf.topKSortFallbackThreshold) { + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(p)), + orderedLimit = true) :: Nil + } case Limit(IntegerLiteral(limit), child) => CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), Sort(order, true, child)) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) => + if (limit < conf.topKSortFallbackThreshold) { + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(s)), + orderedLimit = true) :: Nil + } + case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) => + if (limit < conf.topKSortFallbackThreshold) { + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(p)), + orderedLimit = true) :: Nil + } case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index fb46970e38f3c..1a09632f93ca1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -98,7 +98,8 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode wi /** * Take the `limit` elements of the child output. */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { +case class GlobalLimitExec(limit: Int, child: SparkPlan, + orderedLimit: Boolean = false) extends UnaryExecNode { override def output: Seq[Attribute] = child.output @@ -126,7 +127,9 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { // When enabled, Spark goes to take rows at each partition repeatedly until reaching // limit number. When disabled, Spark takes all rows at first partition, then rows // at second partition ..., until reaching limit number. - val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit + // The optimization is disabled when it is needed to keep the original order of rows + // before global sort, e.g., select * from table order by col limit 10. + val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && !orderedLimit val shuffled = new ShuffledRowRDD(shuffleDependency) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 279b7b8d49f52..f001b138f4b8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} -import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, QueryExecution, TakeOrderedAndProjectExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ @@ -2552,6 +2552,26 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-25352: Ordered global limit when more than topKSortFallbackThreshold ") { + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val baseDf = spark.range(1000).toDF.repartition(3).sort("id") + + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { + val expected = baseDf.limit(99) + val takeOrderedNode1 = expected.queryExecution.executedPlan + .find(_.isInstanceOf[TakeOrderedAndProjectExec]) + assert(takeOrderedNode1.isDefined) + + val result = baseDf.limit(100) + val takeOrderedNode2 = result.queryExecution.executedPlan + .find(_.isInstanceOf[TakeOrderedAndProjectExec]) + assert(takeOrderedNode2.isEmpty) + + checkAnswer(expected, result.collect().take(99)) + } + } + } + test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { val df1 = spark.createDataFrame(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala new file mode 100644 index 0000000000000..a7840a5fcfae0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + + +class LimitSuite extends SparkPlanTest with SharedSQLContext { + + private var rand: Random = _ + private var seed: Long = 0 + + protected override def beforeAll(): Unit = { + super.beforeAll() + seed = System.currentTimeMillis() + rand = new Random(seed) + } + + test("Produce ordered global limit if more than topKSortFallbackThreshold") { + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { + val df = LimitTest.generateRandomInputData(spark, rand).sort("a") + + val globalLimit = df.limit(99).queryExecution.executedPlan.collect { + case g: GlobalLimitExec => g + } + assert(globalLimit.size == 0) + + val topKSort = df.limit(99).queryExecution.executedPlan.collect { + case t: TakeOrderedAndProjectExec => t + } + assert(topKSort.size == 1) + + val orderedGlobalLimit = df.limit(100).queryExecution.executedPlan.collect { + case g: GlobalLimitExec => g + } + assert(orderedGlobalLimit.size == 1 && orderedGlobalLimit(0).orderedLimit == true) + } + } + + test("Ordered global limit") { + val baseDf = LimitTest.generateRandomInputData(spark, rand) + .select("a").repartition(3).sort("a") + + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val orderedGlobalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, + orderedLimit = true) + val orderedGlobalLimitResult = SparkPlanTest.executePlan(orderedGlobalLimit, spark.sqlContext) + .map(_.getInt(0)) + + val globalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, orderedLimit = false) + val globalLimitResult = SparkPlanTest.executePlan(globalLimit, spark.sqlContext) + .map(_.getInt(0)) + + // Global limit without order takes values at each partition sequentially. + // After global sort, the values in second partition must be larger than the values + // in first partition. + assert(orderedGlobalLimitResult(0) == globalLimitResult(0)) + assert(orderedGlobalLimitResult(1) < globalLimitResult(1)) + assert(orderedGlobalLimitResult(2) < globalLimitResult(2)) + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index f076959dfdf7b..9322204063af3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.internal.SQLConf @@ -32,28 +32,10 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { private var rand: Random = _ private var seed: Long = 0 - private val originalLimitFlatGlobalLimit = SQLConf.get.getConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT) - protected override def beforeAll(): Unit = { super.beforeAll() seed = System.currentTimeMillis() rand = new Random(seed) - - // Disable the optimization to make Sort-Limit match `TakeOrderedAndProject` semantics. - SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) - } - - protected override def afterAll() = { - SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) - super.afterAll() - } - - private def generateRandomInputData(): DataFrame = { - val schema = new StructType() - .add("a", IntegerType, nullable = false) - .add("b", IntegerType, nullable = false) - val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) - spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) } /** @@ -66,32 +48,62 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input))), - sortAnswers = false) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + LimitTest.generateRandomInputData(spark, rand), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input))), + sortAnswers = false) + } } } test("TakeOrderedAndProject.doExecute with project") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - ProjectExec(Seq(input.output.last), - SortExec(sortOrder, true, input)))), - sortAnswers = false) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + LimitTest.generateRandomInputData(spark, rand), + input => + noOpFilter( + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + ProjectExec(Seq(input.output.last), + SortExec(sortOrder, true, input)))), + sortAnswers = false) + } } } + + test("TakeOrderedAndProject.doExecute equals to ordered global limit") { + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + LimitTest.generateRandomInputData(spark, rand), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input)), orderedLimit = true), + sortAnswers = false) + } + } + } +} + +object LimitTest { + def generateRandomInputData(spark: SparkSession, rand: Random): DataFrame = { + val schema = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", IntegerType, nullable = false) + val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) + spark.createDataFrame(spark.sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + } }