Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19591][ML][MLlib] Add sample weights to decision trees #21632

Closed
wants to merge 13 commits into from

Conversation

imatiach-msft
Copy link
Contributor

@imatiach-msft imatiach-msft commented Jun 25, 2018

This is updated PR #16722 to latest master

What changes were proposed in this pull request?

This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier.

Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr.

How was this patch tested?

The algorithms are tested to ensure that:
1. Arbitrary scaling of constant weights has no effect
2. Outliers with small weights do not affect the learned model
3. Oversampling and weighting are equivalent

Unit tests are also added to test other smaller components.

Summary of changes

  • Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode.

  • Impurity aggregators now also hold the raw count.

  • This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight.

  • This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added.

  • TreePoint is modified to hold a sample weight

  • BaggedPoint is modified from:

private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable

to

private[spark] class BaggedPoint[Datum](
    val datum: Datum,
    val subsampleCounts: Array[Int],
    val sampleWeight: Double) extends Serializable

We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode

Note: many of the changed files are due simply to using Instance instead of LabeledPoint

@imatiach-msft
Copy link
Contributor Author

@holdenk @sethah I've updated the PR to latest master (hopefully all of the tests still pass :) )

@SparkQA
Copy link

SparkQA commented Jun 25, 2018

Test build #92283 has finished for PR 21632 at commit b5278e5.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@imatiach-msft
Copy link
Contributor Author

jenkins retest this pretty please :)

@SparkQA
Copy link

SparkQA commented Jun 25, 2018

Test build #92314 has finished for PR 21632 at commit 64576d6.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 2, 2018

Test build #93926 has finished for PR 21632 at commit 30424da.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 2, 2018

Test build #93928 has finished for PR 21632 at commit 4ad2833.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 2, 2018

Test build #93930 has finished for PR 21632 at commit 263b343.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 2, 2018

Test build #93940 has finished for PR 21632 at commit cf77ab2.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 2, 2018

Test build #93942 has finished for PR 21632 at commit 0ad3b08.

  • This patch fails to generate documentation.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 2, 2018

Test build #93948 has finished for PR 21632 at commit 3189259.

  • This patch fails to generate documentation.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 2, 2018

Test build #94015 has finished for PR 21632 at commit 981d707.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 3, 2018

Test build #94165 has finished for PR 21632 at commit 6326bdf.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@imatiach-msft
Copy link
Contributor Author

Jenkins retest this please

@imatiach-msft
Copy link
Contributor Author

looks like a random failure

@SparkQA
Copy link

SparkQA commented Aug 3, 2018

Test build #94164 has finished for PR 21632 at commit a34b3cd.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 4, 2018

Test build #94180 has finished for PR 21632 at commit 6326bdf.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@imatiach-msft
Copy link
Contributor Author

looks like a random test failure with hive client suite (not related to the PR), I'll try updating to latest master and rebuild...

@SparkQA
Copy link

SparkQA commented Aug 6, 2018

Test build #94288 has finished for PR 21632 at commit ad28e44.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@imatiach-msft
Copy link
Contributor Author

@holdenk @sethah @HyukjinKwon I have a successful build, I need to look into 2-3 wacky test results that changed since when @sethah opened his PR (see comments in my PR). In the mean time, would anyone be able to review the PR - are there any comments from the previous PR that were still not resolved and need to be made?

@HyukjinKwon
Copy link
Member

cc also @jkbradley

@SparkQA
Copy link

SparkQA commented Jan 15, 2019

Test build #101230 has finished for PR 21632 at commit 7d2f131.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still had a question about modelPredictionEquals and one other minor thing here but quite close now.

@SparkQA
Copy link

SparkQA commented Jan 22, 2019

Test build #101511 has finished for PR 21632 at commit a8ebf22.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more minor comments. it's up to your judgment on whether to add a new overload to DecisionTreeMetadata to simplify the test code. It seems fine to me either way.

@imatiach-msft
Copy link
Contributor Author

"up to your judgment on whether to add a new overload to DecisionTreeMetadata to simplify the test code"
This is a tough decision to make; I would prefer not to modify the source code for the sake of tests, but modifying a lot of test code to call the DecisionTreeMetadata.buildMetadata with LabeledPoint converted to instances instead of LabeledPoint is bad too.
There are other options as well. I could make DecisionTreeMetadata.buildMetadata accept an RDD[_] and then dynamically figure out the type but this doesn't seem like a good choice either.
I could also create a wrapper around buildMetadata in the test code and then call that wrapper from all tests which should make maintaining code easier in the future (eg the conversion could have been done in the wrapper) but that would only introduce more changes - not less - to the PR, and it also creates another level of indirection which may make the test code more confusing.
The current code seems the slightly better choice of the four options listed above (and there may be other options as well), but if there is a strong preference toward one of the other choices I would be glad to update the PR.

