From 5cbbf7afb164d090bfe5730380a2fbe0a18146c2 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Apr 2018 15:49:53 +0200 Subject: [PATCH 1/7] [SPARK-23930][SQL] Add slice function --- python/pyspark/sql/functions.py | 13 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 98 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 24 +++++ .../expressions/ExpressionEvalHelper.scala | 6 ++ .../expressions/ObjectExpressionsSuite.scala | 1 - .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 16 +++ 8 files changed, 168 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b192680f0795..802d0a76780d8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1846,6 +1846,19 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def slice(x, start, length): + """ + Collection function: returns an array containing all the elements in `x` from index `start` + (or starting from the end if `start` is negative) with the specified `length`. + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() + [Row(sliced=[2, 3]), Row(sliced=[5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..2721b9ddaa77f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -407,6 +407,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Slice]("slice"), expression[SortArray]("sort_array"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 91188da8b0bd3..1263cd577786b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { + val startInt = startVal.asInstanceOf[Int] + val lengthInt = lengthVal.asInstanceOf[Int] + val arr = xVal.asInstanceOf[ArrayData] + val startIndex = if (startInt == 0) { + throw new RuntimeException( + s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") + } else if (startInt < 0) { + startInt + arr.numElements() + } else { + startInt - 1 + } + if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + + s"length must be greater than or equal to 0.") + } + // this can happen if start is negative and its absolute value is greater than the + // number of elements in the array + if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) + } + val elementType = x.dataType.asInstanceOf[ArrayType].elementType + val data = arr.toArray[AnyRef](elementType) + new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val elementType = x.dataType.asInstanceOf[ArrayType].elementType + nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + | + "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + | + "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |Object[] $values; + |if ($startIdx < 0) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + | $values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..87c6aa8be38cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Slice") { + val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + + checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) + checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) + checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) + checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") + checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) + checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) + checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) + + checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) + checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) + + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b68fac4..c3e5351755458 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -102,6 +102,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: Expression, + expectedErrMsg: String): Unit = { + checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg) + } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( expression: => Expression, inputRow: InternalRow, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b1bc67dfac1b5..54dd190120f3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -204,7 +204,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.fromObject(new java.util.LinkedList[Int]), Map("nonexisting" -> Literal(1))) checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), """A method named "nonexisting" is not declared in any enclosing class """ + "nor any supertype") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c658f25ced053..33ff9fe62e25e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3046,6 +3046,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns an array containing all the elements in `x` from index `start` (or starting from the + * end if `start` is negative) with the specified `length`. + * @group collection_funcs + * @since 2.4.0 + */ + def slice(x: Column, start: Int, length: Int): Column = withExpr { + Slice(x.expr, Literal(start), Literal(length)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..0cc7eec2fa292 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("slice function") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5) + ).toDF("x") + + val answer = Seq(Row(Seq(2, 3)), Row(Seq(5))) + + checkAnswer(df.select(slice(df("x"), 2, 2)), answer) + checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer) + + val answerNegative = Seq(Row(Seq(3)), Row(Seq(5))) + checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative) + checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 367aaf2338aa40f3cd0a24eacfe3c8a19ca96ae0 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 20 Apr 2018 17:52:18 +0200 Subject: [PATCH 2/7] fix typo --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1263cd577786b..f6dab4edb18c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -294,7 +294,7 @@ case class ArrayContains(left: Expression, right: Expression) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); @@ -310,10 +310,6 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) - override def nullable: Boolean = children.exists(_.nullable) - - override def foldable: Boolean = children.forall(_.foldable) - override def children: Seq[Expression] = Seq(x, start, length) override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { From b94d067d3358c96a638dbe5c4fbb7270def453c3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 20 Apr 2018 18:05:16 +0200 Subject: [PATCH 3/7] review comments --- .../catalyst/expressions/collectionOperations.scala | 12 ++++++------ .../expressions/CollectionExpressionsSuite.scala | 2 +- .../catalyst/expressions/ExpressionEvalHelper.scala | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ab5197ea25f9b..210fdc8c2c30f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -407,7 +407,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) val arr = xVal.asInstanceOf[ArrayData] val startIndex = if (startInt == 0) { throw new RuntimeException( - s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") + s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") } else if (startInt < 0) { startInt + arr.numElements() } else { @@ -415,15 +415,15 @@ case class Slice(x: Expression, start: Expression, length: Expression) } if (lengthInt < 0) { throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + - s"length must be greater than or equal to 0.") + "length must be greater than or equal to 0.") } - // this can happen if start is negative and its absolute value is greater than the + // startIndex can be negative if start is negative and its absolute value is greater than the // number of elements in the array - if (startIndex < 0) { + if (startIndex < 0 || startIndex >= arr.numElements()) { return new GenericArrayData(Array.empty[AnyRef]) } val elementType = x.dataType.asInstanceOf[ArrayType].elementType - val data = arr.toArray[AnyRef](elementType) + val data = arr.toSeq[AnyRef](elementType) new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) } @@ -457,7 +457,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) | $resLength = $length; |} |Object[] $values; - |if ($startIdx < 0) { + |if ($startIdx < 0 || $startIdx >= $x.numElements()) { | $values = new Object[0]; |} else { | $values = new Object[$resLength]; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b9f143ca21f2d..0a3c4242724f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -127,7 +127,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) - + checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) } test("Array Min") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 8db1a3a00cd5b..a22e9d4655e8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -105,7 +105,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } protected def checkExceptionInExpression[T <: Throwable : ClassTag]( - expression: Expression, + expression: => Expression, expectedErrMsg: String): Unit = { checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg) } From dc6cb60f5bee56473d65e50b500dea694c28d2b3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Apr 2018 14:53:42 +0200 Subject: [PATCH 4/7] specialize codegen for primitive types --- .../expressions/collectionOperations.scala | 55 ++++++++++++++++--- .../CollectionExpressionsSuite.scala | 4 ++ 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 210fdc8c2c30f..691f211d904ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -401,6 +401,8 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def children: Seq[Expression] = Seq(x, start, length) + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { val startInt = startVal.asInstanceOf[Int] val lengthInt = lengthVal.asInstanceOf[Int] @@ -422,17 +424,12 @@ case class Slice(x: Expression, start: Expression, length: Expression) if (startIndex < 0 || startIndex >= arr.numElements()) { return new GenericArrayData(Array.empty[AnyRef]) } - val elementType = x.dataType.asInstanceOf[ArrayType].elementType val data = arr.toSeq[AnyRef](elementType) new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementType = x.dataType.asInstanceOf[ArrayType].elementType nullSafeCodeGen(ctx, ev, (x, start, length) => { - val arrayClass = classOf[GenericArrayData].getName - val values = ctx.freshName("values") - val i = ctx.freshName("i") val startIdx = ctx.freshName("startIdx") val resLength = ctx.freshName("resLength") val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) @@ -456,18 +453,60 @@ case class Slice(x: Expression, start: Expression, length: Expression) |} else { | $resLength = $length; |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" |Object[] $values; - |if ($startIdx < 0 || $startIdx >= $x.numElements()) { + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { | $values = new Object[0]; |} else { | $values = new Object[$resLength]; | for (int $i = 0; $i < $resLength; $i ++) { - | $values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; + | $values[$i] = $getValue; | } |} |${ev.value} = new $arrayClass($values); """.stripMargin - }) + } else { + val sizeInBytes = ctx.freshName("sizeInBytes") + val bytesArray = ctx.freshName("bytesArray") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $resLength = 0; + |} + |${CodeGenerator.JAVA_INT} $sizeInBytes = + | UnsafeArrayData.calculateHeaderPortionInBytes($resLength) + + | ${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( + | ${elementType.defaultSize} * $resLength); + |byte[] $bytesArray = new byte[$sizeInBytes]; + |UnsafeArrayData $values = new UnsafeArrayData(); + |Platform.putLong($bytesArray, ${Platform.BYTE_ARRAY_OFFSET}, $resLength); + |$values.pointTo($bytesArray, ${Platform.BYTE_ARRAY_OFFSET}, $sizeInBytes); + |for (int $i = 0; $i < $resLength; $i ++) { + | if ($inputArray.isNullAt($i + $startIdx)) { + | $values.setNullAt($i); + | } else { + | $values.set$primitiveValueTypeName($i, $getValue); + | } + |} + |${ev.value} = $values; + """.stripMargin + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 0a3c4242724f4..7e2bd613317ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -110,6 +110,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType)) checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) @@ -120,6 +121,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), "Unexpected value for start") checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String]) checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), @@ -128,6 +130,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4)) } test("Array Min") { From 9d655708c2f0bbf18ab7044fb03cf899a0eba4eb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 30 Apr 2018 09:53:08 +0200 Subject: [PATCH 5/7] fix indent --- .../catalyst/expressions/collectionOperations.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 491896500315c..cf658ff155214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -459,11 +459,11 @@ case class Slice(x: Expression, start: Expression, length: Expression) } def genCodeForResult( - ctx: CodegenContext, - ev: ExprCode, - inputArray: String, - startIdx: String, - resLength: String): String = { + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { val values = ctx.freshName("values") val i = ctx.freshName("i") val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") From e2eb21ee322682e5803615b75b293e62a6a84743 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 4 May 2018 18:01:04 +0200 Subject: [PATCH 6/7] add checks for size greater than maxint --- .../expressions/codegen/CodeGenerator.scala | 34 +++++++++++++ .../expressions/collectionOperations.scala | 51 ++----------------- 2 files changed, 37 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cf0a91ff00626..e3503d9bc6785 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ import org.apache.spark.util.{ParentClassLoader, Utils} @@ -730,6 +731,39 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. + * + * @param arrayName name of the array to create + * @param numElements code representing the number of elements the array should contain + * @param elementType data type of the elements in the array + * @param additionalErrorMessage string to include in the error message + */ + def createUnsafeArray( + arrayName: String, + numElements: String, + elementType: DataType, + additionalErrorMessage: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + + s""" + |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | ${elementType.defaultSize}); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + + | "$additionalErrorMessage"); + |} + |byte[] $arrayBytes = new byte[(int)$arraySize]; + |UnsafeArrayData $arrayName = new UnsafeArrayData(); + |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 64f0ae79371ab..965ac196aa470 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -482,21 +481,12 @@ case class Slice(x: Expression, start: Expression, length: Expression) |${ev.value} = new $arrayClass($values); """.stripMargin } else { - val sizeInBytes = ctx.freshName("sizeInBytes") - val bytesArray = ctx.freshName("bytesArray") val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { | $resLength = 0; |} - |${CodeGenerator.JAVA_INT} $sizeInBytes = - | UnsafeArrayData.calculateHeaderPortionInBytes($resLength) + - | ${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( - | ${elementType.defaultSize} * $resLength); - |byte[] $bytesArray = new byte[$sizeInBytes]; - |UnsafeArrayData $values = new UnsafeArrayData(); - |Platform.putLong($bytesArray, ${Platform.BYTE_ARRAY_OFFSET}, $resLength); - |$values.pointTo($bytesArray, ${Platform.BYTE_ARRAY_OFFSET}, $sizeInBytes); + |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} |for (int $i = 0; $i < $resLength; $i ++) { | if ($inputArray.isNullAt($i + $startIdx)) { | $values.setNullAt($i); @@ -1107,24 +1097,11 @@ case class Concat(children: Seq[Expression]) extends Expression { } private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + - | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + - | " for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" @@ -1132,11 +1109,7 @@ case class Concat(children: Seq[Expression]) extends Expression { | public ArrayData concat($javaType[] args) { | ${nullArgumentProtection()} | $numElemCode - | $unsafeArraySizeInBytes - | byte[] $arrayName = new byte[(int)$arraySizeName]; - | UnsafeArrayData $arrayData = new UnsafeArrayData(); - | Platform.putLong($arrayName, $baseOffset, $numElemName); - | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} | int $counter = 0; | for (int y = 0; y < ${children.length}; y++) { | for (int z = 0; z < args[y].numElements(); z++) { @@ -1288,34 +1261,16 @@ case class Flatten(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val tempArrayDataName = ctx.freshName("tempArrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + - | " bytes for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" |$numElemCode - |$unsafeArraySizeInBytes - |byte[] $arrayName = new byte[(int)$arraySizeName]; - |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); - |Platform.putLong($arrayName, $baseOffset, $numElemName); - |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} |int $counter = 0; |for (int k = 0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); From 07604e0c2d8c46210f39a0fb5a583d3532428553 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 4 May 2018 18:33:45 +0200 Subject: [PATCH 7/7] fix scalastyle --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e3503d9bc6785..4dda525294259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -41,8 +41,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types._ import org.apache.spark.util.{ParentClassLoader, Utils}