From 08dbe43bbf961186ee94432a13f9f6cfc221e4a6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 3 Sep 2016 15:26:26 +0100 Subject: [PATCH 1/2] Revise semantics of ProbabilisticClassifierModel thresholds so that classes can only be predicted if they exceed their threshold (meaning no class may be predicted), and otherwise ordering by highest probability, then lowest threshold, then by class index --- .../ProbabilisticClassifier.scala | 27 +++++++----- .../ProbabilisticClassifierSuite.scala | 43 +++++++++++++++---- 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 19df8f7edd43c..a83d98e246fe1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -193,19 +193,24 @@ abstract class ProbabilisticClassificationModel[ /** * Given a vector of class conditional probabilities, select the predicted label. - * This supports thresholds which favor particular labels. - * @return predicted label + * This returns the class, if any, whose probability is equal to or greater than its + * threshold (if specified), and whose probability is highest. If several classes meet + * their thresholds and are equally probable, the one with lower threshold is selected. + * If several have equal thresholds, the one with lower class index is selected. + * + * @return predicted label, or NaN if no label is predicted */ protected def probability2prediction(probability: Vector): Double = { - if (!isDefined(thresholds)) { - probability.argmax + val prob = probability.toArray + if (isDefined(thresholds)) { + val candidates = prob.zip(getThresholds).zipWithIndex.filter { case ((p, t), _) => p >= t } + if (candidates.isEmpty) { + Double.NaN + } else { + candidates.maxBy { case ((p, t), i) => (p, -t, -i) }._2 + } } else { - val thresholds: Array[Double] = getThresholds - val scaledProbability: Array[Double] = - probability.toArray.zip(thresholds).map { case (p, t) => - if (t == 0.0) Double.PositiveInfinity else p / t - } - Vectors.dense(scaledProbability).argmax + prob.zipWithIndex.maxBy { case (p, i) => (p, -i) }._2 } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index b3bd2b3e57b36..7f92e449a0b5a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -36,8 +36,8 @@ final class TestProbabilisticClassificationModel( rawPrediction } - def friendlyPredict(input: Vector): Double = { - predict(input) + def friendlyPredict(input: Double*): Double = { + predict(Vectors.dense(input.toArray)) } } @@ -45,17 +45,44 @@ final class TestProbabilisticClassificationModel( class ProbabilisticClassifierSuite extends SparkFunSuite { test("test thresholding") { - val thresholds = Array(0.5, 0.2) val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - .setThresholds(thresholds) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) + .setThresholds(Array(0.5, 0.2)) + // Both exceed threshold; pick more probable one + assert(testModel.friendlyPredict(0.8, 0.9) === 1.0) + assert(testModel.friendlyPredict(1.0, 0.2) === 0.0) + // Tie; take one with lower threshold + assert(testModel.friendlyPredict(0.8, 0.8) === 1.0) + // Tie at 1 + assert(testModel.friendlyPredict(1.0, 1.0) === 1.0) + // Class 0 more probable but doesn't meet threshold + assert(testModel.friendlyPredict(0.4, 0.3) === 1.0) + // Neither meets threshold + assert(testModel.friendlyPredict(0.4, 0.1).isNaN) + assert(testModel.friendlyPredict(0.0, 0.0).isNaN) } - test("test thresholding not required") { + test("test equals thresholds") { val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) + .setThresholds(Array(0.5, 0.5)) + // Both exceed threshold; pick more probable one + assert(testModel.friendlyPredict(0.8, 0.9) === 1.0) + // Tie; take one with lower class + assert(testModel.friendlyPredict(0.8, 0.8) === 0.0) + assert(testModel.friendlyPredict(0.5, 0.5) === 0.0) + // Neither meets threshold + assert(testModel.friendlyPredict(0.4, 0.1).isNaN) } + + test("test no thresholding") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + // Pick more probable class + assert(testModel.friendlyPredict(1.0, 2.0) === 1.0) + // Tie, pick first class + assert(testModel.friendlyPredict(1.0, 1.0) === 0.0) + assert(testModel.friendlyPredict(0.5, 0.5) === 0.0) + assert(testModel.friendlyPredict(0.0, 0.0) === 0.0) + } + } object ProbabilisticClassifierSuite { From 2fa331ebdf3df31b08e3299c8410c29efa645bf1 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 3 Sep 2016 17:19:29 +0100 Subject: [PATCH 2/2] Update MultinomialLogisticRegression test output to match new threshold meaning --- .../MultinomialLogisticRegressionSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala index 0913fe559c562..9952578d1abd2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala @@ -988,22 +988,22 @@ class MultinomialLogisticRegressionSuite val basePredictions = model.transform(dataset).select("prediction").collect() // should predict all zeros - model.setThresholds(Array(1, 1000, 1000)) + model.setThresholds(Array(0, 1, 1)) val zeroPredictions = model.transform(dataset).select("prediction").collect() assert(zeroPredictions.forall(_.getDouble(0) === 0.0)) // should predict all ones - model.setThresholds(Array(1000, 1, 1000)) + model.setThresholds(Array(1, 0, 1)) val onePredictions = model.transform(dataset).select("prediction").collect() assert(onePredictions.forall(_.getDouble(0) === 1.0)) // should predict all twos - model.setThresholds(Array(1000, 1000, 1)) + model.setThresholds(Array(1, 1, 0)) val twoPredictions = model.transform(dataset).select("prediction").collect() assert(twoPredictions.forall(_.getDouble(0) === 2.0)) // constant threshold scaling is the same as no thresholds - model.setThresholds(Array(1000, 1000, 1000)) + model.setThresholds(Array(0.1, 0.1, 0.1)) val scaledPredictions = model.transform(dataset).select("prediction").collect() assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => scaled.getDouble(0) === base.getDouble(0)