Skip to content

Commit

Permalink
[SPARK-21818][ML][MLLIB] Fix bug of MultivariateOnlineSummarizer.vari…
Browse files Browse the repository at this point in the history
…ance generate negative result

## What changes were proposed in this pull request?

Because of numerical error, MultivariateOnlineSummarizer.variance is possible to generate negative variance.

**This is a serious bug because many algos in MLLib**
**use stddev computed from** `sqrt(variance)`
**it will generate NaN and crash the whole algorithm.**

we can reproduce this bug use the following code:
```
    val summarizer1 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.7)
    val summarizer2 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.4)
    val summarizer3 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.5)
    val summarizer4 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.4)

    val summarizer = summarizer1
      .merge(summarizer2)
      .merge(summarizer3)
      .merge(summarizer4)

    println(summarizer.variance(0))
```
This PR fix the bugs in `mllib.stat.MultivariateOnlineSummarizer.variance` and `ml.stat.SummarizerBuffer.variance`, and several places in `WeightedLeastSquares`

## How was this patch tested?

test cases added.

Author: WeichenXu <WeichenXu123@outlook.com>

Closes #19029 from WeichenXu123/fix_summarizer_var_bug.
  • Loading branch information
WeichenXu123 authored and srowen committed Aug 28, 2017
1 parent 07142cf commit 0456b40
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares {
/**
* Weighted population standard deviation of labels.
*/
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
def bStd: Double = {
// We prevent variance from negative value caused by numerical error.
val variance = math.max(bbSum / wSum - bBar * bBar, 0.0)
math.sqrt(variance)
}

/**
* Weighted mean of (label * features).
Expand Down Expand Up @@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares {
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
std(l) = math.sqrt(aaValues(i) / wSum - aw * aw)
// We prevent variance from negative value caused by numerical error.
std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0))
i += j
j += 1
}
Expand All @@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares {
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
variance(l) = aaValues(i) / wSum - aw * aw
// We prevent variance from negative value caused by numerical error.
variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0)
i += j
j += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
var i = 0
val len = currM2n.length
while (i < len) {
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
var i = 0
val len = currM2n.length
while (i < len) {
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(summarizer.count === 6)
}

test("summarizer buffer zero variance test (SPARK-21818)") {
val summarizer1 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.7)
val summarizer2 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.4)
val summarizer3 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.5)
val summarizer4 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.4)

val summarizer = summarizer1
.merge(summarizer2)
.merge(summarizer3)
.merge(summarizer4)

assert(summarizer.variance(0) >= 0.0)
}

test("summarizer buffer merging summarizer with empty summarizer") {
// If one of two is non-empty, this should return the non-empty summarizer.
// If both of them are empty, then just return the empty summarizer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
}

test ("test zero variance (SPARK-21818)") {
val summarizer1 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.7)
val summarizer2 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.4)
val summarizer3 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.5)
val summarizer4 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.4)

val summarizer = summarizer1
.merge(summarizer2)
.merge(summarizer3)
.merge(summarizer4)

assert(summarizer.variance(0) >= 0.0)
}
}

0 comments on commit 0456b40

Please sign in to comment.