Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8484] [ML]. Added TrainValidationSplit for hyper-parameter tuning. #7337

Conversation

zapletal-martin
Copy link
Contributor

… randomly splits the input dataset into train and validation and use evaluation metric on the validation set to select the best model.
… randomly splits the input dataset into train and validation and use evaluation metric on the validation set to select the best model.
@SparkQA
Copy link

SparkQA commented Jul 10, 2015

Test build #36991 has finished for PR 7337 at commit d699506.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class TrainValidatorSplit(override val uid: String) extends Estimator[TrainValidatorSplitModel]

@SparkQA
Copy link

SparkQA commented Jul 10, 2015

Test build #36993 has finished for PR 7337 at commit 00c4f5a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class TrainValidatorSplit(override val uid: String) extends Estimator[TrainValidatorSplitModel]

* :: Experimental ::
* Validation for hyper-parameter tuning.
* Randomly splits the input dataset into train and validation sets.
* And uses evaluation metric on the validation set to select the best model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "...validation sets, and uses ..." (comma instead of period)

@feynmanliang
Copy link
Contributor

Done for now. This looks like it's in good shape.

@feynmanliang
Copy link
Contributor

Actually, is there much difference between this and CrossValidator other than being able to specify trainRatio? There seems to be a lot of code duplication between the two (I think that CrossValidator is TrainValidatorSplit if numFolds=2 and trainRatio=0.5).

Unless I'm missing something, maybe we can simply extend CrossValidator with an additional trainRatio param allowing the user to specify the train/test split ratio, which we then pass into MLUtils#kfold

@zapletal-martin
Copy link
Contributor Author

@feynmanliang thanks for your comments. Yes there is quite a lot of duplicated code. I attempted to refactor that slightly under #6996.

It is an interesting idea to just call CrossValidator instead of implementing the logic. I will have a look into that, it could simplify the code even a bit more. But I assume we still need the Params and Model specific for TrainValidationSplit? We also need to decide if we want to do that as part of this review or separately.

cc @mengxr

@zapletal-martin
Copy link
Contributor Author

Both CrossValidator and TrainValidationSplit use sampling to split the data to training and validation.

Currently CrossValidator does

  • numFolds = 1 - not valid
  • numFolds = 2 - 0.0 to 0.5 training, 0.5 to 1 validation and 0.0 to 0.5 validation and 0.5 to 1 training

TrainValidationSplit does

  • 0.0 to trainRatio training, trainRatio to 1 validation

Therefore the logic is different and using TrainValidationSplit is not the same as just calling CrossValidator. Please let me know if the logic implemented by TrainValidationSplit is what was expected. We can then potentially address the code duplication.

@feynmanliang
Copy link
Contributor

CrossValidator currently doesn't take numFolds = 1 because it uses MLUtils#kfolds which does splits using trainRatio = 1 - 1/numFolds, which in this case would mean nothing goes into the training set.

If we give a trainRatio param to CrossValidator and modify MLUtils#kfold accordingly, then each fold could be split using your set trainRatio so setting numFolds = 1, trainRatio = x would achieve the same functionality as TrainValidationSplit with trainRatio = x.

Note that in this case we would have to change the default trainRatio = 1 - 1/numFolds behavior for the special case numFolds = 1 to avoid the error case described in my first paragraph.

@feynmanliang
Copy link
Contributor

Sorry, didn't address your questions.

  • I'm proposing to add a trainRatio param to CrossValidatorParams, eliminating the need for the TrainValidatorSplit Params and Model
  • The logic in TrainValidationSplit looks fine to me and is what I expected

@zapletal-martin
Copy link
Contributor Author

Thanks @feynmanliang. As I mentioned I tried to address the code duplication in my previous PR differently than you propose, but we decided to go with the simplest option for now.

I agree what you are proposing makes sense. The only thing that worries me would be unclear purpose of trainRatio when numFolds != 1. In that case CrossValidator splits the dataset to numFolds subsets of the same size and the ratio of training and validation sets is given (e.g. with numFolds set to 4 the training set is 0.75 and validation is 0.25) and therefore the trainRatio param would not be used?

We could do that, document that approach and essentially get rid of TrainValidatorSplit or the other option would be to preserve TrainValidatorSplit as a wrapper around the functionality to avoid confusion of CrossValidator having the trainRatio param.

@feynmanliang
Copy link
Contributor

Ah, I overlooked your point about when 1 - trainRatio > 1/numFolds we will have overlapping folds. Also, I realized that what I'm proposing (adding a trainRatio for each fold) contradicts what Wikipedia defines for k-fold validation. Thanks for pushing back on this.

I agree with you that TrainValidatorSplit and CrossValidator have two different functionality and should be separate classes. I like the idea of both wrapping a common (perhaps more confusing) implementation with both numFolds and trainRatio; it differentiates the concepts in the public API but shares code in the implementation.

@SparkQA
Copy link

SparkQA commented Jul 10, 2015

Test build #37057 has finished for PR 7337 at commit f4fc9c4.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class TrainValidatorSplit(override val uid: String) extends Estimator[TrainValidatorSplitModel]

@zapletal-martin
Copy link
Contributor Author

I think we need a decision how to approach this. I would prefer to focus on the public api and avoid the refactor in this review and then address that in another review as discussed in #6996.

/**
* Params for [[TrainValidatorSplit]] and [[TrainValidatorSplitModel]].
*/
private[ml] trait TrainValidatorSplitParams extends ValidatorParams {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TrainValidatorSplit -> TrainValidationSplit

@mengxr
Copy link
Contributor

mengxr commented Jul 22, 2015

@feynmanliang @zapletal-martin The changes in this PR look good to me except a few minor comments. As discussed in #6996, let's focus on the public API to get this merged first.

We can have another PR for code reuse. There would be more discussion, e.g., having a base class handling arbitrary slicing of the input data and making CrossValidator and TrainValidationSplit extend it. I actually think there will be more lines of code if we implement it that way. It might not be worth the trade-off.

@SparkQA
Copy link

SparkQA commented Jul 22, 2015

Test build #38015 has finished for PR 7337 at commit cafc949.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel]

@asfgit asfgit closed this in a721ee5 Jul 23, 2015
@mengxr
Copy link
Contributor

mengxr commented Jul 23, 2015

LGTM. Merged into master. Please create JIRAs for follow-up work, e.g, Python API, user guide, and refactoring (if it is useful). Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants