Permalink
Cannot retrieve contributors at this time
Fetching contributors…
| /* | |
| * Licensed to the Apache Software Foundation (ASF) under one or more | |
| * contributor license agreements. See the NOTICE file distributed with | |
| * this work for additional information regarding copyright ownership. | |
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |
| * (the "License"); you may not use this file except in compliance with | |
| * the License. You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| package org.apache.spark.ml.regression | |
| import com.github.fommil.netlib.BLAS.{getInstance => blas} | |
| import org.json4s.{DefaultFormats, JObject} | |
| import org.json4s.JsonDSL._ | |
| import org.apache.spark.annotation.Since | |
| import org.apache.spark.internal.Logging | |
| import org.apache.spark.ml.{PredictionModel, Predictor} | |
| import org.apache.spark.ml.feature.LabeledPoint | |
| import org.apache.spark.ml.linalg.Vector | |
| import org.apache.spark.ml.param.ParamMap | |
| import org.apache.spark.ml.tree._ | |
| import org.apache.spark.ml.tree.impl.GradientBoostedTrees | |
| import org.apache.spark.ml.util._ | |
| import org.apache.spark.ml.util.DefaultParamsReader.Metadata | |
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} | |
| import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} | |
| import org.apache.spark.rdd.RDD | |
| import org.apache.spark.sql.{DataFrame, Dataset} | |
| import org.apache.spark.sql.functions._ | |
| /** | |
| * <a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a> | |
| * learning algorithm for regression. | |
| * It supports both continuous and categorical features. | |
| * | |
| * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. | |
| * | |
| * Notes on Gradient Boosting vs. TreeBoost: | |
| * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. | |
| * - Both algorithms learn tree ensembles by minimizing loss functions. | |
| * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes | |
| * based on the loss function, whereas the original gradient boosting method does not. | |
| * - When the loss is SquaredError, these methods give the same result, but they could differ | |
| * for other loss functions. | |
| * - We expect to implement TreeBoost in the future: | |
| * [https://issues.apache.org/jira/browse/SPARK-4240] | |
| */ | |
| @Since("1.4.0") | |
| class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) | |
| extends Predictor[Vector, GBTRegressor, GBTRegressionModel] | |
| with GBTRegressorParams with DefaultParamsWritable with Logging { | |
| @Since("1.4.0") | |
| def this() = this(Identifiable.randomUID("gbtr")) | |
| // Override parameter setters from parent trait for Java API compatibility. | |
| // Parameters from TreeRegressorParams: | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setMaxDepth(value: Int): this.type = set(maxDepth, value) | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setMaxBins(value: Int): this.type = set(maxBins, value) | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) | |
| /** @group expertSetParam */ | |
| @Since("1.4.0") | |
| override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) | |
| /** @group expertSetParam */ | |
| @Since("1.4.0") | |
| override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) | |
| /** | |
| * Specifies how often to checkpoint the cached node IDs. | |
| * E.g. 10 means that the cache will get checkpointed every 10 iterations. | |
| * This is only used if cacheNodeIds is true and if the checkpoint directory is set in | |
| * [[org.apache.spark.SparkContext]]. | |
| * Must be at least 1. | |
| * (default = 10) | |
| * @group setParam | |
| */ | |
| @Since("1.4.0") | |
| override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) | |
| /** | |
| * The impurity setting is ignored for GBT models. | |
| * Individual trees are built using impurity "Variance." | |
| * | |
| * @group setParam | |
| */ | |
| @Since("1.4.0") | |
| override def setImpurity(value: String): this.type = { | |
| logWarning("GBTRegressor.setImpurity should NOT be used") | |
| this | |
| } | |
| // Parameters from TreeEnsembleParams: | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setSeed(value: Long): this.type = set(seed, value) | |
| // Parameters from GBTParams: | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setMaxIter(value: Int): this.type = set(maxIter, value) | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| override def setStepSize(value: Double): this.type = set(stepSize, value) | |
| // Parameters from GBTRegressorParams: | |
| /** @group setParam */ | |
| @Since("1.4.0") | |
| def setLossType(value: String): this.type = set(lossType, value) | |
| override protected def train(dataset: Dataset[_]): GBTRegressionModel = { | |
| val categoricalFeatures: Map[Int, Int] = | |
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | |
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) | |
| val numFeatures = oldDataset.first().features.size | |
| val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) | |
| val instr = Instrumentation.create(this, oldDataset) | |
| instr.logParams(params: _*) | |
| instr.logNumFeatures(numFeatures) | |
| val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, | |
| $(seed)) | |
| val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) | |
| instr.logSuccess(m) | |
| m | |
| } | |
| @Since("1.4.0") | |
| override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) | |
| } | |
| @Since("1.4.0") | |
| object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { | |
| /** Accessor for supported loss settings: squared (L2), absolute (L1) */ | |
| @Since("1.4.0") | |
| final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes | |
| @Since("2.0.0") | |
| override def load(path: String): GBTRegressor = super.load(path) | |
| } | |
| /** | |
| * <a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a> | |
| * model for regression. | |
| * It supports both continuous and categorical features. | |
| * @param _trees Decision trees in the ensemble. | |
| * @param _treeWeights Weights for the decision trees in the ensemble. | |
| */ | |
| @Since("1.4.0") | |
| class GBTRegressionModel private[ml]( | |
| override val uid: String, | |
| private val _trees: Array[DecisionTreeRegressionModel], | |
| private val _treeWeights: Array[Double], | |
| override val numFeatures: Int) | |
| extends PredictionModel[Vector, GBTRegressionModel] | |
| with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] | |
| with MLWritable with Serializable { | |
| require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") | |
| require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + | |
| s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") | |
| /** | |
| * Construct a GBTRegressionModel | |
| * @param _trees Decision trees in the ensemble. | |
| * @param _treeWeights Weights for the decision trees in the ensemble. | |
| */ | |
| @Since("1.4.0") | |
| def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = | |
| this(uid, _trees, _treeWeights, -1) | |
| @Since("1.4.0") | |
| override def trees: Array[DecisionTreeRegressionModel] = _trees | |
| /** | |
| * Number of trees in ensemble | |
| */ | |
| @Since("2.0.0") | |
| val getNumTrees: Int = trees.length | |
| @Since("1.4.0") | |
| override def treeWeights: Array[Double] = _treeWeights | |
| override protected def transformImpl(dataset: Dataset[_]): DataFrame = { | |
| val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) | |
| val predictUDF = udf { (features: Any) => | |
| bcastModel.value.predict(features.asInstanceOf[Vector]) | |
| } | |
| dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | |
| } | |
| override protected def predict(features: Vector): Double = { | |
| // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 | |
| // Classifies by thresholding sum of weighted tree predictions | |
| val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) | |
| blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) | |
| } | |
| /** Number of trees in ensemble */ | |
| val numTrees: Int = trees.length | |
| @Since("1.4.0") | |
| override def copy(extra: ParamMap): GBTRegressionModel = { | |
| copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures), | |
| extra).setParent(parent) | |
| } | |
| @Since("1.4.0") | |
| override def toString: String = { | |
| s"GBTRegressionModel (uid=$uid) with $numTrees trees" | |
| } | |
| /** | |
| * Estimate of the importance of each feature. | |
| * | |
| * Each feature's importance is the average of its importance across all trees in the ensemble | |
| * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. | |
| * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) | |
| * and follows the implementation from scikit-learn. | |
| * | |
| * @see `DecisionTreeRegressionModel.featureImportances` | |
| */ | |
| @Since("2.0.0") | |
| lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) | |
| /** (private[ml]) Convert to a model in the old API */ | |
| private[ml] def toOld: OldGBTModel = { | |
| new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) | |
| } | |
| @Since("2.0.0") | |
| override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) | |
| } | |
| @Since("2.0.0") | |
| object GBTRegressionModel extends MLReadable[GBTRegressionModel] { | |
| @Since("2.0.0") | |
| override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader | |
| @Since("2.0.0") | |
| override def load(path: String): GBTRegressionModel = super.load(path) | |
| private[GBTRegressionModel] | |
| class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter { | |
| override protected def saveImpl(path: String): Unit = { | |
| val extraMetadata: JObject = Map( | |
| "numFeatures" -> instance.numFeatures, | |
| "numTrees" -> instance.getNumTrees) | |
| EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) | |
| } | |
| } | |
| private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] { | |
| /** Checked against metadata when loading model */ | |
| private val className = classOf[GBTRegressionModel].getName | |
| private val treeClassName = classOf[DecisionTreeRegressionModel].getName | |
| override def load(path: String): GBTRegressionModel = { | |
| implicit val format = DefaultFormats | |
| val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = | |
| EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) | |
| val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] | |
| val numTrees = (metadata.metadata \ "numTrees").extract[Int] | |
| val trees: Array[DecisionTreeRegressionModel] = treesData.map { | |
| case (treeMetadata, root) => | |
| val tree = | |
| new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) | |
| DefaultParamsReader.getAndSetParams(tree, treeMetadata) | |
| tree | |
| } | |
| require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" + | |
| s" trees based on metadata but found ${trees.length} trees.") | |
| val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) | |
| DefaultParamsReader.getAndSetParams(model, metadata) | |
| model | |
| } | |
| } | |
| /** Convert a model from the old API */ | |
| private[ml] def fromOld( | |
| oldModel: OldGBTModel, | |
| parent: GBTRegressor, | |
| categoricalFeatures: Map[Int, Int], | |
| numFeatures: Int = -1): GBTRegressionModel = { | |
| require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + | |
| s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") | |
| val newTrees = oldModel.trees.map { tree => | |
| // parent for each tree is null since there is no good way to set this. | |
| DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) | |
| } | |
| val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") | |
| new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures) | |
| } | |
| } |