From 54f0f31aaa14de7c44c336580c7ed18e8ffb4b54 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 4 Dec 2018 13:35:09 +0100 Subject: [PATCH 1/3] [SPARK-25829][SQL][FOLLOWUP] Refactor MapConcat in order to check properly the limit size --- .../expressions/collectionOperations.scala | 77 +------------------ .../catalyst/util/ArrayBasedMapBuilder.scala | 17 +++- 2 files changed, 16 insertions(+), 78 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fa8e38acd522d..67f6739b1e18f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -554,13 +554,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres return null } - val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + - s"elements due to exceeding the map size limit " + - s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") - } - for (map <- maps) { mapBuilder.putAll(map.keyArray(), map.valueArray()) } @@ -569,8 +562,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mapCodes = children.map(_.genCode(ctx)) - val keyType = dataType.keyType - val valueType = dataType.valueType val argsName = ctx.freshName("args") val hasNullName = ctx.freshName("hasNull") val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) @@ -610,41 +601,12 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres ) val idxName = ctx.freshName("idx") - val numElementsName = ctx.freshName("numElems") - val finKeysName = ctx.freshName("finalKeys") - val finValsName = ctx.freshName("finalValues") - - val keyConcat = genCodeForArrays(ctx, keyType, false) - - val valueConcat = - if (valueType.sameType(keyType) && - !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) { - keyConcat - } else { - genCodeForArrays(ctx, valueType, dataType.valueContainsNull) - } - - val keyArgsName = ctx.freshName("keyArgs") - val valArgsName = ctx.freshName("valArgs") - val mapMerge = s""" - |ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}]; - |ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}]; - |long $numElementsName = 0; |for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { - | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); - | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); - | $numElementsName += $argsName[$idxName].numElements(); + | $builderTerm.putAll($argsName[$idxName].keyArray(), $argsName[$idxName].valueArray()); |} - |if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful attempt to concat maps with " + - | $numElementsName + " elements due to exceeding the map size limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); - |} - |ArrayData $finKeysName = $keyConcat($keyArgsName, (int) $numElementsName); - |ArrayData $finValsName = $valueConcat($valArgsName, (int) $numElementsName); - |${ev.value} = $builderTerm.from($finKeysName, $finValsName); + |${ev.value} = $builderTerm.build(); """.stripMargin ev.copy( @@ -660,41 +622,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """.stripMargin) } - private def genCodeForArrays( - ctx: CodegenContext, - elementType: DataType, - checkForNull: Boolean): String = { - val counter = ctx.freshName("counter") - val arrayData = ctx.freshName("arrayData") - val argsName = ctx.freshName("args") - val numElemName = ctx.freshName("numElements") - val y = ctx.freshName("y") - val z = ctx.freshName("z") - - val allocation = CodeGenerator.createArrayData( - arrayData, elementType, numElemName, s" $prettyName failed.") - val assignment = CodeGenerator.createArrayAssignment( - arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull) - - val concat = ctx.freshName("concat") - val concatDef = - s""" - |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { - | $allocation - | int $counter = 0; - | for (int $y = 0; $y < ${children.length}; $y++) { - | for (int $z = 0; $z < $argsName[$y].numElements(); $z++) { - | $assignment - | $counter++; - | } - | } - | return $arrayData; - |} - """.stripMargin - - ctx.addNewFunction(concat, concatDef) - } - override def prettyName: String = "map_concat" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index e7cd61655dc9a..ffe46e09f6a8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods /** * A builder of [[ArrayBasedMapData]], which fails if a null map key is detected, and removes @@ -47,13 +48,17 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria private lazy val keyGetter = InternalRow.getAccessor(keyType) private lazy val valueGetter = InternalRow.getAccessor(valueType) - def put(key: Any, value: Any): Unit = { + def put(key: Any, value: Any, withSizeCheck: Boolean = false): Unit = { if (key == null) { throw new RuntimeException("Cannot use null as map key.") } val index = keyToIndex.getOrDefault(key, -1) if (index == -1) { + if (withSizeCheck && size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $size elements " + + s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } keyToIndex.put(key, values.length) keys.append(key) values.append(value) @@ -76,10 +81,11 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria throw new RuntimeException( "The key array and value array of MapData must have the same length.") } - + val sizeCheckRequired = + size + keyArray.numElements() > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH var i = 0 while (i < keyArray.numElements()) { - put(keyGetter(keyArray, i), valueGetter(valueArray, i)) + put(keyGetter(keyArray, i), valueGetter(valueArray, i), withSizeCheck = sizeCheckRequired) i += 1 } } @@ -117,4 +123,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria build() } } + + /** + * Returns the current size of the map which is going to be produced by the current builder. + */ + def size: Int = keys.size } From 724db5cd752d2c79032a887e8ae2806d9a5acc65 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 4 Dec 2018 15:45:22 +0100 Subject: [PATCH 2/3] address comment] --- .../apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index ffe46e09f6a8f..b5103dfa6a198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -56,7 +56,7 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria val index = keyToIndex.getOrDefault(key, -1) if (index == -1) { if (withSizeCheck && size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw new RuntimeException(s"Unsuccessful attempt to concat maps with $size elements " + + throw new RuntimeException(s"Unsuccessful attempt to build maps with $size elements " + s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") } keyToIndex.put(key, values.length) From 38f3bfa237570a3204c355774bb323973f962d67 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 4 Dec 2018 16:08:39 +0100 Subject: [PATCH 3/3] address comment --- .../spark/sql/catalyst/util/ArrayBasedMapBuilder.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index b5103dfa6a198..98934368205ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -48,14 +48,14 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria private lazy val keyGetter = InternalRow.getAccessor(keyType) private lazy val valueGetter = InternalRow.getAccessor(valueType) - def put(key: Any, value: Any, withSizeCheck: Boolean = false): Unit = { + def put(key: Any, value: Any): Unit = { if (key == null) { throw new RuntimeException("Cannot use null as map key.") } val index = keyToIndex.getOrDefault(key, -1) if (index == -1) { - if (withSizeCheck && size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful attempt to build maps with $size elements " + s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") } @@ -81,11 +81,10 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria throw new RuntimeException( "The key array and value array of MapData must have the same length.") } - val sizeCheckRequired = - size + keyArray.numElements() > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + var i = 0 while (i < keyArray.numElements()) { - put(keyGetter(keyArray, i), valueGetter(valueArray, i), withSizeCheck = sizeCheckRequired) + put(keyGetter(keyArray, i), valueGetter(valueArray, i)) i += 1 } }