Skip to content

Commit

Permalink
[SPARK-10117] Implement SQL data source API for reading LIBSVM data
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Sep 3, 2015
1 parent 00d9af5 commit 3fd8dce
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.source.libsvm

import com.google.common.base.Objects
import org.apache.spark.Logging
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider}


class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with PrunedScan with Logging {

private final val vectorType: DataType
= classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()


override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", vectorType, nullable = false) :: Nil
)

override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)

val rowBuilders = requiredColumns.map {
case "label" => (pt: LabeledPoint) => Seq(pt.label)
case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse)
case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense)
}

baseRdd.map(pt => {
Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty))
})
}

override def hashCode(): Int = {
Objects.hashCode(path, schema)
}

override def equals(other: Any): Boolean = other match {
case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema)
case _ => false
}

}

class DefaultSource extends RelationProvider with DataSourceRegister {

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def format(): String = "parquet"
* }}}
*
* @since 1.5.0
*/
override def shortName(): String = "libsvm"

private def checkPath(parameters: Map[String, String]): String = {
parameters.getOrElse("path", sys.error("'path' must be specified"))
}

/**
* Returns a new base relation with the given parameters.
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
* by the Map that is passed to the function.
*/
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]):
BaseRelation = {
val path = checkPath(parameters)
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
/**
* featuresType can be selected "dense" or "sparse".
* This parameter decides the type of returned feature vector.
*/
val featuresType = parameters.getOrElse("featuresType", "sparse")
new LibSVMRelation(path, numFeatures, featuresType)(sqlContext)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.source

import org.apache.spark.sql.{DataFrame, DataFrameReader}

package object libsvm {

/**
* Implicit declaration in order to be used from SQLContext.
* It is necessary to import org.apache.spark.ml.source.libsvm._
* @param read
*/
implicit class LibSVMReader(read: DataFrameReader) {
def libsvm(filePath: String): DataFrame
= read.format(classOf[DefaultSource].getName).load(filePath)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.source

import java.io.File

import com.google.common.base.Charsets
import com.google.common.io.Files
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.source.libsvm._
import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils

class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
var path: String = _

override def beforeAll(): Unit = {
super.beforeAll()
val lines =
"""
|1 1:1.0 3:2.0 5:3.0
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
path = tempDir.toURI.toString
}

test("select as sparse vector") {
val df = sqlContext.read.options(Map("numFeatures" -> "6")).libsvm(path)
assert(df.columns(0) == "label")
assert(df.columns(1) == "features")
val row1 = df.first()
assert(row1.getDouble(0) == 1.0)
assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}

test("select as dense vector") {
val df = sqlContext.read.options(Map("numFeatures" -> "6", "featuresType" -> "dense"))
.libsvm(path)
assert(df.columns(0) == "label")
assert(df.columns(1) == "features")
assert(df.count() == 3)
val row1 = df.first()
assert(row1.getDouble(0) == 1.0)
assert(row1.getAs[DenseVector](1) == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
}

test("select without any option") {
val df = sqlContext.read.libsvm(path)
val row1 = df.first()
assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}


}

0 comments on commit 3fd8dce

Please sign in to comment.