Skip to content

Commit

Permalink
[SPARK-21770][ML] ProbabilisticClassificationModel fix corner case: n…
Browse files Browse the repository at this point in the history
…ormalization of all-zero raw predictions

## What changes were proposed in this pull request?

Fix probabilisticClassificationModel corner case: normalization of all-zero raw predictions, throw IllegalArgumentException with description.

## How was this patch tested?

Test case added.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19106 from WeichenXu123/SPARK-21770.
  • Loading branch information
WeichenXu123 authored and srowen committed Oct 10, 2017
1 parent af8a34c commit 3b5c2a8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
Expand Up @@ -230,21 +230,23 @@ private[ml] object ProbabilisticClassificationModel {
* Normalize a vector of raw predictions to be a multinomial probability vector, in place.
*
* The input raw predictions should be nonnegative.
* The output vector sums to 1, unless the input vector is all-0 (in which case the output is
* all-0 too).
* The output vector sums to 1.
*
* NOTE: This is NOT applicable to all models, only ones which effectively use class
* instance counts for raw predictions.
*
* @throws IllegalArgumentException if the input vector is all-0 or including negative values
*/
def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = {
v.values.foreach(value => require(value >= 0,
"The input raw predictions should be nonnegative."))
val sum = v.values.sum
if (sum != 0) {
var i = 0
val size = v.size
while (i < size) {
v.values(i) /= sum
i += 1
}
require(sum > 0, "Can't normalize the 0-vector.")
var i = 0
val size = v.size
while (i < size) {
v.values(i) /= sum
i += 1
}
}
}
Expand Up @@ -80,6 +80,24 @@ class ProbabilisticClassifierSuite extends SparkFunSuite {
new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1))
}
}

test("normalizeToProbabilitiesInPlace") {
val vec1 = Vectors.dense(1.0, 2.0, 3.0).toDense
ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec1)
assert(vec1 ~== Vectors.dense(1.0 / 6, 2.0 / 6, 3.0 / 6) relTol 1e-3)

// all-0 input test
val vec2 = Vectors.dense(0.0, 0.0, 0.0).toDense
intercept[IllegalArgumentException] {
ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec2)
}

// negative input test
val vec3 = Vectors.dense(1.0, -1.0, 2.0).toDense
intercept[IllegalArgumentException] {
ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec3)
}
}
}

object ProbabilisticClassifierSuite {
Expand Down

0 comments on commit 3b5c2a8

Please sign in to comment.