From 9872424607a39699a5b23cc83b3ebfd7be062957 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 19 Jan 2015 15:53:15 -0800 Subject: [PATCH] fixed JavaLinearRegressionSuite.java Java sql api --- .../JavaLinearRegressionSuite.java | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLinearRegressionSuite.java index 1f47b711ac6d4..d918fc7caf6a0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLinearRegressionSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.ml.classification; -import scala.Tuple2; - import java.io.Serializable; import java.util.ArrayList; import java.util.List; @@ -29,40 +27,33 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.regression.LinearRegressionModel; import static org.apache.spark.mllib.classification.LogisticRegressionSuite .generateLogisticInputAsList; -import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SchemaRDD; public class JavaLinearRegressionSuite implements Serializable { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient SchemaRDD dataset; private transient JavaRDD datasetRDD; - private transient JavaRDD featuresRDD; private double eps = 1e-5; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); List points = new ArrayList(); for (org.apache.spark.mllib.regression.LabeledPoint lp: generateLogisticInputAsList(1.0, 1.0, 100, 42)) { points.add(new LabeledPoint(lp.label(), lp.features())); } datasetRDD = jsc.parallelize(points, 2); - featuresRDD = datasetRDD.map(new Function() { - @Override public Vector call(LabeledPoint lp) { return lp.features(); } - }); dataset = jsql.applySchema(datasetRDD, LabeledPoint.class); dataset.registerTempTable("dataset"); } @@ -79,7 +70,7 @@ public void linearRegressionDefaultParams() { assert(lr.getLabelCol().equals("label")); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, prediction FROM prediction"); + SchemaRDD predictions = jsql.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults assert(model.getFeaturesCol().equals("features"));