Skip to content

Commit

Permalink
[SPARK-23005][CORE] Improve RDD.take on small number of partitions
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
In current implementation of RDD.take, we overestimate the number of partitions we need to try by 50%:
`(1.5 * num * partsScanned / buf.size).toInt`
However, when the number is small, the result of `.toInt` is not what we want.
E.g, 2.9 will become 2, which should be 3.
Use Math.ceil to fix the problem.

Also clean up the code in RDD.scala.

## How was this patch tested?

Unit test

Author: Wang Gengliang <ltnwgl@gmail.com>

Closes #20200 from gengliangwang/Take.
  • Loading branch information
gengliangwang authored and cloud-fan committed Jan 10, 2018
1 parent 2250cb7 commit 96ba217
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
27 changes: 13 additions & 14 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ abstract class RDD[T: ClassTag](
val id: Int = sc.newRddId()

/** A friendly name for this RDD */
@transient var name: String = null
@transient var name: String = _

/** Assign a name to this RDD */
def setName(_name: String): this.type = {
Expand Down Expand Up @@ -224,8 +224,8 @@ abstract class RDD[T: ClassTag](

// Our dependencies and partitions will be gotten by calling subclass's methods below, and will
// be overwritten when we're checkpointed
private var dependencies_ : Seq[Dependency[_]] = null
@transient private var partitions_ : Array[Partition] = null
private var dependencies_ : Seq[Dependency[_]] = _
@transient private var partitions_ : Array[Partition] = _

/** An Option holding our checkpoint RDD, if we are checkpointed */
private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
Expand Down Expand Up @@ -297,7 +297,7 @@ abstract class RDD[T: ClassTag](
private[spark] def getNarrowAncestors: Seq[RDD[_]] = {
val ancestors = new mutable.HashSet[RDD[_]]

def visit(rdd: RDD[_]) {
def visit(rdd: RDD[_]): Unit = {
val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]])
val narrowParents = narrowDependencies.map(_.rdd)
val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains)
Expand Down Expand Up @@ -449,7 +449,7 @@ abstract class RDD[T: ClassTag](
if (shuffle) {
/** Distributes elements evenly across output partitions, starting from a random partition. */
val distributePartition = (index: Int, items: Iterator[T]) => {
var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions)
var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions)
items.map { t =>
// Note that the hash code of the key will just be the key itself. The HashPartitioner
// will mod it with the number of total partitions.
Expand Down Expand Up @@ -951,7 +951,7 @@ abstract class RDD[T: ClassTag](
def collectPartition(p: Int): Array[T] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head
}
(0 until partitions.length).iterator.flatMap(i => collectPartition(i))
partitions.indices.iterator.flatMap(i => collectPartition(i))
}

/**
Expand Down Expand Up @@ -1338,20 +1338,20 @@ abstract class RDD[T: ClassTag](
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1L
val left = num - buf.size
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
if (buf.isEmpty) {
numPartsToTry = partsScanned * scaleUpFactor
} else {
// the left side of max is >=1 whenever partsScanned >= 2
numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
// As left > 0, numPartsToTry is always >= 1
numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor)
}
}

val left = num - buf.size
val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)

Expand Down Expand Up @@ -1677,16 +1677,15 @@ abstract class RDD[T: ClassTag](
// an RDD and its parent in every batch, in which case the parent may never be checkpointed
// and its lineage never truncated, leading to OOMs in the long run (SPARK-6847).
private val checkpointAllMarkedAncestors =
Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS))
.map(_.toBoolean).getOrElse(false)
Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).exists(_.toBoolean)

/** Returns the first parent RDD */
protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
}

/** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */
protected[spark] def parent[U: ClassTag](j: Int) = {
protected[spark] def parent[U: ClassTag](j: Int): RDD[U] = {
dependencies(j).rdd.asInstanceOf[RDD[U]]
}

Expand Down Expand Up @@ -1754,7 +1753,7 @@ abstract class RDD[T: ClassTag](
* collected. Subclasses of RDD may override this method for implementing their own cleaning
* logic. See [[org.apache.spark.rdd.UnionRDD]] for an example.
*/
protected def clearDependencies() {
protected def clearDependencies(): Unit = {
dependencies_ = null
}

Expand Down Expand Up @@ -1790,7 +1789,7 @@ abstract class RDD[T: ClassTag](
val lastDepStrings =
debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true)

(frontDepStrings ++ lastDepStrings)
frontDepStrings ++ lastDepStrings
}
}
// The first RDD in the dependency stack has no parents, so no need for a +-
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
if (buf.isEmpty) {
numPartsToTry = partsScanned * limitScaleUpFactor
} else {
// the left side of max is >=1 whenever partsScanned >= 2
numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1)
val left = n - buf.size
// As left > 0, numPartsToTry is always >= 1
numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
}
}
Expand Down

0 comments on commit 96ba217

Please sign in to comment.