Skip to content

Commit

Permalink
Changes as per comments in PR #8377.
Browse files Browse the repository at this point in the history
  • Loading branch information
zapletal-martin committed Aug 27, 2015
1 parent 6443d55 commit 1dd1cd1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
19 changes: 19 additions & 0 deletions docs/ml-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,13 @@ The `ParamMap` which produces the best evaluation metric is selected as the best

<div data-lang="scala" markdown="1">
{% highlight scala %}
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

import sqlContext.implicits._

// Prepare training and test data.
Expand Down Expand Up @@ -931,6 +938,18 @@ sc.stop()

<div data-lang="java" markdown="1">
{% highlight java %}
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.*;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;

RDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt");
RDD<LabeledPoint>[] splits = data.randomSplit(new double []{0.9, 0.1}, 12345);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@
public class JavaTrainValidationSplitExample {

public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample").setMaster("local[2]");
SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);

RDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt");
RDD<LabeledPoint>[] splits = data.randomSplit(new double []{0.9, 0.1}, 12345);
RDD<LabeledPoint>[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);

// Prepare training and test data.
DataFrame training = jsql.createDataFrame(splits[0], LabeledPoint.class);
Expand All @@ -66,9 +66,9 @@ public static void main(String[] args) {
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[]{0.1, 0.01})
.addGrid(lr.regParam(), new double[] {0.1, 0.01})
.addGrid(lr.fitIntercept())
.addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0})
.addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
.build();

trainValidationSplit.setEstimatorParamMaps(paramGrid);
Expand Down

0 comments on commit 1dd1cd1

Please sign in to comment.