-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-6885] [ML] decision tree support predict class probabilities #7694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
d746ffc
decision tree support predict class probabilities
yanboliang 227c91b
refactor LearningNode to store ImpurityCalculator
yanboliang 5ec3323
implement InformationGainAndImpurityStats
yanboliang 99e8943
code optimization
yanboliang beb1634
try to eliminate impurityStats for each LearningNode
yanboliang fbbe2ec
eliminate duplicated struct and code
yanboliang 6167fb0
optimize calculateImpurityStats function
yanboliang c32d6ce
optimize calculateImpurityStats function again
yanboliang ff043d3
raw2probabilityInPlace should operate in-place
yanboliang 33ae183
fix annotation
yanboliang 7e90ba8
fix typos
yanboliang 2174278
solve merge conflicts
yanboliang 08d5b7f
fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 …
yanboliang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,8 +19,9 @@ package org.apache.spark.ml.tree | |
|
|
||
| import org.apache.spark.annotation.DeveloperApi | ||
| import org.apache.spark.mllib.linalg.Vector | ||
| import org.apache.spark.mllib.tree.impurity.ImpurityCalculator | ||
| import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, | ||
| Node => OldNode, Predict => OldPredict} | ||
| Node => OldNode, Predict => OldPredict, ImpurityStats} | ||
|
|
||
| /** | ||
| * :: DeveloperApi :: | ||
|
|
@@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable { | |
| /** Impurity measure at this node (for training data) */ | ||
| def impurity: Double | ||
|
|
||
| /** | ||
| * Statistics aggregated from training data at this node, used to compute prediction, impurity, | ||
| * and probabilities. | ||
| * For classification, the array of class counts must be normalized to a probability distribution. | ||
| */ | ||
| private[tree] def impurityStats: ImpurityCalculator | ||
|
|
||
| /** Recursive prediction helper method */ | ||
| private[ml] def predict(features: Vector): Double = prediction | ||
| private[ml] def predictImpl(features: Vector): LeafNode | ||
|
|
||
| /** | ||
| * Get the number of nodes in tree below this node, including leaf nodes. | ||
|
|
@@ -75,7 +83,8 @@ private[ml] object Node { | |
| if (oldNode.isLeaf) { | ||
| // TODO: Once the implementation has been moved to this API, then include sufficient | ||
| // statistics here. | ||
| new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) | ||
| new LeafNode(prediction = oldNode.predict.predict, | ||
| impurity = oldNode.impurity, impurityStats = null) | ||
| } else { | ||
| val gain = if (oldNode.stats.nonEmpty) { | ||
| oldNode.stats.get.gain | ||
|
|
@@ -85,7 +94,7 @@ private[ml] object Node { | |
| new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, | ||
| gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), | ||
| rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), | ||
| split = Split.fromOld(oldNode.split.get, categoricalFeatures)) | ||
| split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -99,11 +108,13 @@ private[ml] object Node { | |
| @DeveloperApi | ||
| final class LeafNode private[ml] ( | ||
| override val prediction: Double, | ||
| override val impurity: Double) extends Node { | ||
| override val impurity: Double, | ||
| override val impurityStats: ImpurityCalculator) extends Node { | ||
|
|
||
| override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" | ||
| override def toString: String = | ||
| s"LeafNode(prediction = $prediction, impurity = $impurity)" | ||
|
|
||
| override private[ml] def predict(features: Vector): Double = prediction | ||
| override private[ml] def predictImpl(features: Vector): LeafNode = this | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you are removing this comment, please set "prob" in the OldNode. |
||
| override private[tree] def numDescendants: Int = 0 | ||
|
|
||
|
|
@@ -115,9 +126,8 @@ final class LeafNode private[ml] ( | |
| override private[tree] def subtreeDepth: Int = 0 | ||
|
|
||
| override private[ml] def toOld(id: Int): OldNode = { | ||
| // NOTE: We do NOT store 'prob' in the new API currently. | ||
| new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, | ||
| None, None, None, None) | ||
| new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), | ||
| impurity, isLeaf = true, None, None, None, None) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -139,17 +149,18 @@ final class InternalNode private[ml] ( | |
| val gain: Double, | ||
| val leftChild: Node, | ||
| val rightChild: Node, | ||
| val split: Split) extends Node { | ||
| val split: Split, | ||
| override val impurityStats: ImpurityCalculator) extends Node { | ||
|
|
||
| override def toString: String = { | ||
| s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" | ||
| } | ||
|
|
||
| override private[ml] def predict(features: Vector): Double = { | ||
| override private[ml] def predictImpl(features: Vector): LeafNode = { | ||
| if (split.shouldGoLeft(features)) { | ||
| leftChild.predict(features) | ||
| leftChild.predictImpl(features) | ||
| } else { | ||
| rightChild.predict(features) | ||
| rightChild.predictImpl(features) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -172,9 +183,8 @@ final class InternalNode private[ml] ( | |
| override private[ml] def toOld(id: Int): OldNode = { | ||
| assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" | ||
| + " since the old API does not support deep trees.") | ||
| // NOTE: We do NOT store 'prob' in the new API currently. | ||
| new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, | ||
| Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), | ||
| new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, | ||
| isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), | ||
| Some(rightChild.toOld(OldNode.rightChildIndex(id))), | ||
| Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, | ||
| new OldPredict(leftChild.prediction, prob = 0.0), | ||
|
|
@@ -223,36 +233,36 @@ private object InternalNode { | |
| * | ||
| * @param id We currently use the same indexing as the old implementation in | ||
| * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. | ||
| * @param predictionStats Predicted label + class probability (for classification). | ||
| * We will later modify this to store aggregate statistics for labels | ||
| * to provide all class probabilities (for classification) and maybe a | ||
| * distribution (for regression). | ||
| * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, | ||
| * so that we do not need to consider splitting it further. | ||
| * @param stats Old structure for storing stats about information gain, prediction, etc. | ||
| * This is legacy and will be modified in the future. | ||
| * @param stats Impurity statistics for this node. | ||
| */ | ||
| private[tree] class LearningNode( | ||
| var id: Int, | ||
| var predictionStats: OldPredict, | ||
| var impurity: Double, | ||
| var leftChild: Option[LearningNode], | ||
| var rightChild: Option[LearningNode], | ||
| var split: Option[Split], | ||
| var isLeaf: Boolean, | ||
| var stats: Option[OldInformationGainStats]) extends Serializable { | ||
| var stats: ImpurityStats) extends Serializable { | ||
|
|
||
| /** | ||
| * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. | ||
| */ | ||
| def toNode: Node = { | ||
| if (leftChild.nonEmpty) { | ||
| assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty, | ||
| assert(rightChild.nonEmpty && split.nonEmpty && stats != null, | ||
| "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") | ||
| new InternalNode(predictionStats.predict, impurity, stats.get.gain, | ||
| leftChild.get.toNode, rightChild.get.toNode, split.get) | ||
| new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, | ||
| leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) | ||
| } else { | ||
| new LeafNode(predictionStats.predict, impurity) | ||
| if (stats.valid) { | ||
| new LeafNode(stats.impurityCalculator.predict, stats.impurity, | ||
| stats.impurityCalculator) | ||
| } else { | ||
| // Here we want to keep same behavior with the old mllib.DecisionTreeModel | ||
| new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) | ||
| } | ||
|
|
||
| } | ||
| } | ||
|
|
||
|
|
@@ -263,16 +273,14 @@ private[tree] object LearningNode { | |
| /** Create a node with some of its fields set. */ | ||
| def apply( | ||
| id: Int, | ||
| predictionStats: OldPredict, | ||
| impurity: Double, | ||
| isLeaf: Boolean): LearningNode = { | ||
| new LearningNode(id, predictionStats, impurity, None, None, None, false, None) | ||
| isLeaf: Boolean, | ||
| stats: ImpurityStats): LearningNode = { | ||
| new LearningNode(id, None, None, None, false, stats) | ||
| } | ||
|
|
||
| /** Create an empty node with the given node index. Values must be set later on. */ | ||
| def emptyNode(nodeIndex: Int): LearningNode = { | ||
| new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN, | ||
| None, None, None, false, None) | ||
| new LearningNode(nodeIndex, None, None, None, false, null) | ||
| } | ||
|
|
||
| // The below indexing methods were copied from spark.mllib.tree.model.Node | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be safe, we should handle the sum = 0 case.