Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14682][ML] Provide evaluateEachIteration method or equivalent for spark.ml GBTs #21097

Closed
wants to merge 5 commits into from

Conversation

WeichenXu123
Copy link
Contributor

What changes were proposed in this pull request?

Provide evaluateEachIteration method or equivalent for spark.ml GBTs.

How was this patch tested?

UT.

Please review http://spark.apache.org/contributing.html before opening a pull request.

@SparkQA
Copy link

SparkQA commented Apr 18, 2018

Test build #89499 has finished for PR 21097 at commit 836d760.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 18, 2018

Test build #89500 has finished for PR 21097 at commit 16fd4d6.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -365,6 +365,20 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
assert(mostImportantFeature !== mostIF)
}

test("model evaluateEachIteration") {
for (lossType <- Seq("logistic")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is only one lossType. for is not necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. But I think it can fit for future, if we add more loss type for GBT classifier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. It makes sense.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a few comments.

.setLossType(lossType)
val model = gbt.fit(trainData.toDF)
val eval1 = model.evaluateEachIteration(validationData.toDF)
val eval2 = GradientBoostedTrees.evaluateEachIteration(validationData,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is testing the spark.ml implementation against itself. I was about to recommend using the old spark.mllib implementation as a reference. However, the old implementation is not tested at all. Would you be able to test against a standard implementation in R or scikit-learn (following the patterns used elsewhere in MLlib)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I search scikit-learn doc, there seems no similar method like evaluateEachIteration, we can only use staged_predict in sklearn.ensemble.GradientBoostingRegressor and then use metric functions to evaluate them. And I doubt the implementation differ slightly in other library will be troublesome. In R package I also do not find this method.

Now I update the unit test, to just compare with hardcoded result.

* @param dataset Dataset for validation.
*/
@Since("2.4.0")
def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to support evaluation on other losses, as in the old API? It might be nice to be able to without having to modify the Model's loss Param value.

@SparkQA
Copy link

SparkQA commented May 4, 2018

Test build #90188 has finished for PR 21097 at commit a2af286.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unit tests, what about this?

  • Used a fixed random seed.
  • Run for maxIter = 3
  • Create models with 1 and 2 trees by manually getting the trees and constructing new GBT models.
  • Check to make sure the loss for a model with 1 tree matches the first value returned by evaluateEachIteration for the other 2 models.
  • Check to make sure the loss for a model with 2 trees matches the second value returned by evaluateEachIteration for the model with 3 trees.

/**
* Method to compute error or loss for every iteration of gradient boosting.
*
* @param dataset Dataset for validation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add doc for "loss" arg, including what the options are

@SparkQA
Copy link

SparkQA commented May 8, 2018

Test build #90351 has finished for PR 21097 at commit c32b5a8.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a tiny comment left. Thanks!

val model2 = new GBTClassificationModel("gbt-cls-model-test2",
model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses)

for (evalLossType <- GBTClassifier.supportedLossTypes) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

evalLossType is not used, so I'd remove this loop.

@SparkQA
Copy link

SparkQA commented May 9, 2018

Test build #90404 has finished for PR 21097 at commit 0e7311f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

LGTM
Merging with master
Thanks @WeichenXu123 ! Would you mind creating & linking a JIRA for the Python API update?

@asfgit asfgit closed this in 7aaa148 May 9, 2018
@WeichenXu123 WeichenXu123 deleted the GBTeval branch May 9, 2018 23:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants