Skip to content

Commit

Permalink
[SPARK-21723][ML] Fix writing LibSVM (key not found: numFeatures)
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Check the option "numFeatures" only when reading LibSVM, not when writing. When writing, Spark was raising an exception. After the change it will ignore the option completely. liancheng HyukjinKwon

(Maybe the usage should be forbidden when writing, in a major version change?).

## How was this patch tested?

Manual test, that loading and writing LibSVM files work fine, both with and without the numFeatures option.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Jan Vrsovsky <jan.vrsovsky@firma.seznam.cz>

Closes apache#18872 from ProtD/master.
  • Loading branch information
Jan Vrsovsky authored and srowen committed Aug 16, 2017
1 parent 8c54f1e commit 8321c14
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ private[libsvm] class LibSVMFileFormat

override def toString: String = "LibSVM"

private def verifySchema(dataSchema: StructType): Unit = {
private def verifySchema(dataSchema: StructType, forWriting: Boolean): Unit = {
if (
dataSchema.size != 2 ||
!dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
!dataSchema(1).dataType.sameType(new VectorUDT()) ||
!(dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
!(forWriting || dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
) {
throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
Expand Down Expand Up @@ -119,7 +119,7 @@ private[libsvm] class LibSVMFileFormat
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema)
verifySchema(dataSchema, true)
new OutputWriterFactory {
override def newInstance(
path: String,
Expand All @@ -142,7 +142,7 @@ private[libsvm] class LibSVMFileFormat
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
verifySchema(dataSchema)
verifySchema(dataSchema, false)
val numFeatures = dataSchema("features").metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt
assert(numFeatures > 0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ package org.apache.spark.ml.source.libsvm

import java.io.{File, IOException}
import java.nio.charset.StandardCharsets
import java.util.List

import com.google.common.io.Files

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.util.Utils


Expand All @@ -44,14 +47,14 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
"""
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data")
val dir = Utils.createTempDir()
val succ = new File(dir, "_SUCCESS")
val file0 = new File(dir, "part-00000")
val file1 = new File(dir, "part-00001")
Files.write("", succ, StandardCharsets.UTF_8)
Files.write(lines0, file0, StandardCharsets.UTF_8)
Files.write(lines1, file1, StandardCharsets.UTF_8)
path = dir.toURI.toString
path = dir.getPath
}

override def afterAll(): Unit = {
Expand Down Expand Up @@ -108,12 +111,12 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {

test("write libsvm data and read it again") {
val df = spark.read.format("libsvm").load(path)
val tempDir2 = new File(tempDir, "read_write_test")
val writepath = tempDir2.toURI.toString
val writePath = Utils.createTempDir().getPath

// TODO: Remove requirement to coalesce by supporting multiple reads.
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath)

val df2 = spark.read.format("libsvm").load(writepath)
val df2 = spark.read.format("libsvm").load(writePath)
val row1 = df2.first()
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
Expand All @@ -126,6 +129,27 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}

test("write libsvm data from scratch and read it again") {
val rawData = new java.util.ArrayList[Row]()
rawData.add(Row(1.0, Vectors.sparse(3, Seq((0, 2.0), (1, 3.0)))))
rawData.add(Row(4.0, Vectors.sparse(3, Seq((0, 5.0), (2, 6.0)))))

val struct = StructType(
StructField("labelFoo", DoubleType, false) ::
StructField("featuresBar", VectorType, false) :: Nil
)
val df = spark.sqlContext.createDataFrame(rawData, struct)

val writePath = Utils.createTempDir().getPath

df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath)

val df2 = spark.read.format("libsvm").load(writePath)
val row1 = df2.first()
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(3, Seq((0, 2.0), (1, 3.0))))
}

test("select features from libsvm relation") {
val df = spark.read.format("libsvm").load(path)
df.select("features").rdd.map { case Row(d: Vector) => d }.first
Expand Down

0 comments on commit 8321c14

Please sign in to comment.