From 35d1f244f918bd8ea7fe7fdf10796a64e7a62fc9 Mon Sep 17 00:00:00 2001 From: Peng Meng Date: Thu, 10 Aug 2017 15:19:59 +0800 Subject: [PATCH] optimzie RF communicaiton cost --- .../ml/tree/impl/DTStatsAggregator.scala | 21 +++++++++++++------ .../spark/ml/tree/impl/RandomForest.scala | 10 +++++---- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 5aeea1443d499..e81c75440f546 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree.impl +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.mllib.tree.impurity._ @@ -74,7 +75,15 @@ private[spark] class DTStatsAggregator( * Index for start of stats for a (feature, bin) is: * index = featureOffsets(featureIndex) + binIndex * statsSize */ - private val allStats: Array[Double] = new Array[Double](allStatsSize) + @transient private var allStats: Array[Double] = new Array[Double](allStatsSize) + + // This is used for reducing shuffle communication cost + private var compressedAllStats: Vector = null + + def compressAllStats(): DTStatsAggregator = { + compressedAllStats = Vectors.dense(allStats).compressed + this + } /** * Array of parent node sufficient stats. @@ -159,13 +168,13 @@ private[spark] class DTStatsAggregator( require(allStatsSize == other.allStatsSize, s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") - var i = 0 - // TODO: Test BLAS.axpy - while (i < allStatsSize) { - allStats(i) += other.allStats(i) - i += 1 + + if(allStats == null) { + allStats = compressedAllStats.toArray } + other.compressedAllStats.foreachActive((i, v) => allStats(i) += v) + require(statsSize == other.statsSize, s"DTStatsAggregator.merge requires that both aggregators have the same length parent " + s"stats vectors. This aggregator's parent stats are length $statsSize, " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 82e1ed85a0a14..bc311f37df92d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -526,9 +526,10 @@ private[spark] object RandomForest extends Logging { // iterator all instances in current partition and update aggregate stats points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) - // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // compress allStats of nodeAggregateStats, + // and transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` - nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + nodeStatsAggregators.view.map(_.compressAllStats()).zipWithIndex.map(_.swap).iterator } } else { input.mapPartitions { points => @@ -544,9 +545,10 @@ private[spark] object RandomForest extends Logging { // iterator all instances in current partition and update aggregate stats points.foreach(binSeqOp(nodeStatsAggregators, _)) - // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // compress allStats of nodeAggregateStats, + // and transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` - nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + nodeStatsAggregators.view.map(_.compressAllStats()).zipWithIndex.map(_.swap).iterator } }