From ef56011f03d8bae4634e5d3108e4d6502482383c Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 7 Aug 2018 01:42:45 +0200 Subject: [PATCH 01/13] [SPARK-23938][SQL] Add map_zip_with function --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 63 +++++++----- .../expressions/complexTypeExtractors.scala | 30 ++++-- .../expressions/higherOrderFunctions.scala | 99 ++++++++++++++++++- .../HigherOrderFunctionsSuite.scala | 96 ++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 61 ++++++++++++ 6 files changed, 316 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 35f8de1328b50..08d9000e280d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -443,6 +443,7 @@ object FunctionRegistry { expression[ArrayTransform]("transform"), expression[ArrayFilter]("filter"), expression[ArrayAggregate]("aggregate"), + expression[MapZipWith]("map_zip_with"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3f94f25796634..16b44eef415c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3667,10 +3667,8 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { @transient protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) - @transient protected lazy val elementTypeSupportEquals = elementType match { - case BinaryType => false - case _: AtomicType => true - case _ => false + @transient protected lazy val elementTypeSupportEquals = { + ArraySetLike.typeSupportsEquals(elementType) } } @@ -3680,25 +3678,21 @@ object ArraySetLike { s"elements due to exceeding the array size limit " + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") } -} + def typeSupportsEquals(dateType: DataType): Boolean = dateType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } +} /** - * Returns an array of the elements in the union of x and y, without duplicates + * The class performs union operation with two [[ArrayData]] objects. */ -@ExpressionDescription( - usage = """ - _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, - without duplicates. - """, - examples = """ - Examples: - > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(1, 2, 3, 5) - """, - since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike - with ComplexTypeMergingExpression { +class ArrayDataMerger(elementType: DataType) { + lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) + lazy val elementTypeSupportEquals = ArraySetLike.typeSupportsEquals(elementType) + var hsInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ @@ -3767,10 +3761,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike pos } - override def nullSafeEval(input1: Any, input2: Any): Any = { - val array1 = input1.asInstanceOf[ArrayData] - val array2 = input2.asInstanceOf[ArrayData] - + def merge(array1: ArrayData, array2: ArrayData): ArrayData = { if (elementTypeSupportEquals) { elementType match { case IntegerType => @@ -3834,6 +3825,32 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike ArrayUnion.unionOrdering(array1, array2, elementType, ordering) } } +} + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + + @transient lazy val merger = new ArrayDataMerger(elementType) + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + merger.merge(array1, array2) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val i = ctx.freshName("i") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 8994eeff92c7f..e6ca4f71a924b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -267,14 +267,15 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } } -/** - * Common base class for [[GetMapValue]] and [[ElementAt]]. - */ - -abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { +object GetMapValueUtil +{ // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { - val map = value.asInstanceOf[MapData] + def getValueEval( + map: MapData, + key: Any, + keyType: DataType, + valueType: DataType, + ordering: Ordering[Any]): Any = { val length = map.numElements() val keys = map.keyArray() val values = map.valueArray() @@ -282,7 +283,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy var i = 0 var found = false while (i < length && !found) { - if (ordering.equiv(keys.get(i, keyType), ordinal)) { + if (ordering.equiv(keys.get(i, keyType), key)) { found = true } else { i += 1 @@ -292,9 +293,20 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy if (!found || values.isNullAt(i)) { null } else { - values.get(i, dataType) + values.get(i, valueType) } } +} + + +/** + * Common base class for [[GetMapValue]] and [[ElementAt]]. + */ +abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { + // todo: current search is O(n), improve it. + def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { + GetMapValueUtil.getValueEval(value.asInstanceOf[MapData], ordinal, keyType, dataType, ordering) + } def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { val index = ctx.freshName("index") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 20c7f7d43b9dc..cb7eef39aaa31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -24,8 +24,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ /** @@ -365,3 +364,99 @@ case class ArrayAggregate( override def prettyName: String = "aggregate" } + +/** + * Merges two given maps into a single map by applying function to the pair of values with + * the same key. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(map1, map2, function) - Merges two given maps into a single map by applying + function to the pair of values with the same key. For keys only presented in one map, + NULL will be passed as the value for the missing key. If an input map contains duplicated + keys, only the first entry of the duplicated key is passed into the lambda function. + """, + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); + {1:"ax",2:"by"} + """, + since = "2.4.0") +case class MapZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + @transient lazy val functionForEval: Expression = functionsForEval.head + + @transient lazy val MapType(keyType, leftValueType, _) = getMapType(left) + + @transient lazy val MapType(_, rightValueType, _) = getMapType(right) + + @transient lazy val merger = new ArrayDataMerger(keyType) + + override def inputs: Seq[Expression] = left :: right :: Nil + + override def functions: Seq[Expression] = function :: Nil + + override def nullable: Boolean = left.nullable || right.nullable + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (MapType(k1, _, _), MapType(k2, _, _)) if k1.sameType(k2) => + TypeUtils.checkForOrderingExpr(k1, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + + s"been two ${MapType.simpleString}s with the same key type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + } + } + + private def getMapType(expr: Expression) = expr.dataType match { + case m: MapType => m + case _ => MapType.defaultConcreteType + } + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { + val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) + copy(function = f(function, arguments)) + } + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + if (value2 == null) { + null + } else { + nullSafeEval(input, value1, value2) + } + } + } + + @transient lazy val LambdaFunction(_, Seq( + keyVar: NamedLambdaVariable, + value1Var: NamedLambdaVariable, + value2Var: NamedLambdaVariable), + _) = function + + private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = { + val mapData1 = value1.asInstanceOf[MapData] + val mapData2 = value2.asInstanceOf[MapData] + val keys = merger.merge(mapData1.keyArray(), mapData2.keyArray()) + val values = new GenericArrayData(new Array[Any](keys.numElements())) + keys.foreach(keyType, (idx: Int, key: Any) => { + val v1 = GetMapValueUtil.getValueEval(mapData1, key, keyType, leftValueType, merger.ordering) + val v2 = GetMapValueUtil.getValueEval(mapData2, key, keyType, rightValueType, merger.ordering) + keyVar.value.set(key) + value1Var.value.set(v1) + value2Var.value.set(v2) + values.update(idx, functionForEval.eval(inputRow)) + }) + new ArrayBasedMapData(keys, values) + } + + override def prettyName: String = "map_zip_with" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 40cfc0ccc7c07..58fa80684a7a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -44,6 +44,21 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper LambdaFunction(function, Seq(lv1, lv2)) } + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) + val function = f(lv1, lv2, lv3) + LambdaFunction(function, Seq(lv1, lv2, lv3)) + } + def transform(expr: Expression, f: Expression => Expression): Expression = { val at = expr.dataType.asInstanceOf[ArrayType] ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f)) @@ -181,4 +196,85 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), 15) } + + test("MapZipWith") { + def map_zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression, Expression) => Expression): Expression = { + val MapType(kt, vt1, vcn1) = left.dataType.asInstanceOf[MapType] + val MapType(_, vt2, vcn2) = right.dataType.asInstanceOf[MapType] + MapZipWith(left, right, createLambda(kt, false, vt1, vcn1, vt2, vcn2, f)) + } + + val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii4 = MapFromArrays( + Literal.create(Seq(2, 2), ArrayType(IntegerType, false)), + Literal.create(Seq(20, 200), ArrayType(IntegerType, false))) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val multiplyKeyWithValues: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => k * v1 * v2 + } + + checkEvaluation( + map_zip_with(mii0, mii1, multiplyKeyWithValues), + Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null)) + checkEvaluation( + map_zip_with(mii0, mii2, multiplyKeyWithValues), + Map(1 -> null, 2 -> -80, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii3, multiplyKeyWithValues), + Map(1 -> null, 2 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii4, multiplyKeyWithValues), + Map(1 -> null, 2 -> 800, 3 -> null)) + checkEvaluation( + map_zip_with(mii4, mii0, multiplyKeyWithValues), + Map(2 -> 800, 1 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, miin, multiplyKeyWithValues), + null) + + val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss1 = Literal.create(Map("d" -> "b", "b" -> "d"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val mss3 = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false)) + val mss4 = MapFromArrays( + Literal.create(Seq("a", "a"), ArrayType(StringType, false)), + Literal.create(Seq("a", "n"), ArrayType(StringType, false))) + val mssn = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) + + val concat: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => Concat(Seq(k, v1, v2)) + } + + checkEvaluation( + map_zip_with(mss0, mss1, concat), + Map("a" -> null, "b" -> "byd", "d" -> "dzb")) + checkEvaluation( + map_zip_with(mss1, mss2, concat), + Map("d" -> null, "b" -> "bdt", "c" -> null, "a" -> null)) + checkEvaluation( + map_zip_with(mss0, mss3, concat), + Map("a" -> null, "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mss4, concat), + Map("a" -> "axa", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss4, mss0, concat), + Map("a" -> "aax", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mssn, concat), + null) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index af3301b1599a9..cb488b0b61dd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2017,6 +2017,67 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) } + test("map_zip_with function - map of primitive types") { + val df = Seq( + (Map(8 -> 6L, 3 -> 5L, 6 -> 2L), Map[Integer, Integer]((6, 4), (8, 2), (3, 2))), + (Map(10 -> 6L, 8 -> 3L), Map[Integer, Integer]((8, 4), (4, null))), + (Map.empty[Int, Long], Map[Integer, Integer]((5, 1))), + (Map(5 -> 1L), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) + } + + test("map_zip_with function - map of complex types") { + val df = Seq( + (Map("z" -> "a", "y" -> "b", "x" -> "c"), Map("x" -> "a", "z" -> "c")), + (Map("b" -> "a", "c" -> "d"), Map("c" -> "a", "b" -> null, "d" -> "k")), + (Map("a" -> "d"), Map.empty[String, String]), + (Map("a" -> "d"), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"), + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) + } + + test("map_zip_with function - invalid") + { + val df = Seq( + (Map(1 -> 2), Map(1 -> "a"), Map("a" -> "b"), Map(Map(1 -> 2) -> 2), 1) + ).toDF("mii", "mis", "mss", "mmi", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mii, mis, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, mss, (x, y, z) -> concat(x, y, z))") + } + assert(ex2.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with the same key type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") + } + assert(ex3.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with the same key type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mmi,mmi, (x, y, z) -> x)") + } + assert(ex4.getMessage.contains("function map_zip_with does not support ordering on type map")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From ec583eb29ba6fdb79d0b85cbecb3f709e6648b25 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 7 Aug 2018 11:16:59 +0200 Subject: [PATCH 02/13] [SPARK-23938][SQL] Merging master into the feature branch and resolving confilicts. --- .../expressions/collectionOperations.scala | 44 ++++++++++++------- .../expressions/higherOrderFunctions.scala | 10 +++-- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 755cb95d59eba..60c830d070906 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3756,24 +3756,14 @@ object ArraySetLike { } /** - * Returns an array of the elements in the union of x and y, without duplicates + * The class performs union operation with two [[ArrayData]] objects. */ -@ExpressionDescription( - usage = """ - _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, - without duplicates. - """, - examples = """ - Examples: - > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(1, 2, 3, 5) - """, - since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike - with ComplexTypeMergingExpression { +class ArrayDataUnion(elementType: DataType) extends ((ArrayData, ArrayData) => ArrayData) { - @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { - if (elementTypeSupportEquals) { + private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) + + private lazy val evalFunc: (ArrayData, ArrayData) => ArrayData = { + if (ArraySetLike.typeSupportsEquals(elementType)) { (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] @@ -3834,6 +3824,28 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } } + def apply(array1: ArrayData, array2: ArrayData): ArrayData = evalFunc(array1, array2) +} + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + + @transient lazy val evalUnion = new ArrayDataUnion(elementType) + override def nullSafeEval(input1: Any, input2: Any): Any = { val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index cb7eef39aaa31..84d8156018e22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -392,7 +392,9 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) @transient lazy val MapType(_, rightValueType, _) = getMapType(right) - @transient lazy val merger = new ArrayDataMerger(keyType) + @transient lazy val arrayDataUnion = new ArrayDataUnion(keyType) + + @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) override def inputs: Seq[Expression] = left :: right :: Nil @@ -445,11 +447,11 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = { val mapData1 = value1.asInstanceOf[MapData] val mapData2 = value2.asInstanceOf[MapData] - val keys = merger.merge(mapData1.keyArray(), mapData2.keyArray()) + val keys = arrayDataUnion(mapData1.keyArray(), mapData2.keyArray()) val values = new GenericArrayData(new Array[Any](keys.numElements())) keys.foreach(keyType, (idx: Int, key: Any) => { - val v1 = GetMapValueUtil.getValueEval(mapData1, key, keyType, leftValueType, merger.ordering) - val v2 = GetMapValueUtil.getValueEval(mapData2, key, keyType, rightValueType, merger.ordering) + val v1 = GetMapValueUtil.getValueEval(mapData1, key, keyType, leftValueType, ordering) + val v2 = GetMapValueUtil.getValueEval(mapData2, key, keyType, rightValueType, ordering) keyVar.value.set(key) value1Var.value.set(v1) value2Var.value.set(v2) From be2e10d5c0f57110064e9ecf114160d45567f8b8 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 7 Aug 2018 23:11:39 +0200 Subject: [PATCH 03/13] [SPARK-23938][SQL] Addressing review comments --- .../expressions/collectionOperations.scala | 56 ++++----- .../expressions/complexTypeExtractors.scala | 29 ++--- .../expressions/higherOrderFunctions.scala | 112 ++++++++++++++++-- .../HigherOrderFunctionsSuite.scala | 47 ++++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 3 +- 5 files changed, 172 insertions(+), 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 60c830d070906..6adbbccd6ade3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3667,8 +3667,10 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { @transient protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) - @transient protected lazy val elementTypeSupportEquals = { - ArraySetLike.typeSupportsEquals(elementType) + @transient protected lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false } @transient protected lazy val canUseSpecializedHashSet = elementType match { @@ -3747,23 +3749,27 @@ object ArraySetLike { s"elements due to exceeding the array size limit " + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") } - - def typeSupportsEquals(dateType: DataType): Boolean = dateType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } } /** - * The class performs union operation with two [[ArrayData]] objects. + * Returns an array of the elements in the union of x and y, without duplicates */ -class ArrayDataUnion(elementType: DataType) extends ((ArrayData, ArrayData) => ArrayData) { - - private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { - private lazy val evalFunc: (ArrayData, ArrayData) => ArrayData = { - if (ArraySetLike.typeSupportsEquals(elementType)) { + @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] @@ -3824,28 +3830,6 @@ class ArrayDataUnion(elementType: DataType) extends ((ArrayData, ArrayData) => A } } - def apply(array1: ArrayData, array2: ArrayData): ArrayData = evalFunc(array1, array2) -} - -/** - * Returns an array of the elements in the union of x and y, without duplicates - */ -@ExpressionDescription( - usage = """ - _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, - without duplicates. - """, - examples = """ - Examples: - > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(1, 2, 3, 5) - """, - since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike - with ComplexTypeMergingExpression { - - @transient lazy val evalUnion = new ArrayDataUnion(elementType) - override def nullSafeEval(input1: Any, input2: Any): Any = { val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index e6ca4f71a924b..0091d1baca9ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -267,15 +267,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } } -object GetMapValueUtil -{ +/** + * Common base class for [[GetMapValue]] and [[ElementAt]]. + */ +abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - def getValueEval( - map: MapData, - key: Any, - keyType: DataType, - valueType: DataType, - ordering: Ordering[Any]): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { + val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() val values = map.valueArray() @@ -283,7 +281,7 @@ object GetMapValueUtil var i = 0 var found = false while (i < length && !found) { - if (ordering.equiv(keys.get(i, keyType), key)) { + if (ordering.equiv(keys.get(i, keyType), ordinal)) { found = true } else { i += 1 @@ -293,20 +291,9 @@ object GetMapValueUtil if (!found || values.isNullAt(i)) { null } else { - values.get(i, valueType) + values.get(i, dataType) } } -} - - -/** - * Common base class for [[GetMapValue]] and [[ElementAt]]. - */ -abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { - // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { - GetMapValueUtil.getValueEval(value.asInstanceOf[MapData], ordinal, keyType, dataType, ordering) - } def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { val index = ctx.freshName("index") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 84d8156018e22..d31ea9036137e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttrib import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods /** * A named lambda variable. @@ -392,8 +393,6 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) @transient lazy val MapType(_, rightValueType, _) = getMapType(right) - @transient lazy val arrayDataUnion = new ArrayDataUnion(keyType) - @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) override def inputs: Seq[Expression] = left :: right :: Nil @@ -444,19 +443,114 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) value2Var: NamedLambdaVariable), _) = function + private def keyTypeSupportsEquals = keyType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + @transient private lazy val getKeysWithValueIndexes: + (ArrayData, ArrayData) => Seq[(Any, Array[Option[Int]])] = { + if (keyTypeSupportsEquals) { + getKeysWithIndexesFast + } else { + getKeysWithIndexesBruteForce + } + } + + private def assertSizeOfArrayBuffer(size: Int): Unit = { + if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to zip maps with $size " + + s"unique keys due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + } + + private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = { + val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] + val hashMap = new mutable.OpenHashMap[Any, Array[Option[Int]]] + val keys = Array(keys1, keys2) + var z = 0 + while(z < 2) { + var i = 0 + val array = keys(z) + while (i < array.numElements()) { + val key = array.get(i, keyType) + hashMap.get(key) match { + case Some(indexes) => + if (indexes(z).isEmpty) indexes(z) = Some(i) + case None => + assertSizeOfArrayBuffer(arrayBuffer.size) + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + hashMap.put(key, indexes) + arrayBuffer += Tuple2(key, indexes) + } + i += 1 + } + z += 1 + } + arrayBuffer + } + + private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = { + val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] + val keys = Array(keys1, keys2) + var z = 0 + while(z < 2) { + var i = 0 + val array = keys(z) + while (i < array.numElements()) { + val key = array.get(i, keyType) + var found = false + var j = 0 + while (!found && j < arrayBuffer.size) { + val (bufferKey, indexes) = arrayBuffer(j) + if (ordering.equiv(bufferKey, key)) { + found = true + if(indexes(z).isEmpty) indexes(z) = Some(i) + } + j += 1 + } + if (!found) { + assertSizeOfArrayBuffer(arrayBuffer.size) + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + arrayBuffer += Tuple2(key, indexes) + } + i += 1 + } + z += 1 + } + arrayBuffer + } + + private def getValue(valueData: ArrayData, eType: DataType, index: Option[Int]) = index match { + case Some(i) => valueData.get(i, eType) + case None => null + } + private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = { val mapData1 = value1.asInstanceOf[MapData] val mapData2 = value2.asInstanceOf[MapData] - val keys = arrayDataUnion(mapData1.keyArray(), mapData2.keyArray()) - val values = new GenericArrayData(new Array[Any](keys.numElements())) - keys.foreach(keyType, (idx: Int, key: Any) => { - val v1 = GetMapValueUtil.getValueEval(mapData1, key, keyType, leftValueType, ordering) - val v2 = GetMapValueUtil.getValueEval(mapData2, key, keyType, rightValueType, ordering) + val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray()) + val length = keysWithIndexes.length + val keys = new GenericArrayData(new Array[Any](length)) + val values = new GenericArrayData(new Array[Any](length)) + val valueData1 = mapData1.valueArray() + val valueData2 = mapData2.valueArray() + var i = 0 + while(i < length) { + val (key, indexes) = keysWithIndexes(i) + val v1 = getValue(valueData1, leftValueType, indexes(0)) + val v2 = getValue(valueData2, rightValueType, indexes(1)) keyVar.value.set(key) value1Var.value.set(v1) value2Var.value.set(v2) - values.update(idx, functionForEval.eval(inputRow)) - }) + keys.update(i, key) + values.update(i, functionForEval.eval(inputRow)) + i += 1 + } new ArrayBasedMapData(keys, values) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 58fa80684a7a7..a567db09b92af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -45,13 +45,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } private def createLambda( - dt1: DataType, - nullable1: Boolean, - dt2: DataType, - nullable2: Boolean, - dt3: DataType, - nullable3: Boolean, - f: (Expression, Expression, Expression) => Expression): Expression = { + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) @@ -276,5 +276,38 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mss0, mssn, concat), null) + + def b(data: Byte*): Array[Byte] = Array[Byte](data: _*) + + val mbb0 = Literal.create(Map(b(1, 2) -> b(4), b(2, 1) -> b(5), b(1, 3) -> b(8)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb1 = Literal.create(Map(b(2, 1) -> b(7), b(1, 2) -> b(3), b(1, 1) -> b(6)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> null), + MapType(BinaryType, BinaryType, valueContainsNull = true)) + val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb4 = MapFromArrays( + Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)), + Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false))) + val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, valueContainsNull = false)) + + checkEvaluation( + map_zip_with(mbb0, mbb1, concat), + Map(b(1, 2) -> b(1, 2, 4, 3), b(2, 1) -> b(2, 1, 5, 7), b(1, 3) -> null, b(1, 1) -> null)) + checkEvaluation( + map_zip_with(mbb1, mbb2, concat), + Map(b(2, 1) -> null, b(1, 2) -> b(1, 2, 3, 2), b(1, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb3, concat), + Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb4, concat), + Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb4, mbb0, concat), + Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbbn, concat), + null) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5a2ee94ea86b4..1efa17eeab4e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2103,8 +2103,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null))) } - test("map_zip_with function - invalid") - { + test("map_zip_with function - invalid") { val df = Seq( (Map(1 -> 2), Map(1 -> "a"), Map("a" -> "b"), Map(Map(1 -> 2) -> 2), 1) ).toDF("mii", "mis", "mss", "mmi", "i") From 89a3da4e292690b78fbb41deef4104be3f843c1b Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 7 Aug 2018 23:36:11 +0200 Subject: [PATCH 04/13] [SPARK-23938][SQL] Addressing review comments --- .../expressions/higherOrderFunctions.scala | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 80a13bce53e1b..736cb61032e13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -24,8 +24,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods @@ -467,9 +466,11 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) @transient lazy val functionForEval: Expression = functionsForEval.head - @transient lazy val MapType(keyType, leftValueType, _) = getMapType(left) + @transient lazy val (keyType, leftValueType, _) = + HigherOrderFunction.mapKeyValueArgumentType(left.dataType) - @transient lazy val MapType(_, rightValueType, _) = getMapType(right) + @transient lazy val (_, rightValueType, _) = + HigherOrderFunction.mapKeyValueArgumentType(right.dataType) @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) @@ -491,11 +492,6 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) } } - private def getMapType(expr: Expression) = expr.dataType match { - case m: MapType => m - case _ => MapType.defaultConcreteType - } - override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) copy(function = f(function, arguments)) From 12ad8b2248b7acb4a04289ca8da439ecb63206a9 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 7 Aug 2018 23:42:34 +0200 Subject: [PATCH 05/13] [SPARK-23938][SQL] Adding new lines back --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 + .../spark/sql/catalyst/expressions/complexTypeExtractors.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6adbbccd6ade3..fbb182631eefa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3751,6 +3751,7 @@ object ArraySetLike { } } + /** * Returns an array of the elements in the union of x and y, without duplicates */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 0091d1baca9ee..8994eeff92c7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -270,6 +270,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) /** * Common base class for [[GetMapValue]] and [[ElementAt]]. */ + abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { From 6aeaaa807b41577d9549dd9955b8e50be3005407 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Wed, 8 Aug 2018 23:30:02 +0200 Subject: [PATCH 06/13] [SPARK-23938][SQL] Map comparison in tests ignoring ordering --- .../expressions/ExpressionEvalHelper.scala | 51 +++++++++++++++++-- .../HigherOrderFunctionsSuite.scala | 2 +- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 6684e5ce18d4c..a50e26dc8f5c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.scalacheck.Gen @@ -97,10 +98,48 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa isSame } case (result: MapData, expected: MapData) => - val kt = dataType.asInstanceOf[MapType].keyType - val vt = dataType.asInstanceOf[MapType].valueType - checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) && - checkResult(result.valueArray, expected.valueArray, ArrayType(vt)) + // Comparison ignoring order of entries and considering duplicated keys + // No hashing applied due to types without support of equals + result.numElements == expected.numElements && { + val MapType(keyType, valueType, _) = exprDataType + + // Calculate a number of entry occurrences (key, value, numberOfOccurrences) + def calculateOccurrences(mapData: MapData): Seq[(Any, Any, Int)] = { + var arrayBuffer = new ArrayBuffer[(Any, Any, Int)]() + val keys = mapData.keyArray() + val values = mapData.valueArray() + var i = 0 + while(i < keys.numElements) { + var found = false + val key = keys.get(i, keyType) + val value = values.get(i, valueType) + var j = 0 + while (!found && j < arrayBuffer.size) { + val (bufferKey, bufferValue, numOfOccurrences) = arrayBuffer(j) + if (checkResult(key, bufferKey, keyType) && + checkResult(value, bufferValue, valueType)) { + found = true + arrayBuffer(j) = (bufferKey, bufferValue, numOfOccurrences + 1) + } + j += 1 + } + if (!found) arrayBuffer += Tuple3(key, value, 1) + i += 1 + } + arrayBuffer + } + + val resultOccurrences = calculateOccurrences(result) + val expectedOccurrences = calculateOccurrences(expected) + + resultOccurrences.map { case (resultKey, resultValue, resultOccurrences) => + expectedOccurrences.exists { case (expectedKey, expectedValue, expectedOccurrences) => + resultOccurrences == expectedOccurrences && + checkResult(resultKey, expectedKey, keyType) && + checkResult(resultValue, expectedValue, valueType) + } + }.fold(true)(_ && _) + } case (result: Double, expected: Double) => if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => @@ -225,7 +264,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa val lit = InternalRow(expected, expected) val expectedRow = UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) - if (unsafeRow != expectedRow) { + val field = StructField("field", expression.dataType) + val dataType = StructType(field :: field :: Nil) + if (!checkResult(unsafeRow, expectedRow, dataType)) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 993a430874274..cf72e97106aeb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -277,7 +277,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null)) checkEvaluation( map_zip_with(mii0, mii2, multiplyKeyWithValues), - Map(1 -> null, 2 -> -80, 3 -> null)) + Map(1 -> null, 3 -> null, 2 -> -80)) checkEvaluation( map_zip_with(mii0, mii3, multiplyKeyWithValues), Map(1 -> null, 2 -> null, 3 -> null)) From 38ce4e72209d2f21cdb0993f89799d563e9ecd97 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Wed, 8 Aug 2018 23:31:30 +0200 Subject: [PATCH 07/13] [SPARK-23938][SQL] Addressing review comments --- .../expressions/higherOrderFunctions.scala | 113 ++++++++---------- 1 file changed, 53 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 736cb61032e13..2650895b7bbdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -524,7 +524,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) } @transient private lazy val getKeysWithValueIndexes: - (ArrayData, ArrayData) => Seq[(Any, Array[Option[Int]])] = { + (ArrayData, ArrayData) => mutable.Iterable[(Any, (Option[Int], Option[Int]))] = { if (keyTypeSupportsEquals) { getKeysWithIndexesFast } else { @@ -541,83 +541,76 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) } private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = { - val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] - val hashMap = new mutable.OpenHashMap[Any, Array[Option[Int]]] - val keys = Array(keys1, keys2) - var z = 0 - while(z < 2) { - var i = 0 - val array = keys(z) - while (i < array.numElements()) { - val key = array.get(i, keyType) - hashMap.get(key) match { - case Some(indexes) => - if (indexes(z).isEmpty) indexes(z) = Some(i) - case None => - assertSizeOfArrayBuffer(arrayBuffer.size) - val indexes = Array[Option[Int]](None, None) - indexes(z) = Some(i) - hashMap.put(key, indexes) - arrayBuffer += Tuple2(key, indexes) - } - i += 1 + val hashMap = new mutable.OpenHashMap[Any, (Option[Int], Option[Int])] + var i = 0 + while (i < keys1.numElements) { + val key = keys1.get(i, keyType) + if(!hashMap.contains(key)) hashMap.put(key, (Some(i), None)) + i += 1 + } + i = 0 + while (i < keys2.numElements) { + val key = keys2.get(i, keyType) + hashMap.get(key) match { + case Some((index1, index2)) => + if (index2.isEmpty) hashMap.update(key, (index1, Some(i))) + case None => + hashMap.put(key, (None, Some(i))) } - z += 1 + i += 1 } - arrayBuffer + hashMap } private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = { - val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] - val keys = Array(keys1, keys2) - var z = 0 - while(z < 2) { - var i = 0 - val array = keys(z) - while (i < array.numElements()) { - val key = array.get(i, keyType) - var found = false - var j = 0 - while (!found && j < arrayBuffer.size) { - val (bufferKey, indexes) = arrayBuffer(j) - if (ordering.equiv(bufferKey, key)) { - found = true - if(indexes(z).isEmpty) indexes(z) = Some(i) - } - j += 1 - } - if (!found) { - assertSizeOfArrayBuffer(arrayBuffer.size) - val indexes = Array[Option[Int]](None, None) - indexes(z) = Some(i) - arrayBuffer += Tuple2(key, indexes) + val arrayBuffer = new mutable.ArrayBuffer[(Any, (Option[Int], Option[Int]))] + var i = 0 + while (i < keys1.numElements) { + val key = keys1.get(i, keyType) + var found = false + var j = 0 + while (!found && j < arrayBuffer.size) { + if (ordering.equiv(arrayBuffer(j)._1, key)) found = true + j += 1 + } + if (!found) arrayBuffer += Tuple2(key, (Some(i), None)) + i += 1 + } + i = 0 + while (i < keys2.numElements) { + val key = keys2.get(i, keyType) + var found = false + var j = 0 + while (!found && j < arrayBuffer.size) { + val (bufferKey, (index1, index2)) = arrayBuffer(j) + if (ordering.equiv(bufferKey, key)) { + found = true + if(index2.isEmpty) arrayBuffer(j) = Tuple2(key, (index1, Some(i))) } - i += 1 + j += 1 } - z += 1 + if (!found) { + assertSizeOfArrayBuffer(arrayBuffer.size) + arrayBuffer += Tuple2(key, (None, Some(i))) + } + i += 1 } arrayBuffer } - private def getValue(valueData: ArrayData, eType: DataType, index: Option[Int]) = index match { - case Some(i) => valueData.get(i, eType) - case None => null - } - private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = { val mapData1 = value1.asInstanceOf[MapData] val mapData2 = value2.asInstanceOf[MapData] val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray()) - val length = keysWithIndexes.length - val keys = new GenericArrayData(new Array[Any](length)) - val values = new GenericArrayData(new Array[Any](length)) + val size = keysWithIndexes.size + val keys = new GenericArrayData(new Array[Any](size)) + val values = new GenericArrayData(new Array[Any](size)) val valueData1 = mapData1.valueArray() val valueData2 = mapData2.valueArray() var i = 0 - while(i < length) { - val (key, indexes) = keysWithIndexes(i) - val v1 = getValue(valueData1, leftValueType, indexes(0)) - val v2 = getValue(valueData2, rightValueType, indexes(1)) + keysWithIndexes.foreach { case (key, (index1, index2)) => + val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null) + val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null) keyVar.value.set(key) value1Var.value.set(v1) value2Var.value.set(v2) From 562ee81549c0ea4748b560d02c4d4fc44112c63d Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Thu, 9 Aug 2018 16:31:12 +0200 Subject: [PATCH 08/13] Revert "[SPARK-23938][SQL] Map comparison in tests ignoring ordering" This reverts commit 6aeaaa8 --- .../expressions/ExpressionEvalHelper.scala | 51 ++----------------- .../HigherOrderFunctionsSuite.scala | 2 +- 2 files changed, 6 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a50e26dc8f5c7..6684e5ce18d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.scalacheck.Gen @@ -98,48 +97,10 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa isSame } case (result: MapData, expected: MapData) => - // Comparison ignoring order of entries and considering duplicated keys - // No hashing applied due to types without support of equals - result.numElements == expected.numElements && { - val MapType(keyType, valueType, _) = exprDataType - - // Calculate a number of entry occurrences (key, value, numberOfOccurrences) - def calculateOccurrences(mapData: MapData): Seq[(Any, Any, Int)] = { - var arrayBuffer = new ArrayBuffer[(Any, Any, Int)]() - val keys = mapData.keyArray() - val values = mapData.valueArray() - var i = 0 - while(i < keys.numElements) { - var found = false - val key = keys.get(i, keyType) - val value = values.get(i, valueType) - var j = 0 - while (!found && j < arrayBuffer.size) { - val (bufferKey, bufferValue, numOfOccurrences) = arrayBuffer(j) - if (checkResult(key, bufferKey, keyType) && - checkResult(value, bufferValue, valueType)) { - found = true - arrayBuffer(j) = (bufferKey, bufferValue, numOfOccurrences + 1) - } - j += 1 - } - if (!found) arrayBuffer += Tuple3(key, value, 1) - i += 1 - } - arrayBuffer - } - - val resultOccurrences = calculateOccurrences(result) - val expectedOccurrences = calculateOccurrences(expected) - - resultOccurrences.map { case (resultKey, resultValue, resultOccurrences) => - expectedOccurrences.exists { case (expectedKey, expectedValue, expectedOccurrences) => - resultOccurrences == expectedOccurrences && - checkResult(resultKey, expectedKey, keyType) && - checkResult(resultValue, expectedValue, valueType) - } - }.fold(true)(_ && _) - } + val kt = dataType.asInstanceOf[MapType].keyType + val vt = dataType.asInstanceOf[MapType].valueType + checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) && + checkResult(result.valueArray, expected.valueArray, ArrayType(vt)) case (result: Double, expected: Double) => if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => @@ -264,9 +225,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa val lit = InternalRow(expected, expected) val expectedRow = UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) - val field = StructField("field", expression.dataType) - val dataType = StructType(field :: field :: Nil) - if (!checkResult(unsafeRow, expectedRow, dataType)) { + if (unsafeRow != expectedRow) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index cf72e97106aeb..993a430874274 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -277,7 +277,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null)) checkEvaluation( map_zip_with(mii0, mii2, multiplyKeyWithValues), - Map(1 -> null, 3 -> null, 2 -> -80)) + Map(1 -> null, 2 -> -80, 3 -> null)) checkEvaluation( map_zip_with(mii0, mii3, multiplyKeyWithValues), Map(1 -> null, 2 -> null, 3 -> null)) From 5d2a78ef09edb425593eb90b9bab616d2cd626ab Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Thu, 9 Aug 2018 17:44:36 +0200 Subject: [PATCH 09/13] [SPARK-23938][SQL] Modification to a solution with LinkedHashMap and indexes represented as arrays. --- .../expressions/higherOrderFunctions.scala | 102 +++++++++--------- .../spark/sql/DataFrameFunctionsSuite.scala | 2 +- 2 files changed, 53 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2650895b7bbdc..aa7aa7733611c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -22,7 +22,7 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ @@ -466,12 +466,15 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) @transient lazy val functionForEval: Expression = functionsForEval.head - @transient lazy val (keyType, leftValueType, _) = + @transient lazy val (leftKeyType, leftValueType, _) = HigherOrderFunction.mapKeyValueArgumentType(left.dataType) - @transient lazy val (_, rightValueType, _) = + @transient lazy val (rightKeyType, rightValueType, _) = HigherOrderFunction.mapKeyValueArgumentType(right.dataType) + @transient lazy val keyType = + TypeCoercion.findTightestCommonType(leftKeyType, rightKeyType).getOrElse(NullType) + @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) override def inputs: Seq[Expression] = left :: right :: Nil @@ -523,8 +526,13 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) case _ => false } + /** + * The function accepts two key arrays and returns a collection of keys with indexes + * to value arrays. Indexes are represented as an array of two items. This is a small + * optimization leveraging mutability of arrays. + */ @transient private lazy val getKeysWithValueIndexes: - (ArrayData, ArrayData) => mutable.Iterable[(Any, (Option[Int], Option[Int]))] = { + (ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = { if (keyTypeSupportsEquals) { getKeysWithIndexesFast } else { @@ -541,59 +549,53 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) } private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = { - val hashMap = new mutable.OpenHashMap[Any, (Option[Int], Option[Int])] - var i = 0 - while (i < keys1.numElements) { - val key = keys1.get(i, keyType) - if(!hashMap.contains(key)) hashMap.put(key, (Some(i), None)) - i += 1 - } - i = 0 - while (i < keys2.numElements) { - val key = keys2.get(i, keyType) - hashMap.get(key) match { - case Some((index1, index2)) => - if (index2.isEmpty) hashMap.update(key, (index1, Some(i))) - case None => - hashMap.put(key, (None, Some(i))) + val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + hashMap.get(key) match { + case Some(indexes) => + if (indexes(z).isEmpty) { + indexes(z) = Some(i) + } + case None => + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + hashMap.put(key, indexes) + } + i += 1 } - i += 1 } hashMap } private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = { - val arrayBuffer = new mutable.ArrayBuffer[(Any, (Option[Int], Option[Int]))] - var i = 0 - while (i < keys1.numElements) { - val key = keys1.get(i, keyType) - var found = false - var j = 0 - while (!found && j < arrayBuffer.size) { - if (ordering.equiv(arrayBuffer(j)._1, key)) found = true - j += 1 - } - if (!found) arrayBuffer += Tuple2(key, (Some(i), None)) - i += 1 - } - i = 0 - while (i < keys2.numElements) { - val key = keys2.get(i, keyType) - var found = false - var j = 0 - while (!found && j < arrayBuffer.size) { - val (bufferKey, (index1, index2)) = arrayBuffer(j) - if (ordering.equiv(bufferKey, key)) { - found = true - if(index2.isEmpty) arrayBuffer(j) = Tuple2(key, (index1, Some(i))) + val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + var found = false + var j = 0 + while (!found && j < arrayBuffer.size) { + val (bufferKey, indexes) = arrayBuffer(j) + if (ordering.equiv(bufferKey, key)) { + found = true + if(indexes(z).isEmpty) { + indexes(z) = Some(i) + } + } + j += 1 } - j += 1 - } - if (!found) { - assertSizeOfArrayBuffer(arrayBuffer.size) - arrayBuffer += Tuple2(key, (None, Some(i))) + if (!found) { + assertSizeOfArrayBuffer(arrayBuffer.size) + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + arrayBuffer += Tuple2(key, indexes) + } + i += 1 } - i += 1 } arrayBuffer } @@ -608,7 +610,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) val valueData1 = mapData1.valueArray() val valueData2 = mapData2.valueArray() var i = 0 - keysWithIndexes.foreach { case (key, (index1, index2)) => + for ((key, Array(index1, index2)) <- keysWithIndexes) { val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null) val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null) keyVar.value.set(key) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 00881de5058b8..36216ca09a0d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2172,7 +2172,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "been two maps with the same key type")) val ex4 = intercept[AnalysisException] { - df.selectExpr("map_zip_with(mmi,mmi, (x, y, z) -> x)") + df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") } assert(ex4.getMessage.contains("function map_zip_with does not support ordering on type map")) } From 3c849cbe70922bd22029b41f2558100dfbc16d9e Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 10 Aug 2018 00:19:33 +0200 Subject: [PATCH 10/13] [SPARK-23938][SQL] Coercion rule for the key type --- .../sql/catalyst/analysis/TypeCoercion.scala | 14 ++++++++++++++ .../expressions/higherOrderFunctions.scala | 6 +++--- .../spark/sql/DataFrameFunctionsSuite.scala | 19 ++++++++++++++++--- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 27839d72c6306..f0d714c0d6f77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -602,6 +602,20 @@ object TypeCoercion { CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) + case m @ MapZipWith(left, right, function) if MapType.acceptsType(left.dataType) && + MapType.acceptsType(right.dataType) && !m.leftKeyType.sameType(m.rightKeyType) => + findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { + case Some(finalKeyType) => + val newLeft = castIfNotSameType( + left, + MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull)) + val newRight = castIfNotSameType( + right, + MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull)) + MapZipWith(newLeft, newRight, function) + case None => m + } + // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index aa7aa7733611c..bab8e32e2669b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -466,10 +466,10 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) @transient lazy val functionForEval: Expression = functionsForEval.head - @transient lazy val (leftKeyType, leftValueType, _) = + @transient lazy val (leftKeyType, leftValueType, leftValueContainsNull) = HigherOrderFunction.mapKeyValueArgumentType(left.dataType) - @transient lazy val (rightKeyType, rightValueType, _) = + @transient lazy val (rightKeyType, rightValueType, rightValueContainsNull) = HigherOrderFunction.mapKeyValueArgumentType(right.dataType) @transient lazy val keyType = @@ -490,7 +490,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) case (MapType(k1, _, _), MapType(k2, _, _)) if k1.sameType(k2) => TypeUtils.checkForOrderingExpr(k1, s"function $prettyName") case _ => TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + - s"been two ${MapType.simpleString}s with the same key type, but it's " + + s"been two ${MapType.simpleString}s with compatible key types, but it's " + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 36216ca09a0d3..1b28c83a99898 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2133,6 +2133,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null))) } + test("map_zip_with function - key type coercion to long") { + // TODO: Add proper SQL coercion tests + val df = Seq( + (Map(8 -> 6, 3 -> 5, 6 -> 2), Map[Long, Integer]((6L, 4), (8L, 2), (3L, 2))), + (Map(10 -> 6, 8 -> 3), Map[Long, Integer]((8L, 4), (4L, null))) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), + Seq( + Row(Map(8L -> true, 3L -> false, 6L -> true)), + Row(Map(10L -> null, 8L -> false, 4L -> null)))) + } + test("map_zip_with function - map of complex types") { val df = Seq( (Map("z" -> "a", "y" -> "b", "x" -> "c"), Map("x" -> "a", "z" -> "c")), @@ -2160,16 +2173,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) val ex2 = intercept[AnalysisException] { - df.selectExpr("map_zip_with(mis, mss, (x, y, z) -> concat(x, y, z))") + df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") } assert(ex2.getMessage.contains("The input to function map_zip_with should have " + - "been two maps with the same key type")) + "been two maps with compatible key types")) val ex3 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") } assert(ex3.getMessage.contains("The input to function map_zip_with should have " + - "been two maps with the same key type")) + "been two maps with compatible key types")) val ex4 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") From 595161fefbf55711b76530a9e53aff73491febd6 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 10 Aug 2018 17:48:29 +0200 Subject: [PATCH 11/13] [SPARK-23938][SQL] Bug fix + type coercion tests --- .../sql/catalyst/analysis/TypeCoercion.scala | 11 +- .../expressions/higherOrderFunctions.scala | 2 +- .../inputs/typeCoercion/native/mapZipWith.sql | 66 ++++++++ .../typeCoercion/native/mapZipWith.sql.out | 142 ++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 13 -- 5 files changed, 219 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f0d714c0d6f77..4c43bc5ccf1f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -231,6 +231,15 @@ object TypeCoercion { }) } + /** + * Similar to [[findTightestCommonType]] but with string promotion. + */ + def findWiderTypeForTwoExceptDecimals(t1: DataType, t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) + .orElse(stringPromotion(t1, t2)) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwoExceptDecimals)) + } + /** * Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to * string. If the wider decimal type exceeds system limitation, this rule will truncate @@ -604,7 +613,7 @@ object TypeCoercion { case m @ MapZipWith(left, right, function) if MapType.acceptsType(left.dataType) && MapType.acceptsType(right.dataType) && !m.leftKeyType.sameType(m.rightKeyType) => - findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { + findWiderTypeForTwoExceptDecimals(m.leftKeyType, m.rightKeyType) match { case Some(finalKeyType) => val newLeft = castIfNotSameType( left, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index bab8e32e2669b..daf539aba21a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -473,7 +473,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) HigherOrderFunction.mapKeyValueArgumentType(right.dataType) @transient lazy val keyType = - TypeCoercion.findTightestCommonType(leftKeyType, rightKeyType).getOrElse(NullType) + TypeCoercion.findWiderTypeForTwoExceptDecimals(leftKeyType, rightKeyType).getOrElse(NullType) @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql new file mode 100644 index 0000000000000..119f868cb48e6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql @@ -0,0 +1,66 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +); + +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out new file mode 100644 index 0000000000000..7dc10f4fef4ca --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -0,0 +1,142 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 1 schema +struct>> +-- !query 1 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 2 +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 2 schema +struct>> +-- !query 2 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 3 +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 3 schema +struct>> +-- !query 3 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 4 +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 4 schema +struct>> +-- !query 4 output +{2.0:{"k":2.0,"v1":1.0,"v2":1.0}} + + +-- !query 5 +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct('k', namedlambdavariable(), 'v1', namedlambdavariable(), 'v2', namedlambdavariable()), namedlambdavariable(), namedlambdavariable(), namedlambdavariable()))' due to data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but it's [map, map].; line 1 pos 7 + + +-- !query 6 +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 6 schema +struct>> +-- !query 6 output +{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}} + + +-- !query 7 +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 7 schema +struct>> +-- !query 7 output +{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}} + + +-- !query 8 +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 8 schema +struct>> +-- !query 8 output +{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}} + + +-- !query 9 +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 9 schema +struct>> +-- !query 9 output +{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}} + + +-- !query 10 +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 10 schema +struct,struct,v1:array,v2:array>>> +-- !query 10 output +{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}} + + +-- !query 11 +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 11 schema +struct,struct,v1:struct,v2:struct>>> +-- !query 11 output +{{"col1":1,"col2":2}:{"k":{"col1":1,"col2":2},"v1":{"col1":1,"col2":2},"v2":{"col1":1,"col2":2}}} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1b28c83a99898..07ac7844f830d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2133,19 +2133,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null))) } - test("map_zip_with function - key type coercion to long") { - // TODO: Add proper SQL coercion tests - val df = Seq( - (Map(8 -> 6, 3 -> 5, 6 -> 2), Map[Long, Integer]((6L, 4), (8L, 2), (3L, 2))), - (Map(10 -> 6, 8 -> 3), Map[Long, Integer]((8L, 4), (4L, null))) - ).toDF("m1", "m2") - - checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), - Seq( - Row(Map(8L -> true, 3L -> false, 6L -> true)), - Row(Map(10L -> null, 8L -> false, 4L -> null)))) - } - test("map_zip_with function - map of complex types") { val df = Seq( (Map("z" -> "a", "y" -> "b", "x" -> "c"), Map("x" -> "a", "z" -> "c")), From 2b7e9e59a69d991eba24ca86b8df5fe54b6e077f Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 14 Aug 2018 00:00:53 +0200 Subject: [PATCH 12/13] [SPARK-23938][SQL] Key type resolved by findCommonTypeDifferentOnlyInNullFlags + coercion rule refactoring --- .../sql/catalyst/analysis/TypeCoercion.scala | 48 ++++++++++--------- .../expressions/higherOrderFunctions.scala | 41 +++++++++------- .../typeCoercion/native/mapZipWith.sql.out | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 12 +++-- 4 files changed, 59 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 4bd49cd636ff5..288b6358fbff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + MapZipWithCoercion :: EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: @@ -238,15 +239,6 @@ object TypeCoercion { }) } - /** - * Similar to [[findTightestCommonType]] but with string promotion. - */ - def findWiderTypeForTwoExceptDecimals(t1: DataType, t2: DataType): Option[DataType] = { - findTightestCommonType(t1, t2) - .orElse(stringPromotion(t1, t2)) - .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwoExceptDecimals)) - } - /** * Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to * string. If the wider decimal type exceeds system limitation, this rule will truncate @@ -618,20 +610,6 @@ object TypeCoercion { CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) - case m @ MapZipWith(left, right, function) if MapType.acceptsType(left.dataType) && - MapType.acceptsType(right.dataType) && !m.leftKeyType.sameType(m.rightKeyType) => - findWiderTypeForTwoExceptDecimals(m.leftKeyType, m.rightKeyType) match { - case Some(finalKeyType) => - val newLeft = castIfNotSameType( - left, - MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull)) - val newRight = castIfNotSameType( - right, - MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull)) - MapZipWith(newLeft, newRight, function) - case None => m - } - // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) @@ -785,6 +763,30 @@ object TypeCoercion { } } + /** + * Coerces key types of two different [[MapType]] arguments of the [[MapZipWith]] expression + * to a common type. + */ + object MapZipWithCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + // Lambda function isn't resolved when the rule is executed. + case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved && + MapType.acceptsType(a.dataType)) && !m.leftKeyType.sameType(m.rightKeyType) => + findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { + case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) && + !Cast.forceNullable(m.rightKeyType, finalKeyType) => + val newLeft = castIfNotSameType( + left, + MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull)) + val newRight = castIfNotSameType( + right, + MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull)) + MapZipWith(newLeft, newRight, function) + case _ => m + } + } + } + /** * Coerces the types of [[Elt]] children to expected ones. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index b952d23e02680..afaa991d67235 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -520,40 +520,49 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) @transient lazy val functionForEval: Expression = functionsForEval.head - @transient lazy val (leftKeyType, leftValueType, leftValueContainsNull) = - HigherOrderFunction.mapKeyValueArgumentType(left.dataType) + @transient lazy val MapType(leftKeyType, leftValueType, leftValueContainsNull) = left.dataType - @transient lazy val (rightKeyType, rightValueType, rightValueContainsNull) = - HigherOrderFunction.mapKeyValueArgumentType(right.dataType) + @transient lazy val MapType(rightKeyType, rightValueType, rightValueContainsNull) = right.dataType @transient lazy val keyType = - TypeCoercion.findWiderTypeForTwoExceptDecimals(leftKeyType, rightKeyType).getOrElse(NullType) + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(leftKeyType, rightKeyType).get @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) - override def inputs: Seq[Expression] = left :: right :: Nil + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = MapType :: MapType :: Nil override def functions: Seq[Expression] = function :: Nil + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (MapType(k1, _, _), MapType(k2, _, _)) if k1.sameType(k2) => - TypeUtils.checkForOrderingExpr(k1, s"function $prettyName") - case _ => TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + - s"been two ${MapType.simpleString}s with compatible key types, but it's " + - s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") - } - } - override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) copy(function = f(function, arguments)) } + override def checkArgumentDataTypes(): TypeCheckResult = { + super.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (leftKeyType.sameType(rightKeyType)) { + TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName") + } else { + TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + + s"been two ${MapType.simpleString}s with compatible key types, but the key types are " + + s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].") + } + case failure => failure + } + } + + // Nothing to check since the data type of the lambda function can be anything. + override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess + override def eval(input: InternalRow): Any = { val value1 = left.eval(input) if (value1 == null) { diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index 7dc10f4fef4ca..7f7e2f07b9e74 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -85,7 +85,7 @@ FROM various_maps struct<> -- !query 5 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct('k', namedlambdavariable(), 'v1', namedlambdavariable(), 'v2', namedlambdavariable()), namedlambdavariable(), namedlambdavariable(), namedlambdavariable()))' due to data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but it's [map, map].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 -- !query 6 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a7c2be302c353..387bbc480d65b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2287,15 +2287,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "been two maps with compatible key types")) val ex3 = intercept[AnalysisException] { - df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") + df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") } - assert(ex3.getMessage.contains("The input to function map_zip_with should have " + - "been two maps with compatible key types")) + assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) val ex4 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") + } + assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) + + val ex5 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") } - assert(ex4.getMessage.contains("function map_zip_with does not support ordering on type map")) + assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { From bcd4e0f3956f622d99e5fcf4a98c155b92dbeb9a Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 14 Aug 2018 09:54:56 +0200 Subject: [PATCH 13/13] [SPARK-23938][SQL] Addressing review comments --- .../sql/catalyst/expressions/higherOrderFunctions.scala | 5 ++--- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index afaa991d67235..22210f692e755 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -518,7 +518,7 @@ case class ArrayAggregate( case class MapZipWith(left: Expression, right: Expression, function: Expression) extends HigherOrderFunction with CodegenFallback { - @transient lazy val functionForEval: Expression = functionsForEval.head + def functionForEval: Expression = functionsForEval.head @transient lazy val MapType(leftKeyType, leftValueType, leftValueContainsNull) = left.dataType @@ -560,8 +560,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) } } - // Nothing to check since the data type of the lambda function can be anything. - override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess + override def checkInputDataTypes(): TypeCheckResult = checkArgumentDataTypes() override def eval(input: InternalRow): Any = { val value1 = left.eval(input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 387bbc480d65b..8d7695b6ebbcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2254,7 +2254,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null))) } - test("map_zip_with function - map of complex types") { + test("map_zip_with function - map of non-primitive types") { val df = Seq( (Map("z" -> "a", "y" -> "b", "x" -> "c"), Map("x" -> "a", "z" -> "c")), (Map("b" -> "a", "c" -> "d"), Map("c" -> "a", "b" -> null, "d" -> "k")),