From 9786a4ee7d49a502d885f8c172022063fc3af98b Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 10 Nov 2015 18:11:38 +0800 Subject: [PATCH 1/9] [SPARK-11622][MLLIB] Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter --- .../ml/source/libsvm/LibSVMRelation.scala | 105 +++++++++++++++--- .../source/libsvm/LibSVMRelationSuite.scala | 21 +++- 2 files changed, 111 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 1bed542c40316..57cfa9618069b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -17,16 +17,27 @@ package org.apache.spark.ml.source.libsvm +import java.io.{CharArrayWriter, IOException} + +import com.fasterxml.jackson.core.JsonFactory import com.google.common.base.Objects +import org.apache.hadoop.fs.{Path, FileStatus} +import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} import org.apache.spark.Logging import org.apache.spark.annotation.Since +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration /** * LibSVMRelation provides the DataFrame constructed from LibSVM format data. @@ -37,14 +48,10 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) - extends BaseRelation with TableScan with Logging with Serializable { - - override def schema: StructType = StructType( - StructField("label", DoubleType, nullable = false) :: - StructField("features", new VectorUDT(), nullable = false) :: Nil - ) + extends HadoopFsRelation with Logging with Serializable { - override def buildScan(): RDD[Row] = { + override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]) + : RDD[Row] = { val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) val sparse = vectorType == "sparse" @@ -66,8 +73,63 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val case _ => false } + + override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job): + _root_.org.apache.spark.sql.sources.OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new LibSVMOutputWriter(path, dataSchema, context) + } + } + } + + override def paths: Array[String] = Array(path) + + override def dataSchema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil) } + +private[libsvm] class LibSVMOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private[this] val buffer = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) + } + + override def write(row: Row): Unit = { + val label = row.get(0) + val vector = row.get(1).asInstanceOf[Vector] + val sb = new StringBuilder(label.toString) + vector.foreachActive { case (i, v) => + sb += ' ' + sb ++= s"${i + 1}:$v" + } + buffer.set(sb.mkString) + recordWriter.write(NullWritable.get(), buffer) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} /** * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and @@ -99,16 +161,31 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] */ @Since("1.6.0") -class DefaultSource extends RelationProvider with DataSourceRegister { +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" - @Since("1.6.0") - override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) - : BaseRelation = { - val path = parameters.getOrElse("path", - throw new IllegalArgumentException("'path' must be specified")) + private def verifySchema(dataSchema: StructType): Unit = { + if (dataSchema.size != 2 || + (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) + || !dataSchema(1).dataType.sameType(new VectorUDT()))) { + throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}") + } + } + + override def createRelation(sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + val path = if (paths.length == 1) paths(0) + else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data") + else throw new IOException("Multiple input paths are not supported for libsvm data") + if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) { + throw new IOException("Partition is not supported for libsvm data") + } + dataSchema.foreach(verifySchema(_)) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt val vectorType = parameters.getOrElse("vectorType", "sparse") new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 5f4d5f11bdd68..90460f2cf859d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.source.libsvm -import java.io.File +import java.io.{IOException, File} import com.google.common.base.Charsets import com.google.common.io.Files @@ -82,4 +82,23 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } + + test("write libsvm data and read it again") { + val df = sqlContext.read.format("libsvm").load(path) + val writepath = path + "_2" + df.write.save(writepath) + + val df2 = sqlContext.read.format("libsvm").load(writepath) + val row1 = df.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("write libsvm data failed due to invalid schema") { + val df = sqlContext.read.format("text").load(path) + val e = intercept[IOException] { + df.write.format("libsvm").save(path + "_2") + } + assert(e.getMessage.contains("Illegal schema for libsvm data")) + } } From 7cf79efa3005e05a84edb922efd38874e3503106 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 11 Nov 2015 12:32:57 +0800 Subject: [PATCH 2/9] minor change (remove Logging trait) --- .../org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 57cfa9618069b..3fd54ff09dc57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -48,7 +48,7 @@ import org.apache.spark.util.SerializableConfiguration */ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) - extends HadoopFsRelation with Logging with Serializable { + extends HadoopFsRelation with Serializable { override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]) : RDD[Row] = { From 409f4d53eb6b6171f7a5559f51d279f5958986fd Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 19 Jan 2016 20:45:40 -0800 Subject: [PATCH 3/9] address review comments --- .../spark/ml/source/libsvm/LibSVMRelation.scala | 11 ++++++----- .../spark/ml/source/libsvm/LibSVMRelationSuite.scala | 8 +++++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 3fd54ff09dc57..cfa783f9fcb6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -174,11 +174,12 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { } } - override def createRelation(sqlContext: SQLContext, - paths: Array[String], - dataSchema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { val path = if (paths.length == 1) paths(0) else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data") else throw new IOException("Multiple input paths are not supported for libsvm data") diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 90460f2cf859d..e00d9b2703ed5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -25,6 +25,7 @@ import com.google.common.io.Files import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SaveMode import org.apache.spark.util.Utils class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -85,11 +86,12 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data and read it again") { val df = sqlContext.read.format("libsvm").load(path) - val writepath = path + "_2" - df.write.save(writepath) + val tempDir2 = Utils.createTempDir() + val writepath = tempDir2.toURI.toString + df.write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) val df2 = sqlContext.read.format("libsvm").load(writepath) - val row1 = df.first() + val row1 = df2.first() val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } From 75fcb5019565b747e44d128b8ebbb0b11b98977e Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 20 Jan 2016 11:46:34 -0800 Subject: [PATCH 4/9] minor code style fix --- .../org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index cfa783f9fcb6c..46c4db0173d85 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -21,7 +21,7 @@ import java.io.{CharArrayWriter, IOException} import com.fasterxml.jackson.core.JsonFactory import com.google.common.base.Objects -import org.apache.hadoop.fs.{Path, FileStatus} +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} From 65cfb1177f0db9a28c44a370d337f8783ac0107c Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 20 Jan 2016 15:00:28 -0800 Subject: [PATCH 5/9] fix import ordering --- .../spark/ml/source/libsvm/LibSVMRelation.scala | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 46c4db0173d85..7ab1e750b6e25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -17,27 +17,19 @@ package org.apache.spark.ml.source.libsvm -import java.io.{CharArrayWriter, IOException} - -import com.fasterxml.jackson.core.JsonFactory +import java.io.IOException import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} - -import org.apache.spark.Logging +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.annotation.Since -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.util.SerializableConfiguration /** * LibSVMRelation provides the DataFrame constructed from LibSVM format data. From 6957dfe9d281eb0e18e4d61c0130306373fcd59f Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 20 Jan 2016 15:13:09 -0800 Subject: [PATCH 6/9] fix code style --- .../org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 7ab1e750b6e25..20ba6f8016f07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -18,11 +18,14 @@ package org.apache.spark.ml.source.libsvm import java.io.IOException + import com.google.common.base.Objects + import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat + import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils From 0d6d06dc7aa98f2f2e6a1fb20d0af59f31ae4531 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 20 Jan 2016 20:33:17 -0800 Subject: [PATCH 7/9] fix code style issue --- .../scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 20ba6f8016f07..b9c364b05dc11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -20,7 +20,6 @@ package org.apache.spark.ml.source.libsvm import java.io.IOException import com.google.common.base.Objects - import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} From 8a2c96fc28021e28c8009d4c58ce3d94e9227683 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 20 Jan 2016 20:43:42 -0800 Subject: [PATCH 8/9] code style issue --- .../org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index e00d9b2703ed5..528d9e21cb1fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.source.libsvm -import java.io.{IOException, File} +import java.io.{File, IOException} import com.google.common.base.Charsets import com.google.common.io.Files From 5bdf2249a970e443796ab6f88f1680646109e570 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 20 Jan 2016 21:46:17 -0800 Subject: [PATCH 9/9] fix binary incompatibilities --- project/MimaExcludes.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 905fb4cd90377..9b738daa4bb04 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -167,6 +167,10 @@ object MimaExcludes { // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") + ) ++ Seq( + // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") ) case v if v.startsWith("1.6") => Seq(