@SparkQA
Copy link

SparkQA commented Jan 23, 2019

Test build #101569 has finished for PR 21632 at commit a993ce3.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 23, 2019

Test build #101570 has finished for PR 21632 at commit 6adeda8.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 24, 2019

Test build #101612 has finished for PR 21632 at commit 7d6654e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen
Copy link
Member

srowen commented Jan 25, 2019

Merged to master

@srowen srowen closed this in b2d36f6 Jan 25, 2019
@imatiach-msft
Copy link
Contributor Author

@srowen thank you for the merge and the thorough review. I have some doubts about the tolerance we decided for zero values:
val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples

https://github.com/apache/spark/pull/21632/files#diff-1fd1bc8d3fc9306c83cd65fbf3ca4bbeR1054

For a large number of unweighted samples I am worried that it might be too high. Note EPSILON=2.2E-16. I am wondering if I should change the tolerance to be:
val tolerance = Utils.EPSILON * unweightedNumSamples * (some constant)
What are your thoughts?

@srowen
Copy link
Member

srowen commented Jan 28, 2019

Is there a good reason to scale it by the square of the samples? if not, yeah, worth a follow-up. If there is a good reason, then is there a case in the tests here where epsilon becomes really large, like of the same order of magnitude as the expected values? I don't think the tests have ~1e8 samples. Up to your judgment.

@imatiach-msft
Copy link
Contributor Author

I think I made a mistake and it should actually be:

val tolerance = Utils.EPSILON * (unweightedNumSamples + unweightedNumSamples)

or perhaps a larger threshold:

val tolerance = Utils.EPSILON * unweightedNumSamples * SomeLargeConstant

but I will need to verify by adding some debug to ensure that no zero features slip through for the sample tests, otherwise that tolerance would still be too low and the factor would need to be increased; my worry is that by using the square of the samples the tolerance would become too high with a very large number of samples and then some values would be included as zero feature values which we don't want

@srowen
Copy link
Member

srowen commented Jan 28, 2019

That's fine @imatiach-msft just open another PR for the same JIRA. We usually put [FOLLOWUP] in the title and link to the previous PR for discoverability.

@imatiach-msft
Copy link
Contributor Author

@srowen thanks for the quick response, I've created a follow-up PR here:
#23682
In testing I've found the tolerance:
val tolerance = Utils.EPSILON * unweightedNumSamples * 100
to be good enough, not sure if I need to make it larger

srowen pushed a commit that referenced this pull request Jan 31, 2019
…es - fix tolerance

This is a follow-up to PR:
#21632

## What changes were proposed in this pull request?

This PR tunes the tolerance used for deciding whether to add zero feature values to a value-count map (where the key is the feature value and the value is the weighted count of those feature values).
In the previous PR the tolerance scaled by the square of the unweighted number of samples, which is too aggressive for a large number of unweighted samples.  Unfortunately using just "Utils.EPSILON * unweightedNumSamples" is not enough either, so I multiplied that by a factor tuned by the testing procedure below.

## How was this patch tested?

This involved manually running the sample weight tests for decision tree regressor to see whether the tolerance was large enough to exclude zero feature values.

Eg in SBT:
```
./build/sbt
> project mllib
> testOnly *DecisionTreeRegressorSuite -- -z "training with sample weights"
```

