Skip to content

Commit

Permalink
Improve test suites
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Sep 6, 2015
1 parent 5ab62ab commit 4f40891
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import com.google.common.base.Objects

import org.apache.spark.Logging
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StructType, StructField, DoubleType}
Expand All @@ -37,7 +36,7 @@ import org.apache.spark.sql.sources._
*/
private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with Logging {
extends BaseRelation with TableScan with Logging with Serializable {

override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
Expand All @@ -48,18 +47,10 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)

val rowBuilders = Array(
(pt: LabeledPoint) => Seq(pt.label),
if (vectorType == "dense") {
(pt: LabeledPoint) => Seq(pt.features.toDense)
} else {
(pt: LabeledPoint) => Seq(pt.features.toSparse)
}
)

baseRdd.map(pt => {
Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty))
})
baseRdd.map { pt =>
val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse
Row(pt.label, features)
}
}

override def hashCode(): Int = {
Expand Down Expand Up @@ -95,7 +86,7 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
* featuresType can be selected "dense" or "sparse".
* This parameter decides the type of returned feature vector.
*/
val featuresType = parameters.getOrElse("featuresType", "sparse")
new LibSVMRelation(path, numFeatures, featuresType)(sqlContext)
val vectorType = parameters.getOrElse("vectorType", "sparse")
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,25 @@

package org.apache.spark.ml.source;

import java.io.File;
import java.io.IOException;

import com.google.common.base.Charsets;
import com.google.common.io.Files;

import org.apache.spark.mllib.linalg.DenseVector;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.io.File;
import java.io.IOException;

/**
* Test LibSVMRelation in Java.
Expand All @@ -50,11 +52,8 @@ public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
jsql = new SQLContext(jsc);

path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource")
.getCanonicalFile();
if (path.exists()) {
path.delete();
}
File tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
path = File.createTempFile("datasource", "libsvm-relation", tmpDir);

String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
Files.write(s, path, Charsets.US_ASCII);
Expand All @@ -69,11 +68,13 @@ public void tearDown() {

@Test
public void verifyLibSVMDF() {
dataset = jsql.read().format("org.apache.spark.ml.source.libsvm").load(path.getPath());
dataset = jsql.read().format("org.apache.spark.ml.source.libsvm").option("vectorType", "dense")
.load(path.getPath());
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
Row r = dataset.first();
Assert.assertEquals(Double.valueOf(r.getDouble(0)), Double.valueOf(1.0));
Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0));
Assert.assertEquals(Double.valueOf(1.0), Double.valueOf(r.getDouble(0)));
DenseVector v = r.getAs(1);
Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,40 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
}

test("select as sparse vector") {
val df = sqlContext.read.options(Map("numFeatures" -> "6")).libsvm(path)
val df = sqlContext.read.libsvm(path)
assert(df.columns(0) == "label")
assert(df.columns(1) == "features")
val row1 = df.first()
assert(row1.getDouble(0) == 1.0)
assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}

test("select as dense vector") {
val df = sqlContext.read.options(Map("numFeatures" -> "6", "featuresType" -> "dense"))
val df = sqlContext.read.options(Map("vectorType" -> "dense"))
.libsvm(path)
assert(df.columns(0) == "label")
assert(df.columns(1) == "features")
assert(df.count() == 3)
val row1 = df.first()
assert(row1.getDouble(0) == 1.0)
assert(row1.getAs[DenseVector](1) == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
val v = row1.getAs[DenseVector](1)
assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
}

test("select without any option") {
val df = sqlContext.read.libsvm(path)
test("select long vector with specifying the number of features") {
val lines =
"""
|1 1:1 10:2 20:3 30:4 40:5 50:6 60:7 70:8 80:9 90:10 100:1
|0 1:1 10:10 20:9 30:8 40:7 50:6 60:5 70:4 80:3 90:2 100:1
""".stripMargin
val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00001")
Files.write(lines, file, Charsets.US_ASCII)
val df = sqlContext.read.option("numFeatures", "100").libsvm(tempDir.toURI.toString)
val row1 = df.first()
assert(row1.getAs[SparseVector](1) == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(100, Seq((0, 1.0), (9, 2.0), (19, 3.0), (29, 4.0), (39, 5.0),
(49, 6.0), (59, 7.0), (69, 8.0), (79, 9.0), (89, 10.0), (99, 1.0))))
}
}

0 comments on commit 4f40891

Please sign in to comment.