From 6ba8b284ec8f43a76c9ba54349438e484a097223 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 18 Jan 2017 08:21:17 +0000 Subject: [PATCH 1/2] Make GlobalLimit without shuffling data to single partition. --- .../apache/spark/sql/execution/limit.scala | 91 +++++++++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 ++ 2 files changed, 84 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 757fe2185d302..fdd246af456bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import scala.collection.mutable.HashMap + import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow @@ -47,17 +49,19 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } /** - * Helper trait which defines methods that are shared by both - * [[LocalLimitExec]] and [[GlobalLimitExec]]. + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -trait BaseLimitExec extends UnaryExecNode with CodegenSupport { - val limit: Int +case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() } @@ -90,21 +94,76 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Take the first `limit` elements of the child's partitions. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning -} +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { + override def output: Seq[Attribute] = child.output -/** - * Take the first `limit` elements of the child's single output partition. - */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + protected override def doExecute(): RDD[InternalRow] = { + // This logic is mainly copyed from `SparkPlan.executeTake`. + // TODO: combine this with `SparkPlan.executeTake`, if possible. + val childRDD = child.execute() + val totalParts = childRDD.partitions.length + var partsScanned = 0 + var totalNum = 0 + var resultRDD: RDD[InternalRow] = null + while (totalNum < limit && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1L + if (partsScanned > 0) { + // If we didn't find any rows after the previous iteration, quadruple and retry. + // Otherwise, interpolate the number of partitions we need to try, but overestimate + // it by 50%. We also cap the estimation in the end. + val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2) + if (totalNum == 0) { + numPartsToTry = partsScanned * limitScaleUpFactor + } else { + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * limit * partsScanned / totalNum).toInt - partsScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) + } + } - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) + val sc = sqlContext.sparkContext + val res = sc.runJob(childRDD, + (it: Iterator[InternalRow]) => Array[Int](it.size), p) + + totalNum += res.map(_.head).sum + partsScanned += p.size + + if (totalNum >= limit) { + // If we scan more rows than the limit number, we need to reduce that from scanned. + // We calculate how many rows need to be reduced for each partition, + // until all redunant rows are reduced. + var numToReduce = (totalNum - limit) + val reduceAmounts = new HashMap[Int, Int]() + val partitionsToReduce = p.zip(res.map(_.head)).foreach { case (part, size) => + val toReduce = if (size > numToReduce) numToReduce else size + reduceAmounts += ((part, toReduce)) + numToReduce -= toReduce + } + resultRDD = childRDD.mapPartitionsWithIndexInternal { case (index, iter) => + if (index < partsScanned) { + // Those partitions are scanned. + reduceAmounts.get(index).map { size => + iter.drop(size) + }.getOrElse(iter) + } else { + // Those partitions are not scanned. + Array.empty[InternalRow].toIterator + } + } + } + } + // If totalNum is less than limit after we scan all partitions, just return all the data. + if (resultRDD == null) { + childRDD + } else { + resultRDD + } + } override def outputPartitioning: Partitioning = child.outputPartitioning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 605dec4a1ef90..56458189e2f94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -530,6 +530,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(e.contains(expected)) } + test("limit for skew dataframe") { + // Create a skew dataframe. + val df = testData.repartition(100).union(testData).limit(50) + // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`, + // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit` + // work on skew partitions. + assert(df.rdd.count() == 50L) + } + test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), From 3cbd6ee19a994d368a4130da47a2554bd0019679 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Jan 2017 04:06:19 +0000 Subject: [PATCH 2/2] Fix test because we use less shuffle exchange now. --- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0bfc92fdb6218..0ebb3eb32fac3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -216,7 +216,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchange => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } { @@ -231,7 +231,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchange => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } }