From fe42a5e8f5d002d22bd53a4cbcb81607efa10ab1 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 17 Jun 2014 01:16:01 -0700 Subject: [PATCH] add treeAggregate --- .../main/scala/org/apache/spark/rdd/RDD.scala | 24 +++++++++++++++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 10 ++++++++ 2 files changed, 34 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 27cc60d775788..a963f34929108 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -862,6 +862,30 @@ abstract class RDD[T: ClassTag]( jobResult } + @DeveloperApi + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + level: Int): U = { + require(level >= 1, s"Level must be greater than 1 but got $level.") + if (this.partitions.size == 0) { + return Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + } + val cleanSeqOp = sc.clean(seqOp) + val cleanCombOp = sc.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var local = this.mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = local.partitions.size + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / level)).toInt, 2) + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + local = local.mapPartitionsWithIndex { (i, iter) => + iter.map((i % numPartitions, _)) + }.reduceByKey(new HashPartitioner(numPartitions), cleanCombOp).values + } + local.reduce(cleanCombOp) + } + /** * Return the number of elements in the RDD. */ diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e94a1e76d410c..28ad63c998e41 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -769,4 +769,14 @@ class RDDSuite extends FunSuite with SharedSparkContext { mutableDependencies += dep } } + + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp = (c: Long, x: Int) => c + x + def combOp = (c1: Long, c2: Long) => c1 + c2 + for (level <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, level) + assert(sum === -1000L) + } + } }