Skip to content

Commit

Permalink
[SPARK-22009][ML] Using treeAggregate improve some algs
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

I test on a dataset of about 13M instances, and found that using `treeAggregate` give a speedup in following algs:

|Algs| SpeedUp |
|------|-----------|
|OneHotEncoder| 5% |
|StatFunctions.calculateCov| 7% |
|StatFunctions.multipleApproxQuantiles|  9% |
|RegressionEvaluator| 8% |

## How was this patch tested?
existing tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #19232 from zhengruifeng/use_treeAggregate.
  • Loading branch information
zhengruifeng authored and srowen committed Sep 21, 2017
1 parent b21b806 commit a8a5cd2
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
Expand Up @@ -142,7 +142,7 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e
if (outputAttrGroup.size < 0) {
// If the number of attributes is unknown, we check the values from the input column.
val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
.aggregate(0.0)(
.treeAggregate(0.0)(
(m, x) => {
assert(x <= Int.MaxValue,
s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x")
Expand Down
Expand Up @@ -54,7 +54,7 @@ class RegressionMetrics @Since("2.0.0") (
private lazy val summary: MultivariateStatisticalSummary = {
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
case (prediction, observation) => Vectors.dense(observation, observation - prediction)
}.aggregate(new MultivariateOnlineSummarizer())(
}.treeAggregate(new MultivariateOnlineSummarizer())(
(summary, v) => summary.add(v),
(sum1, sum2) => sum1.merge(sum2)
)
Expand Down
Expand Up @@ -95,7 +95,7 @@ object FrequentItems extends Logging {
(name, originalSchema.fields(index).dataType)
}.toArray

val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
Expand Down
Expand Up @@ -99,7 +99,7 @@ object StatFunctions extends Logging {
sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = {
sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) }
}
val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge)
val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)

summaries.map { summary => probabilities.flatMap(summary.query) }
}
Expand Down Expand Up @@ -160,7 +160,7 @@ object StatFunctions extends Logging {
s"for columns with dataType ${data.get.dataType} not supported.")
}
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)(
df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)(
seqOp = (counter, row) => {
counter.add(row.getDouble(0), row.getDouble(1))
},
Expand Down

0 comments on commit a8a5cd2

Please sign in to comment.