Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10524][ML] Use the soft prediction to order categories' bins #8734

Closed
wants to merge 8 commits into from

Conversation

viirya
Copy link
Member

@viirya viirya commented Sep 13, 2015

JIRA: https://issues.apache.org/jira/browse/SPARK-10524

Currently we use the hard prediction (ImpurityCalculator.predict) to order categories' bins. But we should use the soft prediction.

@SparkQA
Copy link

SparkQA commented Sep 13, 2015

Test build #42382 has finished for PR 8734 at commit 84260ca.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodeQueue = new mutable.Queue[(Int, Node)]()
DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please update this test to call binsToBestSplit directly? You can change it to be private[tree] so that it's callable from this test suite.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to call binsToBestSplit directly, we need to expose many details of findBestSplits too, e.g., binSeqOp, getNodeToFeatures and partitionAggregates...etc., because binsToBestSplit needs binAggregates and featuresForNode..etc. as parameters. Is it a good idea?

@jkbradley
Copy link
Member

I'll have bandwidth to get this merged now, so I'll watch for updates. Thanks!

@jkbradley
Copy link
Member

Ping! Please let me know if you don't have time to work on this, and I can take it over. Thanks

@viirya
Copy link
Member Author

viirya commented Jan 14, 2016

@jkbradley Sorry for replying late. I will try to finish this soon. Thanks.

@jkbradley
Copy link
Member

OK thanks!

@SparkQA
Copy link

SparkQA commented Jan 20, 2016

Test build #49757 has finished for PR 8734 at commit 2c32350.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -740,7 +740,7 @@ private[ml] object RandomForest extends Logging {
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.predict
categoryStats.prob(categoryStats.predict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this is correct. Ordering by the probability of the prediction is essentially the same as ordering by impurity. That's because when the impurity is low, the predicted value will have high probability and vice versa.

From Hastie, Tibshirani, and Friedman:
"We order the predictor classes according to the proportion falling in outcome class 1. Then we split this predictor as if it were an ordered predictor."

For binary category I think it should be as @jkbradley suggested categoryStats.stats(1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I saw from the implementation, categoryStats.stats(1) is just the count of class 1, not the proportion falling in outcome class 1. Are we going to order bins by that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finding the proportion falling in outcome class 1 simply requires division of the counts by a constant. Since we're just using that number for an ordering, constant division won't matter. They are the same.

My initial comment has a typo. It should say for a "binary outcome", not "binary category".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I see. I was thinking we are going to order them by soft prediction of each bin. Actually what we want is to order them by soft prediction of certain class.

@SparkQA
Copy link

SparkQA commented Jan 21, 2016

Test build #49865 has finished for PR 8734 at commit cd25214.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 21, 2016

Test build #49874 has finished for PR 8734 at commit a37d3d8.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -740,7 +740,11 @@ private[ml] object RandomForest extends Logging {
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.predict
if (categoryStats.count == 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant categoryStats.stats.length == 2. categoryStats.count is the count of data points falling into that particular bin. Since we are trying to determine here whether this is regression or binary classification, I think checking if (binAggregates.metadata.isClassification) is more clear.

Additionally, the code under the if and else statements of centroidForCategories is identical except for a single line. It seems cleaner to restructure to something like:

val centroidForCategories = Range(0, numCategories).map { case featureValue =>
  val categoryStats =
    binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
  val centroid = if (categoryStats.count != 0) {
    if (binAggregates.metadata.isMulticlass) {
      // multiclass classification
      categoryStats.calculate()
    } else if (binAggregates.metadata.isClassification) {
      // binary classification
      categoryStats.stats(1)
    } else {
      // regression
      categoryStats.predict
    }
  } else {
    Double.MaxValue
  }
  (featureValue, centroid)
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sethah Thanks. You are right. I didn't read this part of codes thoroughly.

@SparkQA
Copy link

SparkQA commented Jan 23, 2016

Test build #49934 has finished for PR 8734 at commit 5c44e23.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

@viirya Thanks for the updates.

I think the code is correct now, but I'm going to send you a PR (to update this PR) in order to improve the test. I agree with @sethah that the current test does not really test anything.

@sethah Does it look good to you, other than the test?

@sethah
Copy link
Contributor

sethah commented Feb 9, 2016

Yes, LGTM pending the improved test, thanks!

@jkbradley
Copy link
Member

Here it is: [https://github.com/viirya/spark-1/pull/1]

Fixed unit test and added one to spark.ml
@viirya
Copy link
Member Author

viirya commented Feb 10, 2016

@jkbradley Great thanks. I've merged your PR.

@SparkQA
Copy link

SparkQA commented Feb 10, 2016

Test build #51012 has finished for PR 8734 at commit 2bbe037.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

LGTM! Thanks @viirya and @sethah

I'll merge with master and see how far back I can backport it easily.

@asfgit asfgit closed this in 9267bc6 Feb 10, 2016
asfgit pushed a commit that referenced this pull request Feb 10, 2016
JIRA: https://issues.apache.org/jira/browse/SPARK-10524

Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction.

Author: Liang-Chi Hsieh <viirya@gmail.com>
Author: Liang-Chi Hsieh <viirya@appier.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #8734 from viirya/dt-soft-centroids.

(cherry picked from commit 9267bc6)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
@viirya viirya deleted the dt-soft-centroids branch December 27, 2023 18:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants