Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
Expand All @@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame
*/
@Experimental
final class DecisionTreeClassifier(override val uid: String)
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {

def this() = this(Identifiable.randomUID("dtc"))
Expand Down Expand Up @@ -106,8 +105,9 @@ object DecisionTreeClassifier {
@Experimental
final class DecisionTreeClassificationModel private[ml] (
override val uid: String,
override val rootNode: Node)
extends PredictionModel[Vector, DecisionTreeClassificationModel]
override val rootNode: Node,
override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {

require(rootNode != null,
Expand All @@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
def this(rootNode: Node, numClasses: Int) =
this(Identifiable.randomUID("dtc"), rootNode, numClasses)

override protected def predict(features: Vector): Double = {
rootNode.predict(features)
rootNode.predictImpl(features).prediction
}

override protected def predictRaw(features: Vector): Vector = {
Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
}

override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
var i = 0
val size = dv.size
val sum = dv.values.sum
while (i < size) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be safe, we should handle the sum = 0 case.

dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0
i += 1
}
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
" raw2probabilityInPlace encountered SparseVector")
}
}

override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
}

override def toString: String = {
Expand All @@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel {
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
new DecisionTreeClassificationModel(uid, rootNode)
new DecisionTreeClassificationModel(uid, rootNode, -1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ final class GBTClassificationModel(
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
if (prediction > 0.0) 1.0 else 0.0
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] (
// Ignore the weights since all are 1.0 for now.
val votes = new Array[Double](numClasses)
_trees.view.foreach { tree =>
val prediction = tree.rootNode.predict(features).toInt
val prediction = tree.rootNode.predictImpl(features).prediction.toInt
votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
}
Vectors.dense(votes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] (
def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)

override protected def predict(features: Vector): Double = {
rootNode.predict(features)
rootNode.predictImpl(features).prediction
}

override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ final class GBTRegressionModel(
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
_trees.map(_.rootNode.predict(features)).sum / numTrees
_trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
}

override def copy(extra: ParamMap): RandomForestRegressionModel = {
Expand Down
80 changes: 44 additions & 36 deletions mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package org.apache.spark.ml.tree

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
Node => OldNode, Predict => OldPredict}
Node => OldNode, Predict => OldPredict, ImpurityStats}

/**
* :: DeveloperApi ::
Expand All @@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable {
/** Impurity measure at this node (for training data) */
def impurity: Double

/**
* Statistics aggregated from training data at this node, used to compute prediction, impurity,
* and probabilities.
* For classification, the array of class counts must be normalized to a probability distribution.
*/
private[tree] def impurityStats: ImpurityCalculator

/** Recursive prediction helper method */
private[ml] def predict(features: Vector): Double = prediction
private[ml] def predictImpl(features: Vector): LeafNode

/**
* Get the number of nodes in tree below this node, including leaf nodes.
Expand Down Expand Up @@ -75,7 +83,8 @@ private[ml] object Node {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then include sufficient
// statistics here.
new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
new LeafNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, impurityStats = null)
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
Expand All @@ -85,7 +94,7 @@ private[ml] object Node {
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
split = Split.fromOld(oldNode.split.get, categoricalFeatures))
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
}
}
}
Expand All @@ -99,11 +108,13 @@ private[ml] object Node {
@DeveloperApi
final class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double) extends Node {
override val impurity: Double,
override val impurityStats: ImpurityCalculator) extends Node {

override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
override def toString: String =
s"LeafNode(prediction = $prediction, impurity = $impurity)"

override private[ml] def predict(features: Vector): Double = prediction
override private[ml] def predictImpl(features: Vector): LeafNode = this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are removing this comment, please set "prob" in the OldNode.

override private[tree] def numDescendants: Int = 0

Expand All @@ -115,9 +126,8 @@ final class LeafNode private[ml] (
override private[tree] def subtreeDepth: Int = 0

override private[ml] def toOld(id: Int): OldNode = {
// NOTE: We do NOT store 'prob' in the new API currently.
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
None, None, None, None)
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
impurity, isLeaf = true, None, None, None, None)
}
}

Expand All @@ -139,17 +149,18 @@ final class InternalNode private[ml] (
val gain: Double,
val leftChild: Node,
val rightChild: Node,
val split: Split) extends Node {
val split: Split,
override val impurityStats: ImpurityCalculator) extends Node {

override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
}

override private[ml] def predict(features: Vector): Double = {
override private[ml] def predictImpl(features: Vector): LeafNode = {
if (split.shouldGoLeft(features)) {
leftChild.predict(features)
leftChild.predictImpl(features)
} else {
rightChild.predict(features)
rightChild.predictImpl(features)
}
}

Expand All @@ -172,9 +183,8 @@ final class InternalNode private[ml] (
override private[ml] def toOld(id: Int): OldNode = {
assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ " since the old API does not support deep trees.")
// NOTE: We do NOT store 'prob' in the new API currently.
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
Some(rightChild.toOld(OldNode.rightChildIndex(id))),
Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
new OldPredict(leftChild.prediction, prob = 0.0),
Expand Down Expand Up @@ -223,36 +233,36 @@ private object InternalNode {
*
* @param id We currently use the same indexing as the old implementation in
* [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
* @param predictionStats Predicted label + class probability (for classification).
* We will later modify this to store aggregate statistics for labels
* to provide all class probabilities (for classification) and maybe a
* distribution (for regression).
* @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
* so that we do not need to consider splitting it further.
* @param stats Old structure for storing stats about information gain, prediction, etc.
* This is legacy and will be modified in the future.
* @param stats Impurity statistics for this node.
*/
private[tree] class LearningNode(
var id: Int,
var predictionStats: OldPredict,
var impurity: Double,
var leftChild: Option[LearningNode],
var rightChild: Option[LearningNode],
var split: Option[Split],
var isLeaf: Boolean,
var stats: Option[OldInformationGainStats]) extends Serializable {
var stats: ImpurityStats) extends Serializable {

/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/
def toNode: Node = {
if (leftChild.nonEmpty) {
assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
new InternalNode(predictionStats.predict, impurity, stats.get.gain,
leftChild.get.toNode, rightChild.get.toNode, split.get)
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
} else {
new LeafNode(predictionStats.predict, impurity)
if (stats.valid) {
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
stats.impurityCalculator)
} else {
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
}

}
}

Expand All @@ -263,16 +273,14 @@ private[tree] object LearningNode {
/** Create a node with some of its fields set. */
def apply(
id: Int,
predictionStats: OldPredict,
impurity: Double,
isLeaf: Boolean): LearningNode = {
new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
isLeaf: Boolean,
stats: ImpurityStats): LearningNode = {
new LearningNode(id, None, None, None, false, stats)
}

/** Create an empty node with the given node index. Values must be set later on. */
def emptyNode(nodeIndex: Int): LearningNode = {
new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
None, None, None, false, None)
new LearningNode(nodeIndex, None, None, None, false, null)
}

// The below indexing methods were copied from spark.mllib.tree.model.Node
Expand Down
Loading