From 6ad7676910609bf9e62524b158d9ff34a9e7c0fe Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 1 Jul 2015 15:26:29 -0700 Subject: [PATCH 01/40] Add hasthresholds --- .../ml/param/shared/SharedParamsCodeGen.scala | 2 ++ .../spark/ml/param/shared/sharedParams.scala | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index b0a6af171c01f..9b2ff8145abaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -47,6 +47,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", isValid = "ParamValidators.inRange(0, 1)"), + ParamDesc[Array[Double]]("thresholds", + "thresholds in multi-class classification prediction, must be array with size of classes."), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index bbe08939b6d75..3356fb34810a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -154,6 +154,21 @@ private[ml] trait HasThreshold extends Params { final def getThreshold: Double = $(threshold) } +/** + * (private[ml]) Trait for shared param thresholds. + */ +private[ml] trait HasThresholds extends Params { + + /** + * Param for thresholds in multi-class classification prediction, must be array with size of classes.. + * @group param + */ + final val thresholds: Param[Array[Double]] = new Param[Array[Double]](this, "thresholds", "thresholds in multi-class classification prediction, must be array with size of classes.") + + /** @group getParam */ + final def getThresholds: Array[Double] = $(thresholds) +} + /** * (private[ml]) Trait for shared param inputCol. */ From 0c9c8a80ee0fb5107901f20d6a4a48313f38ee32 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 1 Jul 2015 16:05:24 -0700 Subject: [PATCH 02/40] Start threading the threshold info through --- .../spark/ml/classification/RandomForestClassifier.scala | 3 ++- .../main/scala/org/apache/spark/ml/tree/treeParams.scala | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d3c67494a31e4..d60ced49ae5e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -122,7 +122,8 @@ object RandomForestClassifier { @Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, - private val _trees: Array[DecisionTreeClassificationModel]) + private val _trees: Array[DecisionTreeClassificationModel], + private val _threshold: Option[Array[Double]]=None) extends PredictionModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index a0c5238d966bf..4c6bbd202093f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -266,7 +266,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed with HasThresholds { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -288,6 +288,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @group setParam */ def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + + /** * Create a Strategy instance to use with the old API. * NOTE: The caller should set impurity and seed. From 4abe7347e1f198e5c08577a91b3301d06dcd5cd4 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 1 Jul 2015 17:08:08 -0700 Subject: [PATCH 03/40] Use thresholds to scale scores in random forest classifcation --- .../spark/ml/classification/RandomForestClassifier.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d60ced49ae5e1..5e68cecec927b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -123,7 +123,7 @@ object RandomForestClassifier { final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - private val _threshold: Option[Array[Double]]=None) + private val _thresholds: Option[Array[Double]]=None) extends PredictionModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -146,7 +146,12 @@ final class RandomForestClassificationModel private[ml] ( val prediction = tree.rootNode.predict(features).toInt votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight } - votes.maxBy(_._2)._1 + // Apply thresholding, or use votes if no thresholding + val scores = _thresholds.map{thresholds => + votes.map{case (index, count) => + (index, count/thresholds(index)) + }}.getOrElse(votes) + scores.maxBy(_._2)._1 } override def copy(extra: ParamMap): RandomForestClassificationModel = { From 7d3172942c01a40a4c6df057788c4d590b7ec256 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 1 Jul 2015 18:06:51 -0700 Subject: [PATCH 04/40] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test) --- .../ml/classification/RandomForestClassifier.scala | 8 +++++--- .../RandomForestClassifierSuite.scala | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 5e68cecec927b..551e33bbfa764 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -95,7 +95,8 @@ final class RandomForestClassifier(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) val oldModel = OldRandomForest.trainClassifier( oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) - RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures) + RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures, + Option($(thresholds))) } override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) @@ -174,7 +175,8 @@ private[ml] object RandomForestClassificationModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestClassifier, - categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = { + categoricalFeatures: Map[Int, Int], + thresholds: Option[Array[Double]]): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -182,6 +184,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees) + new RandomForestClassificationModel(uid, newTrees, thresholds) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 1b6b69c7dc71e..6edc45a68e87b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -103,6 +103,20 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf, categoricalFeatures, numClasses) } + test("test thresholding") { + val rf = new RandomForestClassifier().setNumTrees(3) + val rfThreshold = new RandomForestClassifier().setNumTrees(3) + .setThresholds(Array(0.1, 100000.0, 0.2)) + val input = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), + LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + ) + + assert(false) + } + test("subsampling rate in RandomForest"){ val rdd = orderedLabeledPoints5_20 val categoricalFeatures = Map.empty[Int, Int] From 85e46f5180fbeb8b64537758f6c6eec47ba87a87 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 16:03:29 -0700 Subject: [PATCH 05/40] Move thresholding into Classifier trait --- .../spark/ml/classification/Classifier.scala | 15 +++++++++-- .../RandomForestClassifier.scala | 27 +++++++++---------- .../classification/ClassificationModel.scala | 6 ++--- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 85c097bc64a4f..29ee713eea8a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -153,8 +153,19 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** * Given a vector of raw predictions, select the predicted label. - * This may be overridden to support thresholds which favor particular labels. + * This may be overridden to support custom thresholds which favor particular labels. * @return predicted label */ - protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax + protected def raw2prediction(rawPrediction: Vector): Double = { + val modelScores = rawPrediction.toArray.zipWithIndex + // Apply thresholding, or use votes if no thresholding + val scores = _thresholds.map{thresholds => + modelScores.map{case (score, index) => + (score/thresholds(index), index) + }}.getOrElse(modelScores) + scores.maxBy(_._1)._2 + } + + // TODO: Leave this as undefined to force algs to implement. + protected val _thresholds: Option[Array[Double]]=None } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 551e33bbfa764..94381846604fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -95,7 +95,7 @@ final class RandomForestClassifier(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) val oldModel = OldRandomForest.trainClassifier( oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) - RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures, + RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures, numClasses, Option($(thresholds))) } @@ -124,8 +124,9 @@ object RandomForestClassifier { final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - private val _thresholds: Option[Array[Double]]=None) - extends PredictionModel[Vector, RandomForestClassificationModel] + override val numClasses: Int, + override protected val _thresholds: Option[Array[Double]]=None) + extends ClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -137,26 +138,21 @@ final class RandomForestClassificationModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights - override protected def predict(features: Vector): Double = { + override protected def predictRaw(features: Vector): Vector = { // TODO: Override transform() to broadcast model. SPARK-7127 // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. - val votes = mutable.Map.empty[Int, Double] + val votes = Array[Double](numClasses) _trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt - votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight + votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } - // Apply thresholding, or use votes if no thresholding - val scores = _thresholds.map{thresholds => - votes.map{case (index, count) => - (index, count/thresholds(index)) - }}.getOrElse(votes) - scores.maxBy(_._2)._1 + Vectors.dense(votes) } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(uid, _trees), extra) + copyValues(new RandomForestClassificationModel(uid, _trees, numClasses, _thresholds), extra) } override def toString: String = { @@ -176,6 +172,7 @@ private[ml] object RandomForestClassificationModel { oldModel: OldRandomForestModel, parent: RandomForestClassifier, categoricalFeatures: Map[Int, Int], + numClasses: Int, thresholds: Option[Array[Double]]): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") @@ -184,6 +181,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees, thresholds) + new RandomForestClassificationModel(uid, newTrees, numClasses, thresholds) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 35a0db76f3a8c..4b79819a84564 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -40,10 +40,10 @@ trait ClassificationModel extends Serializable { def predict(testData: RDD[Vector]): RDD[Double] /** - * Predict values for a single data point using the model trained. + * Predict values for the given data using the model trained. * - * @param testData array representing a single data point - * @return predicted category from the trained model + * @param testData RDD representing data points to be predicted + * @return an RDD[Double] where each entry contains the corresponding prediction */ def predict(testData: Vector): Double From 9f28f4bae8c6cedf3c06f63ea6818ece515c40df Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 16:44:15 -0700 Subject: [PATCH 06/40] Fix test compile issues --- .../ml/classification/RandomForestClassifierSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 6edc45a68e87b..f760496494d06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)))) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 1) ParamsSuite.checkParams(model) } @@ -181,7 +181,8 @@ private object RandomForestClassifierSuite { val newModel = rf.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures, + numClasses, None) TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) From 4b96c949a148e2bf6efd25dd472f4fa56802ea29 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 16:57:16 -0700 Subject: [PATCH 07/40] Start adding a classifiersuite --- .../ml/classification/ClassifierSuite.scala | 48 +++++++++++++++++++ .../RandomForestClassifierSuite.scala | 14 ------ 2 files changed, 48 insertions(+), 14 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala new file mode 100644 index 0000000000000..b161e7df3b6be --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +class ClassifierSuite extends SparkFunSuite { + class TestClassificationModel( + override val numClasses: Int, + override protected val _thresholds: Option[Array[Double]]) + extends ClassificationModel[Vector, TestClassificationModel] { + override val uid = "1" + override def copy(extra: org.apache.spark.ml.param.ParamMap): + ClassifierSuite.this.TestClassificationModel = { + null + } + + override def predictRaw(input: Vector) = { + input + } + def friendlyPredict(input: Vector) = { + predict(input) + } + } + + test("test thresholding") { + val threshold = Array(0.5, 0.2) + val testModel = new TestClassificationModel(2, Some(threshold)) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) == 1.0) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) == 0.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index f760496494d06..dd0fb048b3d0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -103,20 +103,6 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf, categoricalFeatures, numClasses) } - test("test thresholding") { - val rf = new RandomForestClassifier().setNumTrees(3) - val rfThreshold = new RandomForestClassifier().setNumTrees(3) - .setThresholds(Array(0.1, 100000.0, 0.2)) - val input = Array( - LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), - LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) - ) - - assert(false) - } - test("subsampling rate in RandomForest"){ val rdd = orderedLabeledPoints5_20 val categoricalFeatures = Map.empty[Int, Int] From cacb802402c136c7d7f3349154ce423814487649 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 17:23:59 -0700 Subject: [PATCH 08/40] Move thresholds around some more (set on model not trainer) --- .../spark/ml/classification/Classifier.scala | 5 ++--- .../classification/LogisticRegression.scala | 1 + .../RandomForestClassifier.scala | 19 ++++++++++--------- .../org/apache/spark/ml/tree/treeParams.scala | 8 ++------ .../ml/classification/ClassifierSuite.scala | 4 +++- 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 29ee713eea8a3..b0384922e0f6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -159,13 +159,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur protected def raw2prediction(rawPrediction: Vector): Double = { val modelScores = rawPrediction.toArray.zipWithIndex // Apply thresholding, or use votes if no thresholding - val scores = _thresholds.map{thresholds => + val scores = Option(getThresholds).map{thresholds => modelScores.map{case (score, index) => (score/thresholds(index), index) }}.getOrElse(modelScores) scores.maxBy(_._1)._2 } - // TODO: Leave this as undefined to force algs to implement. - protected val _thresholds: Option[Array[Double]]=None + protected def getThresholds: Array[Double] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 2e6eedd45ab07..fdf726ad9873d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -238,6 +238,7 @@ class LogisticRegressionModel private[ml] ( /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) + override protected def getThresholds = Array($(threshold)) /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 94381846604fc..4c22863f37301 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasThresholds import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -95,8 +96,7 @@ final class RandomForestClassifier(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) val oldModel = OldRandomForest.trainClassifier( oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) - RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures, numClasses, - Option($(thresholds))) + RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures, numClasses) } override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) @@ -124,13 +124,15 @@ object RandomForestClassifier { final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - override val numClasses: Int, - override protected val _thresholds: Option[Array[Double]]=None) + override val numClasses: Int) extends ClassificationModel[Vector, RandomForestClassificationModel] - with TreeEnsembleModel with Serializable { + with TreeEnsembleModel with HasThresholds with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + /** @group setParam */ + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. @@ -152,7 +154,7 @@ final class RandomForestClassificationModel private[ml] ( } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(uid, _trees, numClasses, _thresholds), extra) + copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } override def toString: String = { @@ -172,8 +174,7 @@ private[ml] object RandomForestClassificationModel { oldModel: OldRandomForestModel, parent: RandomForestClassifier, categoricalFeatures: Map[Int, Int], - numClasses: Int, - thresholds: Option[Array[Double]]): RandomForestClassificationModel = { + numClasses: Int): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -181,6 +182,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees, numClasses, thresholds) + new RandomForestClassificationModel(uid, newTrees, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 4c6bbd202093f..a0c5238d966bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -266,7 +266,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed with HasThresholds { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -288,10 +288,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed wit /** @group setParam */ def setSeed(value: Long): this.type = set(seed, value) - /** @group setParam */ - def setThresholds(value: Array[Double]): this.type = set(thresholds, value) - - /** * Create a Strategy instance to use with the old API. * NOTE: The caller should set impurity and seed. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index b161e7df3b6be..80efe0bd6bc56 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} class ClassifierSuite extends SparkFunSuite { class TestClassificationModel( override val numClasses: Int, - override protected val _thresholds: Option[Array[Double]]) + val thresholds: Option[Array[Double]]) extends ClassificationModel[Vector, TestClassificationModel] { override val uid = "1" override def copy(extra: org.apache.spark.ml.param.ParamMap): @@ -31,6 +31,8 @@ class ClassifierSuite extends SparkFunSuite { null } + override def getThresholds = thresholds + override def predictRaw(input: Vector) = { input } From 9f086dd9c821fff16286287a646f288aa108afe3 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 18:14:51 -0700 Subject: [PATCH 09/40] Test passes again... little fnur --- .../spark/ml/classification/Classifier.scala | 13 ++++-- .../classification/LogisticRegression.scala | 1 - .../RandomForestClassifier.scala | 5 +-- .../org/apache/spark/ml/tree/treeParams.scala | 8 +++- .../ml/classification/ClassifierSuite.scala | 45 ++++++++++--------- .../RandomForestClassifierSuite.scala | 2 +- 6 files changed, 43 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index b0384922e0f6e..ce3192f43d9f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} -import org.apache.spark.ml.param.shared.HasRawPredictionCol +import org.apache.spark.ml.param.shared.{HasRawPredictionCol, HasThresholds} import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} * (private[spark]) Params for classification. */ private[spark] trait ClassifierParams - extends PredictorParams with HasRawPredictionCol { + extends PredictorParams with HasRawPredictionCol with HasThresholds { override protected def validateAndTransformSchema( schema: StructType, @@ -63,6 +63,9 @@ abstract class Classifier[ /** @group setParam */ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] + /** @group setParam */ + def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E] + // TODO: defaultEvaluator (follow-up PR) } @@ -82,6 +85,10 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** @group setParam */ def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] + /** @group setParam */ + def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] + setDefault(thresholds -> null) + /** Number of classes (values which the label can take). */ def numClasses: Int @@ -165,6 +172,4 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur }}.getOrElse(modelScores) scores.maxBy(_._1)._2 } - - protected def getThresholds: Array[Double] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index fdf726ad9873d..2e6eedd45ab07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -238,7 +238,6 @@ class LogisticRegressionModel private[ml] ( /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) - override protected def getThresholds = Array($(threshold)) /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 4c22863f37301..06e6dd07c4378 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -74,6 +74,8 @@ final class RandomForestClassifier(override val uid: String) override def setSeed(value: Long): this.type = super.setSeed(value) + override def setThresholds(value: Array[Double]): this.type = super.set(thresholds, value) + // Parameters from RandomForestParams: override def setNumTrees(value: Int): this.type = super.setNumTrees(value) @@ -130,9 +132,6 @@ final class RandomForestClassificationModel private[ml] ( require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") - /** @group setParam */ - def setThresholds(value: Array[Double]): this.type = set(thresholds, value) - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index a0c5238d966bf..d004ed781c25c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -266,7 +266,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed with HasThresholds { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -279,6 +279,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { setDefault(subsamplingRate -> 1.0) + /** @group setParam */ + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + + /** @group setParam */ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 80efe0bd6bc56..d44968dc48fa8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -18,33 +18,38 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, Vectors} -class ClassifierSuite extends SparkFunSuite { - class TestClassificationModel( - override val numClasses: Int, - val thresholds: Option[Array[Double]]) - extends ClassificationModel[Vector, TestClassificationModel] { - override val uid = "1" - override def copy(extra: org.apache.spark.ml.param.ParamMap): - ClassifierSuite.this.TestClassificationModel = { - null - } - - override def getThresholds = thresholds - - override def predictRaw(input: Vector) = { - input - } - def friendlyPredict(input: Vector) = { - predict(input) - } +final class TestClassificationModel( + override val numClasses: Int) + extends ClassificationModel[Vector, TestClassificationModel] { + override val uid = null + override def copy(extra: org.apache.spark.ml.param.ParamMap): + TestClassificationModel = { + defaultCopy(extra) + } + + override def predictRaw(input: Vector) = { + input } + def friendlyPredict(input: Vector) = { + predict(input) + } +} + + +class ClassifierSuite extends SparkFunSuite { test("test thresholding") { val threshold = Array(0.5, 0.2) - val testModel = new TestClassificationModel(2, Some(threshold)) + val testModel = (new TestClassificationModel(2)).setThresholds(threshold) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) == 1.0) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) == 0.0) } + + test("test thresholding not required") { + val testModel = new TestClassificationModel(2) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) == 1.0) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index dd0fb048b3d0d..a10623db14b8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -168,7 +168,7 @@ private object RandomForestClassifierSuite { // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures, - numClasses, None) + numClasses) TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) From 6c014948b164c397274df8aa7fb9161ce310a3ee Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 18:34:03 -0700 Subject: [PATCH 10/40] Some progress towards unifying threshold and thresholds --- .../spark/ml/classification/Classifier.scala | 4 +- .../spark/ml/param/shared/multiParams.scala | 38 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/shared/multiParams.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index ce3192f43d9f6..47b2bde67f1fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} -import org.apache.spark.ml.param.shared.{HasRawPredictionCol, HasThresholds} +import org.apache.spark.ml.param.shared.{HasRawPredictionCol, HasThresholdInfo} import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} * (private[spark]) Params for classification. */ private[spark] trait ClassifierParams - extends PredictorParams with HasRawPredictionCol with HasThresholds { + extends PredictorParams with HasRawPredictionCol with HasThresholdInfo { override protected def validateAndTransformSchema( schema: StructType, diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/multiParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/multiParams.scala new file mode 100644 index 0000000000000..0560385fb7114 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/multiParams.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +/** + * (private[ml]) Trait for support HasThreshold & HashThresholds in the + * same class. + */ +private[ml] trait HasThresholdInfo extends HasThreshold with HasThresholds { + /** @group getParam + * Return the threshold info if it is set (or any of the defaults are set) */ + def getThresholdInfo: Option[Array[Double]] = { + val thresholds = getThresholds + val threshold = getThreshold + if (thresholds != null) { + Some(thresholds) + } else if (threshold != null) { + Some(Array()) + } else { + None + } + } +} From 192f8e1ae785c39b69e58bf7920336ce4631d424 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 18:34:08 -0700 Subject: [PATCH 11/40] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds" This reverts commit f8538a65a265e86724922aa63b0cc602a3c7603f. --- .../spark/ml/classification/Classifier.scala | 4 +- .../spark/ml/param/shared/multiParams.scala | 38 ------------------- 2 files changed, 2 insertions(+), 40 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/shared/multiParams.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 47b2bde67f1fa..ce3192f43d9f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} -import org.apache.spark.ml.param.shared.{HasRawPredictionCol, HasThresholdInfo} +import org.apache.spark.ml.param.shared.{HasRawPredictionCol, HasThresholds} import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} * (private[spark]) Params for classification. */ private[spark] trait ClassifierParams - extends PredictorParams with HasRawPredictionCol with HasThresholdInfo { + extends PredictorParams with HasRawPredictionCol with HasThresholds { override protected def validateAndTransformSchema( schema: StructType, diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/multiParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/multiParams.scala deleted file mode 100644 index 0560385fb7114..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/multiParams.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.param.shared - -/** - * (private[ml]) Trait for support HasThreshold & HashThresholds in the - * same class. - */ -private[ml] trait HasThresholdInfo extends HasThreshold with HasThresholds { - /** @group getParam - * Return the threshold info if it is set (or any of the defaults are set) */ - def getThresholdInfo: Option[Array[Double]] = { - val thresholds = getThresholds - val threshold = getThreshold - if (thresholds != null) { - Some(thresholds) - } else if (threshold != null) { - Some(Array()) - } else { - None - } - } -} From 6f662a5999a742c11ac359daba2bf1105a733591 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 18:38:46 -0700 Subject: [PATCH 12/40] Add a global default of null for thresholds param --- .../scala/org/apache/spark/ml/classification/Classifier.scala | 1 - .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 3 ++- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index ce3192f43d9f6..87cb298d0b117 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -87,7 +87,6 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** @group setParam */ def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] - setDefault(thresholds -> null) /** Number of classes (values which the label can take). */ def numClasses: Int diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 9b2ff8145abaa..4dc05ed0c9438 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -48,7 +48,8 @@ private[shared] object SharedParamsCodeGen { "threshold in binary classification prediction, in range [0, 1]", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Array[Double]]("thresholds", - "thresholds in multi-class classification prediction, must be array with size of classes."), + "thresholds in multi-class classification prediction, must be array with size of classes.", + Some("null")), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 3356fb34810a0..fbba3038ee280 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -155,7 +155,7 @@ private[ml] trait HasThreshold extends Params { } /** - * (private[ml]) Trait for shared param thresholds. + * (private[ml]) Trait for shared param thresholds (default: null). */ private[ml] trait HasThresholds extends Params { @@ -165,6 +165,8 @@ private[ml] trait HasThresholds extends Params { */ final val thresholds: Param[Array[Double]] = new Param[Array[Double]](this, "thresholds", "thresholds in multi-class classification prediction, must be array with size of classes.") + setDefault(thresholds, null) + /** @group getParam */ final def getThresholds: Array[Double] = $(thresholds) } From f9e0100aaf7f053b68d1f5aece0d885b4a83c3e0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 18:42:39 -0700 Subject: [PATCH 13/40] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down. --- .../scala/org/apache/spark/ml/classification/Classifier.scala | 3 --- .../spark/ml/classification/RandomForestClassifier.scala | 3 +++ .../org/apache/spark/ml/classification/ClassifierSuite.scala | 3 +++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 87cb298d0b117..99e3b1f7c8def 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -85,9 +85,6 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** @group setParam */ def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] - /** @group setParam */ - def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] - /** Number of classes (values which the label can take). */ def numClasses: Int diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 06e6dd07c4378..117f99af2ee20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -132,6 +132,9 @@ final class RandomForestClassificationModel private[ml] ( require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + /** @group setParam */ + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index d44968dc48fa8..bcb252e382f5a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -30,6 +30,9 @@ final class TestClassificationModel( defaultCopy(extra) } + /** @group setParam */ + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + override def predictRaw(input: Vector) = { input } From d2920f3e3aa07f271b69f5956eeed9c1d25f13f0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 19:06:40 -0700 Subject: [PATCH 14/40] Fix creation of vote array --- .../apache/spark/ml/classification/RandomForestClassifier.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 117f99af2ee20..d0c78d01cc450 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -147,7 +147,7 @@ final class RandomForestClassificationModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. - val votes = Array[Double](numClasses) + val votes = new Array[Double](numClasses) _trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight From aac0aebefded4103dc7a310800c3923cff8264e1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 19:07:01 -0700 Subject: [PATCH 15/40] Add a test with thresholding for the RFCS --- .../RandomForestClassifierSuite.scala | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index a10623db14b8c..6b5a6985af5b3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRando import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{SQLContext, DataFrame} /** * Test suite for [[RandomForestClassifier]]. @@ -103,6 +103,32 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf, categoricalFeatures, numClasses) } + test("ensure thresholding works") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), + LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + ) + val rdd = sc.parallelize(arr) + val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4) + val numClasses = 3 + + val thresholds = Array(1.0, 10000.0, 0.01) + val rf = new RandomForestClassifier() + .setNumTrees(2) + .setSeed(12345) + .setThresholds(thresholds) + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(newData) + assert(model.getThresholds == thresholds) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val testData = rdd.toDF + val results = model.transform(testData).select("prediction") + results.count() + } + test("subsampling rate in RandomForest"){ val rdd = orderedLabeledPoints5_20 val categoricalFeatures = Map.empty[Int, Int] From cd532c8eb445614d9f89f5cd78c2387245e571af Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 19:13:43 -0700 Subject: [PATCH 16/40] move setThresholds only to where its used --- .../scala/org/apache/spark/ml/classification/Classifier.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 99e3b1f7c8def..ab34dd6722afe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -63,9 +63,6 @@ abstract class Classifier[ /** @group setParam */ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] - /** @group setParam */ - def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E] - // TODO: defaultEvaluator (follow-up PR) } From bd1e19180b4294ba435cc9a92e798eeb4fcbfaa9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 19:32:13 -0700 Subject: [PATCH 17/40] try and hide threshold but chainges the API so no dice there --- .../spark/ml/classification/Classifier.scala | 6 ++++++ .../ml/classification/LogisticRegression.scala | 15 +++++++++++---- .../classification/RandomForestClassifier.scala | 3 --- .../spark/ml/classification/ClassifierSuite.scala | 3 --- .../classification/LogisticRegressionSuite.scala | 4 ++-- .../spark/ml/classification/OneVsRestSuite.scala | 2 +- 6 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index ab34dd6722afe..87cb298d0b117 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -63,6 +63,9 @@ abstract class Classifier[ /** @group setParam */ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] + /** @group setParam */ + def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E] + // TODO: defaultEvaluator (follow-up PR) } @@ -82,6 +85,9 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** @group setParam */ def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] + /** @group setParam */ + def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] + /** Number of classes (values which the label can take). */ def numClasses: Int diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 2e6eedd45ab07..d97b8c5b765eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -41,7 +41,6 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasThreshold /** * :: Experimental :: @@ -99,8 +98,12 @@ class LogisticRegression(override val uid: String) setDefault(fitIntercept -> true) /** @group setParam */ - def setThreshold(value: Double): this.type = set(threshold, value) - setDefault(threshold -> 0.5) + def setThreshold(value: Double): this.type = set(thresholds, Array(1.0, value)) + setDefault(thresholds -> Array(1.0, 0.5)) + def getThreshold: Double = { + val thresholds = getThresholds + thresholds(1)/thresholds(0) + } override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. @@ -237,7 +240,11 @@ class LogisticRegressionModel private[ml] ( with LogisticRegressionParams { /** @group setParam */ - def setThreshold(value: Double): this.type = set(threshold, value) + def setThreshold(value: Double): this.type = set(thresholds, Array(1.0, value)) + def getThreshold: Double = { + val thresholds = getThresholds + thresholds(1) / thresholds(0) + } /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d0c78d01cc450..e6712b91c4114 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -132,9 +132,6 @@ final class RandomForestClassificationModel private[ml] ( require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") - /** @group setParam */ - def setThresholds(value: Array[Double]): this.type = set(thresholds, value) - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index bcb252e382f5a..d44968dc48fa8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -30,9 +30,6 @@ final class TestClassificationModel( defaultCopy(extra) } - /** @group setParam */ - def setThresholds(value: Array[Double]): this.type = set(thresholds, value) - override def predictRaw(input: Vector) = { input } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index ba8fbee84197c..a52a11e93e5b5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -122,14 +122,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") + model.transform(dataset, model.thresholds -> Array(1.0, 0.0), model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() .map { case Row(pred: Double, prob: Vector) => pred } assert(predNotAllZero.exists(_ !== 0.0)) // Call fit() with new params, and check as many params as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.thresholds -> Array(0.5, 0.4), lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 75cf5bd4ead4f..4fe7a9b20876e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -127,7 +127,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, "copy should handle extra classifier params") - val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1)) + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(1.0, 0.1))) ovrModel.models.foreach { case m: LogisticRegressionModel => require(m.getThreshold === 0.1, "copy should handle extra model params") } From d87e6cace3ac52649092586d5edb318698b43433 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jul 2015 19:32:19 -0700 Subject: [PATCH 18/40] Revert "try and hide threshold but chainges the API so no dice there" This reverts commit 90ef80fbc7b1b55bdce24c107a0a30603ceb6f9a. --- .../spark/ml/classification/Classifier.scala | 6 ------ .../ml/classification/LogisticRegression.scala | 15 ++++----------- .../classification/RandomForestClassifier.scala | 3 +++ .../spark/ml/classification/ClassifierSuite.scala | 3 +++ .../classification/LogisticRegressionSuite.scala | 4 ++-- .../spark/ml/classification/OneVsRestSuite.scala | 2 +- 6 files changed, 13 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 87cb298d0b117..ab34dd6722afe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -63,9 +63,6 @@ abstract class Classifier[ /** @group setParam */ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] - /** @group setParam */ - def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E] - // TODO: defaultEvaluator (follow-up PR) } @@ -85,9 +82,6 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** @group setParam */ def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] - /** @group setParam */ - def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] - /** Number of classes (values which the label can take). */ def numClasses: Int diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d97b8c5b765eb..2e6eedd45ab07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -41,6 +41,7 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol + with HasThreshold /** * :: Experimental :: @@ -98,12 +99,8 @@ class LogisticRegression(override val uid: String) setDefault(fitIntercept -> true) /** @group setParam */ - def setThreshold(value: Double): this.type = set(thresholds, Array(1.0, value)) - setDefault(thresholds -> Array(1.0, 0.5)) - def getThreshold: Double = { - val thresholds = getThresholds - thresholds(1)/thresholds(0) - } + def setThreshold(value: Double): this.type = set(threshold, value) + setDefault(threshold -> 0.5) override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. @@ -240,11 +237,7 @@ class LogisticRegressionModel private[ml] ( with LogisticRegressionParams { /** @group setParam */ - def setThreshold(value: Double): this.type = set(thresholds, Array(1.0, value)) - def getThreshold: Double = { - val thresholds = getThresholds - thresholds(1) / thresholds(0) - } + def setThreshold(value: Double): this.type = set(threshold, value) /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index e6712b91c4114..d0c78d01cc450 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -132,6 +132,9 @@ final class RandomForestClassificationModel private[ml] ( require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + /** @group setParam */ + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index d44968dc48fa8..bcb252e382f5a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -30,6 +30,9 @@ final class TestClassificationModel( defaultCopy(extra) } + /** @group setParam */ + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + override def predictRaw(input: Vector) = { input } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a52a11e93e5b5..ba8fbee84197c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -122,14 +122,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.thresholds -> Array(1.0, 0.0), model.probabilityCol -> "myProb") + model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() .map { case Row(pred: Double, prob: Vector) => pred } assert(predNotAllZero.exists(_ !== 0.0)) // Call fit() with new params, and check as many params as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.thresholds -> Array(0.5, 0.4), + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 4fe7a9b20876e..75cf5bd4ead4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -127,7 +127,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, "copy should handle extra classifier params") - val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(1.0, 0.1))) + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1)) ovrModel.models.foreach { case m: LogisticRegressionModel => require(m.getThreshold === 0.1, "copy should handle extra model params") } From ffc075740d83ee7ab651476d0b345e092b7c903b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 Jul 2015 12:15:53 -0700 Subject: [PATCH 19/40] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions --- .../spark/ml/classification/Classifier.scala | 25 ++++++++++++++++++- .../classification/LogisticRegression.scala | 1 - 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index ab34dd6722afe..154ac4661025f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} * (private[spark]) Params for classification. */ private[spark] trait ClassifierParams - extends PredictorParams with HasRawPredictionCol with HasThresholds { + extends PredictorParams with HasRawPredictionCol with HasThresholds with HasThreshold { override protected def validateAndTransformSchema( schema: StructType, @@ -41,6 +41,29 @@ private[spark] trait ClassifierParams val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT) } + + /** + * Customized version of getThreshold that looks at both threshold & thresholds param. + * The priority order is thresholds assigned param, threshold assigned param + * thresholds default value, threshold default value. + * When converting from threshold to thresholds the threshold for class 0 will be 0.5 + * and the threshold for class 1 will be the assigned threshold value. + **/ + override protected def getThresholds: Array[Double] = { + + } + /** + * Customized version of getThreshold that looks at both threshold & thresholds param. + * The priority order is threshold assigned param, thresholds assigned param + * threshold default value, thresholds default value. + * When converting from thresholds to threshold the threshold will be the ratio between + * class 1 and class 0. + **/ + override protected def getThreshold: Double = { + if (isDefined(threshold)) { + getThreshold + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 2e6eedd45ab07..a79367be322ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -41,7 +41,6 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasThreshold /** * :: Experimental :: From 63e6137e988f9383e9cb07748b2a9c15bd258854 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 Jul 2015 20:10:13 -0700 Subject: [PATCH 20/40] Allow us to override the get methods selectively --- .../ml/param/shared/SharedParamsCodeGen.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 4dc05ed0c9438..3bf51b40195d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -46,10 +46,10 @@ private[shared] object SharedParamsCodeGen { Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", - isValid = "ParamValidators.inRange(0, 1)"), + isValid = "ParamValidators.inRange(0, 1)", finalMethods=false), ParamDesc[Array[Double]]("thresholds", "thresholds in multi-class classification prediction, must be array with size of classes.", - Some("null")), + Some("null"), finalMethods=false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), @@ -78,7 +78,8 @@ private[shared] object SharedParamsCodeGen { name: String, doc: String, defaultValueStr: Option[String] = None, - isValid: String = "") { + isValid: String = "", + finalMethods: Boolean=true) { require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") require(doc.nonEmpty) // TODO: more rigorous on doc @@ -135,6 +136,11 @@ private[shared] object SharedParamsCodeGen { } else { "" } + val methodStr = if (param.finalMethods) { + "final def" + } else { + "def" + } s""" |/** @@ -149,7 +155,7 @@ private[shared] object SharedParamsCodeGen { | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) |$setDefault | /** @group getParam */ - | final def get$Name: $T = $$($name) + | $methodStr get$Name: $T = $$($name) |} |""".stripMargin } From 622ad4b6b7c00bc923fb778e47c1e8d63c7c96ef Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 Jul 2015 20:15:24 -0700 Subject: [PATCH 21/40] Update the sharedParams --- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index fbba3038ee280..5d86057ea4e92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -151,7 +151,7 @@ private[ml] trait HasThreshold extends Params { final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) /** @group getParam */ - final def getThreshold: Double = $(threshold) + def getThreshold: Double = $(threshold) } /** @@ -168,7 +168,7 @@ private[ml] trait HasThresholds extends Params { setDefault(thresholds, null) /** @group getParam */ - final def getThresholds: Array[Double] = $(thresholds) + def getThresholds: Array[Double] = $(thresholds) } /** From ab237ad7906d3f739931a89068b7dabd53f289a1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 Jul 2015 20:34:55 -0700 Subject: [PATCH 22/40] Since hasthreshold/hasthresholds is in root classifier now --- .../spark/ml/classification/Classifier.scala | 30 ++++++++++++++----- .../RandomForestClassifier.scala | 3 +- .../org/apache/spark/ml/tree/treeParams.scala | 4 +-- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 154ac4661025f..4398bd2da746b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} -import org.apache.spark.ml.param.shared.{HasRawPredictionCol, HasThresholds} +import org.apache.spark.ml.param.shared.{HasRawPredictionCol, HasThreshold, HasThresholds} import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame @@ -45,23 +45,37 @@ private[spark] trait ClassifierParams /** * Customized version of getThreshold that looks at both threshold & thresholds param. * The priority order is thresholds assigned param, threshold assigned param - * thresholds default value, threshold default value. + * then thresholds default value. * When converting from threshold to thresholds the threshold for class 0 will be 0.5 * and the threshold for class 1 will be the assigned threshold value. **/ - override protected def getThresholds: Array[Double] = { - + override def getThresholds: Array[Double] = { + def thresholdToThresholds(threshold: Double): Array[Double] = { + Array[Double](0.5, threshold) + } + if (isDefined(thresholds) || !isDefined(threshold)) { + super.getThresholds + } else { + thresholdToThresholds(getThreshold) + } } /** * Customized version of getThreshold that looks at both threshold & thresholds param. * The priority order is threshold assigned param, thresholds assigned param - * threshold default value, thresholds default value. + * then the threshold default value. * When converting from thresholds to threshold the threshold will be the ratio between * class 1 and class 0. **/ - override protected def getThreshold: Double = { - if (isDefined(threshold)) { - getThreshold + override def getThreshold(): Double = { + def thresholdsToThreshold(thresholds: Array[Double]): Double = { + assert(thresholds.size == 2, "Attempting to use threshold array for binary classification, size " + + "must be 2 instead of " + thresholds.size) + thresholds(1)/thresholds(0) + } + if (isDefined(threshold) || !isDefined(thresholds)) { + super.getThreshold + } else { + thresholdsToThreshold(getThresholds) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d0c78d01cc450..2bcef4a122139 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -22,7 +22,6 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.HasThresholds import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -128,7 +127,7 @@ final class RandomForestClassificationModel private[ml] ( private val _trees: Array[DecisionTreeClassificationModel], override val numClasses: Int) extends ClassificationModel[Vector, RandomForestClassificationModel] - with TreeEnsembleModel with HasThresholds with Serializable { + with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d004ed781c25c..cb9458a5d084d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -266,7 +266,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed with HasThresholds { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. From 55e40056c509529db7da7fe33606082e4cca4ccd Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 Jul 2015 20:46:38 -0700 Subject: [PATCH 23/40] scala style fixes --- .../org/apache/spark/ml/classification/Classifier.scala | 8 ++++---- .../spark/ml/param/shared/SharedParamsCodeGen.scala | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 4398bd2da746b..1fdac4a956195 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -48,7 +48,7 @@ private[spark] trait ClassifierParams * then thresholds default value. * When converting from threshold to thresholds the threshold for class 0 will be 0.5 * and the threshold for class 1 will be the assigned threshold value. - **/ + */ override def getThresholds: Array[Double] = { def thresholdToThresholds(threshold: Double): Array[Double] = { Array[Double](0.5, threshold) @@ -65,11 +65,11 @@ private[spark] trait ClassifierParams * then the threshold default value. * When converting from thresholds to threshold the threshold will be the ratio between * class 1 and class 0. - **/ + */ override def getThreshold(): Double = { def thresholdsToThreshold(thresholds: Array[Double]): Double = { - assert(thresholds.size == 2, "Attempting to use threshold array for binary classification, size " + - "must be 2 instead of " + thresholds.size) + assert(thresholds.size == 2, "Attempting to use threshold array for binary classification, " + + "size must be 2 instead of " + thresholds.size) thresholds(1)/thresholds(0) } if (isDefined(threshold) || !isDefined(thresholds)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 3bf51b40195d6..6bd2c865117fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -46,10 +46,10 @@ private[shared] object SharedParamsCodeGen { Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", - isValid = "ParamValidators.inRange(0, 1)", finalMethods=false), + isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), ParamDesc[Array[Double]]("thresholds", "thresholds in multi-class classification prediction, must be array with size of classes.", - Some("null"), finalMethods=false), + Some("null"), finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), @@ -79,7 +79,7 @@ private[shared] object SharedParamsCodeGen { doc: String, defaultValueStr: Option[String] = None, isValid: String = "", - finalMethods: Boolean=true) { + finalMethods: Boolean = true) { require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") require(doc.nonEmpty) // TODO: more rigorous on doc From 686484ac98c2d8e2fa11e4cfe4e01d16e019176d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 Jul 2015 21:09:05 -0700 Subject: [PATCH 24/40] Add explicit return types even though just test --- .../org/apache/spark/ml/classification/ClassifierSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index bcb252e382f5a..b72e2d8234fcc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -33,10 +33,10 @@ final class TestClassificationModel( /** @group setParam */ def setThresholds(value: Array[Double]): this.type = set(thresholds, value) - override def predictRaw(input: Vector) = { + override def predictRaw(input: Vector): Vector = { input } - def friendlyPredict(input: Vector) = { + def friendlyPredict(input: Vector): Vector = { predict(input) } } From fdb448301c622e346c15a77624f9dc38bc9c3bf0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 Jul 2015 21:44:07 -0700 Subject: [PATCH 25/40] Use ClassifierParams as the head --- .../scala/org/apache/spark/ml/tree/treeParams.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index cb9458a5d084d..8f2f8a84ca6da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree +import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} @@ -181,7 +182,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { /** * Parameters for Decision Tree-based classification algorithms. */ -private[ml] trait TreeClassifierParams extends Params { +private[ml] trait TreeClassifierParams extends ClassifierParams { /** * Criterion used for information gain calculation (case-insensitive). @@ -196,6 +197,9 @@ private[ml] trait TreeClassifierParams extends Params { setDefault(impurity -> "gini") + /** @group setParam */ + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -279,10 +283,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { setDefault(subsamplingRate -> 1.0) - /** @group setParam */ - def setThresholds(value: Array[Double]): this.type = set(thresholds, value) - - /** @group setParam */ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) From 5d4b46d5fcf30c373f75a2c859c7a933e36c2c28 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 7 Jul 2015 14:03:28 -0700 Subject: [PATCH 26/40] Fix return type, I need more coffee.... --- .../org/apache/spark/ml/classification/ClassifierSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index b72e2d8234fcc..2dadb2bd3f5de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -36,7 +36,7 @@ final class TestClassificationModel( override def predictRaw(input: Vector): Vector = { input } - def friendlyPredict(input: Vector): Vector = { + def friendlyPredict(input: Vector): Double = { predict(input) } } From db0609398c8d541f747897a82ce26ccfebbf77c2 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 7 Jul 2015 16:11:21 -0700 Subject: [PATCH 27/40] Add a scala RandomForestClassifierSuite test based on corresponding python test --- .../RandomForestClassifierSuite.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 6b5a6985af5b3..e1c5373a935ea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -147,6 +147,28 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf2, categoricalFeatures, numClasses) } + test("simple two input training test") { + val trainingInput = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.sparse(1, Array[Int](), Array[Double]()))) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + val trainingData = TreeTests.setMetadata(sc.parallelize(trainingInput), + categoricalFeatures, numClasses) + val rf = new RandomForestClassifier() + .setNumTrees(2) + .setMaxDepth(2) + .setSeed(42) + .fit(trainingData) + val testInput = Seq( + LabeledPoint(0.0, Vectors.dense(-1.0)), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(1.0))) + ) + val testData = sqlContext.createDataFrame(testInput) + val results = rf.transform(testData).select("prediction").map(_.getDouble(0)) + assert(results.collect() === Array(0.0, 1.0)) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// From 4d3081fdef882ecefb4c287a235ad3b1e09fa2e1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 7 Jul 2015 16:14:17 -0700 Subject: [PATCH 28/40] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok) --- .../spark/ml/classification/RandomForestClassifierSuite.scala | 2 +- python/pyspark/ml/classification.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index e1c5373a935ea..18f48671703b1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -156,7 +156,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val trainingData = TreeTests.setMetadata(sc.parallelize(trainingInput), categoricalFeatures, numClasses) val rf = new RandomForestClassifier() - .setNumTrees(2) + .setNumTrees(3) .setMaxDepth(2) .setSeed(42) .fit(trainingData) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 89117e492846b..7df61e75e2fc3 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -299,7 +299,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) - >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42) + >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) >>> allclose(model.treeWeights, [1.0, 1.0]) True From ac187acd2b4b141aa20367323e264c71a1acab8f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 8 Jul 2015 15:10:39 -0700 Subject: [PATCH 29/40] Update the tree weights vector used for comparision (thought i had this commited already, oops) --- python/pyspark/ml/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 7df61e75e2fc3..5a82bc286d1e8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -301,7 +301,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> td = si_model.transform(df) >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) - >>> allclose(model.treeWeights, [1.0, 1.0]) + >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction From 3f94d6298c62c85b4632fecac2e97486825ab9e4 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 1 Aug 2015 15:47:54 -0700 Subject: [PATCH 30/40] move the thresholding around a bunch based on the design doc --- .../spark/ml/classification/Classifier.scala | 41 +------------------ .../classification/LogisticRegression.scala | 30 ++++++++++++-- .../ProbabilisticClassifier.scala | 21 +++++++--- .../RandomForestClassifier.scala | 18 ++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 10 ++--- .../JavaLogisticRegressionSuite.java | 7 +++- .../ml/classification/ClassifierSuite.scala | 7 +++- .../LogisticRegressionSuite.scala | 5 ++- .../ml/classification/OneVsRestSuite.scala | 2 +- 9 files changed, 78 insertions(+), 63 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 507dbe6e2e677..4daff43de0ce6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} -import org.apache.spark.ml.param.shared.{HasRawPredictionCol, HasThreshold, HasThresholds} +import org.apache.spark.ml.param.shared.{HasRawPredictionCol} import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} * (private[spark]) Params for classification. */ private[spark] trait ClassifierParams - extends PredictorParams with HasRawPredictionCol with HasThresholds with HasThreshold { + extends PredictorParams with HasRawPredictionCol { override protected def validateAndTransformSchema( schema: StructType, @@ -41,43 +41,6 @@ private[spark] trait ClassifierParams val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT) } - - /** - * Customized version of getThreshold that looks at both threshold & thresholds param. - * The priority order is thresholds assigned param, threshold assigned param - * then thresholds default value. - * When converting from threshold to thresholds the threshold for class 0 will be 0.5 - * and the threshold for class 1 will be the assigned threshold value. - */ - override def getThresholds: Array[Double] = { - def thresholdToThresholds(threshold: Double): Array[Double] = { - Array[Double](0.5, threshold) - } - if (isDefined(thresholds) || !isDefined(threshold)) { - super.getThresholds - } else { - thresholdToThresholds(getThreshold) - } - } - /** - * Customized version of getThreshold that looks at both threshold & thresholds param. - * The priority order is threshold assigned param, thresholds assigned param - * then the threshold default value. - * When converting from thresholds to threshold the threshold will be the ratio between - * class 1 and class 0. - */ - override def getThreshold(): Double = { - def thresholdsToThreshold(thresholds: Array[Double]): Double = { - assert(thresholds.size == 2, "Attempting to use threshold array for binary classification, " + - "size must be 2 instead of " + thresholds.size) - thresholds(1)/thresholds(0) - } - if (isDefined(threshold) || !isDefined(thresholds)) { - super.getThreshold - } else { - thresholdsToThreshold(getThresholds) - } - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index c8764e639e575..35d8dd47888f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -111,8 +111,20 @@ class LogisticRegression(override val uid: String) setDefault(standardization -> true) /** @group setParam */ - def setThreshold(value: Double): this.type = set(threshold, value) - setDefault(threshold -> 0.5) + def setThreshold(value: Double): this.type = set(thresholds, Array(0.5, value)) + setDefault(thresholds -> Array(0.5, 0.5)) + + /** + * Convert the thresholds to a threshold + * p/a > (1-p)/b + * p*(b/a) + p > 1 + * p > 1 / [1 + b/a] + * threshold = 1 / [1 + b/a] + */ + def getThreshold() = { + val thresholdValues = $(thresholds).toArray + 1 / (1 + thresholdValues(1) / thresholdValues(0)) + } override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. @@ -271,7 +283,7 @@ class LogisticRegressionModel private[ml] ( with LogisticRegressionParams { /** @group setParam */ - def setThreshold(value: Double): this.type = set(threshold, value) + def setThreshold(value: Double): this.type = set(thresholds, Array(0.5, value)) /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { @@ -286,6 +298,18 @@ class LogisticRegressionModel private[ml] ( override val numClasses: Int = 2 + /** + * Convert the thresholds to a threshold + * p/a > (1-p)/b + * p*(b/a) + p > 1 + * p > 1 / [1 + b/a] + * threshold = 1 / [1 + b/a] + */ + def getThreshold() = { + val thresholdValues = $(thresholds).toArray + 1 / (1 + thresholdValues(1) / thresholdValues(0)) + } + /** * Predict label for the given feature vector. * The behavior of this can be adjusted using [[threshold]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index dad451108626d..a80256dc45477 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, DataType, StructType} @@ -29,8 +29,7 @@ import org.apache.spark.sql.types.{DoubleType, DataType, StructType} * (private[classification]) Params for probabilistic classification. */ private[classification] trait ProbabilisticClassifierParams - extends ClassifierParams with HasProbabilityCol { - + extends ClassifierParams with HasProbabilityCol with HasThresholds { override protected def validateAndTransformSchema( schema: StructType, fitting: Boolean, @@ -75,7 +74,8 @@ private[spark] abstract class ProbabilisticClassifier[ private[spark] abstract class ProbabilisticClassificationModel[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]] - extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { + extends ClassificationModel[FeaturesType, M] + with ProbabilisticClassifierParams { /** @group setParam */ def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] @@ -170,8 +170,17 @@ private[spark] abstract class ProbabilisticClassificationModel[ /** * Given a vector of class conditional probabilities, select the predicted label. - * This may be overridden to support thresholds which favor particular labels. + * This supports thresholds which favor particular labels. * @return predicted label */ - protected def probability2prediction(probability: Vector): Double = probability.argmax + protected def probability2prediction(probability: Vector): Double = { + if (!isDefined(thresholds)) { + probability.argmax + } else { + val thresholds: Array[Double] = getThresholds.toArray + val normalizedProbability: Array[Double] = probability.toArray.zip(thresholds) + .map{case (x, y) => x / y} + Vectors.dense(normalizedProbability).argmax + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index f407a6bb3f4e6..ab050f0ceb9e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType */ @Experimental final class RandomForestClassifier(override val uid: String) - extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] + extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("rfc")) @@ -129,7 +129,7 @@ final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], override val numClasses: Int) - extends ProbabalisticClassificationModel[Vector, RandomForestClassificationModel] + extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -164,14 +164,24 @@ final class RandomForestClassificationModel private[ml] ( // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. val votes = new Array[Double](numClasses) - val weight = 1.0/_.trees.view.length.toDouble // For now all trees have the same weight _trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt - votes(prediction) = votes(prediction) + weight + votes(prediction) = votes(prediction) + 1.0 // weight=1.0 } Vectors.dense(votes) } + override def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + val numTrees = _trees.view.size.toDouble + val values = rawPrediction.toArray // Since we are a dense vector not a copy + var i = 0 + while (i < values.size) { + values(i) = values(i) / numTrees + i += 1 + } + rawPrediction + } + override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 8f2f8a84ca6da..e90f25bae507b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -197,9 +197,6 @@ private[ml] trait TreeClassifierParams extends ClassifierParams { setDefault(impurity -> "gini") - /** @group setParam */ - def setThresholds(value: Array[Double]): this.type = set(thresholds, value) - /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -270,7 +267,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed with HasThresholds { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -283,6 +280,9 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { setDefault(subsamplingRate -> 1.0) + /** @group setParam */ + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + /** @group setParam */ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index f75e024a713ee..5e6f529696b03 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -98,7 +98,9 @@ public void logisticRegressionWithSetters() { assert(r.getDouble(0) == 0.0); } // Call transform with params, and check that the params worked. - model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) + double[] thresholds = {0.5, 0.0}; + model.transform(dataset, + model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; @@ -108,8 +110,9 @@ public void logisticRegressionWithSetters() { assert(foundNonZero); // Call fit() with new params, and check as many params as we can. + double[] thresholds2 = {0.5, 0.4}; LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + lr.thresholds().w(thresholds), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 2dadb2bd3f5de..80b0f5ad99da6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} final class TestClassificationModel( override val numClasses: Int) - extends ClassificationModel[Vector, TestClassificationModel] { + extends ProbabilisticClassificationModel[Vector, TestClassificationModel] { override val uid = null override def copy(extra: org.apache.spark.ml.param.ParamMap): TestClassificationModel = { @@ -36,6 +36,11 @@ final class TestClassificationModel( override def predictRaw(input: Vector): Vector = { input } + + override def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction + } + def friendlyPredict(input: Vector): Double = { predict(input) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b7dd44753896a..7f6401ebc4094 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -123,14 +123,15 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") + model.transform(dataset, model.thresholds -> Array(0.5, 0.0), model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() .map { case Row(pred: Double, prob: Vector) => pred } assert(predNotAllZero.exists(_ !== 0.0)) // Call fit() with new params, and check as many params as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, + lr.thresholds -> Array(0.5, 0.4), lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 3775292f6dca7..a5f20c82346c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -151,7 +151,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, "copy should handle extra classifier params") - val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1)) + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.5, 0.1))) ovrModel.models.foreach { case m: LogisticRegressionModel => require(m.getThreshold === 0.1, "copy should handle extra model params") } From 3ebb2b5c729563c9e60f2c4c9e32f1b0f6e2b069 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 1 Aug 2015 15:49:21 -0700 Subject: [PATCH 31/40] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic --- ...assifierSuite.scala => ProbabilisticClassifierSuite.scala} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename mllib/src/test/scala/org/apache/spark/ml/classification/{ClassifierSuite.scala => ProbabilisticClassifierSuite.scala} (95%) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala similarity index 95% rename from mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index 80b0f5ad99da6..e9d5c702939bb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, Vectors} -final class TestClassificationModel( +final class TestProbabilisticClassificationModel( override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, TestClassificationModel] { override val uid = null @@ -47,7 +47,7 @@ final class TestClassificationModel( } -class ClassifierSuite extends SparkFunSuite { +class ProbabilisticClassifierSuite extends SparkFunSuite { test("test thresholding") { val threshold = Array(0.5, 0.2) From 928fcf20dff7c5d7fd5f8a5f1e78a0d12528da3c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 1 Aug 2015 15:58:49 -0700 Subject: [PATCH 32/40] Fix a java test bug, remove some unecessary changes --- .../scala/org/apache/spark/ml/classification/Classifier.scala | 4 ++-- .../spark/ml/classification/RandomForestClassifier.scala | 2 +- .../spark/mllib/classification/ClassificationModel.scala | 2 +- .../spark/ml/classification/JavaLogisticRegressionSuite.java | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 4daff43de0ce6..581d8fa7749be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} -import org.apache.spark.ml.param.shared.{HasRawPredictionCol} +import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame @@ -153,7 +153,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** * Given a vector of raw predictions, select the predicted label. - * This may be overridden to support custom thresholds which favor particular labels. + * This may be overridden to support thresholds which favor particular labels. * @return predicted label */ protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ab050f0ceb9e3..660c6d55fee47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -166,7 +166,7 @@ final class RandomForestClassificationModel private[ml] ( val votes = new Array[Double](numClasses) _trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt - votes(prediction) = votes(prediction) + 1.0 // weight=1.0 + votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } Vectors.dense(votes) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 80040272ef25c..ba73024e3c04d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -41,7 +41,7 @@ trait ClassificationModel extends Serializable { def predict(testData: RDD[Vector]): RDD[Double] /** - * Predict values for the given data using the model trained. + * Predict values for a single data point using the model trained. * * @param testData array representing a single data point * @return predicted category from the trained model diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 5e6f529696b03..7e69c919b0fa0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -112,7 +112,7 @@ public void logisticRegressionWithSetters() { // Call fit() with new params, and check as many params as we can. double[] thresholds2 = {0.5, 0.4}; LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.thresholds().w(thresholds), lr.probabilityCol().w("theProb")); + lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); From b514d42ad26484ae393f5d94abbe7dd82540e6ab Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 1 Aug 2015 16:04:19 -0700 Subject: [PATCH 33/40] Add explicit types to public methods, fix long line --- .../apache/spark/ml/classification/LogisticRegression.scala | 4 ++-- .../spark/ml/classification/LogisticRegressionSuite.scala | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 35d8dd47888f1..e254bc14caffe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -121,7 +121,7 @@ class LogisticRegression(override val uid: String) * p > 1 / [1 + b/a] * threshold = 1 / [1 + b/a] */ - def getThreshold() = { + def getThreshold(): Double = { val thresholdValues = $(thresholds).toArray 1 / (1 + thresholdValues(1) / thresholdValues(0)) } @@ -305,7 +305,7 @@ class LogisticRegressionModel private[ml] ( * p > 1 / [1 + b/a] * threshold = 1 / [1 + b/a] */ - def getThreshold() = { + def getThreshold(): Double = { val thresholdValues = $(thresholds).toArray 1 / (1 + thresholdValues(1) / thresholdValues(0)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 7f6401ebc4094..ac3ae19b0d1e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -123,7 +123,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.thresholds -> Array(0.5, 0.0), model.probabilityCol -> "myProb") + model.transform(dataset, model.thresholds -> Array(0.5, 0.0), + model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() .map { case Row(pred: Double, prob: Vector) => pred } From 4ceeb9ee0acc41955297c6f3b9c1386ccae4f672 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 1 Aug 2015 16:23:05 -0700 Subject: [PATCH 34/40] revert the changes to random forest :( --- .../examples/ml/JavaSimpleParamsExample.java | 3 +- .../examples/ml/SimpleParamsExample.scala | 2 +- .../RandomForestClassifier.scala | 20 +------- .../RandomForestClassifierSuite.scala | 50 +------------------ 4 files changed, 6 insertions(+), 69 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index dac649d1d5ae6..30f2c06ce571a 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -77,7 +77,8 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + double thresholds[] = {0.5, 0.55}; + paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 58d7b67674ff7..49955a0750b33 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -70,7 +70,7 @@ object SimpleParamsExample { // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. - paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.5, 0.55)) // Specify multiple Params. // One can also combine ParamMaps. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 6d39831617e22..0c7eb4a662fdb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType */ @Experimental final class RandomForestClassifier(override val uid: String) - extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] + extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("rfc")) @@ -75,8 +75,6 @@ final class RandomForestClassifier(override val uid: String) override def setSeed(value: Long): this.type = super.setSeed(value) - override def setThresholds(value: Array[Double]): this.type = super.set(thresholds, value) - // Parameters from RandomForestParams: override def setNumTrees(value: Int): this.type = super.setNumTrees(value) @@ -129,14 +127,11 @@ final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], override val numClasses: Int) - extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] + extends ClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") - /** @group setParam */ - def setThresholds(value: Array[Double]): this.type = set(thresholds, value) - /** * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees @@ -171,17 +166,6 @@ final class RandomForestClassificationModel private[ml] ( Vectors.dense(votes) } - override def raw2probabilityInPlace(rawPrediction: Vector): Vector = { - val numTrees = _trees.view.size.toDouble - val values = rawPrediction.toArray // Since we are a dense vector not a copy - var i = 0 - while (i < values.size) { - values(i) = values(i) / numTrees - i += 1 - } - rawPrediction - } - override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 33020838b8554..dbb2577c6204d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRando import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. @@ -103,32 +103,6 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf, categoricalFeatures, numClasses) } - test("ensure thresholding works") { - val arr = Array( - LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), - LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) - ) - val rdd = sc.parallelize(arr) - val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4) - val numClasses = 3 - - val thresholds = Array(1.0, 10000.0, 0.01) - val rf = new RandomForestClassifier() - .setNumTrees(2) - .setSeed(12345) - .setThresholds(thresholds) - val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) - val model = rf.fit(newData) - assert(model.getThresholds == thresholds) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - val testData = rdd.toDF - val results = model.transform(testData).select("prediction") - results.count() - } - test("subsampling rate in RandomForest"){ val rdd = orderedLabeledPoints5_20 val categoricalFeatures = Map.empty[Int, Int] @@ -147,28 +121,6 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf2, categoricalFeatures, numClasses) } - test("simple two input training test") { - val trainingInput = Seq( - LabeledPoint(1.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.sparse(1, Array[Int](), Array[Double]()))) - val categoricalFeatures = Map.empty[Int, Int] - val numClasses = 2 - val trainingData = TreeTests.setMetadata(sc.parallelize(trainingInput), - categoricalFeatures, numClasses) - val rf = new RandomForestClassifier() - .setNumTrees(3) - .setMaxDepth(2) - .setSeed(42) - .fit(trainingData) - val testInput = Seq( - LabeledPoint(0.0, Vectors.dense(-1.0)), - LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(1.0))) - ) - val testData = sqlContext.createDataFrame(testInput) - val results = rf.transform(testData).select("prediction").map(_.getDouble(0)) - assert(results.collect() === Array(0.0, 1.0)) - } - ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// From dcdc48a0a5190eb80e202ab718cb581d509689a9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 1 Aug 2015 19:59:06 -0700 Subject: [PATCH 35/40] CR feedback and fixed the renamed test --- .../main/scala/org/apache/spark/ml/tree/treeParams.scala | 7 ++----- .../ml/classification/JavaLogisticRegressionSuite.java | 4 ++-- .../ml/classification/ProbabilisticClassifierSuite.scala | 8 ++++---- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index e90f25bae507b..e817090f8a16b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -182,7 +182,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { /** * Parameters for Decision Tree-based classification algorithms. */ -private[ml] trait TreeClassifierParams extends ClassifierParams { +private[ml] trait TreeClassifierParams extends Params { /** * Criterion used for information gain calculation (case-insensitive). @@ -267,7 +267,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed with HasThresholds { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -280,9 +280,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed wit setDefault(subsamplingRate -> 1.0) - /** @group setParam */ - def setThresholds(value: Array[Double]): this.type = set(thresholds, value) - /** @group setParam */ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 7e69c919b0fa0..615f99f0d2107 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -99,8 +99,8 @@ public void logisticRegressionWithSetters() { } // Call transform with params, and check that the params worked. double[] thresholds = {0.5, 0.0}; - model.transform(dataset, - model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) + model.transform( + dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index e9d5c702939bb..6a466aa76cb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -23,10 +23,10 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} final class TestProbabilisticClassificationModel( override val numClasses: Int) - extends ProbabilisticClassificationModel[Vector, TestClassificationModel] { + extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] { override val uid = null override def copy(extra: org.apache.spark.ml.param.ParamMap): - TestClassificationModel = { + TestProbabilisticClassificationModel = { defaultCopy(extra) } @@ -51,13 +51,13 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { test("test thresholding") { val threshold = Array(0.5, 0.2) - val testModel = (new TestClassificationModel(2)).setThresholds(threshold) + val testModel = (new TestProbabilisticClassificationModel(2)).setThresholds(threshold) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) == 1.0) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) == 0.0) } test("test thresholding not required") { - val testModel = new TestClassificationModel(2) + val testModel = new TestProbabilisticClassificationModel(2) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) == 1.0) } } From 4a378389e772ff3ab7a8a1e521cadb251af45c0e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 1 Aug 2015 20:00:31 -0700 Subject: [PATCH 36/40] No default for thresholds --- .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index f231008dce12e..f0e57b32a22f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -49,7 +49,7 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), ParamDesc[Array[Double]]("thresholds", "thresholds in multi-class classification prediction, must be array with size of classes.", - Some("null"), finalMethods = false), + finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 9630abca0e878..93550bf3576a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -154,7 +154,7 @@ private[ml] trait HasThreshold extends Params { } /** - * (private[ml]) Trait for shared param thresholds (default: null). + * Trait for shared param thresholds. */ private[ml] trait HasThresholds extends Params { @@ -164,8 +164,6 @@ private[ml] trait HasThresholds extends Params { */ final val thresholds: Param[Array[Double]] = new Param[Array[Double]](this, "thresholds", "thresholds in multi-class classification prediction, must be array with size of classes.") - setDefault(thresholds, null) - /** @group getParam */ def getThresholds: Array[Double] = $(thresholds) } From d5b0a2f143ab0c3874adddf5366e1f6d5197c296 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 1 Aug 2015 21:48:09 -0700 Subject: [PATCH 37/40] Fix handling of thresholds in LogisticRegression --- .../classification/LogisticRegression.scala | 26 +++++++++++++------ .../JavaLogisticRegressionSuite.java | 6 +++-- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index e254bc14caffe..964e69da7c223 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -111,19 +111,23 @@ class LogisticRegression(override val uid: String) setDefault(standardization -> true) /** @group setParam */ - def setThreshold(value: Double): this.type = set(thresholds, Array(0.5, value)) + def setThreshold(value: Double): this.type = set(thresholds, Array(value, 1-value)) setDefault(thresholds -> Array(0.5, 0.5)) /** - * Convert the thresholds to a threshold + * Convert the thresholds to a threshold, default of 0.5 * p/a > (1-p)/b * p*(b/a) + p > 1 * p > 1 / [1 + b/a] * threshold = 1 / [1 + b/a] */ def getThreshold(): Double = { - val thresholdValues = $(thresholds).toArray - 1 / (1 + thresholdValues(1) / thresholdValues(0)) + if (isDefined(thresholds)) { + val thresholdValues = $(thresholds) + 1 / (1 + thresholdValues(1) / thresholdValues(0)) + } else { + 0.5 + } } override protected def train(dataset: DataFrame): LogisticRegressionModel = { @@ -283,7 +287,8 @@ class LogisticRegressionModel private[ml] ( with LogisticRegressionParams { /** @group setParam */ - def setThreshold(value: Double): this.type = set(thresholds, Array(0.5, value)) + def setThreshold(value: Double): this.type = set(thresholds, Array(value, 1-value)) + setDefault(thresholds -> Array(0.5, 0.5)) /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { @@ -299,15 +304,20 @@ class LogisticRegressionModel private[ml] ( override val numClasses: Int = 2 /** - * Convert the thresholds to a threshold + * Convert the thresholds to a threshold, default of 0.5 * p/a > (1-p)/b * p*(b/a) + p > 1 * p > 1 / [1 + b/a] * threshold = 1 / [1 + b/a] */ def getThreshold(): Double = { - val thresholdValues = $(thresholds).toArray - 1 / (1 + thresholdValues(1) / thresholdValues(0)) + if (isDefined(thresholds)) { + val thresholdValues = $(thresholds) + val ret = 1 / (1 + thresholdValues(1) / thresholdValues(0)) + ret + } else { + 0.5 + } } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 615f99f0d2107..db171a33ca19e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -87,6 +87,8 @@ public void logisticRegressionWithSetters() { LogisticRegression parent = (LogisticRegression) model.parent(); assert(parent.getMaxIter() == 10); assert(parent.getRegParam() == 1.0); + assert(parent.getThresholds()[0] == 0.6); + assert(parent.getThresholds()[1] == 0.4); assert(parent.getThreshold() == 0.6); assert(model.getThreshold() == 0.6); @@ -98,7 +100,7 @@ public void logisticRegressionWithSetters() { assert(r.getDouble(0) == 0.0); } // Call transform with params, and check that the params worked. - double[] thresholds = {0.5, 0.0}; + double[] thresholds = {0.0, 1.0}; model.transform( dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); @@ -110,7 +112,7 @@ public void logisticRegressionWithSetters() { assert(foundNonZero); // Call fit() with new params, and check as many params as we can. - double[] thresholds2 = {0.5, 0.4}; + double[] thresholds2 = {0.4, 0.6}; LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); From 4b1104143b68f01521b61c4904fca3384bcc3fb1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 2 Aug 2015 03:09:21 -0700 Subject: [PATCH 38/40] Override raw2prediction for ProbabilisticClassifier, fix some tests --- .../spark/ml/classification/ProbabilisticClassifier.scala | 8 ++++++++ .../spark/ml/classification/LogisticRegressionSuite.scala | 4 ++-- .../apache/spark/ml/classification/OneVsRestSuite.scala | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index a80256dc45477..0a8393474730f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -155,6 +155,14 @@ private[spark] abstract class ProbabilisticClassificationModel[ raw2probabilityInPlace(probs) } + override protected def raw2prediction(rawPrediction: Vector): Double = { + if (!isDefined(thresholds)) { + rawPrediction.argmax + } else { + probability2prediction(raw2probability(rawPrediction)) + } + } + /** * Predict the probability of each class given the features. * These predictions are also called class conditional probabilities. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index ac3ae19b0d1e4..c74c292be3ec1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -123,7 +123,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.thresholds -> Array(0.5, 0.0), + model.transform(dataset, model.thresholds -> Array(0.0, 1.0), model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() @@ -132,7 +132,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { // Call fit() with new params, and check as many params as we can. val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, - lr.thresholds -> Array(0.5, 0.4), + lr.thresholds -> Array(0.4, 0.6), lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index a5f20c82346c2..9c8b58d101ef9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -151,7 +151,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, "copy should handle extra classifier params") - val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.5, 0.1))) + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.1, 0.9))) ovrModel.models.foreach { case m: LogisticRegressionModel => require(m.getThreshold === 0.1, "copy should handle extra model params") } From aa89af6635836bf80d0b8e69b15da858a4e3901a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 2 Aug 2015 15:47:04 -0700 Subject: [PATCH 39/40] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc. --- .../ml/param/shared/SharedParamsCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 2 +- python/pyspark/ml/classification.py | 38 ++++++++++++------- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index f0e57b32a22f5..3a0fafb80b32d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -92,6 +92,7 @@ private[shared] object SharedParamsCodeGen { case _ if c == classOf[Double] => "DoubleParam" case _ if c == classOf[Boolean] => "BooleanParam" case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam" + case _ if c.isArray && c.getComponentType == classOf[Double] => s"DoubleArrayParam" case _ => s"Param[${getTypeString(c)}]" } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 93550bf3576a0..bad6483576963 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -162,7 +162,7 @@ private[ml] trait HasThresholds extends Params { * Param for thresholds in multi-class classification prediction, must be array with size of classes.. * @group param */ - final val thresholds: Param[Array[Double]] = new Param[Array[Double]](this, "thresholds", "thresholds in multi-class classification prediction, must be array with size of classes.") + final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "thresholds in multi-class classification prediction, must be array with size of classes.") /** @group getParam */ def getThresholds: Array[Double] = $(thresholds) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 93ffcd40949b3..0038e654e0e1b 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -64,17 +64,17 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") - threshold = Param(Params._dummy(), "threshold", - "threshold in binary classification prediction, in range [0, 1].") + thresholds = Param(Params._dummy(), "thresholds", + "array of thresholds in classification") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability"): + threshold=None, thresholds=None, probabilityCol="probability"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability") + threshold=None, thresholds=None, probabilityCol="probability") """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -88,23 +88,26 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification prediction, in range [0, 1]. - self.threshold = Param(self, "threshold", - "threshold in binary classification prediction, in range [0, 1].") + self.thresholds = Param(self, "thresholds", + "thresholds in classification") self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True, threshold=0.5) + fitIntercept=True, thresholds=[0.5, 0.5]) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability"): + threshold=None, thresholds=None, probabilityCol="probability"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability") + threshold=None, probabilityCol="probability") Sets params for logistic regression. """ + # Under the hood we use thresholds so translate threshold to thresholds if applicable + if thresholds is None and threshold is not None: + kwargs[thresholds] = [threshold, 1-threshold] kwargs = self.setParams._input_kwargs return self._set(**kwargs) @@ -139,16 +142,23 @@ def getFitIntercept(self): def setThreshold(self, value): """ - Sets the value of :py:attr:`threshold`. + Sets the value of :py:attr:`thresholds` using [value, 1-value]. """ - self._paramMap[self.threshold] = value + return self.setThresholds([value, 1-value]) + + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. + """ + self._paramMap[self.thresholds] = value return self - def getThreshold(self): + + def getThresholds(self): """ - Gets the value of threshold or its default value. + Gets the value of thresholds or its default value. """ - return self.getOrDefault(self.threshold) + return self.getOrDefault(self.thresholds) class LogisticRegressionModel(JavaModel): From a7dc7b53ac952aee6e1276da78e8572710d3400a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 2 Aug 2015 16:33:16 -0700 Subject: [PATCH 40/40] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat --- python/pyspark/ml/classification.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 0038e654e0e1b..6040161caf7c3 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -65,7 +65,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") thresholds = Param(Params._dummy(), "thresholds", - "array of thresholds in classification") + "array of thresholds in classification") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -89,7 +89,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification prediction, in range [0, 1]. self.thresholds = Param(self, "thresholds", - "thresholds in classification") + "thresholds in classification") self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, fitIntercept=True, thresholds=[0.5, 0.5]) kwargs = self.__init__._input_kwargs @@ -153,13 +153,19 @@ def setThresholds(self, value): self._paramMap[self.thresholds] = value return self - def getThresholds(self): """ Gets the value of thresholds or its default value. """ return self.getOrDefault(self.thresholds) + def getThreshold(self): + """ + Gets the value of threshold or its default value. + """ + thresholds = self.getOrDefault(self.thresholds) + return 1/(1+thresholds[1]/thresholds[0]) + class LogisticRegressionModel(JavaModel): """