For validation, I added a print inside the if in the code below and validated that the tolerance was large enough so that we would not include zero features (which don't exist in that test):
```
      val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
        print("should not print this")
        partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples))
      } else {
        partValueCountMap
      }
```

Closes #23682 from imatiach-msft/ilmat/sample-weights-tol.

Authored-by: Ilya Matiach <ilmat@microsoft.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
jackylee-ch pushed a commit to jackylee-ch/spark that referenced this pull request Feb 18, 2019
This is updated PR apache#16722 to latest master

## What changes were proposed in this pull request?

This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier.

Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr.
## How was this patch tested?

The algorithms are tested to ensure that:
    1. Arbitrary scaling of constant weights has no effect
    2. Outliers with small weights do not affect the learned model
    3. Oversampling and weighting are equivalent

Unit tests are also added to test other smaller components.
## Summary of changes

   - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode.

   - Impurity aggregators now also hold the raw count.

   - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight.

   - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added.

   - TreePoint is modified to hold a sample weight

   - BaggedPoint is modified from:
``` Scala
private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable
```
to
``` Scala
private[spark] class BaggedPoint[Datum](
    val datum: Datum,
    val subsampleCounts: Array[Int],
    val sampleWeight: Double) extends Serializable
```
We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode

**Note**: many of the changed files are due simply to using Instance instead of LabeledPoint

Closes apache#21632 from imatiach-msft/ilmat/sample-weights.

Authored-by: Ilya Matiach <ilmat@microsoft.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
jackylee-ch pushed a commit to jackylee-ch/spark that referenced this pull request Feb 18, 2019
…es - fix tolerance

This is a follow-up to PR:
apache#21632

## What changes were proposed in this pull request?

This PR tunes the tolerance used for deciding whether to add zero feature values to a value-count map (where the key is the feature value and the value is the weighted count of those feature values).
In the previous PR the tolerance scaled by the square of the unweighted number of samples, which is too aggressive for a large number of unweighted samples.  Unfortunately using just "Utils.EPSILON * unweightedNumSamples" is not enough either, so I multiplied that by a factor tuned by the testing procedure below.

## How was this patch tested?

This involved manually running the sample weight tests for decision tree regressor to see whether the tolerance was large enough to exclude zero feature values.

Eg in SBT:
```
./build/sbt
> project mllib
> testOnly *DecisionTreeRegressorSuite -- -z "training with sample weights"
```

For validation, I added a print inside the if in the code below and validated that the tolerance was large enough so that we would not include zero features (which don't exist in that test):
```
      val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
        print("should not print this")
        partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples))
      } else {
        partValueCountMap
      }
```

Closes apache#23682 from imatiach-msft/ilmat/sample-weights-tol.

Authored-by: Ilya Matiach <ilmat@microsoft.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
private[ml] trait DecisionTreeParams extends PredictorParams
with HasCheckpointInterval with HasSeed {
with HasCheckpointInterval with HasSeed with HasWeightCol {

Copy link
Contributor

@zhengruifeng zhengruifeng Sep 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@imatiach-msft @srowen Here params weightCol and minWeightFractionPerNode are introduced into DecisionTreeParams and also exposed to RF and GBT.
But RF and GBT do not support sample weighting for now. Is there any plan to support it? or we should put these params into DecisionTreeRegressorParams and DecisionTreeClassifierParams?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Is there any plan to support it?"
yes, definitely, we should support it eventually, will look into adding it when I get a chance. There's already a JIRA ticket for that as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(tagging @zhengruifeng )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @imatiach-msft ! I just read corresponding tickets.
My concern is that we need to support weighting in RF & GBT in 3.0.0, otherwise two unused params will be added.

Copy link

@ghost ghost Sep 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could someone link to the JIRA ticket for class weights in RF / GBT?

I've been trying to track it down.

I found this : https://issues.apache.org/jira/browse/SPARK-9478

However, it is marked resolved. I do not see where the RF / GBT feature is tracked nor implemented.

It is appreciated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see it's resolved as a duplicate, of https://issues.apache.org/jira/browse/SPARK-19591 for 3.0.0

zhengruifeng added a commit that referenced this pull request Jan 6, 2020
### What changes were proposed in this pull request?
1, fix `BaggedPoint.convertToBaggedRDD` when `subsamplingRate < 1.0`
2, reorg `RandomForest.runWithMetadata` btw

### Why are the changes needed?
In GBT, Instance weights will be discarded if subsamplingRate<1

1, `baggedPoint: BaggedPoint[TreePoint]` is used in the tree growth to find best split;
2, `BaggedPoint[TreePoint]` contains two weights:
```scala
class BaggedPoint[Datum](val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double = 1.0)
class TreePoint(val label: Double, val binnedFeatures: Array[Int], val weight: Double)
```
3, only the var `sampleWeight` in `BaggedPoint` is used, the var `weight` in `TreePoint` is never used in finding splits;
4, The method  `BaggedPoint.convertToBaggedRDD` was changed in #21632, it was only for decisiontree, so only the following code path was changed;
```
if (numSubsamples == 1 && subsamplingRate == 1.0) {
        convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
      }
```
5, In #25926, I made GBT support weights, but only test it with default `subsamplingRate==1`.
GBT with `subsamplingRate<1` will convert treePoints to baggedPoints via
```scala
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
```
in which the orignial weights from `weightCol` will be discarded and all `sampleWeight` are assigned default 1.0;

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
updated testsuites

Closes #27070 from zhengruifeng/gbt_sampling.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants