From daf793599a6da5c11dbc4a6bd6e5dea3e0d47afd Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 15 Aug 2018 14:47:46 -0700 Subject: [PATCH] review comments --- .../expressions/higherOrderFunctions.scala | 21 +++---- .../HigherOrderFunctionsSuite.scala | 44 ++++++++----- .../inputs/higher-order-functions.sql | 4 +- .../results/higher-order-functions.sql.out | 4 +- .../spark/sql/DataFrameFunctionsSuite.scala | 61 +++++-------------- 5 files changed, 53 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 63cf5fc36a5af..5c7fb9a675aeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -504,9 +504,9 @@ case class ArrayAggregate( usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> v + 1); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); map(array(1, 2, 3), array(2, 3, 4)) - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); map(array(1, 2, 3), array(2, 4, 6)) """, since = "2.4.0") @@ -517,33 +517,26 @@ case class TransformValues( override def nullable: Boolean = argument.nullable - override def dataType: DataType = { - val map = argument.dataType.asInstanceOf[MapType] - MapType(map.keyType, function.dataType, function.nullable) - } + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType - @transient val MapType(keyType, valueType, valueContainsNull) = argument.dataType + override def dataType: DataType = MapType(keyType, function.dataType, valueContainsNull) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction) : TransformValues = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } - @transient lazy val (keyVar, valueVar) = { - val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function - (keyVar, valueVar) - } + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val map = argumentValue.asInstanceOf[MapData] - val f = functionForEval val resultValues = new GenericArrayData(new Array[Any](map.numElements)) var i = 0 while (i < map.numElements) { keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) - resultValues.update(i, f.eval(inputRow)) + resultValues.update(i, functionForEval.eval(inputRow)) i += 1 } new ArrayBasedMapData(map.keyArray(), resultValues) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 60ed04672ccea..8c3ff3ed918ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -96,9 +96,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val valueType = expr.dataType.asInstanceOf[MapType].valueType - val keyType = expr.dataType.asInstanceOf[MapType].keyType - TransformValues(expr, createLambda(keyType, false, valueType, true, f)) + val map = expr.dataType.asInstanceOf[MapType] + TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) } test("ArrayTransform") { @@ -292,13 +291,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("TransformValues") { val ai0 = Literal.create( Map(1 -> 1, 2 -> 2, 3 -> 3), - MapType(IntegerType, IntegerType)) + MapType(IntegerType, IntegerType, valueContainsNull = false)) val ai1 = Literal.create( Map(1 -> 1, 2 -> null, 3 -> 3), - MapType(IntegerType, IntegerType)) - val ain = Literal.create( + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( Map.empty[Int, Int], - MapType(IntegerType, IntegerType)) + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1 val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k @@ -311,13 +311,18 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) checkEvaluation( transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) - checkEvaluation(transformValues(ain, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformValues(ai3, plusOne), null) val as0 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType)) + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) val as1 = Literal.create( - Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), MapType(StringType, StringType)) - val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType)) + Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true)) val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) val valueTypeUpdate: (Expression, Expression) => Expression = @@ -337,13 +342,20 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( transformValues(transformValues(as1, concatValue), concatValue), Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx")) - checkEvaluation(transformValues(asn, concatValue), Map.empty[String, String]) - checkEvaluation(transformValues(asn, valueTypeUpdate), Map.empty[String, Int]) + checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String]) + checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int]) checkEvaluation( - transformValues(transformValues(asn, concatValue), valueTypeUpdate), + transformValues(transformValues(as2, concatValue), valueTypeUpdate), Map.empty[String, Int]) - } - + checkEvaluation(transformValues(as3, concatValue), null) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + } + test("MapZipWith") { def map_zip_with( left: Expression, diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index aad5dcbeda303..bdb884ae9ab27 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -53,8 +53,8 @@ select exists(ys, y -> y > 30) as v from nested; select exists(cast(null as array), y -> y > 30) as v; create or replace temporary view nested as values - (1, map(1,1,2,2,3,3)), - (2, map(4,4,5,5,6,6)) + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys); -- Identity Transform Keys in a map diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 770e7875e0514..06e0e231ec710 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -167,8 +167,8 @@ NULL -- !query 17 create or replace temporary view nested as values - (1, map(1,1,2,2,3,3)), - (2, map(4,4,5,5,6,6)) + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys) -- !query 17 schema struct<> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 28bfc8c2cf4d4..ae5e9424a59e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2302,7 +2302,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) } - test("transform values function - test various primitive data types combinations") { + test("transform values function - test primitive data types") { val dfExample1 = Seq( Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) ).toDF("i") @@ -2316,29 +2316,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("y") val dfExample4 = Seq( - Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0) + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) ).toDF("z") val dfExample5 = Seq( - Map[Int, Boolean](25 -> true, 26 -> false) - ).toDF("a") - - val dfExample6 = Seq( - Map[Int, String](25 -> "ab", 26 -> "cd") - ).toDF("b") - - val dfExample7 = Seq( Map[Int, Array[Int]](1 -> Array(1, 2)) ).toDF("c") - val dfExample8 = Seq( - Map[Int, Double](25 -> 26.1E0, 26 -> 31.2E0, 27 -> 37.1E0) - ).toDF("d") - - val dfExample10 = Seq( - Map[String, String]("s0" -> "abc", "s1" -> "def") - ).toDF("f") - def testMapOfPrimitiveTypesCombination(): Unit = { checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) @@ -2373,32 +2357,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) checkAnswer( - dfExample5.selectExpr("transform_values(a, (k, v) -> if(v, k + 1, k + 2))"), - Seq(Row(Map(25 -> 26, 26 -> 28)))) - - checkAnswer( - dfExample5.selectExpr("transform_values(a, (k, v) -> v AND k = 25)"), - Seq(Row(Map(25 -> true, 26 -> false)))) - - checkAnswer( - dfExample5.selectExpr("transform_values(a, (k, v) -> v OR k = 26)"), - Seq(Row(Map(25 -> true, 26 -> true)))) - - checkAnswer( - dfExample6.selectExpr("transform_values(b, (k, v) -> k + length(v))"), - Seq(Row(Map(25 -> 27, 26 -> 28)))) - - checkAnswer( - dfExample7.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), + dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), Seq(Row(Map(1 -> 3)))) - - checkAnswer( - dfExample8.selectExpr("transform_values(d, (k, v) -> CAST(v - k AS BIGINT))"), - Seq(Row(Map(25 -> 1, 26 -> 5, 27 -> 10)))) - - checkAnswer( - dfExample10.selectExpr("transform_values(f, (k, v) -> k || ':' || v)"), - Seq(Row(Map("s0" -> "s0:abc", "s1" -> "s1:def")))) } // Test with local relation, the Project will be evaluated without codegen @@ -2408,10 +2368,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { dfExample3.cache() dfExample4.cache() dfExample5.cache() - dfExample6.cache() - dfExample7.cache() - dfExample8.cache() - dfExample10.cache() // Test with cached relation, the Project will be evaluated with codegen testMapOfPrimitiveTypesCombination() } @@ -2487,6 +2443,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Map[String, String]("a" -> "b") ).toDF("j") + val dfExample3 = Seq( + Seq(1, 2, 3, 4) + ).toDF("x") + def testInvalidLambdaFunctions(): Unit = { val ex1 = intercept[AnalysisException] { @@ -2498,11 +2458,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)") } assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[AnalysisException] { + dfExample3.selectExpr("transform_values(x, (k, v) -> k + 1)") + } + assert(ex3.getMessage.contains( + "data type mismatch: argument 1 requires map type")) } testInvalidLambdaFunctions() dfExample1.cache() dfExample2.cache() + dfExample3.cache() testInvalidLambdaFunctions() }