From f16140975db758cf17e9687a465a9864bd1e0b50 Mon Sep 17 00:00:00 2001 From: codeatri Date: Fri, 17 Aug 2018 11:50:06 +0900 Subject: [PATCH] [SPARK-23940][SQL] Add transform_values SQL function ## What changes were proposed in this pull request? This pr adds `transform_values` function which applies the function to each entry of the map and transforms the values. ```javascript > SELECT transform_values(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> v + 1); map(1->2, 2->3, 3->4) > SELECT transform_values(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + v); map(1->2, 2->4, 3->6) ``` ## How was this patch tested? New Tests added to `DataFrameFunctionsSuite` `HigherOrderFunctionsSuite` `SQLQueryTestSuite` Closes #22045 from codeatri/SPARK-23940. Authored-by: codeatri Signed-off-by: Takuya UESHIN --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 50 ++++- .../HigherOrderFunctionsSuite.scala | 73 ++++++++ .../inputs/higher-order-functions.sql | 9 + .../results/higher-order-functions.sql.out | 29 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 173 +++++++++++++++++- 6 files changed, 332 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 061336455189e..77860e1584f42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -446,6 +446,7 @@ object FunctionRegistry { expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), expression[ArrayAggregate]("aggregate"), + expression[TransformValues]("transform_values"), expression[TransformKeys]("transform_keys"), expression[MapZipWith]("map_zip_with"), expression[ZipWith]("zip_with"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 9d603d79eedcf..f667a64f7f8d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -527,7 +527,7 @@ case class TransformKeys( } @transient lazy val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { @@ -550,6 +550,54 @@ case class TransformKeys( override def prettyName: String = "transform_keys" } +/** + * Returns a map that applies the function to each value of the map. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); + map(array(1, 2, 3), array(2, 3, 4)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + map(array(1, 2, 3), array(2, 4, 6)) + """, + since = "2.4.0") +case class TransformValues( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = argument.nullable + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction) + : TransformValues = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] + val resultValues = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + resultValues.update(i, functionForEval.eval(inputRow)) + i += 1 + } + new ArrayBasedMapData(map.keyArray(), resultValues) + } + + override def prettyName: String = "transform_values" +} + /** * Merges two given maps into a single map by applying function to the pair of values with * the same key. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 3a78f14c8b2cb..9d992c52e5357 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -101,6 +101,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper aggregate(expr, zero, merge, identity) } + def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val map = expr.dataType.asInstanceOf[MapType] + TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) + } + test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) @@ -358,6 +363,74 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) } + test("TransformValues") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val ai1 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1 + val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k + + checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4)) + checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4)) + checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformValues(ai3, plusOne), null) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val valueTypeUpdate: (Expression, Expression) => Expression = + (k, v) => Length(v) + 1 + + checkEvaluation( + transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx")) + checkEvaluation(transformValues(as0, valueTypeUpdate), + Map("a" -> 3, "bb" -> 3, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as0, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx")) + checkEvaluation(transformValues(as1, concatValue), + Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx")) + checkEvaluation(transformValues(as1, valueTypeUpdate), + Map("a" -> 3, "bb" -> null, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as1, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx")) + checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String]) + checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int]) + checkEvaluation( + transformValues(transformValues(as2, concatValue), valueTypeUpdate), + Map.empty[String, Int]) + checkEvaluation(transformValues(as3, concatValue), null) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + } + test("MapZipWith") { def map_zip_with( left: Expression, diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 05ec5effdf146..02ad5e3538689 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -74,3 +74,12 @@ select transform_keys(ys, (k, v) -> k + 1) as v from nested; -- Transform Keys in a map using values select transform_keys(ys, (k, v) -> k + v) as v from nested; + +-- Identity Transform values in a map +select transform_values(ys, (k, v) -> v) as v from nested; + +-- Transform values in a map by adding constant +select transform_values(ys, (k, v) -> v + 1) as v from nested; + +-- Transform values in a map using values +select transform_values(ys, (k, v) -> k + v) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 5a39616191e81..32d20d1b73415 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 27 -- !query 0 @@ -226,3 +226,30 @@ struct> -- !query 23 output {10:5,12:6,8:4} {2:1,4:2,6:3} + + +-- !query 24 +select transform_values(ys, (k, v) -> v) as v from nested +-- !query 24 schema +struct> +-- !query 24 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 25 +select transform_values(ys, (k, v) -> v + 1) as v from nested +-- !query 25 schema +struct> +-- !query 25 output +{1:2,2:3,3:4} +{4:5,5:6,6:7} + + +-- !query 26 +select transform_values(ys, (k, v) -> k + v) as v from nested +-- !query 26 schema +struct> +-- !query 26 output +{1:2,2:4,3:6} +{4:8,5:10,6:12} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9e2bfd3b7fba8..156e54300e38b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2346,6 +2346,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), Seq(Row(Map(false -> false)))) } + // Test with local relation, the Project will be evaluated without codegen testMapOfPrimitiveTypesCombination() dfExample1.cache() @@ -2357,7 +2358,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("transform keys function - Invalid lambda functions and exceptions") { - val dfExample1 = Seq( Map[String, String]("a" -> null) ).toDF("i") @@ -2389,6 +2389,177 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "data type mismatch: argument 1 requires map type")) } + test("transform values function - test primitive data types") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Boolean, String](false -> "abc", true -> "def") + ).toDF("x") + + val dfExample3 = Seq( + Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3) + ).toDF("y") + + val dfExample4 = Seq( + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) + ).toDF("z") + + val dfExample5 = Seq( + Map[Int, Array[Int]](1 -> Array(1, 2)) + ).toDF("c") + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"), + Seq(Row(Map(false -> "false", true -> "def")))) + + checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"), + Seq(Row(Map(false -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"), + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + + checkAnswer(dfExample3.selectExpr( + "transform_values(y, (k, v) -> k || ':' || CAST(v as String))"), + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) + + checkAnswer( + dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"), + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + + checkAnswer( + dfExample4.selectExpr( + "transform_values(" + + "z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " + + "ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"), + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + + checkAnswer( + dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"), + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + + checkAnswer( + dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), + Seq(Row(Map(1 -> 3)))) + } + + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + dfExample5.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform values function - test empty") { + val dfExample1 = Seq( + Map.empty[Integer, Integer] + ).toDF("i") + + val dfExample2 = Seq( + Map.empty[BigInt, String] + ).toDF("j") + + def testEmpty(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> NULL)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> v)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 0)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 'value')"), + Seq(Row(Map.empty[Integer, String]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> true)"), + Seq(Row(Map.empty[Integer, Boolean]))) + + checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), + Seq(Row(Map.empty[BigInt, BigInt]))) + } + + testEmpty() + dfExample1.cache() + dfExample2.cache() + testEmpty() + } + + test("transform values function - test null values") { + val dfExample1 = Seq( + Map[Int, Integer](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4) + ).toDF("a") + + val dfExample2 = Seq( + Map[Int, String](1 -> "a", 2 -> "b", 3 -> null) + ).toDF("b") + + def testNullValue(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(a, (k, v) -> null)"), + Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), + Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) + } + + testNullValue() + dfExample1.cache() + dfExample2.cache() + testNullValue() + } + + test("transform values function - test invalid functions") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[String, String]("a" -> "b") + ).toDF("j") + + val dfExample3 = Seq( + Seq(1, 2, 3, 4) + ).toDF("x") + + def testInvalidLambdaFunctions(): Unit = { + + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_values(i, k -> k)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[AnalysisException] { + dfExample3.selectExpr("transform_values(x, (k, v) -> k + 1)") + } + assert(ex3.getMessage.contains( + "data type mismatch: argument 1 requires map type")) + } + + testInvalidLambdaFunctions() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + testInvalidLambdaFunctions() + } + test("arrays zip_with function - for primitive types") { val df1 = Seq[(Seq[Integer], Seq[Integer])]( (Seq(9001, 9002, 9003), Seq(4, 5, 6)),