diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 6e9a4df828246..d1569a4ec2b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -31,7 +31,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { sqlContext.experimental.extraStrategies ++ ( DataSourceStrategy :: DDLStrategy :: - TakeOrderedAndProject :: + SpecialLimits :: Aggregation :: LeftSemiJoin :: EquiJoinSelection :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ee392e4e8debe..598ddd71613b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -33,6 +33,31 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => + /** + * Plans special cases of limit operators. + */ + object SpecialLimits extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.ReturnAnswer(rootPlan) => rootPlan match { + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil + case logical.Limit(IntegerLiteral(limit), child) => + execution.CollectLimit(limit, planLater(child)) :: Nil + case other => planLater(other) :: Nil + } + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil + case _ => Nil + } + } + object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys( @@ -264,18 +289,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) - object TakeOrderedAndProject extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil - case logical.Limit( - IntegerLiteral(limit), - logical.Project(projectList, logical.Sort(order, true, child))) => - execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil - case _ => Nil - } - } - object InMemoryScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => @@ -338,8 +351,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil - case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) => - execution.CollectLimit(limit, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => val perPartitionLimit = execution.LocalLimit(limit, planLater(child)) val globalLimit = execution.GlobalLimit(limit, perPartitionLimit) @@ -362,7 +373,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil - case logical.ReturnAnswer(child) => planLater(child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 256f4228ae99e..04daf9d0ce2a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -83,8 +83,7 @@ case class TakeOrderedAndProject( child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = { - val projectOutput = projectList.map(_.map(_.toAttribute)) - projectOutput.getOrElse(child.output) + projectList.map(_.map(_.toAttribute)).getOrElse(child.output) } override def outputPartitioning: Partitioning = SinglePartition @@ -93,7 +92,7 @@ case class TakeOrderedAndProject( // and this ordering needs to be created on the driver in order to be passed into Spark core code. private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) - private def collectData(): Array[InternalRow] = { + override def executeCollect(): Array[InternalRow] = { val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) if (projectList.isDefined) { val proj = UnsafeProjection.create(projectList.get, child.output) @@ -103,13 +102,26 @@ case class TakeOrderedAndProject( } } - override def executeCollect(): Array[InternalRow] = { - collectData() - } + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - // TODO: Terminal split should be implemented differently from non-terminal split. - // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) + protected override def doExecute(): RDD[InternalRow] = { + val localTopK: RDD[InternalRow] = { + child.execute().map(_.copy()).mapPartitions { iter => + org.apache.spark.util.collection.Utils.takeOrdered(iter, limit)(ord) + } + } + val shuffled = new ShuffledRowRDD( + Exchange.prepareShuffleDependency(localTopK, child.output, SinglePartition, serializer)) + shuffled.mapPartitions { iter => + val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + topK.map(r => proj(r)) + } else { + topK + } + } + } override def outputOrdering: Seq[SortOrder] = sortOrder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a64ad4038c7c3..250ce8f86698f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -161,30 +162,37 @@ class PlannerSuite extends SharedSQLContext { } } - test("efficient limit -> project -> sort") { - { - val query = - testData.select('key, 'value).sort('key).limit(2).logicalPlan - val planned = sqlContext.planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) - assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) - } + test("efficient terminal limit -> sort should use TakeOrderedAndProject") { + val query = testData.select('key, 'value).sort('key).limit(2) + val planned = query.queryExecution.executedPlan + assert(planned.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.output === testData.select('key, 'value).logicalPlan.output) + } - { - // We need to make sure TakeOrderedAndProject's output is correct when we push a project - // into it. - val query = - testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan - val planned = sqlContext.planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) - assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) - } + test("terminal limit -> project -> sort should use TakeOrderedAndProject") { + val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2) + val planned = query.queryExecution.executedPlan + assert(planned.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.output === testData.select('value, 'key).logicalPlan.output) } - test("terminal limits use CollectLimit") { + test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { val query = testData.select('value).limit(2) val planned = query.queryExecution.sparkPlan assert(planned.isInstanceOf[CollectLimit]) + assert(planned.output === testData.select('value).logicalPlan.output) + } + + test("TakeOrderedAndProject can appear in the middle of plans") { + val query = testData.select('key, 'value).sort('key).limit(2).filter('key === 3) + val planned = query.queryExecution.executedPlan + assert(planned.find(_.isInstanceOf[TakeOrderedAndProject]).isDefined) + } + + test("CollectLimit can appear in the middle of a plan when caching is used") { + val query = testData.select('key, 'value).limit(2).cache() + val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] + assert(planned.child.isInstanceOf[CollectLimit]) } test("PartitioningCollection") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala new file mode 100644 index 0000000000000..03cb04a5f7a03 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + + +class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { + + private var rand: Random = _ + private var seed: Long = 0 + + protected override def beforeAll(): Unit = { + super.beforeAll() + seed = System.currentTimeMillis() + rand = new Random(seed) + } + + private def generateRandomInputData(): DataFrame = { + val schema = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", IntegerType, nullable = false) + val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) + sqlContext.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + } + + /** + * Adds a no-op filter to the child plan in order to prevent executeCollect() from being + * called directly on the child plan. + */ + private def noOpFilter(plan: SparkPlan): SparkPlan = Filter(Literal(true), plan) + + val limit = 250 + val sortOrder = 'a.desc :: 'b.desc :: Nil + + test("TakeOrderedAndProject.doExecute without project") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProject(limit, sortOrder, None, input)), + input => + GlobalLimit(limit, + LocalLimit(limit, + Sort(sortOrder, global = true, input))), + sortAnswers = false) + } + } + + test("TakeOrderedAndProject.doExecute with project") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProject(limit, sortOrder, Some(Seq(input.output.last)), input)), + input => + GlobalLimit(limit, + LocalLimit(limit, + Project(Seq(input.output.last), + Sort(sortOrder, global = true, input)))), + sortAnswers = false) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 05863ae18350d..2433b54ffcb8e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -559,7 +559,7 @@ class HiveContext private[hive]( HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrderedAndProject, + SpecialLimits, InMemoryScans, HiveTableScans, DataSinks,