From 7bc026e364771c596b80a5ab22b4d1eebe04f3f3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 9 Sep 2018 14:54:20 +0200 Subject: [PATCH 1/3] [SPARK-25371][SQL] VectorAssembler shpuld not fail with empty inputCols --- .../org/apache/spark/ml/feature/VectorAssembler.scala | 10 ++++++++-- .../apache/spark/ml/feature/VectorAssemblerSuite.scala | 5 +++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 57e23d5072b88..0a72aa1154769 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -30,7 +30,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -150,8 +151,13 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } + val udfInput = if (args.length > 0) { + struct(args: _*) + } else { + new Column(Literal.default(new StructType)) + } - filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) + filteredDataset.select(col("*"), assembleFunc(udfInput).as($(outputCol), metadata)) } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index ed15a1d88a269..a4d388fd321db 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -256,4 +256,9 @@ class VectorAssemblerSuite assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4) } + test("SPARK-25371: VectorAssembler with empty inputCols") { + val vectorAssembler = new VectorAssembler().setInputCols(Array()).setOutputCol("a") + val output = vectorAssembler.transform(dfWithNullsAndNaNs) + assert(output.select("a").limit(1).collect().head == Row(Vectors.sparse(0, Seq.empty))) + } } From 41bd5bdbc1b8fbaca3ffbba02ac5d367e3fdc8cd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 10 Sep 2018 15:05:22 +0200 Subject: [PATCH 2/3] Revert check in struct for more than one argument --- .../org/apache/spark/ml/feature/VectorAssembler.scala | 7 +------ .../sql/catalyst/expressions/complexTypeCreator.scala | 5 +---- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 -- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0a72aa1154769..9fd0279ef0c87 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -151,13 +151,8 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - val udfInput = if (args.length > 0) { - struct(args: _*) - } else { - new Column(Literal.default(new StructType)) - } - filteredDataset.select(col("*"), assembleFunc(udfInput).as($(outputCol), metadata)) + filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } @Since("1.4.0") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 077a6dc93bd17..aba9c6c8ad6fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -379,10 +379,7 @@ trait CreateNamedStructLike extends Expression { } override def checkInputDataTypes(): TypeCheckResult = { - if (children.length < 1) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName requires at least one argument") - } else if (children.size % 2 != 0) { + if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4b83e51fa8992..121db442c77f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2677,8 +2677,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val funcsMustHaveAtLeastOneArg = ("coalesce", (df: DataFrame) => df.select(coalesce())) :: ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: - ("named_struct", (df: DataFrame) => df.select(struct())) :: - ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: ("hash", (df: DataFrame) => df.select(hash())) :: ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil funcsMustHaveAtLeastOneArg.foreach { case (name, func) => From 69ff3cb28744f4d50a6629708d7ebd0feea3de95 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 10 Sep 2018 15:07:09 +0200 Subject: [PATCH 3/3] fix imports --- .../scala/org/apache/spark/ml/feature/VectorAssembler.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 9fd0279ef0c87..57e23d5072b88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -30,8 +30,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._