From 97a62d695086d342b93059566b5cf061a6fe8962 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 12 Mar 2016 19:43:02 +0800 Subject: [PATCH 1/2] Removes Dataset.collectRows()/takeRows() --- .../examples/ml/JavaBinarizerExample.java | 2 +- .../ml/JavaCrossValidatorExample.java | 2 +- .../examples/ml/JavaDeveloperApiExample.java | 2 +- .../JavaEstimatorTransformerParamExample.java | 3 +- ...delSelectionViaCrossValidationExample.java | 2 +- .../spark/examples/ml/JavaNGramExample.java | 2 +- .../examples/ml/JavaPipelineExample.java | 2 +- .../ml/JavaPolynomialExpansionExample.java | 5 ++- .../examples/ml/JavaSimpleParamsExample.java | 3 +- .../JavaSimpleTextClassificationPipeline.java | 2 +- .../spark/examples/ml/JavaTfIdfExample.java | 2 +- .../examples/ml/JavaTokenizerExample.java | 2 +- .../examples/ml/JavaWord2VecExample.java | 2 +- .../ml/feature/JavaVectorSlicerSuite.java | 2 +- .../org/apache/spark/sql/DataFrame.scala | 18 --------- .../spark/sql/JavaApplySchemaSuite.java | 2 +- .../apache/spark/sql/JavaDataFrameSuite.java | 39 ++++++++++--------- 17 files changed, 39 insertions(+), 53 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java index d554377975b1b..0a6e9c2a1f93c 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -58,7 +58,7 @@ public static void main(String[] args) { .setThreshold(0.5); Dataset binarizedDataFrame = binarizer.transform(continuousDataFrame); Dataset binarizedFeatures = binarizedDataFrame.select("binarized_feature"); - for (Row r : binarizedFeatures.collectRows()) { + for (Row r : binarizedFeatures.collectAsList()) { Double binarized_value = r.getDouble(0); System.out.println(binarized_value); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 90bc94c45bbf9..07edeb3e521c3 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -117,7 +117,7 @@ public static void main(String[] args) { // Make predictions on test documents. cvModel uses the best model found (lrModel). Dataset predictions = cvModel.transform(test); - for (Row r: predictions.select("id", "text", "probability", "prediction").collectRows()) { + for (Row r: predictions.select("id", "text", "probability", "prediction").collectAsList()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index e8ae100d68529..8a10dd48aa72f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -85,7 +85,7 @@ public static void main(String[] args) throws Exception { // Make predictions on test documents. cvModel uses the best model found (lrModel). Dataset results = model.transform(test); double sumPredictions = 0; - for (Row r : results.select("features", "label", "prediction").collectRows()) { + for (Row r : results.select("features", "label", "prediction").collectAsList()) { sumPredictions += r.getDouble(2); } if (sumPredictions != 0.0) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java index f13698ae5e07e..604b193dd489b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java @@ -100,7 +100,8 @@ public static void main(String[] args) { // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. Dataset results = model2.transform(test); - for (Row r : results.select("features", "label", "myProbability", "prediction").collectRows()) { + Dataset rows = results.select("features", "label", "myProbability", "prediction"); + for (Row r: rows.collectAsList()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java index e394605db70ea..c4122d1247a94 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java @@ -111,7 +111,7 @@ public static void main(String[] args) { // Make predictions on test documents. cvModel uses the best model found (lrModel). Dataset predictions = cvModel.transform(test); - for (Row r : predictions.select("id", "text", "probability", "prediction").collectRows()) { + for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java index 0305f737ca94c..608bd80285655 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java @@ -60,7 +60,7 @@ public static void main(String[] args) { Dataset ngramDataFrame = ngramTransformer.transform(wordDataFrame); - for (Row r : ngramDataFrame.select("ngrams", "label").takeRows(3)) { + for (Row r : ngramDataFrame.select("ngrams", "label").takeAsList(3)) { java.util.List ngrams = r.getList(0); for (String ngram : ngrams) System.out.print(ngram + " --- "); System.out.println(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java index 6ae418d564d1f..305420f208b79 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java @@ -80,7 +80,7 @@ public static void main(String[] args) { // Make predictions on test documents. Dataset predictions = model.transform(test); - for (Row r : predictions.select("id", "text", "probability", "prediction").collectRows()) { + for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java index 5a4064c604301..48fc3c8acb0c0 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -23,6 +23,7 @@ // $example on$ import java.util.Arrays; +import java.util.List; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.PolynomialExpansion; @@ -61,8 +62,8 @@ public static void main(String[] args) { Dataset df = jsql.createDataFrame(data, schema); Dataset polyDF = polyExpansion.transform(df); - Row[] row = polyDF.select("polyFeatures").takeRows(3); - for (Row r : row) { + List rows = polyDF.select("polyFeatures").takeAsList(3); + for (Row r : rows) { System.out.println(r.get(0)); } // $example off$ diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 52bb4ec050376..cb911ef5ef586 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -103,7 +103,8 @@ public static void main(String[] args) { // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. Dataset results = model2.transform(test); - for (Row r: results.select("features", "label", "myProbability", "prediction").collectRows()) { + Dataset rows = results.select("features", "label", "myProbability", "prediction"); + for (Row r: rows.collectAsList()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 9bd543c44f983..a18a60f448166 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -84,7 +84,7 @@ public static void main(String[] args) { // Make predictions on test documents. Dataset predictions = model.transform(test); - for (Row r: predictions.select("id", "text", "probability", "prediction").collectRows()) { + for (Row r: predictions.select("id", "text", "probability", "prediction").collectAsList()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java index fd1ce424bf8c4..37a3d0d84dae2 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java @@ -66,7 +66,7 @@ public static void main(String[] args) { IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); IDFModel idfModel = idf.fit(featurizedData); Dataset rescaledData = idfModel.transform(featurizedData); - for (Row r : rescaledData.select("features", "label").takeRows(3)) { + for (Row r : rescaledData.select("features", "label").takeAsList(3)) { Vector features = r.getAs(0); Double label = r.getDouble(1); System.out.println(features); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java index a2f8c436e32f6..9225fe2262f57 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -59,7 +59,7 @@ public static void main(String[] args) { Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); Dataset wordsDataFrame = tokenizer.transform(sentenceDataFrame); - for (Row r : wordsDataFrame.select("words", "label").takeRows(3)) { + for (Row r : wordsDataFrame.select("words", "label").takeAsList(3)) { java.util.List words = r.getList(0); for (String word : words) System.out.print(word + " "); System.out.println(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java index 2dce8c2168c2d..c5bb1eaaa3446 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java @@ -59,7 +59,7 @@ public static void main(String[] args) { .setMinCount(0); Word2VecModel model = word2Vec.fit(documentDF); Dataset result = model.transform(documentDF); - for (Row r : result.select("result").takeRows(3)) { + for (Row r : result.select("result").takeAsList(3)) { System.out.println(r); } // $example off$ diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index b87605ebfd6a3..e2da11183b93f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -78,7 +78,7 @@ public void vectorSlice() { Dataset output = vectorSlicer.transform(dataset); - for (Row r : output.select("userFeatures", "features").takeRows(2)) { + for (Row r : output.select("userFeatures", "features").takeAsList(2)) { Vector features = r.getAs(1); Assert.assertEquals(features.size(), 2); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f1791e6943bb7..1ea7db0388689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1762,10 +1762,6 @@ class Dataset[T] private[sql]( */ def take(n: Int): Array[T] = head(n) - def takeRows(n: Int): Array[Row] = withTypedCallback("takeRows", limit(n)) { ds => - ds.collectRows(needCallback = false) - } - /** * Returns the first `n` rows in the [[DataFrame]] as a list. * @@ -1790,8 +1786,6 @@ class Dataset[T] private[sql]( */ def collect(): Array[T] = collect(needCallback = true) - def collectRows(): Array[Row] = collectRows(needCallback = true) - /** * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. * @@ -1820,18 +1814,6 @@ class Dataset[T] private[sql]( } } - private def collectRows(needCallback: Boolean): Array[Row] = { - def execute(): Array[Row] = withNewExecutionId { - queryExecution.executedPlan.executeCollectPublic() - } - - if (needCallback) { - withCallback("collect", toDF())(_ => execute()) - } else { - execute() - } - } - /** * Returns the number of rows in the [[DataFrame]]. * @group action diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 42af813bc1cd3..a8f4d3972a7fa 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -109,7 +109,7 @@ public Row call(Person person) throws Exception { Dataset df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows(); + List actual = sqlContext.sql("SELECT * FROM people").collectAsList(); List expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 47cc74dbc1f28..d40704e917edf 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; @@ -65,7 +66,7 @@ public void tearDown() { @Test public void testExecution() { Dataset df = context.table("testData").filter("key = 1"); - Assert.assertEquals(1, df.select("key").collectRows()[0].get(0)); + Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); } @Test @@ -208,8 +209,8 @@ public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List rows = Arrays.asList(RowFactory.create(0)); Dataset df = context.createDataFrame(rows, schema); - Row[] result = df.collectRows(); - Assert.assertEquals(1, result.length); + List result = df.collectAsList(); + Assert.assertEquals(1, result.size()); } @Test @@ -241,8 +242,8 @@ public void testCrosstab() { Assert.assertEquals("a_b", columnNames[0]); Assert.assertEquals("2", columnNames[1]); Assert.assertEquals("1", columnNames[2]); - Row[] rows = crosstab.collectRows(); - Arrays.sort(rows, crosstabRowComparator); + List rows = crosstab.collectAsList(); + Collections.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); @@ -257,7 +258,7 @@ public void testFrequentItems() { Dataset df = context.table("testData2"); String[] cols = {"a"}; Dataset results = df.stat().freqItems(cols, 0.2); - Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1)); + Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1)); } @Test @@ -278,27 +279,27 @@ public void testCovariance() { public void testSampleBy() { Dataset df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); - Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows(); - Assert.assertEquals(0, actual[0].getLong(0)); - Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); - Assert.assertEquals(1, actual[1].getLong(0)); - Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); + List actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); + Assert.assertEquals(0, actual.get(0).getLong(0)); + Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); + Assert.assertEquals(1, actual.get(0).getLong(0)); + Assert.assertTrue(2 <= actual.get(0).getLong(1) && actual.get(1).getLong(1) <= 13); } @Test public void pivot() { Dataset df = context.table("courseSales"); - Row[] actual = df.groupBy("year") + List actual = df.groupBy("year") .pivot("course", Arrays.asList("dotNET", "Java")) - .agg(sum("earnings")).orderBy("year").collectRows(); + .agg(sum("earnings")).orderBy("year").collectAsList(); - Assert.assertEquals(2012, actual[0].getInt(0)); - Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); - Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01); + Assert.assertEquals(2012, actual.get(0).getInt(0)); + Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01); - Assert.assertEquals(2013, actual[1].getInt(0)); - Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); - Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); + Assert.assertEquals(2013, actual.get(1).getInt(0)); + Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); } @Test From e7898939af2a8e839099a4f65b2787d60ee12ea7 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 12 Mar 2016 22:30:16 +0800 Subject: [PATCH 2/2] Fixes test failures --- .../java/test/org/apache/spark/sql/JavaApplySchemaSuite.java | 2 +- .../java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index a8f4d3972a7fa..ae9c8cc1ba9ff 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -115,7 +115,7 @@ public Row call(Person person) throws Exception { expected.add(RowFactory.create("Michael", 29)); expected.add(RowFactory.create("Yin", 28)); - Assert.assertEquals(expected, Arrays.asList(actual)); + Assert.assertEquals(expected, actual); } @Test diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index d40704e917edf..42554720edae5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -282,8 +282,8 @@ public void testSampleBy() { List actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); - Assert.assertEquals(1, actual.get(0).getLong(0)); - Assert.assertTrue(2 <= actual.get(0).getLong(1) && actual.get(1).getLong(1) <= 13); + Assert.assertEquals(1, actual.get(1).getLong(0)); + Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); } @Test