New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-9911] [DOC] [ML] Update Userguide for Evaluator #8304
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -645,6 +645,13 @@ An important task in ML is *model selection*, or using data to find the best mod | |
Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). | ||
`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. | ||
`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. | ||
|
||
The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) | ||
for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) | ||
for binary data or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Oxford comma ("...binary data, or a...") |
||
for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the setMetric | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
method in each of these evaluators. | ||
|
||
The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. | ||
`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. | ||
|
||
|
@@ -710,9 +717,12 @@ val pipeline = new Pipeline() | |
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. | ||
// This will allow us to jointly choose parameters for all Pipeline stages. | ||
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. | ||
// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: "default metric used is areaUnderROC" -> "default areaUnderROC metric is used" |
||
// used is areaUnderROC. | ||
val crossval = new CrossValidator() | ||
.setEstimator(pipeline) | ||
.setEvaluator(new BinaryClassificationEvaluator) | ||
|
||
// We use a ParamGridBuilder to construct a grid of parameters to search over. | ||
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, | ||
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. | ||
|
@@ -832,9 +842,12 @@ Pipeline pipeline = new Pipeline() | |
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. | ||
// This will allow us to jointly choose parameters for all Pipeline stages. | ||
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. | ||
// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as L720 |
||
// used is areaUnderROC. | ||
CrossValidator crossval = new CrossValidator() | ||
.setEstimator(pipeline) | ||
.setEvaluator(new BinaryClassificationEvaluator()); | ||
|
||
// We use a ParamGridBuilder to construct a grid of parameters to search over. | ||
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, | ||
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR, but the MathJax formatting will override the
code
backticks; we should remove the backticks here.