Skip to content

Commit

Permalink
TrainValidationSplit user guide and examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
zapletal-martin committed Aug 23, 2015
1 parent c980e20 commit a9988f5
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 0 deletions.
167 changes: 167 additions & 0 deletions docs/ml-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,173 @@ jsc.stop();

</div>

## Example: Model Selection via Train Validation Split
In addition to `CrossValidator` Spark also offers
[`TrainValidationSplit`](api/scala/index.html#org.apache.spark.ml.tuning.TrainValidationSplit) for hyper-parameter tuning.
It randomly splits the input dataset into train and validation sets based on ratio passed as parameter
and use evaluation metric on the validation set to select the best model.
The use is similar to `CrossValidator`, but simpler and less computationally expensive.

`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s, and an
[`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator).
It begins by splitting the dataset into two parts using *trainRatio* parameter
which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default),
`TrainValidationSplit` will generate training and test dataset pair where 75% of the data is used for training and 25% for validation.
Similarly to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s.
For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`.
The `ParamMap` which produces the best evaluation metric is selected as the best option.
`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.

`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in
case of `CrossValidator`. It is therefore less expensive, but will not produce as reliable results.

<div class="codetabs">

<div data-lang="scala">
{% highlight scala %}
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}

val conf = new SparkConf().setAppName("TrainValidationSplitExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val training = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))))

val lr = new LinearRegression()

// In this case the estimator is simply the linear regression.
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
val trainValidationSplit = new TrainValidationSplit()
.setEstimator(lr)
.setEvaluator(new RegressionEvaluator)

// We use a ParamGridBuilder to construct a grid of parameters to search over.
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
val paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.01))
.addGrid(lr.fitIntercept, Array(true, false))
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
.addGrid(lr.maxIter, Array(10, 100))
.addGrid(lr.tol, Array(1E-5, 1E-6))
.build()

trainValidationSplit.setEstimatorParamMaps(paramGrid)

// 80% of the data will be used for training and the remaining 20% for validation.
trainValidationSplit.setTrainRatio(0.8)

// Run train validation split, and choose the best set of parameters.
val model = trainValidationSplit.fit(training.toDF())

// Prepare unlabeled test data.
val test = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))

// Make predictions on test data. model is the model with combination of parameters
// that performed best.
model.transform(test.toDF())
.select("features", "label", "prediction")
.collect()
.foreach { case Row(features: Vector, label: Double, prediction: Double) =>
println(s"($features, $label) --> prediction=$prediction")
}

sc.stop()
{% endhighlight %}
</div>

<div data-lang="java">
{% highlight java %}
import java.util.List;

import com.google.common.collect.Lists;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.*;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);

List<LabeledPoint> localTraining = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));

DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);

LinearRegression lr = new LinearRegression();

// In this case the estimator is simply the linear regression.
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
.setEstimator(lr)
.setEvaluator(new RegressionEvaluator());

// We use a ParamGridBuilder to construct a grid of parameters to search over.
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[]{0.1, 0.01})
.addGrid(lr.fitIntercept())
.addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0})
.addGrid(lr.maxIter(), new int[]{10, 100})
.addGrid(lr.tol(), new double[]{1E-5, 1E-6})
.build();

trainValidationSplit.setEstimatorParamMaps(paramGrid);

// 80% of the data will be used for training and the remaining 20% for validation.
trainValidationSplit.setTrainRatio(0.8);

// Run train validation split, and choose the best set of parameters.
TrainValidationSplitModel model = trainValidationSplit.fit(training);

// Prepare unlabeled test data.
List<LabeledPoint> localTest = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));

DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);

// Make predictions on test data. model is the model with combination of parameters
// that performed best.
DataFrame results = model.transform(test);
for (Row r: results.select("features", "label", "prediction").collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> " + "prediction=" + r.get(2));
}

jsc.stop();
{% endhighlight %}
</div>

</div>

# Dependencies

Spark ML currently depends on MLlib and has the same dependencies.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package org.apache.spark.examples.ml

class TrainValidationSplitExample {

}

0 comments on commit a9988f5

Please sign in to comment.