Skip to content

Commit

Permalink
Add Java test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Sep 3, 2015
1 parent 3fd8dce commit 70ee4dd
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,29 @@ package org.apache.spark.ml.source.libsvm

import com.google.common.base.Objects
import org.apache.spark.Logging
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider}


class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: String)
/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
* @param path
* @param numFeatures
* @param vectorType
* @param sqlContext
*/
private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with PrunedScan with Logging {

private final val vectorType: DataType
= classOf[Vector].getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()


override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", vectorType, nullable = false) :: Nil
StructField("features", new VectorUDT(), nullable = false) :: Nil
)

override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
Expand All @@ -47,8 +50,8 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S

val rowBuilders = requiredColumns.map {
case "label" => (pt: LabeledPoint) => Seq(pt.label)
case "features" if featuresType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse)
case "features" if featuresType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense)
case "features" if vectorType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse)
case "features" if vectorType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense)
}

baseRdd.map(pt => {
Expand All @@ -69,16 +72,6 @@ class LibSVMRelation(val path: String, val numFeatures: Int, val featuresType: S

class DefaultSource extends RelationProvider with DataSourceRegister {

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def format(): String = "parquet"
* }}}
*
* @since 1.5.0
*/
override def shortName(): String = "libsvm"

private def checkPath(parameters: Map[String, String]): String = {
Expand All @@ -90,8 +83,8 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
* by the Map that is passed to the function.
*/
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]):
BaseRelation = {
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
: BaseRelation = {
val path = checkPath(parameters)
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package org.apache.spark.ml.source;

import com.google.common.base.Charsets;
import com.google.common.io.Files;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.io.File;
import java.io.IOException;

/**
* Test LibSVMRelation in Java.
*/
public class JavaLibSVMRelationSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient DataFrame dataset;

private File path;

@Before
public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
jsql = new SQLContext(jsc);

path = Utils.createTempDir(System.getProperty("java.io.tmpdir"),
"datasource").getCanonicalFile();
if (path.exists()) {
path.delete();
}

String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
Files.write(s, path, Charsets.US_ASCII);
}

@After
public void tearDown() {
jsc.stop();
jsc = null;
}

@Test
public void verifyLibSvmDF() {
dataset = jsql.read().format("libsvm").load();
Assert.assertEquals(dataset.columns()[0], "label");
Assert.assertEquals(dataset.columns()[1], "features");
Row r = dataset.first();
Assert.assertTrue(r.getDouble(0) == 1.0);
Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,4 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
val row1 = df.first()
assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}


}

0 comments on commit 70ee4dd

Please sign in to comment.