From 9e4413f6869ff9d0c1ff915b564ef50de196833e Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 10 Oct 2016 13:44:47 +0800 Subject: [PATCH 01/10] create pr --- .../main/scala/org/apache/spark/ml/Predictor.scala | 13 +++++++++++-- .../apache/spark/ml/classification/Classifier.scala | 4 ++-- .../spark/ml/classification/GBTClassifier.scala | 2 +- .../ml/classification/LogisticRegression.scala | 2 +- .../apache/spark/ml/classification/NaiveBayes.scala | 2 +- .../org/apache/spark/ml/util/MLTestingUtils.scala | 2 +- 6 files changed, 17 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e29d7f48a1d6b..5b7f6e2920dd5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -87,7 +87,8 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - copyValues(train(dataset).setParent(this)) + val casted = castPoints(dataset) + copyValues(train(casted).setParent(this)) } override def copy(extra: ParamMap): Learner @@ -121,10 +122,18 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } + + /** + * Return the given DataFrame, with [[labelCol]] casted to DoubleType. + */ + protected def castPoints(dataset: Dataset[_]): DataFrame = { + val labelColMeta = dataset.schema.fields.filter(_.name == $(labelCol)).head.metadata + dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelColMeta) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d1b21b16f2342..a3da3067e1b5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -71,7 +71,7 @@ abstract class Classifier[ * and put it in an RDD with strong types. * * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) - * and features ([[Vector]]). Labels are cast to [[DoubleType]]. + * and features ([[Vector]]). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). * @throws SparkException if any label is not an integer >= 0 @@ -79,7 +79,7 @@ abstract class Classifier[ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + s" $numClasses, but requires numClasses > 0.") - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + s" dataset with invalid label $label. Labels must be integers in range" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 8bffe0cda0327..f8f164e8c14bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") ( // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. val oldDataset: RDD[LabeledPoint] = - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fdaae04c42ec..58ff4958a322b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1168,7 +1168,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map { + predictions.select(col(probabilityCol), col(labelCol)).rdd.map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 994ed993c99df..b03a07a6bc1e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") ( // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( seqOp = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 472a5af06e7a2..5efce504c056e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -117,7 +117,7 @@ object MLTestingUtils extends SparkFunSuite { Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) types.map { t => val castDF = df.select(col(labelColName).cast(t), col(featuresColName)) - t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName) + t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName) }.toMap } From 6ad65080593d5463bd663c7130efd95ac6e3f1e1 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 10 Oct 2016 14:04:59 +0800 Subject: [PATCH 02/10] rename func --- mllib/src/main/scala/org/apache/spark/ml/Predictor.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 5b7f6e2920dd5..d06ddb714c76b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -87,7 +87,7 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - val casted = castPoints(dataset) + val casted = castDataSet(dataset) copyValues(train(casted).setParent(this)) } @@ -130,9 +130,9 @@ abstract class Predictor[ /** * Return the given DataFrame, with [[labelCol]] casted to DoubleType. */ - protected def castPoints(dataset: Dataset[_]): DataFrame = { - val labelColMeta = dataset.schema.fields.filter(_.name == $(labelCol)).head.metadata - dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelColMeta) + protected def castDataSet(dataset: Dataset[_]): DataFrame = { + val labelMeta = dataset.schema.fields.filter(_.name == $(labelCol)).head.metadata + dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) } } From 41e63e223411ac0e50dd3e634baa6f87852fc1b5 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 10 Oct 2016 16:02:13 +0800 Subject: [PATCH 03/10] revert lr --- .../org/apache/spark/ml/classification/LogisticRegression.scala | 2 +- .../spark/ml/classification/LogisticRegressionSuite.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 58ff4958a322b..8fdaae04c42ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1168,7 +1168,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(col(probabilityCol), col(labelCol)).rdd.map { + predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bc631dc6d3149..8771fd2e9d2b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1807,7 +1807,6 @@ class LogisticRegressionSuite .objectiveHistory .sliding(2) .forall(x => x(0) >= x(1))) - } test("binary logistic regression with weighted data") { From 764650a73ddb30ff0cfaa1f5a2d2c4c00556c63b Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 10 Oct 2016 19:54:33 +0800 Subject: [PATCH 04/10] del cast in regression --- .../spark/ml/regression/GeneralizedLinearRegression.scala | 2 +- .../scala/org/apache/spark/ml/regression/LinearRegression.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 33cb25c8c7f66..8656ecf609ea4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 519f3bdec82df..ae876b3839734 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } From 59d02d5d2530e920ef2ec5230a8ebb9b9edb30c9 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 11 Oct 2016 11:19:20 +0800 Subject: [PATCH 05/10] add testsuite for predictor --- .../scala/org/apache/spark/ml/Predictor.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../org/apache/spark/ml/PredictorSuite.scala | 57 +++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index d06ddb714c76b..ceb69370509cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -131,7 +131,7 @@ abstract class Predictor[ * Return the given DataFrame, with [[labelCol]] casted to DoubleType. */ protected def castDataSet(dataset: Dataset[_]): DataFrame = { - val labelMeta = dataset.schema.fields.filter(_.name == $(labelCol)).head.metadata + val labelMeta = dataset.schema($(labelCol)).metadata dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fdaae04c42ec..c4651054fd765 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") ( LogisticRegressionModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala new file mode 100644 index 0000000000000..881ebecefe400 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types._ + +class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + class MockPredictor(override val uid: String) + extends Predictor[Vector, MockPredictor, MockPredictionModel] { + + override def train(dataset: Dataset[_]): MockPredictionModel = { + require(dataset.schema("label").dataType == DoubleType) + new MockPredictionModel(uid) + } + + override def copy(extra: ParamMap): MockPredictor = defaultCopy(extra) + } + + class MockPredictionModel(override val uid: String) + extends PredictionModel[Vector, MockPredictionModel] { + + override def predict(features: Vector): Double = 1.0 + + override def copy(extra: ParamMap): MockPredictionModel = defaultCopy(extra) + } + + test("should support all NumericType labels and not support other types") { + val predictor = new MockPredictor("mock") + MLTestingUtils.checkNumericTypes[MockPredictionModel, MockPredictor]( + predictor, spark) { (expected, actual) => true + } + } +} \ No newline at end of file From e0bbc34e2898fe7025982feeb2108bcfb538e2cc Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 11 Oct 2016 11:20:01 +0800 Subject: [PATCH 06/10] fix one nit --- mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala index 881ebecefe400..2167a5ea152bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -54,4 +54,4 @@ class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext with Defau predictor, spark) { (expected, actual) => true } } -} \ No newline at end of file +} From db83800771fcea6dd0d8a489c1d69a623f497c00 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 12 Oct 2016 19:44:18 +0800 Subject: [PATCH 07/10] update Predictor and PredictorSuite --- .../scala/org/apache/spark/ml/Predictor.scala | 14 +++--- .../org/apache/spark/ml/PredictorSuite.scala | 43 ++++++++++++++----- .../apache/spark/ml/util/MLTestingUtils.scala | 2 +- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index ceb69370509cf..3322ef78d6e69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -87,7 +87,11 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - val casted = castDataSet(dataset) + + // Cast LabelCol to DoubleType and keep the metadata. + val labelMeta = dataset.schema($(labelCol)).metadata + val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + copyValues(train(casted).setParent(this)) } @@ -126,14 +130,6 @@ abstract class Predictor[ case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } - - /** - * Return the given DataFrame, with [[labelCol]] casted to DoubleType. - */ - protected def castDataSet(dataset: Dataset[_]): DataFrame = { - val labelMeta = dataset.schema($(labelCol)).metadata - dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) - } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala index 2167a5ea152bc..20fa0d995a030 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -22,16 +22,43 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - import testImplicits._ + import PredictorSuite._ + + test("should support all NumericType labels and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF("label", "features") + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + + val predictor = new MockPredictor() + + types.foreach { t => + predictor.fit(df.select(col("label").cast(t), col("features"))) + } + + intercept[IllegalArgumentException] { + predictor.fit(df.select(col("label").cast(StringType), col("features"))) + } + } +} + +object PredictorSuite { class MockPredictor(override val uid: String) extends Predictor[Vector, MockPredictor, MockPredictionModel] { + def this() = this(Identifiable.randomUID("mockpredictor")) + override def train(dataset: Dataset[_]): MockPredictionModel = { require(dataset.schema("label").dataType == DoubleType) new MockPredictionModel(uid) @@ -43,15 +70,11 @@ class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext with Defau class MockPredictionModel(override val uid: String) extends PredictionModel[Vector, MockPredictionModel] { - override def predict(features: Vector): Double = 1.0 + def this() = this(Identifiable.randomUID("mockpredictormodel")) - override def copy(extra: ParamMap): MockPredictionModel = defaultCopy(extra) - } + override def predict(features: Vector): Double = + throw new NotImplementedError() - test("should support all NumericType labels and not support other types") { - val predictor = new MockPredictor("mock") - MLTestingUtils.checkNumericTypes[MockPredictionModel, MockPredictor]( - predictor, spark) { (expected, actual) => true - } + override def copy(extra: ParamMap): MockPredictionModel = defaultCopy(extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 5efce504c056e..472a5af06e7a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -117,7 +117,7 @@ object MLTestingUtils extends SparkFunSuite { Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) types.map { t => val castDF = df.select(col(labelColName).cast(t), col(featuresColName)) - t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName) + t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName) }.toMap } From 1944cf1a1bc15daec18a38ddba43d26b3a4c7f54 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 13 Oct 2016 13:50:19 +0800 Subject: [PATCH 08/10] update copy() & del unused interface --- .../src/test/scala/org/apache/spark/ml/PredictorSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala index 20fa0d995a030..7d12513bc9009 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext { import PredictorSuite._ @@ -75,6 +75,7 @@ object PredictorSuite { override def predict(features: Vector): Double = throw new NotImplementedError() - override def copy(extra: ParamMap): MockPredictionModel = defaultCopy(extra) + override def copy(extra: ParamMap): MockPredictionModel = + throw new NotImplementedError() } } From 5b4f34abdb97f5567b2fac5cd8a9212982d86c09 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 13 Oct 2016 15:35:41 +0800 Subject: [PATCH 09/10] update another copy --- mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala index 7d12513bc9009..03e0c536a973e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -64,7 +64,8 @@ object PredictorSuite { new MockPredictionModel(uid) } - override def copy(extra: ParamMap): MockPredictor = defaultCopy(extra) + override def copy(extra: ParamMap): MockPredictor = + throw new NotImplementedError() } class MockPredictionModel(override val uid: String) From 810c973d7394263a047318d7c0ab82cf6814ee7e Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 1 Nov 2016 09:47:07 +0800 Subject: [PATCH 10/10] add doc --- mllib/src/main/scala/org/apache/spark/ml/Predictor.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 3322ef78d6e69..aa92edde7acd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * Abstraction for prediction problems (regression and classification). + * Abstraction for prediction problems (regression and classification). It accepts all NumericType + * labels and will automatically cast it to DoubleType in [[fit()]]. * * @tparam FeaturesType Type of features. * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.