Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9910][ML]User guide for train validation split #8377

Closed
wants to merge 7 commits into from
167 changes: 167 additions & 0 deletions docs/ml-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,173 @@ jsc.stop();

</div>

## Example: Model Selection via Train Validation Split
In addition to `CrossValidator` Spark also offers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would cut L805-808 since L811-819 says essentially the same thing but is better IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have kept L805 as introduction, deleted the others as suggested.

[`TrainValidationSplit`](api/scala/index.html#org.apache.spark.ml.tuning.TrainValidationSplit) for hyper-parameter tuning.
It randomly splits the input dataset into train and validation sets based on ratio passed as parameter
and use evaluation metric on the validation set to select the best model.
The use is similar to `CrossValidator`, but simpler and less computationally expensive.

`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s, and an
[`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley has told me in the past to keep the scala API links under the scala codetab and to create a separate version linking java API docs under java codetabs, for example

However, it looks like we have been inconsistent throughout ml-guide.md about this so I'm OK with how it looks now. Just wanted to flag it in case others have something to say.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to fix inconsistency. It is bad to show links to Scala API doc to Python users. We can remove this link for now.

It begins by splitting the dataset into two parts using *trainRatio* parameter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"using trainRatio" -> "using the trainRatio" (use backticks for code references)

which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default),
`TrainValidationSplit` will generate training and test dataset pair where 75% of the data is used for training and 25% for validation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "will generate training..." -> "will generate a training..."

Similarly to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Similarly -> Similar

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, "set of ParamMaps" -> "set of ParamMaps provided in the estimatorParamMaps parameter"

For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`.
The `ParamMap` which produces the best evaluation metric is selected as the best option.
`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.

