Skip to content

Commit

Permalink
[SPARK-13244][SQL] Migrates DataFrame to Dataset
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR unifies DataFrame and Dataset by migrating existing DataFrame operations to Dataset and make `DataFrame` a type alias of `Dataset[Row]`.

Most Scala code changes are source compatible, but Java API is broken as Java knows nothing about Scala type alias (mostly replacing `DataFrame` with `Dataset<Row>`).

There are several noticeable API changes related to those returning arrays:

1.  `collect`/`take`

    -   Old APIs in class `DataFrame`:

        ```scala
        def collect(): Array[Row]
        def take(n: Int): Array[Row]
        ```

    -   New APIs in class `Dataset[T]`:

        ```scala
        def collect(): Array[T]
        def take(n: Int): Array[T]

        def collectRows(): Array[Row]
        def takeRows(n: Int): Array[Row]
        ```

    Two specialized methods `collectRows` and `takeRows` are added because Java doesn't support returning generic arrays. Thus, for example, `DataFrame.collect(): Array[T]` actually returns `Object` instead of `Array<T>` from Java side.

    Normally, Java users may fall back to `collectAsList` and `takeAsList`.  The two new specialized versions are added to avoid performance regression in ML related code (but maybe I'm wrong and they are not necessary here).

1.  `randomSplit`

    -   Old APIs in class `DataFrame`:

        ```scala
        def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame]
        def randomSplit(weights: Array[Double]): Array[DataFrame]
        ```

    -   New APIs in class `Dataset[T]`:

        ```scala
        def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]]
        def randomSplit(weights: Array[Double]): Array[Dataset[T]]
        ```

    Similar problem as above, but hasn't been addressed for Java API yet.  We can probably add `randomSplitAsList` to fix this one.

1.  `groupBy`

    Some original `DataFrame.groupBy` methods have conflicting signature with original `Dataset.groupBy` methods.  To distinguish these two, typed `Dataset.groupBy` methods are renamed to `groupByKey`.

Other noticeable changes:

1.  Dataset always do eager analysis now

    We used to support disabling DataFrame eager analysis to help reporting partially analyzed malformed logical plan on analysis failure.  However, Dataset encoders requires eager analysi during Dataset construction.  To preserve the error reporting feature, `AnalysisException` now takes an extra `Option[LogicalPlan]` argument to hold the partially analyzed plan, so that we can check the plan tree when reporting test failures.  This plan is passed by `QueryExecution.assertAnalyzed`.

## How was this patch tested?

Existing tests do the work.

## TODO

- [ ] Fix all tests
- [ ] Re-enable MiMA check
- [ ] Update ScalaDoc (`since`, `group`, and example code)

Author: Cheng Lian <lian@databricks.com>
Author: Yin Huai <yhuai@databricks.com>
Author: Wenchen Fan <wenchen@databricks.com>
Author: Cheng Lian <liancheng@users.noreply.github.com>

Closes #11443 from liancheng/ds-to-df.
  • Loading branch information
liancheng authored and yhuai committed Mar 11, 2016
1 parent 27fe6ba commit 1d54278
Show file tree
Hide file tree
Showing 116 changed files with 1,069 additions and 1,444 deletions.
9 changes: 5 additions & 4 deletions dev/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,11 @@ def main():
# spark build
build_apache_spark(build_tool, hadoop_version)

# backwards compatibility checks
if build_tool == "sbt":
# Note: compatibility tests only supported in sbt for now
detect_binary_inop_with_mima()
# TODO Temporarily disable MiMA check for DF-to-DS migration prototyping
# # backwards compatibility checks
# if build_tool == "sbt":
# # Note: compatiblity tests only supported in sbt for now
# detect_binary_inop_with_mima()

# run the test suites
run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.mllib.linalg.*;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
Expand All @@ -52,7 +53,7 @@ public static void main(String[] args) {
new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
DataFrame training = jsql.createDataFrame(data, schema);
Dataset<Row> training = jsql.createDataFrame(data, schema);
double[] quantileProbabilities = new double[]{0.3, 0.6};
AFTSurvivalRegression aft = new AFTSurvivalRegression()
.setQuantileProbabilities(quantileProbabilities)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

// $example on$
Expand Down Expand Up @@ -93,10 +95,10 @@ public Rating call(String str) {
return Rating.parseRating(str);
}
});
DataFrame ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
DataFrame[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
DataFrame training = splits[0];
DataFrame test = splits[1];
Dataset<Row> ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];

