New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-14599][ML] BaggedPoint should support sample weights. #12370
Conversation
In a previous "TODO" it was proposed that we could incorporate sample weights by simply multiplying the subsample counts by the sample weight and storing them in an array. I chose not to do this because of the need to have both raw counts and weighted counts when adding weights to decision trees. If we simply store the weighted counts, we lose the information about the raw counts, which makes it impossible to track |
Test build #55748 has finished for PR 12370 at commit
|
cc @jkbradley If you get a chance to review it would be much appreciated. |
cc @MLnick could you take a look? This is blocking SPARK-9478 which I have a PR ready to submit for. |
ping @jkbradley @MLnick I created this PR and #12374 to make SPARK-9478 easier to review. Alternatively, I could submit them all as one PR. It would be nice to get sample weights for trees into Spark 2.0. Thoughts? Also ping @holdenk |
@@ -60,20 +68,24 @@ private[spark] object BaggedPoint { | |||
subsamplingRate: Double, | |||
numSubsamples: Int, | |||
withReplacement: Boolean, | |||
extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checking my understanding here, but is the intention to in future support something like WeightedTreePoint
(or amend TreePoint
to include a weight), which is constructed in turn from Instance
rather than LabeledPoint
, and then the function passed can be ... => point.weight
or similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is exactly the case for this. I could not think of a better way to implement this, while still keeping bagged point generic (i.e. not requiring Datum to have a weight property or something similar).
@sethah just to confirm, is SPARK-9478 about sample weights, or class weights? The title is for class weights but I think the actual idea and PR etc is for sample weights, yes? |
@MLnick Yes, SPARK-9478 is for sample weighting. |
/** | ||
* Subsample counts weighted by the sample weight. | ||
*/ | ||
def weightedCounts: Array[Double] = subsampleCounts.map(_ * sampleWeight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a val
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this as a convenience method. If we make it a val then we add storage overhead in the class which is redundant. If preferable, we could remove it entirely.
Should be there a sanity check providing input RDD of instance objects and |
What changes were proposed in this pull request?
This PR changes BaggedPoint to store the number of subsamples AND the sample weight of
Datum
. Specifically:subsampleWeights: Array[Double]
is changed tosubsampleCounts: Array[Int]
sampleWeight: Double
field is added to the BaggedPoint constructordatum
is added toconvertToBaggedPointRDD
. This will be helpful when we add weights to decision trees, so that we can extract the instance weight from theRDD[Instance]
.How was this patch tested?
This PR does not introduce any new functional changes, so there are no tests added.