From 8d3c9ecbd8f2bc2e1fcebb189e3bfd67ee7ab7a4 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 20 Nov 2015 11:22:23 -0800 Subject: [PATCH 01/11] partly done adding save/load to DecisionTreeClassifier and Model --- .../DecisionTreeClassifier.scala | 142 +++++++++++++++++- .../org/apache/spark/ml/tree/Split.scala | 4 +- 2 files changed, 136 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 7f0397f6bd65a..3b22e7032b85a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,17 +17,20 @@ package org.apache.spark.ml.classification +import org.apache.hadoop.fs.Path +import org.json4s.JsonAST.JValue + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row, DataFrame} /** * :: Experimental :: @@ -41,7 +44,7 @@ import org.apache.spark.sql.DataFrame final class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] - with DecisionTreeParams with TreeClassifierParams { + with DecisionTreeParams with TreeClassifierParams with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtc")) @@ -107,10 +110,13 @@ final class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") @Experimental -object DecisionTreeClassifier { +object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] { /** Accessor for supported impurities: entropy, gini */ @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + @Since("1.6.0") + override def load(path: String): DecisionTreeClassifier = super.load(path) } /** @@ -127,7 +133,7 @@ final class DecisionTreeClassificationModel private[ml] ( @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] - with DecisionTreeModel with Serializable { + with DecisionTreeModel with MLWritable with Serializable { require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") @@ -192,12 +198,132 @@ final class DecisionTreeClassificationModel private[ml] ( private[ml] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) } + + @Since("1.6.0") + override def write: MLWriter = + new DecisionTreeClassificationModel.DecisionTreeClassificationModelWriter(this) } -private[ml] object DecisionTreeClassificationModel { +@Since("1.6.0") +object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassificationModel] { + + @Since("1.6.0") + override def read: MLReader[DecisionTreeClassificationModel] = + new DecisionTreeClassificationModelReader + + @Since("1.6.0") + override def load(path: String): DecisionTreeClassificationModel = super.load(path) + + private[DecisionTreeClassificationModel] + class DecisionTreeClassificationModelWriter(instance: DecisionTreeClassificationModel) + extends MLWriter { + + import org.json4s.jackson.JsonMethods._ + import org.json4s.JsonDSL._ + + /** + * Info for a [[org.apache.spark.ml.tree.Split]] + * @param featureIndex Index of feature split on + * @param leftCategoriesOrThreshold For categorical feature, set of leftCategories. + * For continuous feature, threshold. + * @param numCategories For categorical feature, number of categories. + * For continuous feature, -1. + */ + private[ml] case class SplitData( + featureIndex: Int, + leftCategoriesOrThreshold: Array[Double], + numCategories: Int) + + private[ml] object SplitData { + def apply(split: Split): SplitData = split match { + case s: CategoricalSplit => + SplitData(s.featureIndex, s.leftCategories, s.numCategories) + case s: ContinuousSplit => + SplitData(s.featureIndex, Array(s.threshold), -1) + } + } + + /** + * Info for a [[Node]] + * @param id Index used for tree reconstruction. Indices follow an in-order traversal. + * @param impurityStats Stats array. Impurity type is stored in metadata. + * @param gain Gain, or arbitrary value if leaf node. + * @param leftChild Left child index, or arbitrary value if leaf node. + * @param rightChild Right child index, or arbitrary value if leaf node. + * @param split Split info, or arbitrary value if leaf node. + */ + private[ml] case class NodeData( + id: Int, + prediction: Double, + impurity: Double, + impurityStats: Array[Double], + gain: Double, + leftChild: Int, + rightChild: Int, + split: SplitData) + + private[ml] object NodeData { + /** + * Create [[NodeData]] instances for this node and all children. + * @param id Current ID. IDs are assigned via an in-order traversal. + * @return (sequence of nodes in preorder traversal order, largest ID in subtree) + * The nodes are returned in preorder traversal (root first) so that it is easy to + * get the ID of the subtree's root node. + */ + def build(node: Node, id: Int): (Seq[NodeData], Int) = node match { + case n: InternalNode => + val (leftNodeData, leftIdx) = build(n.leftChild, id) + val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 2) + val thisNodeData = NodeData(leftIdx + 1, n.prediction, n.impurity, n.impurityStats.stats, + n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split)) + (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx) + case _: LeafNode => + (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, + -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), + id) + } + + def reconstruct(df: DataFrame): Node = { + val nodeDatas = df.select("id", "prediction", "impurity", "impurityStats", "gain", + "leftChild", "rightChild", + "split.featureIndex", "split.leftCategoriesOrThreshold", "split.numCategories") + .map { + case Row(id: Int, prediction: Double, impurity: Double, impurityStats: Seq[Double], + gain: Double, leftChild: Int, rightChild: Int, splitFeatureIndex: Int, + splitLeftCategoriesOrThreshold: Seq[Double], splitNumCategories: Int) => + + } + } + } + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JValue = render(Map( + "numFeatures" -> instance.numFeatures, + "numClasses" -> instance.numClasses)) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val (nodeData, _) = NodeData.build(instance.rootNode, 0) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(nodeData).write.parquet(dataPath) + } + } + + private class DecisionTreeClassificationModelReader + extends MLReader[DecisionTreeClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[DecisionTreeClassificationModel].getName + + override def load(path: String): DecisionTreeClassificationModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + + } + } /** (private[ml]) Convert a model from the old API */ - def fromOld( + private[ml] def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeClassifier, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 78199cc2df582..b143865ac1a6d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} import org.apache.spark.mllib.tree.model.{Split => OldSplit} @@ -76,7 +76,7 @@ private[tree] object Split { final class CategoricalSplit private[ml] ( override val featureIndex: Int, _leftCategories: Array[Double], - private val numCategories: Int) + @Since("1.6.0") val numCategories: Int) extends Split { require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + From 06e16f61de87be54ae57f911db3306e25ee26b84 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 7 Mar 2016 22:41:25 -0800 Subject: [PATCH 02/11] DecisionTreeClassifier,Regressor and Models support save,load. Fixed bug in DefaultReadWriteTest.testEstimatorAndModelReadWrite where it never called checkModelData function. --- .../DecisionTreeClassifier.scala | 117 +++------------ .../org/apache/spark/ml/param/params.scala | 34 +++-- .../ml/regression/DecisionTreeRegressor.scala | 70 +++++++-- .../org/apache/spark/ml/tree/treeModels.scala | 135 +++++++++++++++++- .../org/apache/spark/ml/tree/treeParams.scala | 3 + .../org/apache/spark/ml/util/ReadWrite.scala | 19 ++- .../spark/mllib/tree/impurity/Impurity.scala | 14 ++ .../JavaDecisionTreeClassifierSuite.java | 2 +- .../JavaGBTClassifierSuite.java | 2 +- .../JavaRandomForestClassifierSuite.java | 2 +- .../JavaDecisionTreeRegressorSuite.java | 2 +- .../ml/regression/JavaGBTRegressorSuite.java | 2 +- .../JavaRandomForestRegressorSuite.java | 2 +- .../DecisionTreeClassifierSuite.scala | 52 ++++--- .../classification/GBTClassifierSuite.scala | 2 +- .../RandomForestClassifierSuite.scala | 2 +- .../DecisionTreeRegressorSuite.scala | 36 ++++- .../ml/regression/GBTRegressorSuite.scala | 2 +- .../RandomForestRegressorSuite.scala | 2 +- .../ml/tree/impl/RandomForestSuite.scala | 1 - .../spark/ml/{ => tree}/impl/TreeTests.scala | 44 +++++- .../spark/ml/util/DefaultReadWriteTest.scala | 7 +- 22 files changed, 392 insertions(+), 160 deletions(-) rename mllib/src/test/scala/org/apache/spark/ml/{ => tree}/impl/TreeTests.scala (77%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 3b22e7032b85a..14b25fe1a086f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,11 +18,13 @@ package org.apache.spark.ml.classification import org.apache.hadoop.fs.Path -import org.json4s.JsonAST.JValue +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} @@ -30,7 +32,8 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.DataFrame + /** * :: Experimental :: @@ -44,7 +47,7 @@ import org.apache.spark.sql.{Row, DataFrame} final class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] - with DecisionTreeParams with TreeClassifierParams with DefaultParamsWritable { + with DecisionTreeClassifierParams with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtc")) @@ -115,7 +118,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities - @Since("1.6.0") + @Since("2.0.0") override def load(path: String): DecisionTreeClassifier = super.load(path) } @@ -133,14 +136,15 @@ final class DecisionTreeClassificationModel private[ml] ( @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] - with DecisionTreeModel with MLWritable with Serializable { + with DecisionTreeModel with DecisionTreeClassifierParams with MLWritable with Serializable { require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") /** * Construct a decision tree classification model. - * @param rootNode Root node of tree, with other nodes attached. + * + * @param rootNode Root node of tree, with other nodes attached. */ private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) @@ -199,107 +203,29 @@ final class DecisionTreeClassificationModel private[ml] ( new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) } - @Since("1.6.0") + @Since("2.0.0") override def write: MLWriter = new DecisionTreeClassificationModel.DecisionTreeClassificationModelWriter(this) } -@Since("1.6.0") +@Since("2.0.0") object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassificationModel] { - @Since("1.6.0") + @Since("2.0.0") override def read: MLReader[DecisionTreeClassificationModel] = new DecisionTreeClassificationModelReader - @Since("1.6.0") + @Since("2.0.0") override def load(path: String): DecisionTreeClassificationModel = super.load(path) private[DecisionTreeClassificationModel] class DecisionTreeClassificationModelWriter(instance: DecisionTreeClassificationModel) extends MLWriter { - import org.json4s.jackson.JsonMethods._ - import org.json4s.JsonDSL._ - - /** - * Info for a [[org.apache.spark.ml.tree.Split]] - * @param featureIndex Index of feature split on - * @param leftCategoriesOrThreshold For categorical feature, set of leftCategories. - * For continuous feature, threshold. - * @param numCategories For categorical feature, number of categories. - * For continuous feature, -1. - */ - private[ml] case class SplitData( - featureIndex: Int, - leftCategoriesOrThreshold: Array[Double], - numCategories: Int) - - private[ml] object SplitData { - def apply(split: Split): SplitData = split match { - case s: CategoricalSplit => - SplitData(s.featureIndex, s.leftCategories, s.numCategories) - case s: ContinuousSplit => - SplitData(s.featureIndex, Array(s.threshold), -1) - } - } - - /** - * Info for a [[Node]] - * @param id Index used for tree reconstruction. Indices follow an in-order traversal. - * @param impurityStats Stats array. Impurity type is stored in metadata. - * @param gain Gain, or arbitrary value if leaf node. - * @param leftChild Left child index, or arbitrary value if leaf node. - * @param rightChild Right child index, or arbitrary value if leaf node. - * @param split Split info, or arbitrary value if leaf node. - */ - private[ml] case class NodeData( - id: Int, - prediction: Double, - impurity: Double, - impurityStats: Array[Double], - gain: Double, - leftChild: Int, - rightChild: Int, - split: SplitData) - - private[ml] object NodeData { - /** - * Create [[NodeData]] instances for this node and all children. - * @param id Current ID. IDs are assigned via an in-order traversal. - * @return (sequence of nodes in preorder traversal order, largest ID in subtree) - * The nodes are returned in preorder traversal (root first) so that it is easy to - * get the ID of the subtree's root node. - */ - def build(node: Node, id: Int): (Seq[NodeData], Int) = node match { - case n: InternalNode => - val (leftNodeData, leftIdx) = build(n.leftChild, id) - val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 2) - val thisNodeData = NodeData(leftIdx + 1, n.prediction, n.impurity, n.impurityStats.stats, - n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split)) - (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx) - case _: LeafNode => - (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, - -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), - id) - } - - def reconstruct(df: DataFrame): Node = { - val nodeDatas = df.select("id", "prediction", "impurity", "impurityStats", "gain", - "leftChild", "rightChild", - "split.featureIndex", "split.leftCategoriesOrThreshold", "split.numCategories") - .map { - case Row(id: Int, prediction: Double, impurity: Double, impurityStats: Seq[Double], - gain: Double, leftChild: Int, rightChild: Int, splitFeatureIndex: Int, - splitLeftCategoriesOrThreshold: Seq[Double], splitNumCategories: Int) => - - } - } - } - override protected def saveImpl(path: String): Unit = { - val extraMetadata: JValue = render(Map( + val extraMetadata: JObject = Map( "numFeatures" -> instance.numFeatures, - "numClasses" -> instance.numClasses)) + "numClasses" -> instance.numClasses) DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val (nodeData, _) = NodeData.build(instance.rootNode, 0) val dataPath = new Path(path, "data").toString @@ -314,11 +240,14 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica private val className = classOf[DecisionTreeClassificationModel].getName override def load(path: String): DecisionTreeClassificationModel = { + implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - - val dataPath = new Path(path, "data").toString - val data = sqlContext.read.format("parquet").load(dataPath) - + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val root = loadTreeNodes(path, metadata, sqlContext) + val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) + DefaultParamsReader.getAndSetParams(model, metadata) + model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d7d6c0f5fa16e..42411d2d8af9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -101,7 +101,26 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } /** Decodes a param value from JSON. */ - def jsonDecode(json: String): T = { + def jsonDecode(json: String): T = Param.jsonDecode[T](json) + + private[this] val stringRepresentation = s"${parent}__$name" + + override final def toString: String = stringRepresentation + + override final def hashCode: Int = toString.## + + override final def equals(obj: Any): Boolean = { + obj match { + case p: Param[_] => (p.parent == parent) && (p.name == name) + case _ => false + } + } +} + +private[ml] object Param { + + /** Decodes a param value from JSON. */ + def jsonDecode[T](json: String): T = { parse(json) match { case JString(x) => x.asInstanceOf[T] @@ -116,19 +135,6 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali s"${this.getClass.getName} must override jsonDecode to support its value type.") } } - - private[this] val stringRepresentation = s"${parent}__$name" - - override final def toString: String = stringRepresentation - - override final def hashCode: Int = toString.## - - override final def equals(obj: Any): Boolean = { - obj match { - case p: Param[_] => (p.parent == parent) && (p.name == name) - case _ => false - } - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 897b23383c0cb..80a8e83119547 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -17,12 +17,17 @@ package org.apache.spark.ml.regression +import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} @@ -31,6 +36,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ + /** * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm @@ -41,7 +47,7 @@ import org.apache.spark.sql.functions._ @Experimental final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeRegressorParams { + with DecisionTreeRegressorParams with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) @@ -99,16 +105,20 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") @Experimental -object DecisionTreeRegressor { +object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] { /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + + @Since("2.0.0") + override def load(path: String): DecisionTreeRegressor = super.load(path) } /** * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. * It supports both continuous and categorical features. - * @param rootNode Root of the decision tree + * + * @param rootNode Root of the decision tree */ @Since("1.4.0") @Experimental @@ -117,17 +127,18 @@ final class DecisionTreeRegressionModel private[ml] ( override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] - with DecisionTreeModel with DecisionTreeRegressorParams with Serializable { + with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { /** @group setParam */ def setVarianceCol(value: String): this.type = set(varianceCol, value) require(rootNode != null, - "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + "DecisionTreeRegressionModel given null rootNode, but it requires a non-null rootNode.") /** * Construct a decision tree regression model. - * @param rootNode Root node of tree, with other nodes attached. + * + * @param rootNode Root node of tree, with other nodes attached. */ private[ml] def this(rootNode: Node, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) @@ -192,9 +203,52 @@ final class DecisionTreeRegressionModel private[ml] ( private[ml] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) } + + @Since("2.0.0") + override def write: MLWriter = + new DecisionTreeRegressionModel.DecisionTreeRegressionModelWriter(this) } -private[ml] object DecisionTreeRegressionModel { +@Since("2.0.0") +object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[DecisionTreeRegressionModel] = + new DecisionTreeRegressionModelReader + + @Since("2.0.0") + override def load(path: String): DecisionTreeRegressionModel = super.load(path) + + private[DecisionTreeRegressionModel] + class DecisionTreeRegressionModelWriter(instance: DecisionTreeRegressionModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val (nodeData, _) = NodeData.build(instance.rootNode, 0) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(nodeData).write.parquet(dataPath) + } + } + + private class DecisionTreeRegressionModelReader + extends MLReader[DecisionTreeRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): DecisionTreeRegressionModel = { + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val root = loadTreeNodes(path, metadata, sqlContext) + val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } /** (private[ml]) Convert a model from the old API */ def fromOld( diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 40ed95773e149..b43e62dd70fcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,7 +17,15 @@ package org.apache.spark.ml.tree +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.param.Param +import org.apache.spark.ml.util.DefaultParamsReader import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.sql.SQLContext /** * Abstraction for Decision Tree models. @@ -56,7 +64,8 @@ private[ml] trait DecisionTreeModel { /** * Trace down the tree, and return the largest feature index used in any split. - * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + * + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). */ private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() } @@ -101,3 +110,127 @@ private[ml] trait TreeEnsembleModel { /** Total number of nodes, summed over all trees in the ensemble. */ lazy val totalNumNodes: Int = trees.map(_.numNodes).sum } + +/** Helper classes for tree model persistence */ +private[ml] object DecisionTreeModelReadWrite { + + /** + * Info for a [[org.apache.spark.ml.tree.Split]] + * + * @param featureIndex Index of feature split on + * @param leftCategoriesOrThreshold For categorical feature, set of leftCategories. + * For continuous feature, threshold. + * @param numCategories For categorical feature, number of categories. + * For continuous feature, -1. + */ + case class SplitData( + featureIndex: Int, + leftCategoriesOrThreshold: Array[Double], + numCategories: Int) { + + def getSplit: Split = { + if (numCategories != -1) { + new CategoricalSplit(featureIndex, leftCategoriesOrThreshold, numCategories) + } else { + assert(leftCategoriesOrThreshold.length == 1, s"DecisionTree split data expected" + + s" 1 threshold for ContinuousSplit, but found thresholds: " + + leftCategoriesOrThreshold.mkString(", ")) + new ContinuousSplit(featureIndex, leftCategoriesOrThreshold(0)) + } + } + } + + object SplitData { + def apply(split: Split): SplitData = split match { + case s: CategoricalSplit => + SplitData(s.featureIndex, s.leftCategories, s.numCategories) + case s: ContinuousSplit => + SplitData(s.featureIndex, Array(s.threshold), -1) + } + } + + /** + * Info for a [[Node]] + * + * @param id Index used for tree reconstruction. Indices follow an in-order traversal. + * @param impurityStats Stats array. Impurity type is stored in metadata. + * @param gain Gain, or arbitrary value if leaf node. + * @param leftChild Left child index, or arbitrary value if leaf node. + * @param rightChild Right child index, or arbitrary value if leaf node. + * @param split Split info, or arbitrary value if leaf node. + */ + case class NodeData( + id: Int, + prediction: Double, + impurity: Double, + impurityStats: Array[Double], + gain: Double, + leftChild: Int, + rightChild: Int, + split: SplitData) + + object NodeData { + /** + * Create [[NodeData]] instances for this node and all children. + * + * @param id Current ID. IDs are assigned via an in-order traversal. + * @return (sequence of nodes in in-order traversal order, largest ID in subtree) + * The nodes are returned in in-order traversal (root first) so that it is easy to + * get the ID of the subtree's root node. + */ + def build(node: Node, id: Int): (Seq[NodeData], Int) = node match { + case n: InternalNode => + val (leftNodeData, leftIdx) = build(n.leftChild, id + 1) + val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1) + val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats, + n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split)) + (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx) + case _: LeafNode => + (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, + -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), + id) + } + } + + def loadTreeNodes( + path: String, + metadata: DefaultParamsReader.Metadata, + sqlContext: SQLContext): Node = { + import sqlContext.implicits._ + implicit val format = DefaultFormats + + // Get impurity to construct ImpurityCalculator for each node + val impurityType: String = { + val impurityJson: JValue = metadata.getParamValue("impurity") + Param.jsonDecode[String](compact(render(impurityJson))) + } + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath).as[NodeData] + + // Load all nodes, sorted by ID. + val nodes: Array[NodeData] = data.collect().sortBy(_.id) + // Sanity checks; could remove + assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," + + s" but found ${nodes.head.id}") + assert(nodes.last.id == nodes.length - 1, s"Decision Tree load failed. Expected largest" + + s" node ID to be ${nodes.length - 1}, but found ${nodes.last.id}") + // We fill `finalNodes` in reverse order. Since node IDs are assigned via an in-order + // traversal, this guarantees that child nodes will be built before parent nodes. + val finalNodes = new Array[Node](nodes.length) + nodes.reverseIterator.foreach { case n: NodeData => + val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats) + val node = if (n.leftChild != -1) { + val leftChild = finalNodes(n.leftChild) + val rightChild = finalNodes(n.rightChild) + new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, + n.split.getSplit, impurityStats) + } else { + new LeafNode(n.prediction, n.impurity, impurityStats) + } + finalNodes(n.id) = node + } + // Return the root node + finalNodes.head + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 7a651a37ac77e..3f2d0c7198c8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -217,6 +217,9 @@ private[ml] object TreeClassifierParams { final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) } +private[ml] trait DecisionTreeClassifierParams + extends DecisionTreeParams with TreeClassifierParams + /** * Parameters for Decision Tree-based regression algorithms. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7b2504361a6ea..b532cd3382b79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -273,7 +273,23 @@ private[ml] object DefaultParamsReader { sparkVersion: String, params: JValue, metadata: JValue, - metadataJson: String) + metadataJson: String) { + def getParamValue(paramName: String): JValue = { + implicit val format = DefaultFormats + params match { + case JObject(pairs) => + val values = pairs.filter { case (pName, jsonValue) => + pName == paramName + }.map(_._2) + assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + values.head + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + } /** * Load metadata from file. @@ -302,6 +318,7 @@ private[ml] object DefaultParamsReader { /** * Extract Params from metadata, and set them in the instance. * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * TODO: Move to [[Metadata]] method */ def getAndSetParams(instance: Params, metadata: Metadata): Unit = { implicit val format = DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 4637dcceea7f8..eb0a44d5199ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -179,3 +179,17 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten } } + +private[spark] object ImpurityCalculator { + + def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { + impurity match { + case "gini" => new GiniCalculator(stats) + case "entropy" => new EntropyCalculator(stats) + case "variance" => new VarianceCalculator(stats) + case _ => + throw new IllegalArgumentException( + s"ImpurityCalculator builder did not recognize impurity type: $impurity") + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 40b9c35adc431..51d03c85ba5fa 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java index 59b6fba7a928a..726eb40c761c0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 5485fcbf01bda..b07d3cb802f98 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index d5c9d120c592c..21acef686f9bc 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java index 38d15dc2b7c78..741d2d57cd03d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index 31be8880c25e1..ea1ecf4554f10 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.classification.LogisticRegressionSuite; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 6d6836449979c..b2193dcff63b5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -30,7 +30,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row -class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class DecisionTreeClassifierSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs @@ -338,25 +339,28 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* - test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification) - val newModel = DecisionTreeClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = DecisionTreeClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + test("read/write") { + def checkModelData( + model: DecisionTreeClassificationModel, + model2: DecisionTreeClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + assert(model.numClasses === model2.numClasses) } + + val dt = new DecisionTreeClassifier() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val categoricalData: DataFrame = + TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) + testEstimatorAndModelReadWrite(dt, categoricalData, + DecisionTreeClassifierSuite.allParamSettings, checkModelData) + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(dt, continuousData, + DecisionTreeClassifierSuite.allParamSettings, checkModelData) } - */ } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { @@ -381,4 +385,12 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { TreeTests.checkEqual(oldTreeAsNew, newTree) assert(newTree.numFeatures === numFeatures) } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = + Map("impurity" -> "entropy") ++ TreeTests.allParamSettings } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 039141aeb6f67..29efd675abdba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 6b810ab9eaa17..328aa92026658 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 56b335a33a6b8..3ac13497b2554 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, @@ -28,7 +28,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class DecisionTreeRegressorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import DecisionTreeRegressorSuite.compareAPIs @@ -120,7 +121,27 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: test("model save/load") SPARK-6725 + test("read/write") { + def checkModelData( + model: DecisionTreeRegressionModel, + model2: DecisionTreeRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + } + + val dt = new DecisionTreeRegressor() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val categoricalData: DataFrame = + TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) + testEstimatorAndModelReadWrite(dt, categoricalData, + DecisionTreeRegressorSuite.allParamSettings, checkModelData) + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(dt, continuousData, + DecisionTreeRegressorSuite.allParamSettings, checkModelData) + } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { @@ -144,4 +165,11 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { TreeTests.checkEqual(oldTreeAsNew, newTree) assert(newTree.numFeatures === numFeatures) } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = TreeTests.allParamSettings } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 244db8637bea0..db6860639794d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index efb117f8f9b16..6be0c8bca0227 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index d5c238e9ae164..9d922291a6986 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.tree.impurity.GiniCalculator diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala similarity index 77% rename from mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala rename to mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 5561f6f0ef3c4..e242e764ef6d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.ml.impl +package org.apache.spark.ml.tree.impl import scala.collection.JavaConverters._ -import org.apache.spark.SparkContext -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.tree._ @@ -33,7 +32,8 @@ private[ml] object TreeTests extends SparkFunSuite { /** * Convert the given data to a DataFrame, and set the features and label metadata. - * @param data Dataset. Categorical features and labels must already have 0-based indices. + * + * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. * @param categoricalFeatures Map: categorical feature index -> number of distinct values * @param numClasses Number of classes label can take. If 0, mark as continuous. @@ -129,7 +129,8 @@ private[ml] object TreeTests extends SparkFunSuite { /** * Helper method for constructing a tree for testing. * Given left, right children, construct a parent node. - * @param split Split for parent node + * + * @param split Split for parent node * @return Parent node with children attached */ def buildParentNode(left: Node, right: Node, split: Split): Node = { @@ -145,6 +146,7 @@ private[ml] object TreeTests extends SparkFunSuite { } /** +<<<<<<< HEAD:mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala * Create some toy data for testing feature importances. */ def featureImportanceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq( @@ -154,4 +156,36 @@ private[ml] object TreeTests extends SparkFunSuite { new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) )) +======= + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + * + * This set of Params is for all Decision Tree-based models. + */ + val allParamSettings: Map[String, Any] = Map( + "checkpointInterval" -> 7, + "seed" -> 543L, + "maxDepth" -> 2, + "maxBins" -> 20, + "minInstancesPerNode" -> 2, + "minInfoGain" -> 1e-14, + "maxMemoryInMB" -> 257, + "cacheNodeIds" -> true + ) + + /** Data for tree read/write tests which produces a non-trivial tree. */ + def getTreeReadWriteData(sc: SparkContext): RDD[LabeledPoint] = { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 2.0))) + sc.parallelize(arr) + } +>>>>>>> DecisionTreeClassifier,Regressor and Models support save,load. Fixed bug in DefaultReadWriteTest.testEstimatorAndModelReadWrite where it never called checkModelData function.:mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 8e5365af849a7..87afccd33b4c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -22,6 +22,7 @@ import java.io.{File, IOException} import org.scalatest.Suite import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -33,7 +34,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * Checks "overwrite" option and params. * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name * in order to avoid conflicts from multiple calls to this method. - * @param instance ML instance to test saving/loading + * + * @param instance ML instance to test saving/loading * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. * @tparam T ML instance type * @return Instance loaded from file @@ -85,7 +87,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * - Compare model data * * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. - * @param estimator Estimator to test + * + * @param estimator Estimator to test * @param dataset Dataset to pass to [[Estimator.fit()]] * @param testParams Set of [[Param]] values to set in estimator * @param checkModelData Method which takes the original and loaded [[Model]] and compares their From 315bab12436e8257519cded8abfa303e131dd62e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 8 Mar 2016 00:03:43 -0800 Subject: [PATCH 03/11] Fixed issue in LDA not copying doc,topicConcentration values --- .../src/main/scala/org/apache/spark/ml/clustering/LDA.scala | 6 +++++- .../scala/org/apache/spark/ml/clustering/LDASuite.scala | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 6304b20d544ad..0b8c3f76d73b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -786,7 +786,11 @@ class LDA @Since("1.6.0") ( case m: OldDistributedLDAModel => new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None) } - copyValues(newModel).setParent(this) + copyValues(newModel, + ParamMap( + docConcentration -> oldModel.docConcentration.toArray, + topicConcentration -> oldModel.topicConcentration)) + .setParent(this) } @Since("1.6.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index a3a8f65eac176..c6ff9d51d1a1f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -240,6 +240,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.vocabSize === model2.vocabSize) assert(Vectors.dense(model.topicsMatrix.toArray) ~== Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + model.getDocConcentration + model2.getDocConcentration assert(Vectors.dense(model.getDocConcentration) ~== Vectors.dense(model2.getDocConcentration) absTol 1e-6) } From 2240490209529e574840a558cb8eaf27ce3eecb3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 8 Mar 2016 10:04:10 -0800 Subject: [PATCH 04/11] Reverted annoying style mistakes from IntelliJ. Fixed my fix to LDASuite. Some more docs. --- .../DecisionTreeClassifier.scala | 5 ++- .../org/apache/spark/ml/clustering/LDA.scala | 5 +++ .../ml/regression/DecisionTreeRegressor.scala | 7 ++-- .../org/apache/spark/ml/tree/Split.scala | 4 +-- .../org/apache/spark/ml/tree/treeModels.scala | 3 +- .../org/apache/spark/ml/util/ReadWrite.scala | 6 ++++ .../spark/mllib/tree/impurity/Impurity.scala | 6 ++++ .../DecisionTreeClassifierSuite.scala | 16 +++------- .../apache/spark/ml/clustering/LDASuite.scala | 32 ++++++++----------- .../DecisionTreeRegressorSuite.scala | 11 ++----- .../apache/spark/ml/tree/impl/TreeTests.scala | 6 ++-- .../spark/ml/util/DefaultReadWriteTest.scala | 8 ++--- 12 files changed, 51 insertions(+), 58 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 14b25fe1a086f..bf64878e767f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -143,8 +143,7 @@ final class DecisionTreeClassificationModel private[ml] ( /** * Construct a decision tree classification model. - * - * @param rootNode Root node of tree, with other nodes attached. + * @param rootNode Root node of tree, with other nodes attached. */ private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) @@ -251,7 +250,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica } } - /** (private[ml]) Convert a model from the old API */ + /** Convert a model from the old API */ private[ml] def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeClassifier, diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 0b8c3f76d73b1..16b98bbd0dae0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -786,11 +786,16 @@ class LDA @Since("1.6.0") ( case m: OldDistributedLDAModel => new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None) } + copyValues(newModel).setParent(this) + // We copy the docConcentration, topicConcentration explicitly to handle their + // special default behavior when not set by the user. + /* copyValues(newModel, ParamMap( docConcentration -> oldModel.docConcentration.toArray, topicConcentration -> oldModel.topicConcentration)) .setParent(this) + */ } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 80a8e83119547..67af44408bc48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -117,8 +117,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. * It supports both continuous and categorical features. - * - * @param rootNode Root of the decision tree + * @param rootNode Root of the decision tree */ @Since("1.4.0") @Experimental @@ -250,8 +249,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode } } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeRegressor, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index b143865ac1a6d..9d895b8faca7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.{Since, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} import org.apache.spark.mllib.tree.model.{Split => OldSplit} @@ -76,7 +76,7 @@ private[tree] object Split { final class CategoricalSplit private[ml] ( override val featureIndex: Int, _leftCategories: Array[Double], - @Since("1.6.0") val numCategories: Int) + @Since("2.0.0") val numCategories: Int) extends Split { require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index b43e62dd70fcb..f4ee70afb18ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -64,8 +64,7 @@ private[ml] trait DecisionTreeModel { /** * Trace down the tree, and return the largest feature index used in any split. - * - * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). */ private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index b532cd3382b79..329548f95a669 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -274,6 +274,12 @@ private[ml] object DefaultParamsReader { params: JValue, metadata: JValue, metadataJson: String) { + + /** + * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. + * This can be useful for getting a Param value before an instance of [[Params]] + * is available. + */ def getParamValue(paramName: String): JValue = { implicit val format = DefaultFormats params match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index eb0a44d5199ef..1e184e77ae0b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -182,6 +182,12 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten private[spark] object ImpurityCalculator { + /** + * Create + * @param impurity + * @param stats + * @return + */ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { impurity match { case "gini" => new GiniCalculator(stats) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index b2193dcff63b5..ebd9bc7669d8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -351,15 +351,15 @@ class DecisionTreeClassifierSuite val dt = new DecisionTreeClassifier() val rdd = TreeTests.getTreeReadWriteData(sc) + val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy") + val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) - testEstimatorAndModelReadWrite(dt, categoricalData, - DecisionTreeClassifierSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData) val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(dt, continuousData, - DecisionTreeClassifierSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData) } } @@ -385,12 +385,4 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { TreeTests.checkEqual(oldTreeAsNew, newTree) assert(newTree.numFeatures === numFeatures) } - - /** - * Mapping from all Params to valid settings which differ from the defaults. - * This is useful for tests which need to exercise all Params, such as save/load. - * This excludes input columns to simplify some tests. - */ - val allParamSettings: Map[String, Any] = - Map("impurity" -> "entropy") ++ TreeTests.allParamSettings } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index c6ff9d51d1a1f..6814fb9f1d726 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -236,29 +236,25 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } test("read/write LocalLDAModel") { - def checkModelData(model: LDAModel, model2: LDAModel): Unit = { - assert(model.vocabSize === model2.vocabSize) - assert(Vectors.dense(model.topicsMatrix.toArray) ~== - Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) - model.getDocConcentration - model2.getDocConcentration - assert(Vectors.dense(model.getDocConcentration) ~== - Vectors.dense(model2.getDocConcentration) absTol 1e-6) - } val lda = new LDA() - testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, LDASuite.checkModelData) } test("read/write DistributedLDAModel") { - def checkModelData(model: LDAModel, model2: LDAModel): Unit = { - assert(model.vocabSize === model2.vocabSize) - assert(Vectors.dense(model.topicsMatrix.toArray) ~== - Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) - assert(Vectors.dense(model.getDocConcentration) ~== - Vectors.dense(model2.getDocConcentration) absTol 1e-6) - } val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, - LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), LDASuite.checkModelData) + } +} + +object LDASuite extends SparkFunSuite { + + /** Compare 2 models for persistence tests */ + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(model.estimatedDocConcentration ~== + model2.estimatedDocConcentration absTol 1e-6) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 3ac13497b2554..ab596143664fb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -135,12 +135,12 @@ class DecisionTreeRegressorSuite val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) testEstimatorAndModelReadWrite(dt, categoricalData, - DecisionTreeRegressorSuite.allParamSettings, checkModelData) + TreeTests.allParamSettings, checkModelData) val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) testEstimatorAndModelReadWrite(dt, continuousData, - DecisionTreeRegressorSuite.allParamSettings, checkModelData) + TreeTests.allParamSettings, checkModelData) } } @@ -165,11 +165,4 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { TreeTests.checkEqual(oldTreeAsNew, newTree) assert(newTree.numFeatures === numFeatures) } - - /** - * Mapping from all Params to valid settings which differ from the defaults. - * This is useful for tests which need to exercise all Params, such as save/load. - * This excludes input columns to simplify some tests. - */ - val allParamSettings: Map[String, Any] = TreeTests.allParamSettings } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index e242e764ef6d1..f53ddd4bf81fb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -32,8 +32,7 @@ private[ml] object TreeTests extends SparkFunSuite { /** * Convert the given data to a DataFrame, and set the features and label metadata. - * - * @param data Dataset. Categorical features and labels must already have 0-based indices. + * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. * @param categoricalFeatures Map: categorical feature index -> number of distinct values * @param numClasses Number of classes label can take. If 0, mark as continuous. @@ -129,8 +128,7 @@ private[ml] object TreeTests extends SparkFunSuite { /** * Helper method for constructing a tree for testing. * Given left, right children, construct a parent node. - * - * @param split Split for parent node + * @param split Split for parent node * @return Parent node with children attached */ def buildParentNode(left: Node, right: Node, split: Split): Node = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 87afccd33b4c6..284381f552c9e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -34,8 +34,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * Checks "overwrite" option and params. * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name * in order to avoid conflicts from multiple calls to this method. - * - * @param instance ML instance to test saving/loading + * + * @param instance ML instance to test saving/loading * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. * @tparam T ML instance type * @return Instance loaded from file @@ -87,8 +87,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * - Compare model data * * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. - * - * @param estimator Estimator to test + * + * @param estimator Estimator to test * @param dataset Dataset to pass to [[Estimator.fit()]] * @param testParams Set of [[Param]] values to set in estimator * @param checkModelData Method which takes the original and loaded [[Model]] and compares their From 17c8baf4d0c87a4f6ebce1f282a87764213a0f00 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 8 Mar 2016 10:36:59 -0800 Subject: [PATCH 05/11] tiny cleanups --- .../main/scala/org/apache/spark/ml/clustering/LDA.scala | 9 --------- .../org/apache/spark/mllib/tree/impurity/Impurity.scala | 8 +++----- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 16b98bbd0dae0..6304b20d544ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -787,15 +787,6 @@ class LDA @Since("1.6.0") ( new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None) } copyValues(newModel).setParent(this) - // We copy the docConcentration, topicConcentration explicitly to handle their - // special default behavior when not set by the user. - /* - copyValues(newModel, - ParamMap( - docConcentration -> oldModel.docConcentration.toArray, - topicConcentration -> oldModel.topicConcentration)) - .setParent(this) - */ } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 1e184e77ae0b0..b2c6e2bba43b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -183,11 +183,9 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten private[spark] object ImpurityCalculator { /** - * Create - * @param impurity - * @param stats - * @return - */ + * Create an [[ImpurityCalculator]] instance of the given impurity type and with + * the given stats. + */ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { impurity match { case "gini" => new GiniCalculator(stats) From 45c4dd1cf3d70e89ef273015e5f81a06e1557e1b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 8 Mar 2016 10:49:01 -0800 Subject: [PATCH 06/11] scala style fixes --- .../spark/ml/classification/DecisionTreeClassifierSuite.scala | 2 +- .../scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index ebd9bc7669d8c..e7d255c7798d9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 284381f552c9e..16280473c6ac1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -22,7 +22,6 @@ import java.io.{File, IOException} import org.scalatest.Suite import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext From b784efa7bbf374b774aa59584bb2ff0d4b48d76e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 8 Mar 2016 13:30:10 -0800 Subject: [PATCH 07/11] reverted fix in LDASuite now that it is fixed in master --- .../apache/spark/ml/clustering/LDASuite.scala | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 6814fb9f1d726..a3a8f65eac176 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -236,25 +236,27 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } test("read/write LocalLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } val lda = new LDA() - testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, LDASuite.checkModelData) + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) } test("read/write DistributedLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, - LDASuite.allParamSettings ++ Map("optimizer" -> "em"), LDASuite.checkModelData) - } -} - -object LDASuite extends SparkFunSuite { - - /** Compare 2 models for persistence tests */ - def checkModelData(model: LDAModel, model2: LDAModel): Unit = { - assert(model.vocabSize === model2.vocabSize) - assert(Vectors.dense(model.topicsMatrix.toArray) ~== - Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) - assert(model.estimatedDocConcentration ~== - model2.estimatedDocConcentration absTol 1e-6) + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } } From 76943bd9d109e4765e50d5b4c639f547f67215cb Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 10 Mar 2016 10:25:53 -0800 Subject: [PATCH 08/11] fix style issue --- .../org/apache/spark/ml/regression/DecisionTreeRegressor.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 67af44408bc48..c981337b8fb14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -136,8 +136,7 @@ final class DecisionTreeRegressionModel private[ml] ( /** * Construct a decision tree regression model. - * - * @param rootNode Root node of tree, with other nodes attached. + * @param rootNode Root node of tree, with other nodes attached. */ private[ml] def this(rootNode: Node, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) From 1ba85071464c3f27f77256c6987718a2ed888d35 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 10 Mar 2016 10:36:21 -0800 Subject: [PATCH 09/11] Fixed bad merge. Added unit tests for depth 0 --- .../ml/classification/DecisionTreeClassifierSuite.scala | 6 ++++++ .../spark/ml/regression/DecisionTreeRegressorSuite.scala | 6 ++++++ .../scala/org/apache/spark/ml/tree/impl/TreeTests.scala | 5 ++--- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index e7d255c7798d9..2b075248151d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -353,13 +353,19 @@ class DecisionTreeClassifierSuite val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy") + // Categorical splits with tree depth 2 val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData) + // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData) + + // Continuous splits with tree depth 0 + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), + checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index ab596143664fb..662e3fc67927d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -132,15 +132,21 @@ class DecisionTreeRegressorSuite val dt = new DecisionTreeRegressor() val rdd = TreeTests.getTreeReadWriteData(sc) + // Categorical splits with tree depth 2 val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) testEstimatorAndModelReadWrite(dt, categoricalData, TreeTests.allParamSettings, checkModelData) + // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) testEstimatorAndModelReadWrite(dt, continuousData, TreeTests.allParamSettings, checkModelData) + + // Continuous splits with tree depth 0 + testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index f53ddd4bf81fb..12808b0305916 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -144,7 +144,6 @@ private[ml] object TreeTests extends SparkFunSuite { } /** -<<<<<<< HEAD:mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala * Create some toy data for testing feature importances. */ def featureImportanceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq( @@ -154,7 +153,8 @@ private[ml] object TreeTests extends SparkFunSuite { new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) )) -======= + + /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load. * This excludes input columns to simplify some tests. @@ -185,5 +185,4 @@ private[ml] object TreeTests extends SparkFunSuite { LabeledPoint(1.0, Vectors.dense(1.0, 2.0))) sc.parallelize(arr) } ->>>>>>> DecisionTreeClassifier,Regressor and Models support save,load. Fixed bug in DefaultReadWriteTest.testEstimatorAndModelReadWrite where it never called checkModelData function.:mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala } From ce2d824b54d7c7ad45f133ff13f44c9aee12df28 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 14 Mar 2016 14:47:15 -0700 Subject: [PATCH 10/11] fixed incorrect doc --- .../scala/org/apache/spark/ml/tree/treeModels.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index f4ee70afb18ab..90c8ea9c3a601 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -151,7 +151,7 @@ private[ml] object DecisionTreeModelReadWrite { /** * Info for a [[Node]] * - * @param id Index used for tree reconstruction. Indices follow an in-order traversal. + * @param id Index used for tree reconstruction. Indices follow a pre-order traversal. * @param impurityStats Stats array. Impurity type is stored in metadata. * @param gain Gain, or arbitrary value if leaf node. * @param leftChild Left child index, or arbitrary value if leaf node. @@ -172,9 +172,9 @@ private[ml] object DecisionTreeModelReadWrite { /** * Create [[NodeData]] instances for this node and all children. * - * @param id Current ID. IDs are assigned via an in-order traversal. - * @return (sequence of nodes in in-order traversal order, largest ID in subtree) - * The nodes are returned in in-order traversal (root first) so that it is easy to + * @param id Current ID. IDs are assigned via a pre-order traversal. + * @return (sequence of nodes in pre-order traversal order, largest ID in subtree) + * The nodes are returned in pre-order traversal (root first) so that it is easy to * get the ID of the subtree's root node. */ def build(node: Node, id: Int): (Seq[NodeData], Int) = node match { @@ -214,7 +214,7 @@ private[ml] object DecisionTreeModelReadWrite { s" but found ${nodes.head.id}") assert(nodes.last.id == nodes.length - 1, s"Decision Tree load failed. Expected largest" + s" node ID to be ${nodes.length - 1}, but found ${nodes.last.id}") - // We fill `finalNodes` in reverse order. Since node IDs are assigned via an in-order + // We fill `finalNodes` in reverse order. Since node IDs are assigned via a pre-order // traversal, this guarantees that child nodes will be built before parent nodes. val finalNodes = new Array[Node](nodes.length) nodes.reverseIterator.foreach { case n: NodeData => From cfb770ba4da92de63e9f4a64517c4d9d17de7d02 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 16 Mar 2016 13:30:04 -0700 Subject: [PATCH 11/11] call parquet method directly --- mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 90c8ea9c3a601..3e72e85d10d2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -205,7 +205,7 @@ private[ml] object DecisionTreeModelReadWrite { } val dataPath = new Path(path, "data").toString - val data = sqlContext.read.format("parquet").load(dataPath).as[NodeData] + val data = sqlContext.read.parquet(dataPath).as[NodeData] // Load all nodes, sorted by ID. val nodes: Array[NodeData] = data.collect().sortBy(_.id)