diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a963f34929108..dbde4a8db0011 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -827,6 +827,32 @@ abstract class RDD[T: ClassTag]( jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } + def treeReduce(f: (T, T) => T, level: Int): T = { + require(level >= 1, s"Level must be greater than 1 but got $level.") + val cleanF = sc.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val local = this.mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + local.treeAggregate(Option.empty[T])(op, op, level) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 28ad63c998e41..db46691990ce3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -779,4 +779,12 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sum === -1000L) } } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (level <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, level) + assert(sum === -1000) + } + } }