-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10524][ML] Use the soft prediction to order categories' bins #8734
Changes from 5 commits
84260ca
dfa114c
2c32350
cd25214
a37d3d8
5c44e23
c10872b
2bbe037
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree | |
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} | ||
import org.apache.spark.mllib.tree.model._ | ||
import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
import org.apache.spark.mllib.util.TestingUtils._ | ||
import org.apache.spark.util.Utils | ||
|
||
|
||
|
@@ -288,8 +289,12 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { | |
assert(topNode.impurity !== -1.0) | ||
|
||
// set impurity and predict for child nodes | ||
assert(topNode.leftNode.get.predict.predict === 0.0) | ||
assert(topNode.rightNode.get.predict.predict === 1.0) | ||
if (topNode.leftNode.get.predict.predict === 0.0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happened here? Did the leaf nodes switch because of the internal change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because previously we rank the bins according to hard prediction. So for example the bin of hard prediction 0 is always in front of hard prediction 1. Now the order of bins can be switched. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think now that we are ordering by the correct thing, this change is not necessary. Since we order by proportion falling in outcome class 1, the hard prediction 0 will always be before hard prediction 1. |
||
assert(topNode.rightNode.get.predict.predict === 1.0) | ||
} else { | ||
assert(topNode.leftNode.get.predict.predict === 1.0) | ||
assert(topNode.rightNode.get.predict.predict === 0.0) | ||
} | ||
assert(topNode.leftNode.get.impurity === 0.0) | ||
assert(topNode.rightNode.get.impurity === 0.0) | ||
} | ||
|
@@ -331,12 +336,62 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { | |
assert(topNode.impurity !== -1.0) | ||
|
||
// set impurity and predict for child nodes | ||
assert(topNode.leftNode.get.predict.predict === 0.0) | ||
assert(topNode.rightNode.get.predict.predict === 1.0) | ||
if (topNode.leftNode.get.predict.predict === 0.0) { | ||
assert(topNode.rightNode.get.predict.predict === 1.0) | ||
} else { | ||
assert(topNode.leftNode.get.predict.predict === 1.0) | ||
assert(topNode.rightNode.get.predict.predict === 0.0) | ||
} | ||
assert(topNode.leftNode.get.impurity === 0.0) | ||
assert(topNode.rightNode.get.impurity === 0.0) | ||
} | ||
|
||
test("Use soft prediction for binary classification with ordered categorical features") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the goal of this test? I guessed that the goal would be to find a case where ordering by hard predictions produces a different (suboptimal) tree than ordering by soft predictions. However, I did a quick simulation for this dataset and the results I got were the same either way. Just wanted to clarify. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I just want a test case to show it actually order the bins by soft prediction. Although @jkbradley suggested we should use directly |
||
val arr = Array( | ||
LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), // left node | ||
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), // right node | ||
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), // left node | ||
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)), // right node | ||
LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 0.0)), // left node | ||
LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.0))) // left node | ||
val input = sc.parallelize(arr) | ||
|
||
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, | ||
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) | ||
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) | ||
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) | ||
|
||
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) | ||
val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) | ||
|
||
val topNode = Node.emptyNode(nodeIndex = 1) | ||
assert(topNode.predict.predict === Double.MinValue) | ||
assert(topNode.impurity === -1.0) | ||
assert(topNode.isLeaf === false) | ||
|
||
val nodesForGroup = Map((0, Array(topNode))) | ||
val treeToNodeToIndexInfo = Map((0, Map( | ||
(topNode.id, new RandomForest.NodeIndexInfo(0, None)) | ||
))) | ||
val nodeQueue = new mutable.Queue[(Int, Node)]() | ||
DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please update this test to call binsToBestSplit directly? You can change it to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ping There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to call |
||
nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) | ||
|
||
// don't enqueue leaf nodes into node queue | ||
assert(nodeQueue.isEmpty) | ||
|
||
// set impurity and predict for topNode | ||
assert(topNode.predict.predict !== Double.MinValue) | ||
assert(topNode.impurity !== -1.0) | ||
|
||
val impurityForRightNode = Gini.calculate(Array(0.0, 3.0, 1.0), 4.0) | ||
|
||
// set impurity and predict for child nodes | ||
assert(topNode.leftNode.get.predict.predict === 0.0) | ||
assert(topNode.rightNode.get.predict.predict === 1.0) | ||
assert(topNode.leftNode.get.impurity ~== 0.44 absTol impurityForRightNode) | ||
assert(topNode.rightNode.get.impurity === 0.0) | ||
} | ||
test("Second level node building with vs. without groups") { | ||
val arr = DecisionTreeSuite.generateOrderedLabeledPoints() | ||
assert(arr.length === 1000) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant
categoryStats.stats.length == 2
.categoryStats.count
is the count of data points falling into that particular bin. Since we are trying to determine here whether this is regression or binary classification, I think checkingif (binAggregates.metadata.isClassification)
is more clear.Additionally, the code under the if and else statements of
centroidForCategories
is identical except for a single line. It seems cleaner to restructure to something like:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sethah Thanks. You are right. I didn't read this part of codes thoroughly.