Skip to content

Commit

Permalink
optimzie RF communicaiton cost
Browse files Browse the repository at this point in the history
  • Loading branch information
Peng Meng committed Aug 10, 2017
1 parent 84454d7 commit 35d1f24
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tree.impl

import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity._


Expand Down Expand Up @@ -74,7 +75,15 @@ private[spark] class DTStatsAggregator(
* Index for start of stats for a (feature, bin) is:
* index = featureOffsets(featureIndex) + binIndex * statsSize
*/
private val allStats: Array[Double] = new Array[Double](allStatsSize)
@transient private var allStats: Array[Double] = new Array[Double](allStatsSize)

// This is used for reducing shuffle communication cost
private var compressedAllStats: Vector = null

def compressAllStats(): DTStatsAggregator = {
compressedAllStats = Vectors.dense(allStats).compressed
this
}

/**
* Array of parent node sufficient stats.
Expand Down Expand Up @@ -159,13 +168,13 @@ private[spark] class DTStatsAggregator(
require(allStatsSize == other.allStatsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
+ s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
var i = 0
// TODO: Test BLAS.axpy
while (i < allStatsSize) {
allStats(i) += other.allStats(i)
i += 1

if(allStats == null) {
allStats = compressedAllStats.toArray
}

other.compressedAllStats.foreachActive((i, v) => allStats(i) += v)

require(statsSize == other.statsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
s"stats vectors. This aggregator's parent stats are length $statsSize, " +
Expand Down
Expand Up @@ -526,9 +526,10 @@ private[spark] object RandomForest extends Logging {
// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))

// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// compress allStats of nodeAggregateStats,
// and transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
nodeStatsAggregators.view.map(_.compressAllStats()).zipWithIndex.map(_.swap).iterator
}
} else {
input.mapPartitions { points =>
Expand All @@ -544,9 +545,10 @@ private[spark] object RandomForest extends Logging {
// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOp(nodeStatsAggregators, _))

// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// compress allStats of nodeAggregateStats,
// and transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
nodeStatsAggregators.view.map(_.compressAllStats()).zipWithIndex.map(_.swap).iterator
}
}

Expand Down

0 comments on commit 35d1f24

Please sign in to comment.