Skip to content

Commit

Permalink
rename DatasetExample under mllib/examples and remove support for den…
Browse files Browse the repository at this point in the history
…se format
  • Loading branch information
mengxr committed Nov 20, 2015
1 parent 921900f commit 95f7b6d
Showing 1 changed file with 26 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,30 @@
*/

// scalastyle:off println
package org.apache.spark.examples.mllib
package org.apache.spark.examples.ml

import java.io.File

import com.google.common.io.Files
import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext, DataFrame}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

/**
* An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with
* An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
* ./bin/run-example ml.DataFrameExample [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DatasetExample {
object DataFrameExample {

case class Params(
input: String = "data/mllib/sample_libsvm_data.txt",
dataFormat: String = "libsvm") extends AbstractParams[Params]
case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
extends AbstractParams[Params]

def main(args: Array[String]) {
val defaultParams = Params()
Expand All @@ -52,9 +49,6 @@ object DatasetExample {
opt[String]("input")
.text(s"input path to dataset")
.action((x, c) => c.copy(input = x))
opt[String]("dataFormat")
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
.action((x, c) => c.copy(input = x))
checkConfig { params =>
success
}
Expand All @@ -69,55 +63,42 @@ object DatasetExample {

def run(params: Params) {

val conf = new SparkConf().setAppName(s"DatasetExample with $params")
val conf = new SparkConf().setAppName(s"DataFrameExample with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._ // for implicit conversions

// Load input data
val origData: RDD[LabeledPoint] = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
}
println(s"Loaded ${origData.count()} instances from file: ${params.input}")

// Convert input data to DataFrame explicitly.
val df: DataFrame = origData.toDF()
println(s"Inferred schema:\n${df.schema.prettyJson}")
println(s"Converted to DataFrame with ${df.count()} records")

// Select columns
val labelsDf: DataFrame = df.select("label")
val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v }
val numLabels = labels.count()
val meanLabel = labels.fold(0.0)(_ + _) / numLabels
println(s"Selected label column with average value $meanLabel")

val featuresDf: DataFrame = df.select("features")
val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v }
println(s"Loading LIBSVM file with UDT from ${params.input}.")
val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache()
println("Schema from LIBSVM:")
df.printSchema()
println(s"Loaded training data as a DataFrame with ${df.count()} records.")

// Show statistical summary of labels.
val labelSummary = df.describe("label")
labelSummary.show()

// Convert features column to an RDD of vectors.
val features = df.select("features").map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

// Save the records in a parquet file.
val tmpDir = Files.createTempDir()
tmpDir.deleteOnExit()
val outputDir = new File(tmpDir, "dataset").toString
println(s"Saving to $outputDir as Parquet file.")
df.write.parquet(outputDir)

// Load the records back.
println(s"Loading Parquet file with UDT from $outputDir.")
val newDataset = sqlContext.read.parquet(outputDir)

println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v }
val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
val newDF = sqlContext.read.parquet(outputDir)
println(s"Schema from Parquet:")
newDF.printSchema()

sc.stop()
}

}
// scalastyle:on println

0 comments on commit 95f7b6d

Please sign in to comment.