Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14599][ML] BaggedPoint should support sample weights. #12370

Closed
wants to merge 3 commits into from

Conversation

sethah
Copy link
Contributor

@sethah sethah commented Apr 13, 2016

What changes were proposed in this pull request?

This PR changes BaggedPoint to store the number of subsamples AND the sample weight of Datum. Specifically:

  • subsampleWeights: Array[Double] is changed to subsampleCounts: Array[Int]
  • A sampleWeight: Double field is added to the BaggedPoint constructor
  • A function to extract the sample weight from datum is added to convertToBaggedPointRDD. This will be helpful when we add weights to decision trees, so that we can extract the instance weight from the RDD[Instance].

How was this patch tested?

This PR does not introduce any new functional changes, so there are no tests added.

@sethah
Copy link
Contributor Author

sethah commented Apr 13, 2016

In a previous "TODO" it was proposed that we could incorporate sample weights by simply multiplying the subsample counts by the sample weight and storing them in an array. I chose not to do this because of the need to have both raw counts and weighted counts when adding weights to decision trees. If we simply store the weighted counts, we lose the information about the raw counts, which makes it impossible to track minInstancesPerNode for decision trees.

@SparkQA
Copy link

SparkQA commented Apr 13, 2016

Test build #55748 has finished for PR 12370 at commit a673658.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@sethah
Copy link
Contributor Author

sethah commented Apr 18, 2016

cc @jkbradley If you get a chance to review it would be much appreciated.

@sethah
Copy link
Contributor Author

sethah commented Apr 20, 2016

cc @MLnick could you take a look? This is blocking SPARK-9478 which I have a PR ready to submit for.

@sethah
Copy link
Contributor Author

sethah commented Apr 28, 2016

ping @jkbradley @MLnick

I created this PR and #12374 to make SPARK-9478 easier to review. Alternatively, I could submit them all as one PR. It would be nice to get sample weights for trees into Spark 2.0. Thoughts?

Also ping @holdenk

@@ -60,20 +68,24 @@ private[spark] object BaggedPoint {
subsamplingRate: Double,
numSubsamples: Int,
withReplacement: Boolean,
extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking my understanding here, but is the intention to in future support something like WeightedTreePoint (or amend TreePoint to include a weight), which is constructed in turn from Instance rather than LabeledPoint, and then the function passed can be ... => point.weight or similar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is exactly the case for this. I could not think of a better way to implement this, while still keeping bagged point generic (i.e. not requiring Datum to have a weight property or something similar).

@MLnick
Copy link
Contributor

MLnick commented Apr 29, 2016

@sethah just to confirm, is SPARK-9478 about sample weights, or class weights? The title is for class weights but I think the actual idea and PR etc is for sample weights, yes?

@sethah
Copy link
Contributor Author

sethah commented Apr 29, 2016

@MLnick Yes, SPARK-9478 is for sample weighting.

/**
* Subsample counts weighted by the sample weight.
*/
def weightedCounts: Array[Double] = subsampleCounts.map(_ * sampleWeight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a val?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this as a convenience method. If we make it a val then we add storage overhead in the class which is redundant. If preferable, we could remove it entirely.

@MechCoder
Copy link
Contributor

Should be there a sanity check providing input RDD of instance objects and extractSampleWeight as callable that just returns the weight for each instance?

@sethah sethah closed this Oct 10, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants