From fbfb7f44db03cb50096cb03848a5cb9c015a9dc7 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 22 Mar 2017 16:55:27 +0800 Subject: [PATCH 1/3] revert domain --- .../ml/regression/IsotonicRegression.scala | 22 +++++----- .../regression/IsotonicRegressionSuite.scala | 43 +++++++++++++++++++ .../apache/spark/ml/util/MLTestingUtils.scala | 2 +- 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 529f66eadbcff..0388127806138 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsoton import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} import org.apache.spark.storage.StorageLevel /** @@ -84,7 +84,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures val extract = udf { v: Vector => v(idx) } extract(col($(featuresCol))) } else { - col($(featuresCol)) + col($(featuresCol)).cast(DoubleType) } val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0) @@ -112,7 +112,9 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } } val featuresType = schema($(featuresCol)).dataType - require(featuresType == DoubleType || featuresType.isInstanceOf[VectorUDT]) + require(featuresType.isInstanceOf[NumericType] || featuresType.isInstanceOf[VectorUDT], + s"Column $featuresCol must be of type NumericType or VectorUDT," + + s" but was actually of type $featuresType") SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } } @@ -241,14 +243,14 @@ class IsotonicRegressionModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val predict = dataset.schema($(featuresCol)).dataType match { - case DoubleType => - udf { feature: Double => oldModel.predict(feature) } - case _: VectorUDT => - val idx = $(featureIndex) - udf { features: Vector => oldModel.predict(features(idx)) } + if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { + val idx = $(featureIndex) + val predict = udf { features: Vector => oldModel.predict(features(idx)) } + dataset.withColumn($(predictionCol), predict(col($(featuresCol)))) + } else { + val predict = udf { feature: Double => oldModel.predict(feature) } + dataset.withColumn($(predictionCol), predict(col($(featuresCol)).cast(DoubleType))) } - dataset.withColumn($(predictionCol), predict(col($(featuresCol)))) } @Since("1.5.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index f41a3601b1fa8..b0c8917e6526b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -23,6 +23,8 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -189,6 +191,47 @@ class IsotonicRegressionSuite assert(expected.predictions === actual.predictions) } } + + test("Besides VectorUDT, should support all NumericType features, and not support other types") { + val df = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression() + + val expected = ir.fit(df) + val expectedPrediction = + expected.transform(df).select("prediction").map(_.getDouble(0)).collect() + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + + types.foreach { t => + val df2 = df.select(col("label"), col("features").cast(t), col("weight")) + + val actual = ir.fit(df2) + assert(expected.boundaries === actual.boundaries) + assert(expected.predictions === actual.predictions) + + val actualPrediction = + actual.transform(df).select("prediction").map(_.getDouble(0)).collect() + assert(expectedPrediction === actualPrediction) + } + + val dfWithStringFeatures = + df.select(col("label"), col("features").cast(StringType), col("weight")) + + val thrown = intercept[IllegalArgumentException] { + ir.fit(dfWithStringFeatures) + } + assert(thrown.getMessage.contains( + "Column features must be of type NumericType or VectorUDT," + + " but was actually of type StringType")) + + val thrown2 = intercept[IllegalArgumentException] { + expected.transform(dfWithStringFeatures) + } + assert(thrown2.getMessage.contains( + "Column features must be of type NumericType or VectorUDT," + + " but was actually of type StringType")) + } } object IsotonicRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index f1ed568d5e60a..7034f90034936 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} From 26ff8d8e147d9317c8208022a528a3529071fef9 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 22 Mar 2017 17:03:07 +0800 Subject: [PATCH 2/3] update tests --- .../apache/spark/ml/regression/IsotonicRegressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index b0c8917e6526b..9d0e503f93238 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -211,7 +211,7 @@ class IsotonicRegressionSuite assert(expected.predictions === actual.predictions) val actualPrediction = - actual.transform(df).select("prediction").map(_.getDouble(0)).collect() + actual.transform(df2).select("prediction").map(_.getDouble(0)).collect() assert(expectedPrediction === actualPrediction) } From 647ff461ee5ac1e36e6d9a83ca62271403ad01ab Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 22 Mar 2017 17:40:32 +0800 Subject: [PATCH 3/3] fix one nit --- .../org/apache/spark/ml/regression/IsotonicRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 0388127806138..3bcad1fcc8316 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -113,7 +113,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } val featuresType = schema($(featuresCol)).dataType require(featuresType.isInstanceOf[NumericType] || featuresType.isInstanceOf[VectorUDT], - s"Column $featuresCol must be of type NumericType or VectorUDT," + + s"Column ${$(featuresCol)} must be of type NumericType or VectorUDT," + s" but was actually of type $featuresType") SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) }