Permalink
Cannot retrieve contributors at this time
Fetching contributors…
| /* | |
| * Licensed to the Apache Software Foundation (ASF) under one or more | |
| * contributor license agreements. See the NOTICE file distributed with | |
| * this work for additional information regarding copyright ownership. | |
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |
| * (the "License"); you may not use this file except in compliance with | |
| * the License. You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| package org.apache.spark.ml.regression | |
| import org.apache.hadoop.fs.Path | |
| import org.json4s.{DefaultFormats, JObject} | |
| import org.json4s.JsonDSL._ | |
| import org.apache.spark.annotation.Since | |
| import org.apache.spark.ml.{PredictionModel, Predictor} | |
| import org.apache.spark.ml.feature.LabeledPoint | |
| import org.apache.spark.ml.linalg.Vector | |
| import org.apache.spark.ml.param.ParamMap | |
| import org.apache.spark.ml.tree._ | |
| import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ | |
| import org.apache.spark.ml.tree.impl.RandomForest | |
| import org.apache.spark.ml.util._ | |
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} | |
| import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} | |
| import org.apache.spark.rdd.RDD | |
| import org.apache.spark.sql.{DataFrame, Dataset} | |
| import org.apache.spark.sql.functions._ | |
| /** | |
| * <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">Decision tree</a> | |
| * learning algorithm for regression. | |
| * It supports both continuous and categorical features. | |
| */ | |
| @Since("1.4.0") | |
| class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) | |
| extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] | |
| with DecisionTreeRegressorParams with DefaultParamsWritable { | |
| @Since("1.4.0") | |
| def this() = this(Identifiable.randomUID("dtr")) | |
| // Override parameter setters from parent trait for Java API compatibility. | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setMaxDepth(value: Int): this.type = set(maxDepth, value) | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setMaxBins(value: Int): this.type = set(maxBins, value) | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) | |
| /** @group expertSetParam */ | |
| @Since("1.4.0") | |
| override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) | |
| /** @group expertSetParam */ | |
| @Since("1.4.0") | |
| override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) | |
| /** | |
| * Specifies how often to checkpoint the cached node IDs. | |
| * E.g. 10 means that the cache will get checkpointed every 10 iterations. | |
| * This is only used if cacheNodeIds is true and if the checkpoint directory is set in | |
| * [[org.apache.spark.SparkContext]]. | |
| * Must be at least 1. | |
| * (default = 10) | |
| * @group setParam | |
| */ | |
| @Since("1.4.0") | |
| override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setImpurity(value: String): this.type = set(impurity, value) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| override def setSeed(value: Long): this.type = set(seed, value) | |
| /** @group setParam */ | |
| @Since("2.0.0") | |
| def setVarianceCol(value: String): this.type = set(varianceCol, value) | |
| override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { | |
| val categoricalFeatures: Map[Int, Int] = | |
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | |
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) | |
| val strategy = getOldStrategy(categoricalFeatures) | |
| val instr = Instrumentation.create(this, oldDataset) | |
| instr.logParams(params: _*) | |
| val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", | |
| seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | |
| val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] | |
| instr.logSuccess(m) | |
| m | |
| } | |
| /** (private[ml]) Train a decision tree on an RDD */ | |
| private[ml] def train(data: RDD[LabeledPoint], | |
| oldStrategy: OldStrategy): DecisionTreeRegressionModel = { | |
| val instr = Instrumentation.create(this, data) | |
| instr.logParams(params: _*) | |
| val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", | |
| seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | |
| val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] | |
| instr.logSuccess(m) | |
| m | |
| } | |
| /** (private[ml]) Create a Strategy instance to use with the old API. */ | |
| private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { | |
| super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, | |
| subsamplingRate = 1.0) | |
| } | |
| @Since("1.4.0") | |
| override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra) | |
| } | |
| @Since("1.4.0") | |
| object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] { | |
| /** Accessor for supported impurities: variance */ | |
| final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities | |
| @Since("2.0.0") | |
| override def load(path: String): DecisionTreeRegressor = super.load(path) | |
| } | |
| /** | |
| * <a href="http://en.wikipedia.org/wiki/Decision_tree_learning"> | |
| * Decision tree (Wikipedia)</a> model for regression. | |
| * It supports both continuous and categorical features. | |
| * @param rootNode Root of the decision tree | |
| */ | |
| @Since("1.4.0") | |
| class DecisionTreeRegressionModel private[ml] ( | |
| override val uid: String, | |
| override val rootNode: Node, | |
| override val numFeatures: Int) | |
| extends PredictionModel[Vector, DecisionTreeRegressionModel] | |
| with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { | |
| /** @group setParam */ | |
| def setVarianceCol(value: String): this.type = set(varianceCol, value) | |
| require(rootNode != null, | |
| "DecisionTreeRegressionModel given null rootNode, but it requires a non-null rootNode.") | |
| /** | |
| * Construct a decision tree regression model. | |
| * @param rootNode Root node of tree, with other nodes attached. | |
| */ | |
| private[ml] def this(rootNode: Node, numFeatures: Int) = | |
| this(Identifiable.randomUID("dtr"), rootNode, numFeatures) | |
| override protected def predict(features: Vector): Double = { | |
| rootNode.predictImpl(features).prediction | |
| } | |
| /** We need to update this function if we ever add other impurity measures. */ | |
| protected def predictVariance(features: Vector): Double = { | |
| rootNode.predictImpl(features).impurityStats.calculate() | |
| } | |
| @Since("2.0.0") | |
| override def transform(dataset: Dataset[_]): DataFrame = { | |
| transformSchema(dataset.schema, logging = true) | |
| transformImpl(dataset) | |
| } | |
| override protected def transformImpl(dataset: Dataset[_]): DataFrame = { | |
| val predictUDF = udf { (features: Vector) => predict(features) } | |
| val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } | |
| var output = dataset.toDF() | |
| if ($(predictionCol).nonEmpty) { | |
| output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | |
| } | |
| if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { | |
| output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) | |
| } | |
| output | |
| } | |
| @Since("1.4.0") | |
| override def copy(extra: ParamMap): DecisionTreeRegressionModel = { | |
| copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) | |
| } | |
| @Since("1.4.0") | |
| override def toString: String = { | |
| s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes" | |
| } | |
| /** | |
| * Estimate of the importance of each feature. | |
| * | |
| * This generalizes the idea of "Gini" importance to other losses, | |
| * following the explanation of Gini importance from "Random Forests" documentation | |
| * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. | |
| * | |
| * This feature importance is calculated as follows: | |
| * - importance(feature j) = sum (over nodes which split on feature j) of the gain, | |
| * where gain is scaled by the number of instances passing through node | |
| * - Normalize importances for tree to sum to 1. | |
| * | |
| * @note Feature importance for single decision trees can have high variance due to | |
| * correlated predictor variables. Consider using a [[RandomForestRegressor]] | |
| * to determine feature importance instead. | |
| */ | |
| @Since("2.0.0") | |
| lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) | |
| /** Convert to spark.mllib DecisionTreeModel (losing some information) */ | |
| override private[spark] def toOld: OldDecisionTreeModel = { | |
| new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) | |
| } | |
| @Since("2.0.0") | |
| override def write: MLWriter = | |
| new DecisionTreeRegressionModel.DecisionTreeRegressionModelWriter(this) | |
| } | |
| @Since("2.0.0") | |
| object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionModel] { | |
| @Since("2.0.0") | |
| override def read: MLReader[DecisionTreeRegressionModel] = | |
| new DecisionTreeRegressionModelReader | |
| @Since("2.0.0") | |
| override def load(path: String): DecisionTreeRegressionModel = super.load(path) | |
| private[DecisionTreeRegressionModel] | |
| class DecisionTreeRegressionModelWriter(instance: DecisionTreeRegressionModel) | |
| extends MLWriter { | |
| override protected def saveImpl(path: String): Unit = { | |
| val extraMetadata: JObject = Map( | |
| "numFeatures" -> instance.numFeatures) | |
| DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) | |
| val (nodeData, _) = NodeData.build(instance.rootNode, 0) | |
| val dataPath = new Path(path, "data").toString | |
| sparkSession.createDataFrame(nodeData).write.parquet(dataPath) | |
| } | |
| } | |
| private class DecisionTreeRegressionModelReader | |
| extends MLReader[DecisionTreeRegressionModel] { | |
| /** Checked against metadata when loading model */ | |
| private val className = classOf[DecisionTreeRegressionModel].getName | |
| override def load(path: String): DecisionTreeRegressionModel = { | |
| implicit val format = DefaultFormats | |
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | |
| val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] | |
| val root = loadTreeNodes(path, metadata, sparkSession) | |
| val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) | |
| DefaultParamsReader.getAndSetParams(model, metadata) | |
| model | |
| } | |
| } | |
| /** Convert a model from the old API */ | |
| private[ml] def fromOld( | |
| oldModel: OldDecisionTreeModel, | |
| parent: DecisionTreeRegressor, | |
| categoricalFeatures: Map[Int, Int], | |
| numFeatures: Int = -1): DecisionTreeRegressionModel = { | |
| require(oldModel.algo == OldAlgo.Regression, | |
| s"Cannot convert non-regression DecisionTreeModel (old API) to" + | |
| s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") | |
| val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) | |
| val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") | |
| new DecisionTreeRegressionModel(uid, rootNode, numFeatures) | |
| } | |
| } |