`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move this to the top of TrainValidationSplit's description since it is important in helping the user choose between this and CrossValidator

case of `CrossValidator`. It is therefore less expensive, but will not produce as reliable results.

<div class="codetabs">

<div data-lang="scala">
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use this instead:

<div data-lang="scala" markdown="1">

That helps when, e.g., you put links inside the code tabs (which will likely be done in the future as docs expand).

{% highlight scala %}
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for importing SQLContext, SparkConf, and SparkContext anymore (they are provided in spark-shell as sqlContext and sc)


val conf = new SparkConf().setAppName("TrainValidationSplitExample")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to set SparkContext or SQLContext (examples should be copy/pasteable directly into bin/spark-shell). Also remove the imports when this is cut

val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is imported by default in console. We can remove this line and write:

val df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()


val training = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))))

val lr = new LinearRegression()

// In this case the estimator is simply the linear regression.
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
val trainValidationSplit = new TrainValidationSplit()
.setEstimator(lr)
.setEvaluator(new RegressionEvaluator)

// We use a ParamGridBuilder to construct a grid of parameters to search over.
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
val paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.01))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make this simpler (maybe only a 2x2 grid).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified to 2x2x3.

I wanted the example to be different from CrossValidator example. It only evaluates parameters of the LinearRegression rather than whole pipeline, it uses RegressionEvaluator as opposed to BinaryClassificationEvaluator and I also wanted to show how multiple parameter combinations can be evaluated.

.addGrid(lr.fitIntercept, Array(true, false))
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
.addGrid(lr.maxIter, Array(10, 100))
.addGrid(lr.tol, Array(1E-5, 1E-6))
.build()

trainValidationSplit.setEstimatorParamMaps(paramGrid)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we prepare the paramGrid first and then construct the trainValidationSplit with setters chained. That seems simpler to me.


// 80% of the data will be used for training and the remaining 20% for validation.
trainValidationSplit.setTrainRatio(0.8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this reliably work? A lot of the time, this should use all 4 data points for training and 0 for testing. I'd use more data (maybe loading from "data/mllib/sample_libsvm_data.txt").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It was indeed unreliable. Using data/mllib/sample_libsvm_data.txt now instead.


// Run train validation split, and choose the best set of parameters.
val model = trainValidationSplit.fit(training.toDF())

// Prepare unlabeled test data.
val test = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))

// Make predictions on test data. model is the model with combination of parameters
// that performed best.
model.transform(test.toDF())
.select("features", "label", "prediction")
.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to use collect-foreach-println anymore. You can just call show().

.foreach { case Row(features: Vector, label: Double, prediction: Double) =>
println(s"($features, $label) --> prediction=$prediction")
}

sc.stop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not necessary in example code

{% endhighlight %}
</div>

<div data-lang="java">
{% highlight java %}
import java.util.List;

import com.google.common.collect.Lists;

import org.apache.spark.SparkConf;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove SparkConf and JavaSparkContext

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.*;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove SQLContext import


SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, no need for jsc and jsql

SQLContext jsql = new SQLContext(jsc);

List<LabeledPoint> localTraining = Lists.newArrayList(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use java.util.Arrays.asList instead so we can remove the Guava (com.google.common) dependency

new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));

DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);

LinearRegression lr = new LinearRegression();

// In this case the estimator is simply the linear regression.
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
.setEstimator(lr)
.setEvaluator(new RegressionEvaluator());

// We use a ParamGridBuilder to construct a grid of parameters to search over.
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[]{0.1, 0.01})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space before and after []

.addGrid(lr.fitIntercept())
.addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0})
.addGrid(lr.maxIter(), new int[]{10, 100})
.addGrid(lr.tol(), new double[]{1E-5, 1E-6})
.build();

trainValidationSplit.setEstimatorParamMaps(paramGrid);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. Construct paramGrid first, then create trainValidationSplit and chain all setters.


// 80% of the data will be used for training and the remaining 20% for validation.
trainValidationSplit.setTrainRatio(0.8);

// Run train validation split, and choose the best set of parameters.
TrainValidationSplitModel model = trainValidationSplit.fit(training);

// Prepare unlabeled test data.
List<LabeledPoint> localTest = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));

DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);

// Make predictions on test data. model is the model with combination of parameters
// that performed best.
DataFrame results = model.transform(test);
for (Row r: results.select("features", "label", "prediction").collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> " + "prediction=" + r.get(2));
}

jsc.stop();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

{% endhighlight %}
</div>

</div>

# Dependencies

Spark ML currently depends on MLlib and has the same dependencies.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml;

import java.util.List;

import com.google.common.collect.Lists;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.*;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

/**
* A simple example demonstrating model selection using TrainValidationSplit.
*
* The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample}
* using linear regression.
*
* Run with
* {{{
* bin/run-example ml.JavaTrainValidationSplitExample
* }}}
*/
public class JavaTrainValidationSplitExample {

public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);

List<LabeledPoint> localTraining = Lists.newArrayList(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrays.asList and remove Guava

new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));

DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);

LinearRegression lr = new LinearRegression();

// In this case the estimator is simply the linear regression.
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
.setEstimator(lr)
.setEvaluator(new RegressionEvaluator());

// We use a ParamGridBuilder to construct a grid of parameters to search over.
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[]{0.1, 0.01})
.addGrid(lr.fitIntercept())
.addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0})
.addGrid(lr.maxIter(), new int[]{10, 100})
.addGrid(lr.tol(), new double[]{1E-5, 1E-6})
.build();

trainValidationSplit.setEstimatorParamMaps(paramGrid);

// 80% of the data will be used for training and the remaining 20% for validation.
trainValidationSplit.setTrainRatio(0.8);

// Run train validation split, and choose the best set of parameters.
TrainValidationSplitModel model = trainValidationSplit.fit(training);

// Prepare unlabeled test data.
List<LabeledPoint> localTest = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));

DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);

// Make predictions on test data. model is the model with combination of parameters
// that performed best.
DataFrame results = model.transform(test);
for (Row r: results.select("features", "label", "prediction").collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> " + "prediction=" + r.get(2));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use DataFrame.show()

}

jsc.stop();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// scalastyle:off println
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the scalastyle override as close as possible to the println; even better use show() instead of println

package org.apache.spark.examples.ml

import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}

/**
* A simple example demonstrating model selection using TrainValidationSplit.
*
* The example is based on [[SimpleParamsExample]] using linear regression.
* Run with
* {{{
* bin/run-example ml.TrainValidationSplitExample
* }}}
*/
object TrainValidationSplitExample {

def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("TrainValidationSplitExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val training = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))))

val lr = new LinearRegression()

// In this case the estimator is simply the linear regression.
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
val trainValidationSplit = new TrainValidationSplit()
.setEstimator(lr)
.setEvaluator(new RegressionEvaluator)

// We use a ParamGridBuilder to construct a grid of parameters to search over.
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
val paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.01))
.addGrid(lr.fitIntercept, Array(true, false))
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
.addGrid(lr.maxIter, Array(10, 100))
.addGrid(lr.tol, Array(1E-5, 1E-6))
.build()

trainValidationSplit.setEstimatorParamMaps(paramGrid)

// 80% of the data will be used for training and the remaining 20% for validation.
trainValidationSplit.setTrainRatio(0.8)

// Run train validation split, and choose the best set of parameters.
val model = trainValidationSplit.fit(training.toDF())

// Prepare unlabeled test data.
val test = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))

// Make predictions on test data. model is the model with combination of parameters
// that performed best.
model.transform(test.toDF())
.select("features", "label", "prediction")
.collect()
.foreach { case Row(features: Vector, label: Double, prediction: Double) =>
println(s"($features, $label) --> prediction=$prediction")
}

sc.stop()
}
}
// scalastyle:on println