diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala similarity index 51% rename from examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index dc13f82488af7..424f00158c2f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -16,7 +16,7 @@ */ // scalastyle:off println -package org.apache.spark.examples.mllib +package org.apache.spark.examples.ml import java.io.File @@ -24,25 +24,22 @@ import com.google.common.io.Files import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * ./bin/run-example ml.DataFrameExample [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object DatasetExample { +object DataFrameExample { - case class Params( - input: String = "data/mllib/sample_libsvm_data.txt", - dataFormat: String = "libsvm") extends AbstractParams[Params] + case class Params(input: String = "data/mllib/sample_libsvm_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -52,9 +49,6 @@ object DatasetExample { opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) - opt[String]("dataFormat") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(input = x)) checkConfig { params => success } @@ -69,55 +63,42 @@ object DatasetExample { def run(params: Params) { - val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val conf = new SparkConf().setAppName(s"DataFrameExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // for implicit conversions // Load input data - val origData: RDD[LabeledPoint] = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) - } - println(s"Loaded ${origData.count()} instances from file: ${params.input}") - - // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDF() - println(s"Inferred schema:\n${df.schema.prettyJson}") - println(s"Converted to DataFrame with ${df.count()} records") - - // Select columns - val labelsDf: DataFrame = df.select("label") - val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } - val numLabels = labels.count() - val meanLabel = labels.fold(0.0)(_ + _) / numLabels - println(s"Selected label column with average value $meanLabel") - - val featuresDf: DataFrame = df.select("features") - val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } + println(s"Loading LIBSVM file with UDT from ${params.input}.") + val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() + println("Schema from LIBSVM:") + df.printSchema() + println(s"Loaded training data as a DataFrame with ${df.count()} records.") + + // Show statistical summary of labels. + val labelSummary = df.describe("label") + labelSummary.show() + + // Convert features column to an RDD of vectors. + val features = df.select("features").map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + // Save the records in a parquet file. val tmpDir = Files.createTempDir() tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) + // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.read.parquet(outputDir) - - println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } - val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + val newDF = sqlContext.read.parquet(outputDir) + println(s"Schema from Parquet:") + newDF.printSchema() sc.stop() } - } // scalastyle:on println