// Build the recommendation model using ALS on the training data
ALS als = new ALS()
Expand All @@ -108,8 +110,8 @@ public Rating call(String str) {
ALSModel model = als.fit(training);

// Evaluate the model by computing the RMSE on the test data
DataFrame rawPredictions = model.transform(test);
DataFrame predictions = rawPredictions
Dataset<Row> rawPredictions = model.transform(test);
Dataset<Row> predictions = rawPredictions
.withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
.withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;

// $example on$
Expand Down Expand Up @@ -51,18 +52,18 @@ public static void main(String[] args) {
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema);
Dataset<Row> continuousDataFrame = jsql.createDataFrame(jrdd, schema);
Binarizer binarizer = new Binarizer()
.setInputCol("feature")
.setOutputCol("binarized_feature")
.setThreshold(0.5);
DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame);
DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature");
for (Row r : binarizedFeatures.collect()) {
Dataset<Row> binarizedDataFrame = binarizer.transform(continuousDataFrame);
Dataset<Row> binarizedFeatures = binarizedDataFrame.select("binarized_feature");
for (Row r : binarizedFeatures.collectRows()) {
Double binarized_value = r.getDouble(0);
System.out.println(binarized_value);
}
// $example off$
jsc.stop();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
Expand Down Expand Up @@ -62,7 +62,7 @@ public static void main(String[] args) {
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});

DataFrame dataset = jsql.createDataFrame(data, schema);
Dataset<Row> dataset = jsql.createDataFrame(data, schema);

BisectingKMeans bkm = new BisectingKMeans().setK(2);
BisectingKMeansModel model = bkm.fit(dataset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
Expand All @@ -53,15 +53,15 @@ public static void main(String[] args) {
StructType schema = new StructType(new StructField[]{
new StructField("features", DataTypes.DoubleType, false, Metadata.empty())
});
DataFrame dataFrame = jsql.createDataFrame(data, schema);
Dataset<Row> dataFrame = jsql.createDataFrame(data, schema);

Bucketizer bucketizer = new Bucketizer()
.setInputCol("features")
.setOutputCol("bucketedFeatures")
.setSplits(splits);

// Transform original data into its bucket index.
DataFrame bucketedData = bucketizer.transform(dataFrame);
Dataset<Row> bucketedData = bucketizer.transform(dataFrame);
bucketedData.show();
// $example off$
jsc.stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;

// $example on$
Expand All @@ -28,7 +29,6 @@
import org.apache.spark.ml.feature.ChiSqSelector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
Expand All @@ -55,15 +55,15 @@ public static void main(String[] args) {
new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty())
});

DataFrame df = sqlContext.createDataFrame(jrdd, schema);
Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema);

ChiSqSelector selector = new ChiSqSelector()
.setNumTopFeatures(1)
.setFeaturesCol("features")
.setLabelCol("clicked")
.setOutputCol("selectedFeatures");

