diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala new file mode 100644 index 0000000000000..52b808e01a0bf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import com.google.common.base.Objects +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider} + + +class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String) + (@transient val sqlContext: SQLContext) + extends BaseRelation with PrunedScan with Logging { + + private final val vectorType: DataType + = classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + + + override def schema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", vectorType, nullable = false) :: Nil + ) + + override def buildScan(requiredColumns: Array[String]): RDD[Row] = { + val sc = sqlContext.sparkContext + val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) + + val rowBuilders = requiredColumns.map { + case "label" => (pt: LabeledPoint) => Seq(pt.label) + case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) + case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) + } + + baseRdd.map(pt => { + Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty)) + }) + } + + override def hashCode(): Int = { + Objects.hashCode(path, schema) + } + + override def equals(other: Any): Boolean = other match { + case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema) + case _ => false + } + +} + +class DefaultSource extends RelationProvider with DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 + */ + override def shortName(): String = "libsvm" + + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified")) + } + + /** + * Returns a new base relation with the given parameters. + * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. + */ + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): + BaseRelation = { + val path = checkPath(parameters) + val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt + /** + * featuresType can be selected "dense" or "sparse". + * This parameter decides the type of returned feature vector. + */ + val featuresType = parameters.getOrElse("featuresType", "sparse") + new LibSVMRelation(path, numFeatures, featuresType)(sqlContext) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala new file mode 100644 index 0000000000000..92c021e4b4e69 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/package.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source + +import org.apache.spark.sql.{DataFrame, DataFrameReader} + +package object libsvm { + + /** + * Implicit declaration in order to be used from SQLContext. + * It is necessary to import org.apache.spark.ml.source.libsvm._ + * @param read + */ + implicit class LibSVMReader(read: DataFrameReader) { + def libsvm(filePath: String): DataFrame + = read.format(classOf[DefaultSource].getName).load(filePath) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala new file mode 100644 index 0000000000000..accf37d9886a9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.source.libsvm._ +import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + +class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var path: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val lines = + """ + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + path = tempDir.toURI.toString + } + + test("select as sparse vector") { + val df = sqlContext.read.options(Map("numFeatures" -> "6")).libsvm(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("select as dense vector") { + val df = sqlContext.read.options(Map("numFeatures" -> "6", "featuresType" -> "dense")) + .libsvm(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + assert(df.count() == 3) + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + assert(row1.getAs[DenseVector](1) == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) + } + + test("select without any option") { + val df = sqlContext.read.libsvm(path) + val row1 = df.first() + assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + +}