diff --git a/mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala index 8a310fc7b1fee..8b6b2f3fa2756 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala @@ -34,6 +34,9 @@ import org.apache.spark.mllib.linalg.Vector @BeanInfo case class LabeledPoint(label: Double, features: Vector, weight: Double) { + /** Constructor which sets instance weight to 1.0 */ + def this(label: Double, features: Vector) = this(label, features, 1.0) + override def toString: String = { "(%s,%s,%s)".format(label, features, weight) } diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaLabeledPointSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaLabeledPointSuite.java new file mode 100644 index 0000000000000..ac6cb7aa3b344 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/JavaLabeledPointSuite.java @@ -0,0 +1,61 @@ +package org.apache.spark.ml; + +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import org.apache.spark.sql.api.java.Row; + +/** + * Test {@link LabeledPoint} in Java + */ +public class JavaLabeledPointSuite { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLabeledPointSuite"); + jsql = new JavaSQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void labeledPointDefaultWeight() { + double label = 1.0; + Vector features = Vectors.dense(1.0, 2.0, 3.0); + LabeledPoint lp1 = new LabeledPoint(label, features); + LabeledPoint lp2 = new LabeledPoint(label, features, 1.0); + assert(lp1.equals(lp2)); + } + + @Test + public void labeledPointSchemaRDD() { + List arr = Lists.newArrayList( + new LabeledPoint(0.0, Vectors.dense(1.0, 2.0, 3.0)), + new LabeledPoint(1.0, Vectors.dense(1.1, 2.1, 3.1)), + new LabeledPoint(0.0, Vectors.dense(1.2, 2.2, 3.2)), + new LabeledPoint(1.0, Vectors.dense(1.3, 2.3, 3.3))); + JavaRDD rdd = jsc.parallelize(arr); + JavaSchemaRDD schemaRDD = jsql.applySchema(rdd, LabeledPoint.class); + schemaRDD.registerTempTable("points"); + List points = jsql.sql("SELECT label, features FROM points").collect(); + assert (points.size() == arr.size()); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/LabeledPointSuite.scala index 34460a9e21d0d..94659ba95b1be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/LabeledPointSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.ml import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.SQLContext +/** + * Test [[LabeledPoint]] + */ class LabeledPointSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _