From 492a08c6f90b9f5ee13fd189340a6b65740b430a Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Tue, 16 Feb 2016 15:46:10 +0100 Subject: [PATCH 1/4] [SPARK-13340][ML] PolynomialExpansion and Normalizer should validate input type --- .../apache/spark/ml/feature/Normalizer.scala | 4 ++++ .../spark/ml/feature/PolynomialExpansion.scala | 4 ++++ .../spark/ml/feature/NormalizerSuite.scala | 17 +++++++++++++++++ .../ml/feature/PolynomialExpansionSuite.scala | 17 +++++++++++++++++ 4 files changed, 42 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index a603b3f833202..51aa2cffc875e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -55,6 +55,10 @@ class Normalizer(override val uid: String) normalizer.transform } + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + } + override protected def outputDataType: DataType = new VectorUDT() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 0a9b9719c15d3..ccc16dbfbd7c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -60,6 +60,10 @@ class PolynomialExpansion(override val uid: String) PolynomialExpansion.expand(v, $(degree)) } + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + } + override protected def outputDataType: DataType = new VectorUDT() override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 468833901995a..56802aeada38c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -105,6 +105,23 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assertValues(result, l1Normalized) } + test("Normalization should throw adequate exception on input type mismatch") { + val data = Seq(Tuple1("string value")) + + val df = sqlContext.createDataFrame(data).toDF("features") + + val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("polyFeatures") + + val thrown = intercept[IllegalArgumentException] { + normalizer.transform(df).collect() + } + assert(thrown.getClass === classOf[IllegalArgumentException]) + assert( + thrown.getMessage == "requirement failed: Input type must be VectorUDT but got StringType.") + } + test("read/write") { val t = new Normalizer() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 86dbee1cf4a5a..526d256769581 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -109,6 +109,23 @@ class PolynomialExpansionSuite } } + test("Polynomial expansion should throw adequate exception on input type mismatch") { + val data = Seq(Tuple1("string value")) + + val df = sqlContext.createDataFrame(data).toDF("features") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + + val thrown = intercept[IllegalArgumentException] { + polynomialExpansion.transform(df) + } + assert(thrown.getClass === classOf[IllegalArgumentException]) + assert( + thrown.getMessage == "requirement failed: Input type must be VectorUDT but got StringType.") + } + test("read/write") { val t = new PolynomialExpansion() .setInputCol("myInputCol") From a753e3607fccb66ea20a72c2e1b0bfa3a2b45e30 Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Fri, 26 Feb 2016 11:15:20 +0100 Subject: [PATCH 2/4] Fix NormalizerSuite --- .../scala/org/apache/spark/ml/feature/NormalizerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 56802aeada38c..3fc3a2e0e92a9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -112,7 +112,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val normalizer = new Normalizer() .setInputCol("features") - .setOutputCol("polyFeatures") + .setOutputCol("normalized_features") val thrown = intercept[IllegalArgumentException] { normalizer.transform(df).collect() From fb3a8486ed11c6cbbe4c4f0aacce54fa52acfef8 Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Mon, 18 Apr 2016 12:55:36 +0200 Subject: [PATCH 3/4] ElementwiseProduct should validate input type --- .../spark/ml/feature/ElementwiseProduct.scala | 5 +++++ .../apache/spark/ml/feature/Normalizer.scala | 1 + .../spark/ml/feature/PolynomialExpansion.scala | 1 + .../ml/feature/ElementwiseProductSuite.scala | 17 +++++++++++++++++ 4 files changed, 24 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 1b0a9a12e83bc..ff3139711de48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -55,6 +55,11 @@ class ElementwiseProduct(override val uid: String) elemScaler.transform } + override protected def validateInputType(inputType: DataType): Unit = { + super.validateInputType(inputType) + require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + } + override protected def outputDataType: DataType = new VectorUDT() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 51aa2cffc875e..57fff7f37ef82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -56,6 +56,7 @@ class Normalizer(override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { + super.validateInputType(inputType) require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index ccc16dbfbd7c2..28490a6404fa1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -61,6 +61,7 @@ class PolynomialExpansion(override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { + super.validateInputType(inputType) require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala index fc1c05de233ea..3de28fdba828a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala @@ -25,6 +25,23 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + test("ElementwiseProduct should throw adequate exception on input type mismatch") { + val data = Seq(Tuple1("string value")) + + val df = sqlContext.createDataFrame(data).toDF("features") + + val elementwiseProduct = new ElementwiseProduct() + .setInputCol("features") + .setOutputCol("scaled_features") + + val thrown = intercept[IllegalArgumentException] { + elementwiseProduct.transform(df).collect() + } + assert(thrown.getClass === classOf[IllegalArgumentException]) + assert( + thrown.getMessage == "requirement failed: Input type must be VectorUDT but got StringType.") + } + test("read/write") { val ep = new ElementwiseProduct() .setInputCol("myInputCol") From 174e2a5bd8e8e76e9e2c6f10e35779a7f7cc4e81 Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Wed, 4 May 2016 10:56:29 +0200 Subject: [PATCH 4/4] Move checkDataTypeEquality logic to SchemaUtils --- .../apache/spark/ml/feature/ElementwiseProduct.scala | 4 ++-- .../scala/org/apache/spark/ml/feature/Normalizer.scala | 2 +- .../apache/spark/ml/feature/PolynomialExpansion.scala | 4 ++-- .../scala/org/apache/spark/ml/util/SchemaUtils.scala | 10 ++++++++++ 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index ff3139711de48..b55c7f5245685 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.Param -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -57,7 +57,7 @@ class ElementwiseProduct(override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { super.validateInputType(inputType) - require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + SchemaUtils.checkDataTypeEquality(inputType, outputDataType) } override protected def outputDataType: DataType = new VectorUDT() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 57fff7f37ef82..994f0fef51475 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -57,7 +57,7 @@ class Normalizer(override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { super.validateInputType(inputType) - require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + SchemaUtils.checkDataTypeEquality(inputType, outputDataType) } override protected def outputDataType: DataType = new VectorUDT() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 28490a6404fa1..79e1c52a1e89f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.{VectorUDT, _} import org.apache.spark.sql.types.DataType /** @@ -62,7 +62,7 @@ class PolynomialExpansion(override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { super.validateInputType(inputType) - require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + SchemaUtils.checkDataTypeEquality(inputType, outputDataType) } override protected def outputDataType: DataType = new VectorUDT() diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 334410c9620de..b2732b1d1845a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -74,6 +74,16 @@ private[spark] object SchemaUtils { s"NumericType but was actually of type $actualDataType.$message") } + /** + * Check data types equality + * @param dataType actual data type + * @param expectedDataType expected data type + */ + def checkDataTypeEquality(dataType: DataType, expectedDataType: DataType): Unit = { + require(dataType.getClass.equals(expectedDataType.getClass), + s"Input type must be ${expectedDataType.getClass.getSimpleName} but got $dataType.") + } + /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema