From 70ee4dd4fc8081c4b1abd52c4bd25b158299b907 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Wed, 2 Sep 2015 23:35:18 +0900 Subject: [PATCH] Add Java test --- .../ml/source/libsvm/LibSVMRelation.scala | 37 +++++------- .../ml/source/JavaLibSVMRelationSuite.java | 59 +++++++++++++++++++ .../spark/ml/source/LibSVMRelationSuite.scala | 2 - 3 files changed, 74 insertions(+), 24 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 52b808e01a0bf..bf10536f3955d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.source.libsvm import com.google.common.base.Objects import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -27,18 +28,20 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider} - -class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String) +/** + * LibSVMRelation provides the DataFrame constructed from LibSVM format data. + * @param path + * @param numFeatures + * @param vectorType + * @param sqlContext + */ +private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) extends BaseRelation with PrunedScan with Logging { - private final val vectorType: DataType - = classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - - override def schema: StructType = StructType( StructField("label", DoubleType, nullable = false) :: - StructField("features", vectorType, nullable = false) :: Nil + StructField("features", new VectorUDT(), nullable = false) :: Nil ) override def buildScan(requiredColumns: Array[String]): RDD[Row] = { @@ -47,8 +50,8 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S val rowBuilders = requiredColumns.map { case "label" => (pt: LabeledPoint) => Seq(pt.label) - case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) - case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) + case "features" if vectorType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse) + case "features" if vectorType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense) } baseRdd.map(pt => { @@ -69,16 +72,6 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S class DefaultSource extends RelationProvider with DataSourceRegister { - /** - * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source. For example: - * - * {{{ - * override def format(): String = "parquet" - * }}} - * - * @since 1.5.0 - */ override def shortName(): String = "libsvm" private def checkPath(parameters: Map[String, String]): String = { @@ -90,8 +83,8 @@ class DefaultSource extends RelationProvider with DataSourceRegister { * Note: the parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ - override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): - BaseRelation = { + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) + : BaseRelation = { val path = checkPath(parameters) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt /** diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java new file mode 100644 index 0000000000000..0464988f99803 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -0,0 +1,59 @@ +package org.apache.spark.ml.source; + +import com.google.common.base.Charsets; +import com.google.common.io.Files; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; + +/** + * Test LibSVMRelation in Java. + */ +public class JavaLibSVMRelationSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + + private File path; + + @Before + public void setUp() throws IOException { + jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); + jsql = new SQLContext(jsc); + + path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), + "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + + String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; + Files.write(s, path, Charsets.US_ASCII); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void verifyLibSvmDF() { + dataset = jsql.read().format("libsvm").load(); + Assert.assertEquals(dataset.columns()[0], "label"); + Assert.assertEquals(dataset.columns()[1], "features"); + Row r = dataset.first(); + Assert.assertTrue(r.getDouble(0) == 1.0); + Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala index accf37d9886a9..960ab8575fa52 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -69,6 +69,4 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val row1 = df.first() assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } - - }