Skip to content

Commit

Permalink
Remove more return statements from scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Or committed Apr 29, 2015
1 parent 5e388ea commit 5f07e9c
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,11 @@ abstract class RDD[T: ClassTag](
* @param seed seed for the random number generator
* @return sample of specified size in an array
*/
// TODO: rewrite this without return statements so we can wrap it in a scope
def takeSample(
withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = withScope {
seed: Long = Utils.random.nextLong): Array[T] = {
val numStDev = 10.0

if (num < 0) {
Expand Down Expand Up @@ -1027,23 +1028,26 @@ abstract class RDD[T: ClassTag](
depth: Int = 2): U = withScope {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
if (partitions.length == 0) {
return Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
}
val cleanSeqOp = context.clean(seqOp)
val cleanCombOp = context.clean(combOp)
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
var numPartitions = partiallyAggregated.partitions.length
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
while (numPartitions > scale + numPartitions / scale) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
} else {
val cleanSeqOp = context.clean(seqOp)
val cleanCombOp = context.clean(combOp)
val aggregatePartition =
(it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
var numPartitions = partiallyAggregated.partitions.length
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// If creating an extra level doesn't help reduce
// the wall-clock time, we stop tree aggregation.
while (numPartitions > scale + numPartitions / scale) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
(i, iter) => iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
}
partiallyAggregated.reduce(cleanCombOp)
}
partiallyAggregated.reduce(cleanCombOp)
}

/**
Expand Down

0 comments on commit 5f07e9c

Please sign in to comment.