Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Sep 3, 2015
1 parent 70ee4dd commit aef9564
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,44 @@
package org.apache.spark.ml.source.libsvm

import com.google.common.base.Objects

import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{StructType, StructField, DoubleType}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, PrunedScan, BaseRelation, RelationProvider}
import org.apache.spark.sql.sources._

/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
* @param path
* @param numFeatures
* @param vectorType
* @param sqlContext
* @param path File path of LibSVM format
* @param numFeatures The number of features
* @param vectorType The type of vector. It can be 'sparse' or 'dense'
* @param sqlContext The Spark SQLContext
*/
private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with PrunedScan with Logging {
extends BaseRelation with TableScan with Logging {

override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false) :: Nil
)

override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
override def buildScan(): RDD[Row] = {
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)

val rowBuilders = requiredColumns.map {
case "label" => (pt: LabeledPoint) => Seq(pt.label)
case "features" if vectorType == "sparse" => (pt: LabeledPoint) => Seq(pt.features.toSparse)
case "features" if vectorType == "dense" => (pt: LabeledPoint) => Seq(pt.features.toDense)
}
val rowBuilders = Array(
(pt: LabeledPoint) => Seq(pt.label),
if (vectorType == "dense") {
(pt: LabeledPoint) => Seq(pt.features.toSparse)
} else {
(pt: LabeledPoint) => Seq(pt.features.toDense)
}
)

baseRdd.map(pt => {
Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty))
Expand All @@ -75,7 +78,8 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
override def shortName(): String = "libsvm"

private def checkPath(parameters: Map[String, String]): String = {
parameters.getOrElse("path", sys.error("'path' must be specified"))
require(parameters.contains("path"), "'path' must be specified")
parameters.get("path")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package object libsvm {
/**
* Implicit declaration in order to be used from SQLContext.
* It is necessary to import org.apache.spark.ml.source.libsvm._
* @param read
* @param read Given original DataFrameReader
*/
implicit class LibSVMReader(read: DataFrameReader) {
def libsvm(filePath: String): DataFrame
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.source;

import com.google.common.base.Charsets;
import com.google.common.io.Files;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
Expand All @@ -31,8 +50,8 @@ public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
jsql = new SQLContext(jsc);

path = Utils.createTempDir(System.getProperty("java.io.tmpdir"),
"datasource").getCanonicalFile();
path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource")
.getCanonicalFile();
if (path.exists()) {
path.delete();
}
Expand All @@ -45,15 +64,16 @@ public void setUp() throws IOException {
public void tearDown() {
jsc.stop();
jsc = null;
path.delete();
}

@Test
public void verifyLibSvmDF() {
public void verifyLibSVMDF() {
dataset = jsql.read().format("libsvm").load();
Assert.assertEquals(dataset.columns()[0], "label");
Assert.assertEquals(dataset.columns()[1], "features");
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
Row r = dataset.first();
Assert.assertTrue(r.getDouble(0) == 1.0);
Assert.assertEquals(Double.valueOf(r.getDouble(0)), Double.valueOf(1.0));
Assert.assertEquals(r.getAs(1), Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0));
}
}

0 comments on commit aef9564

Please sign in to comment.