From bf5d852cdb2b8d2754d67dfa4b21e7bbd6165edc Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Sep 2015 18:35:05 +0800 Subject: [PATCH 01/12] DecisionTreeRegressor: provide variance of prediction --- .../ml/param/shared/SharedParamsCodeGen.scala | 2 ++ .../spark/ml/param/shared/sharedParams.scala | 17 ++++++++++++ .../ml/regression/DecisionTreeRegressor.scala | 16 +++++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 27 ++++++++++++++++++- 4 files changed, 59 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index c7bca1243092c..faa52e741fd03 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -44,6 +44,8 @@ private[shared] object SharedParamsCodeGen { " probabilities. Note: Not all models output well-calibrated probability estimates!" + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), + ParamDesc[String]("varianceCol", "Column name for the variance of prediction", + Some("\"variance\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index cb2a060a34dd6..0feb1625b84f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -138,6 +138,23 @@ private[ml] trait HasProbabilityCol extends Params { final def getProbabilityCol: String = $(probabilityCol) } +/** + * Trait for shared param varianceCol (default: "variance"). + */ +private[ml] trait HasVarianceCol extends Params { + + /** + * Param for Column name for the variance of prediction. + * @group param + */ + final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the variance of prediction") + + setDefault(varianceCol, "variance") + + /** @group getParam */ + final def getVarianceCol: String = $(varianceCol) +} + /** * Trait for shared param threshold (default: 0.5). */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 477030d9ea3ee..bb81dcc61c5d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ /** * :: Experimental :: @@ -40,7 +41,7 @@ import org.apache.spark.sql.DataFrame @Experimental final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeParams with TreeRegressorParams { + with TreeRegressorParams { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) @@ -113,7 +114,7 @@ final class DecisionTreeRegressionModel private[ml] ( override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] - with DecisionTreeModel with Serializable { + with DecisionTreeModel with TreeRegressorParams with Serializable { require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") @@ -129,6 +130,17 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).prediction } + def predictVariance(features: Vector): Double = { + rootNode.predictImpl(features).impurityStats.calculate() + } + + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val predictUDF = udf { (features: Vector) => predict(features) } + val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + .withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) + } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeRegressionModel = { copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 1da97db9277d8..bb92113278713 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,9 +20,11 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.sql.types.{DoubleType, DataType, StructType} /** * Parameters for Decision Tree-based algorithms. @@ -218,7 +220,7 @@ private[ml] object TreeClassifierParams { /** * Parameters for Decision Tree-based regression algorithms. */ -private[ml] trait TreeRegressorParams extends Params { +private[ml] trait TreeRegressorParams extends DecisionTreeParams with HasVarianceCol { /** * Criterion used for information gain calculation (case-insensitive). @@ -249,6 +251,29 @@ private[ml] trait TreeRegressorParams extends Params { s"TreeRegressorParams was given unrecognized impurity: $impurity") } } + + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting + * @param featuresDataType SQL DataType for FeaturesType. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @return output schema + */ + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) + if (fitting) { + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + SchemaUtils.appendColumn(schema, $(varianceCol), DoubleType) + } } private[ml] object TreeRegressorParams { From 533be3346339db82960fbba0f7bfdf38cbde0109 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 14 Oct 2015 19:59:31 +0800 Subject: [PATCH 02/12] make predictVariance protected --- .../org/apache/spark/ml/regression/DecisionTreeRegressor.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index bb81dcc61c5d8..2a2626a71b031 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -130,7 +130,8 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).prediction } - def predictVariance(features: Vector): Double = { + /** We need to update this function if we ever add other impurity measures. */ + protected def predictVariance(features: Vector): Double = { rootNode.predictImpl(features).impurityStats.calculate() } From 89e51766df16ddb3c74c48849505115120818fae Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 22 Oct 2015 01:12:11 +0800 Subject: [PATCH 03/12] make VarianceCol validation only when users explicitly specified --- .../ml/param/shared/SharedParamsCodeGen.scala | 3 +-- .../spark/ml/param/shared/sharedParams.scala | 4 +-- .../ml/regression/DecisionTreeRegressor.scala | 19 ++++++++++--- .../org/apache/spark/ml/tree/treeParams.scala | 27 +------------------ 4 files changed, 18 insertions(+), 35 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index faa52e741fd03..a71eeddfd68e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -44,8 +44,7 @@ private[shared] object SharedParamsCodeGen { " probabilities. Note: Not all models output well-calibrated probability estimates!" + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), - ParamDesc[String]("varianceCol", "Column name for the variance of prediction", - Some("\"variance\"")), + ParamDesc[String]("varianceCol", "Column name for the variance of prediction"), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 0feb1625b84f0..bdc97c28fcaf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params { } /** - * Trait for shared param varianceCol (default: "variance"). + * Trait for shared param varianceCol. */ private[ml] trait HasVarianceCol extends Params { @@ -149,8 +149,6 @@ private[ml] trait HasVarianceCol extends Params { */ final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the variance of prediction") - setDefault(varianceCol, "variance") - /** @group getParam */ final def getVarianceCol: String = $(varianceCol) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 2a2626a71b031..96b10aa362d7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasVarianceCol import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} @@ -41,7 +42,7 @@ import org.apache.spark.sql.functions._ @Experimental final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with TreeRegressorParams { + with DecisionTreeParams with TreeRegressorParams with HasVarianceCol { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) @@ -74,6 +75,9 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setSeed(value: Long): this.type = super.setSeed(value) + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -114,11 +118,14 @@ final class DecisionTreeRegressionModel private[ml] ( override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] - with DecisionTreeModel with TreeRegressorParams with Serializable { + with DecisionTreeModel with HasVarianceCol with Serializable { require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) + /** * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. @@ -138,8 +145,12 @@ final class DecisionTreeRegressionModel private[ml] ( override protected def transformImpl(dataset: DataFrame): DataFrame = { val predictUDF = udf { (features: Vector) => predict(features) } val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) - .withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) + if (isDefined(varianceCol)) { + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + .withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) + } else { + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index bb92113278713..1da97db9277d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,11 +20,9 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} -import org.apache.spark.sql.types.{DoubleType, DataType, StructType} /** * Parameters for Decision Tree-based algorithms. @@ -220,7 +218,7 @@ private[ml] object TreeClassifierParams { /** * Parameters for Decision Tree-based regression algorithms. */ -private[ml] trait TreeRegressorParams extends DecisionTreeParams with HasVarianceCol { +private[ml] trait TreeRegressorParams extends Params { /** * Criterion used for information gain calculation (case-insensitive). @@ -251,29 +249,6 @@ private[ml] trait TreeRegressorParams extends DecisionTreeParams with HasVarianc s"TreeRegressorParams was given unrecognized impurity: $impurity") } } - - /** @group setParam */ - def setVarianceCol(value: String): this.type = set(varianceCol, value) - - /** - * Validates and transforms the input schema with the provided param map. - * @param schema input schema - * @param fitting whether this is in fitting - * @param featuresDataType SQL DataType for FeaturesType. - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. - * @return output schema - */ - override protected def validateAndTransformSchema( - schema: StructType, - fitting: Boolean, - featuresDataType: DataType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) - if (fitting) { - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) - } - SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) - SchemaUtils.appendColumn(schema, $(varianceCol), DoubleType) - } } private[ml] object TreeRegressorParams { From f2928e79c338314128b5f8f9a4ecb26aae28a8c6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 22 Oct 2015 01:51:32 +0800 Subject: [PATCH 04/12] fix validateAndTransformSchema --- .../ml/regression/DecisionTreeRegressor.scala | 12 +++---- .../org/apache/spark/ml/tree/treeParams.scala | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 96b10aa362d7f..244ca5a43dd3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -20,8 +20,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.HasVarianceCol -import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector @@ -42,7 +41,7 @@ import org.apache.spark.sql.functions._ @Experimental final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeParams with TreeRegressorParams with HasVarianceCol { + with DecisionTreeRegressorParams { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) @@ -76,7 +75,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setSeed(value: Long): this.type = super.setSeed(value) /** @group setParam */ - def setVarianceCol(value: String): this.type = set(varianceCol, value) + override def setVarianceCol(value: String): this.type = set(varianceCol, value) override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = @@ -118,14 +117,11 @@ final class DecisionTreeRegressionModel private[ml] ( override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] - with DecisionTreeModel with HasVarianceCol with Serializable { + with DecisionTreeModel with DecisionTreeRegressorParams with Serializable { require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") - /** @group setParam */ - def setVarianceCol(value: String): this.type = set(varianceCol, value) - /** * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 1da97db9277d8..eec128e33f075 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,9 +20,11 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.sql.types.{DoubleType, DataType, StructType} /** * Parameters for Decision Tree-based algorithms. @@ -256,6 +258,35 @@ private[ml] object TreeRegressorParams { final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) } +private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams + with TreeRegressorParams with HasVarianceCol { + + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting + * @param featuresDataType SQL DataType for FeaturesType. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @return output schema + */ + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) + if (fitting) { + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + if (isDefined(varianceCol)) { + SchemaUtils.appendColumn(schema, $(varianceCol), DoubleType) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + /** * Parameters for Decision Tree-based ensemble algorithms. * From b51c76ba06cf6bedd1fcc8f289adf79fce193b20 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 30 Dec 2015 14:40:28 +0800 Subject: [PATCH 05/12] Move setVarianceCol from trait to class for Java compatibility --- .../org/apache/spark/ml/regression/DecisionTreeRegressor.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 244ca5a43dd3f..97b3eeb118554 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -75,7 +75,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setSeed(value: Long): this.type = super.setSeed(value) /** @group setParam */ - override def setVarianceCol(value: String): this.type = set(varianceCol, value) + def setVarianceCol(value: String): this.type = set(varianceCol, value) override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index eec128e33f075..e91d622063f98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -261,9 +261,6 @@ private[ml] object TreeRegressorParams { private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams with TreeRegressorParams with HasVarianceCol { - /** @group setParam */ - def setVarianceCol(value: String): this.type = set(varianceCol, value) - /** * Validates and transforms the input schema with the provided param map. * @param schema input schema From b4b4eb44f758a3a28f67c3c1a2fb0915d20a2d7e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 30 Dec 2015 15:16:13 +0800 Subject: [PATCH 06/12] Update varianceCol doc --- .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a71eeddfd68e9..4aff749ff75af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -44,7 +44,7 @@ private[shared] object SharedParamsCodeGen { " probabilities. Note: Not all models output well-calibrated probability estimates!" + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), - ParamDesc[String]("varianceCol", "Column name for the variance of prediction"), + ParamDesc[String]("varianceCol", "Column name for the biased sample variance of prediction"), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index bdc97c28fcaf7..c088c16d1b05d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -144,10 +144,10 @@ private[ml] trait HasProbabilityCol extends Params { private[ml] trait HasVarianceCol extends Params { /** - * Param for Column name for the variance of prediction. + * Param for Column name for the biased sample variance of prediction. * @group param */ - final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the variance of prediction") + final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the biased sample variance of prediction") /** @group getParam */ final def getVarianceCol: String = $(varianceCol) From 7eb4febeb71c056a820bd8d9ed117148865d24f5 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 30 Dec 2015 15:36:39 +0800 Subject: [PATCH 07/12] remove duplicated doc --- .../spark/ml/regression/DecisionTreeRegressor.scala | 2 +- .../main/scala/org/apache/spark/ml/tree/treeParams.scala | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 97b3eeb118554..7e06421f8e8af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -141,7 +141,7 @@ final class DecisionTreeRegressionModel private[ml] ( override protected def transformImpl(dataset: DataFrame): DataFrame = { val predictUDF = udf { (features: Vector) => predict(features) } val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } - if (isDefined(varianceCol)) { + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) .withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index e91d622063f98..d05fc8ab34275 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -261,14 +261,6 @@ private[ml] object TreeRegressorParams { private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams with TreeRegressorParams with HasVarianceCol { - /** - * Validates and transforms the input schema with the provided param map. - * @param schema input schema - * @param fitting whether this is in fitting - * @param featuresDataType SQL DataType for FeaturesType. - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. - * @return output schema - */ override protected def validateAndTransformSchema( schema: StructType, fitting: Boolean, From 31a08cea13f62adeade61c53aa8a729fbeadccf7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 30 Dec 2015 16:12:15 +0800 Subject: [PATCH 08/12] add test case --- .../DecisionTreeRegressorSuite.scala | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 6999a910c34a4..37f2ede0be0dd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row, DataFrame} class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -73,6 +74,28 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex MLTestingUtils.checkCopy(model) } + test("predictVariance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + .setVarianceCol("variance") + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = dt.fit(df) + + val predictions = model.transform(df) + .select(model.getFeaturesCol, model.getVarianceCol) + .collect() + + predictions.foreach { case Row(features: Vector, variance: Double) => + val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() + assert(variance === expectedVariance, + s"Expected variance $expectedVariance but got $variance.") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// From ec0f8804cf58b1f858d4010f19ab46ef779e2ef0 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 30 Dec 2015 16:25:48 +0800 Subject: [PATCH 09/12] add setVarianceCol for DecisionTreeRegressionModel --- .../org/apache/spark/ml/regression/DecisionTreeRegressor.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 7e06421f8e8af..d961049cd3565 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -119,6 +119,9 @@ final class DecisionTreeRegressionModel private[ml] ( extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with DecisionTreeRegressorParams with Serializable { + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) + require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") From 132bf21c38386033b218b057db4f13b0cc8159cd Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 31 Dec 2015 11:33:45 +0800 Subject: [PATCH 10/12] Fix turn on/off for predictionCol and varianceCol of DecisionTreeRegressor --- .../ml/regression/DecisionTreeRegressor.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index d961049cd3565..18c94f36387b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -141,15 +141,22 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).impurityStats.calculate() } + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + transformImpl(dataset) + } + override protected def transformImpl(dataset: DataFrame): DataFrame = { val predictUDF = udf { (features: Vector) => predict(features) } val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } + var output = dataset + if ($(predictionCol).nonEmpty) { + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) - .withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) - } else { - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) } + output } @Since("1.4.0") From 010edc5387ff1c374ef00c7efb75d7ec255a6dad Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 31 Dec 2015 11:38:36 +0800 Subject: [PATCH 11/12] Reuse super.validateAndTransformSchema --- .../scala/org/apache/spark/ml/tree/treeParams.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d05fc8ab34275..7443097492d82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -265,14 +265,12 @@ private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) - if (fitting) { - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + SchemaUtils.appendColumn(newSchema, $(varianceCol), DoubleType) + } else { + newSchema } - if (isDefined(varianceCol)) { - SchemaUtils.appendColumn(schema, $(varianceCol), DoubleType) - } - SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } } From ad6b24c511a34670a18acfd5dc5731cfff0f7fd9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 31 Dec 2015 11:48:51 +0800 Subject: [PATCH 12/12] Add test --- .../apache/spark/ml/regression/DecisionTreeRegressorSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 37f2ede0be0dd..0b39af5543e93 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -79,6 +79,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) + .setPredictionCol("") .setVarianceCol("variance") val categoricalFeatures = Map(0 -> 2, 1 -> 2)