DataFrame result = selector.fit(df).transform(df);
Dataset<Row> result = selector.fit(df).transform(df);
result.show();
// $example off$
jsc.stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
Expand All @@ -48,7 +48,7 @@ public static void main(String[] args) {
StructType schema = new StructType(new StructField [] {
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
DataFrame df = sqlContext.createDataFrame(jrdd, schema);
Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema);

// fit a CountVectorizerModel from the corpus
CountVectorizerModel cvModel = new CountVectorizer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

Expand Down Expand Up @@ -71,7 +72,8 @@ public static void main(String[] args) {
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
Dataset<Row> training = jsql.createDataFrame(
jsc.parallelize(localTraining), LabeledDocument.class);

// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
Expand Down Expand Up @@ -112,11 +114,11 @@ public static void main(String[] args) {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);

// Make predictions on test documents. cvModel uses the best model found (lrModel).
DataFrame predictions = cvModel.transform(test);
for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
Dataset<Row> predictions = cvModel.transform(test);
for (Row r: predictions.select("id", "text", "probability", "prediction").collectRows()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;

// $example on$
Expand All @@ -28,7 +29,6 @@
import org.apache.spark.ml.feature.DCT;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.Metadata;
Expand All @@ -51,12 +51,12 @@ public static void main(String[] args) {
StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
DataFrame df = jsql.createDataFrame(data, schema);
Dataset<Row> df = jsql.createDataFrame(data, schema);
DCT dct = new DCT()
.setInputCol("features")
.setOutputCol("featuresDCT")
.setInverse(false);
DataFrame dctDf = dct.transform(df);
Dataset<Row> dctDf = dct.transform(df);
dctDf.select("featuresDCT").show(3);
// $example off$
jsc.stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$

Expand All @@ -38,7 +39,7 @@ public static void main(String[] args) {

// $example on$
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
Dataset<Row> data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");

// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
Expand All @@ -55,9 +56,9 @@ public static void main(String[] args) {
.fit(data);

// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];

// Train a DecisionTree model.
DecisionTreeClassifier dt = new DecisionTreeClassifier()
Expand All @@ -78,7 +79,7 @@ public static void main(String[] args) {
PipelineModel model = pipeline.fit(trainingData);

// Make predictions.
DataFrame predictions = model.transform(testData);
Dataset<Row> predictions = model.transform(testData);

// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$

Expand All @@ -38,7 +39,7 @@ public static void main(String[] args) {
SQLContext sqlContext = new SQLContext(jsc);
// $example on$
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm")
Dataset<Row> data = sqlContext.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt");

// Automatically identify categorical features, and index them.
Expand All @@ -50,9 +51,9 @@ public static void main(String[] args) {
.fit(data);

// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];

// Train a DecisionTree model.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
Expand All @@ -66,7 +67,7 @@ public static void main(String[] args) {
PipelineModel model = pipeline.fit(trainingData);

// Make predictions.
DataFrame predictions = model.transform(testData);
Dataset<Row> predictions = model.transform(testData);

// Select example rows to display.
predictions.select("label", "features").show(5);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

Expand Down Expand Up @@ -61,7 +62,8 @@ public static void main(String[] args) throws Exception {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
Dataset<Row> training = jsql.createDataFrame(
jsc.parallelize(localTraining), LabeledPoint.class);

// Create a LogisticRegression instance. This instance is an Estimator.
MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
Expand All @@ -79,12 +81,12 @@ public static void main(String[] args) throws Exception {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);

// Make predictions on test documents. cvModel uses the best model found (lrModel).
DataFrame results = model.transform(test);
Dataset<Row> results = model.transform(test);
double sumPredictions = 0;
for (Row r : results.select("features", "label", "prediction").collect()) {
for (Row r : results.select("features", "label", "prediction").collectRows()) {
sumPredictions += r.getDouble(2);
}
if (sumPredictions != 0.0) {
Expand Down Expand Up @@ -145,7 +147,7 @@ MyJavaLogisticRegression setMaxIter(int value) {

// This method is used by fit().
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
public MyJavaLogisticRegressionModel train(DataFrame dataset) {
public MyJavaLogisticRegressionModel train(Dataset<Row> dataset) {
// Extract columns from data using helper method.
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();

Expand Down
Loading

0 comments on commit 1d54278

Please sign in to comment.