-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10524][ML] Use the soft prediction to order categories' bins #8734
Conversation
Test build #42382 has finished for PR 8734 at commit
|
(topNode.id, new RandomForest.NodeIndexInfo(0, None)) | ||
))) | ||
val nodeQueue = new mutable.Queue[(Int, Node)]() | ||
DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please update this test to call binsToBestSplit directly? You can change it to be private[tree]
so that it's callable from this test suite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to call binsToBestSplit
directly, we need to expose many details of findBestSplits
too, e.g., binSeqOp
, getNodeToFeatures
and partitionAggregates
...etc., because binsToBestSplit
needs binAggregates
and featuresForNode
..etc. as parameters. Is it a good idea?
I'll have bandwidth to get this merged now, so I'll watch for updates. Thanks! |
Ping! Please let me know if you don't have time to work on this, and I can take it over. Thanks |
@jkbradley Sorry for replying late. I will try to finish this soon. Thanks. |
OK thanks! |
Test build #49757 has finished for PR 8734 at commit
|
@@ -740,7 +740,7 @@ private[ml] object RandomForest extends Logging { | |||
val categoryStats = | |||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) | |||
val centroid = if (categoryStats.count != 0) { | |||
categoryStats.predict | |||
categoryStats.prob(categoryStats.predict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't believe this is correct. Ordering by the probability of the prediction is essentially the same as ordering by impurity. That's because when the impurity is low, the predicted value will have high probability and vice versa.
From Hastie, Tibshirani, and Friedman:
"We order the predictor classes according to the proportion falling in outcome class 1. Then we split this predictor as if it were an ordered predictor."
For binary category I think it should be as @jkbradley suggested categoryStats.stats(1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I saw from the implementation, categoryStats.stats(1)
is just the count of class 1, not the proportion falling in outcome class 1. Are we going to order bins by that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finding the proportion falling in outcome class 1 simply requires division of the counts by a constant. Since we're just using that number for an ordering, constant division won't matter. They are the same.
My initial comment has a typo. It should say for a "binary outcome", not "binary category".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I see. I was thinking we are going to order them by soft prediction of each bin. Actually what we want is to order them by soft prediction of certain class.
Test build #49865 has finished for PR 8734 at commit
|
Test build #49874 has finished for PR 8734 at commit
|
@@ -740,7 +740,11 @@ private[ml] object RandomForest extends Logging { | |||
val categoryStats = | |||
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) | |||
val centroid = if (categoryStats.count != 0) { | |||
categoryStats.predict | |||
if (categoryStats.count == 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant categoryStats.stats.length == 2
. categoryStats.count
is the count of data points falling into that particular bin. Since we are trying to determine here whether this is regression or binary classification, I think checking if (binAggregates.metadata.isClassification)
is more clear.
Additionally, the code under the if and else statements of centroidForCategories
is identical except for a single line. It seems cleaner to restructure to something like:
val centroidForCategories = Range(0, numCategories).map { case featureValue =>
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
if (binAggregates.metadata.isMulticlass) {
// multiclass classification
categoryStats.calculate()
} else if (binAggregates.metadata.isClassification) {
// binary classification
categoryStats.stats(1)
} else {
// regression
categoryStats.predict
}
} else {
Double.MaxValue
}
(featureValue, centroid)
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sethah Thanks. You are right. I didn't read this part of codes thoroughly.
Test build #49934 has finished for PR 8734 at commit
|
Yes, LGTM pending the improved test, thanks! |
Fixed unit test and added one to spark.ml
@jkbradley Great thanks. I've merged your PR. |
Test build #51012 has finished for PR 8734 at commit
|
JIRA: https://issues.apache.org/jira/browse/SPARK-10524 Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction. Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Liang-Chi Hsieh <viirya@appier.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #8734 from viirya/dt-soft-centroids. (cherry picked from commit 9267bc6) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
JIRA: https://issues.apache.org/jira/browse/SPARK-10524
Currently we use the hard prediction (
ImpurityCalculator.predict
) to order categories' bins. But we should use the soft prediction.