Skip to content

Commit

Permalink
add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jun 17, 2014
1 parent eb71c33 commit 0f94490
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,15 @@ abstract class RDD[T: ClassTag](
jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}

def treeReduce(f: (T, T) => T, level: Int): T = {
require(level >= 1, s"Level must be greater than 1 but got $level.")
/**
* :: DeveloperApi ::
* Reduces the elements of this RDD in a tree pattern.
* @param depth suggested depth of the tree
* @see [[org.apache.spark.rdd.RDD#reduce]]
*/
@DeveloperApi
def treeReduce(f: (T, T) => T, depth: Int): T = {
require(depth >= 1, s"Depth must be greater than 1 but got $depth.")
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Expand All @@ -849,7 +856,7 @@ abstract class RDD[T: ClassTag](
None
}
}
local.treeAggregate(Option.empty[T])(op, op, level)
local.treeAggregate(Option.empty[T])(op, op, depth)
.getOrElse(throw new UnsupportedOperationException("empty collection"))
}

Expand Down Expand Up @@ -888,12 +895,18 @@ abstract class RDD[T: ClassTag](
jobResult
}

/**
* :: DeveloperApi ::
* Aggregates the elements of this RDD in a tree pattern.
* @param depth suggested depth of the tree
* @see [[org.apache.spark.rdd.RDD#aggregate]]
*/
@DeveloperApi
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
level: Int): U = {
require(level >= 1, s"Level must be greater than 1 but got $level.")
depth: Int): U = {
require(depth >= 1, s"Depth must be greater than 1 but got $depth.")
if (this.partitions.size == 0) {
return Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
}
Expand All @@ -902,7 +915,7 @@ abstract class RDD[T: ClassTag](
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var local = this.mapPartitions(it => Iterator(aggregatePartition(it)))
var numPartitions = local.partitions.size
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / level)).toInt, 2)
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
while (numPartitions > scale + numPartitions / scale) {
numPartitions /= scale
local = local.mapPartitionsWithIndex { (i, iter) =>
Expand Down

0 comments on commit 0f94490

Please sign in to comment.