Skip to content

Commit

Permalink
Added JavaLabeledPointSuite.java for spark.ml, and added constructor …
Browse files Browse the repository at this point in the history
…to LabeledPoint which defaults weight to 1.0
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent adbe50a commit 1680905
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 2 deletions.
3 changes: 3 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ import org.apache.spark.mllib.linalg.Vector
@BeanInfo
case class LabeledPoint(label: Double, features: Vector, weight: Double) {

/** Constructor which sets instance weight to 1.0 */
def this(label: Double, features: Vector) = this(label, features, 1.0)

override def toString: String = {
"(%s,%s,%s)".format(label, features, weight)
}
Expand Down
61 changes: 61 additions & 0 deletions mllib/src/test/java/org/apache/spark/ml/JavaLabeledPointSuite.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.apache.spark.ml;

import java.util.List;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import com.google.common.collect.Lists;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.api.java.JavaSQLContext;
import org.apache.spark.sql.api.java.JavaSchemaRDD;
import org.apache.spark.sql.api.java.Row;

/**
* Test {@link LabeledPoint} in Java
*/
public class JavaLabeledPointSuite {

private transient JavaSparkContext jsc;
private transient JavaSQLContext jsql;

@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLabeledPointSuite");
jsql = new JavaSQLContext(jsc);
}

@After
public void tearDown() {
jsc.stop();
jsc = null;
}

@Test
public void labeledPointDefaultWeight() {
double label = 1.0;
Vector features = Vectors.dense(1.0, 2.0, 3.0);
LabeledPoint lp1 = new LabeledPoint(label, features);
LabeledPoint lp2 = new LabeledPoint(label, features, 1.0);
assert(lp1.equals(lp2));
}

@Test
public void labeledPointSchemaRDD() {
List<LabeledPoint> arr = Lists.newArrayList(
new LabeledPoint(0.0, Vectors.dense(1.0, 2.0, 3.0)),
new LabeledPoint(1.0, Vectors.dense(1.1, 2.1, 3.1)),
new LabeledPoint(0.0, Vectors.dense(1.2, 2.2, 3.2)),
new LabeledPoint(1.0, Vectors.dense(1.3, 2.3, 3.3)));
JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
JavaSchemaRDD schemaRDD = jsql.applySchema(rdd, LabeledPoint.class);
schemaRDD.registerTempTable("points");
List<Row> points = jsql.sql("SELECT label, features FROM points").collect();
assert (points.size() == arr.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.ml

import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{SQLContext, SchemaRDD}
import org.apache.spark.sql.SQLContext

/**
* Test [[LabeledPoint]]
*/
class LabeledPointSuite extends FunSuite with MLlibTestSparkContext {

@transient var sqlContext: SQLContext = _
Expand Down

0 comments on commit 1680905

Please sign in to comment.