From 6d31daeb6a2c5607ffe3b23ffb381626ad57f576 Mon Sep 17 00:00:00 2001 From: "xuewei.linxuewei" Date: Thu, 12 Nov 2020 08:50:32 +0000 Subject: [PATCH] [SPARK-33386][SQL] Accessing array elements in ElementAt/Elt/GetArrayItem should failed if index is out of bound MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Instead of returning NULL, throws runtime ArrayIndexOutOfBoundsException when ansiMode is enable for `element_at`,`elt`, `GetArrayItem` functions. ### Why are the changes needed? For ansiMode. ### Does this PR introduce any user-facing change? When `spark.sql.ansi.enabled` = true, Spark will throw `ArrayIndexOutOfBoundsException` if out-of-range index when accessing array elements ### How was this patch tested? Added UT and existing UT. Closes #30297 from leanken/leanken-SPARK-33386. Authored-by: xuewei.linxuewei Signed-off-by: Wenchen Fan --- docs/sql-ref-ansi-compliance.md | 9 +- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../expressions/ProjectionOverSchema.scala | 6 +- .../catalyst/expressions/SelectedField.scala | 2 +- .../expressions/collectionOperations.scala | 53 ++-- .../expressions/complexTypeExtractors.scala | 67 +++-- .../expressions/stringExpressions.scala | 33 ++- .../sql/catalyst/optimizer/ComplexTypes.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 7 +- .../CollectionExpressionsSuite.scala | 136 ++++++---- .../expressions/ComplexTypeSuite.scala | 23 ++ .../expressions/StringExpressionsSuite.scala | 32 ++- .../resources/sql-tests/inputs/ansi/array.sql | 1 + .../test/resources/sql-tests/inputs/array.sql | 12 + .../sql-tests/results/ansi/array.sql.out | 234 ++++++++++++++++++ .../resources/sql-tests/results/array.sql.out | 67 ++++- 16 files changed, 584 insertions(+), 104 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/ansi/array.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index d6e99312bb66e..c2b36033e318e 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -110,7 +110,14 @@ SELECT * FROM t; ### SQL Functions The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - - `size`: This function returns null for null input under ANSI mode. + - `size`: This function returns null for null input. + - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. + - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. + +### SQL Operators + +The behavior of some SQL operators can be different under ANSI mode (`spark.sql.ansi.enabled=true`). + - `array_col[index]`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices. ### SQL Keywords diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index becdef8b9c603..e8dab28b5e907 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -840,8 +840,8 @@ object TypeCoercion { plan resolveOperators { case p => p transformExpressionsUp { // Skip nodes if unresolved or not enough children - case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c - case c @ Elt(children) => + case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children, _) => val index = children.head val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) val newInputs = if (conf.eltOutputAsString || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 13c6f8db7c129..6f1d9d065ab1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -34,8 +34,10 @@ case class ProjectionOverSchema(schema: StructType) { expr match { case a: AttributeReference if fieldNames.contains(a.name) => Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) - case GetArrayItem(child, arrayItemOrdinal) => - getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) } + case GetArrayItem(child, arrayItemOrdinal, failOnError) => + getProjection(child).map { + projection => GetArrayItem(projection, arrayItemOrdinal, failOnError) + } case a: GetArrayStructFields => getProjection(a.child).map(p => (p, p.dataType)).map { case (projection, ArrayType(projSchema @ StructType(_), _)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala index 7ba3d302d553b..adcc4be10687e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala @@ -119,7 +119,7 @@ object SelectedField { throw new AnalysisException(s"DataType '$x' is not supported by MapKeys.") } selectField(child, opt) - case GetArrayItem(child, _) => + case GetArrayItem(child, _, _) => // GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be // the top-level extractor. However it can be part of an extractor chain. val ArrayType(_, containsNull) = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index cb081b80ba096..ee98ebf5a8a50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1906,8 +1906,10 @@ case class ArrayPosition(left: Expression, right: Expression) @ExpressionDescription( usage = """ _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, - accesses elements from the last to the first. Returns NULL if the index exceeds the length - of the array. + accesses elements from the last to the first. The function returns NULL + if the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false. + If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException + for invalid indices. _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map """, @@ -1919,9 +1921,14 @@ case class ArrayPosition(left: Expression, right: Expression) b """, since = "2.4.0") -case class ElementAt(left: Expression, right: Expression) +case class ElementAt( + left: Expression, + right: Expression, + failOnError: Boolean = SQLConf.get.ansiEnabled) extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant { + def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull @@ -1969,7 +1976,7 @@ case class ElementAt(left: Expression, right: Expression) if (ordinal == 0) { false } else if (elements.length < math.abs(ordinal)) { - true + !failOnError } else { if (ordinal < 0) { elements(elements.length + ordinal).nullable @@ -1979,24 +1986,9 @@ case class ElementAt(left: Expression, right: Expression) } } - override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = { - if (ordinal.foldable && !ordinal.nullable) { - val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() - child match { - case CreateArray(ar, _) => - nullability(ar, intOrdinal) - case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) => - nullability(elements, intOrdinal) || field.nullable - case _ => - true - } - } else { - true - } - } - override def nullable: Boolean = left.dataType match { - case _: ArrayType => computeNullabilityFromArray(left, right) + case _: ArrayType => + computeNullabilityFromArray(left, right, failOnError, nullability) case _: MapType => true } @@ -2008,7 +2000,12 @@ case class ElementAt(left: Expression, right: Expression) val array = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Int] if (array.numElements() < math.abs(index)) { - null + if (failOnError) { + throw new ArrayIndexOutOfBoundsException( + s"Invalid index: $index, numElements: ${array.numElements()}") + } else { + null + } } else { val idx = if (index == 0) { throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") @@ -2042,10 +2039,20 @@ case class ElementAt(left: Expression, right: Expression) } else { "" } + + val indexOutOfBoundBranch = if (failOnError) { + s"""throw new ArrayIndexOutOfBoundsException( + | "Invalid index: " + $index + ", numElements: " + $eval1.numElements() + |); + """.stripMargin + } else { + s"${ev.isNull} = true;" + } + s""" |int $index = (int) $eval2; |if ($eval1.numElements() < Math.abs($index)) { - | ${ev.isNull} = true; + | $indexOutOfBoundBranch |} else { | if ($index == 0) { | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 60afe140960cc..363d388692c9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -222,10 +223,15 @@ case class GetArrayStructFields( * * We need to do type checking here as `ordinal` expression maybe unresolved. */ -case class GetArrayItem(child: Expression, ordinal: Expression) +case class GetArrayItem( + child: Expression, + ordinal: Expression, + failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue with NullIntolerant { + def this(child: Expression, ordinal: Expression) = this(child, ordinal, SQLConf.get.ansiEnabled) + // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) @@ -234,13 +240,29 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def left: Expression = child override def right: Expression = ordinal - override def nullable: Boolean = computeNullabilityFromArray(left, right) + override def nullable: Boolean = + computeNullabilityFromArray(left, right, failOnError, nullability) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType + private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = { + if (ordinal >= 0 && ordinal < elements.length) { + elements(ordinal).nullable + } else { + !failOnError + } + } + protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) { + if (index >= baseValue.numElements() || index < 0) { + if (failOnError) { + throw new ArrayIndexOutOfBoundsException( + s"Invalid index: $index, numElements: ${baseValue.numElements()}") + } else { + null + } + } else if (baseValue.isNullAt(index)) { null } else { baseValue.get(index, dataType) @@ -251,15 +273,28 @@ case class GetArrayItem(child: Expression, ordinal: Expression) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("index") val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) { - s" || $eval1.isNullAt($index)" + s"""else if ($eval1.isNullAt($index)) { + ${ev.isNull} = true; + } + """ } else { "" } + + val indexOutOfBoundBranch = if (failOnError) { + s"""throw new ArrayIndexOutOfBoundsException( + | "Invalid index: " + $index + ", numElements: " + $eval1.numElements() + |); + """.stripMargin + } else { + s"${ev.isNull} = true;" + } + s""" final int $index = (int) $eval2; - if ($index >= $eval1.numElements() || $index < 0$nullCheck) { - ${ev.isNull} = true; - } else { + if ($index >= $eval1.numElements() || $index < 0) { + $indexOutOfBoundBranch + } $nullCheck else { ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; } """ @@ -273,20 +308,24 @@ case class GetArrayItem(child: Expression, ordinal: Expression) trait GetArrayItemUtil { /** `Null` is returned for invalid ordinals. */ - protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = { + protected def computeNullabilityFromArray( + child: Expression, + ordinal: Expression, + failOnError: Boolean, + nullability: (Seq[Expression], Int) => Boolean): Boolean = { + val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull if (ordinal.foldable && !ordinal.nullable) { val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() child match { - case CreateArray(ar, _) if intOrdinal < ar.length => - ar(intOrdinal).nullable - case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) - if intOrdinal < elements.length => - elements(intOrdinal).nullable || field.nullable + case CreateArray(ar, _) => + nullability(ar, intOrdinal) + case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) => + nullability(elements, intOrdinal) || field.nullable case _ => true } } else { - true + if (failOnError) arrayContainsNull else true } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1fe990207160c..16e22940495f1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -231,7 +232,12 @@ case class ConcatWs(children: Seq[Expression]) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", + usage = """ + _FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2. + The function returns NULL if the index exceeds the length of the array + and `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true, + it throws ArrayIndexOutOfBoundsException for invalid indices. + """, examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); @@ -239,7 +245,11 @@ case class ConcatWs(children: Seq[Expression]) """, since = "2.0.0") // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) extends Expression { +case class Elt( + children: Seq[Expression], + failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression { + + def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) private lazy val indexExpr = children.head private lazy val inputExprs = children.tail.toArray @@ -275,7 +285,12 @@ case class Elt(children: Seq[Expression]) extends Expression { } else { val index = indexObj.asInstanceOf[Int] if (index <= 0 || index > inputExprs.length) { - null + if (failOnError) { + throw new ArrayIndexOutOfBoundsException( + s"Invalid index: $index, numElements: ${inputExprs.length}") + } else { + null + } } else { inputExprs(index - 1).eval(input) } @@ -323,6 +338,17 @@ case class Elt(children: Seq[Expression]) extends Expression { """.stripMargin }.mkString) + val indexOutOfBoundBranch = if (failOnError) { + s""" + |if (!$indexMatched) { + | throw new ArrayIndexOutOfBoundsException( + | "Invalid index: " + ${index.value} + ", numElements: " + ${inputExprs.length}); + |} + """.stripMargin + } else { + "" + } + ev.copy( code""" |${index.code} @@ -332,6 +358,7 @@ case class Elt(children: Seq[Expression]) extends Expression { |do { | $codes |} while (false); + |$indexOutOfBoundBranch |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 2ac8f62b67b3d..7a21ce254a235 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -61,7 +61,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))), useStringTypeWhenEmpty) // Remove redundant map lookup. - case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) => + case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx), _) => // Instead of creating the array and then selecting one row, remove array creation // altogether. if (idx >= 0 && idx < elems.size) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 21357a492e39e..ef988052affcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2144,9 +2144,10 @@ object SQLConf { val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled") .doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " + - "throw a runtime exception if an overflow occurs in any operation on integral/decimal " + - "field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " + - "the SQL parser.") + "throw an exception at runtime if the inputs to a SQL operator/function are invalid, " + + "e.g. overflow in arithmetic operations, out-of-range index when accessing array elements. " + + "2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " + + "the SQL parser. 3. Spark will return NULL for null input for function `size`.") .version("3.0.0") .booleanConf .createWithDefault(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d59d13d49cef4..6ee88c9eaef86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1118,58 +1118,62 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("correctly handles ElementAt nullability for arrays") { - // CreateArray case - val a = AttributeReference("a", IntegerType, nullable = false)() - val b = AttributeReference("b", IntegerType, nullable = true)() - val array = CreateArray(a :: b :: Nil) - assert(!ElementAt(array, Literal(1)).nullable) - assert(!ElementAt(array, Literal(-2)).nullable) - assert(ElementAt(array, Literal(2)).nullable) - assert(ElementAt(array, Literal(-1)).nullable) - assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable) - assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) - - // CreateArray case invalid indices - assert(!ElementAt(array, Literal(0)).nullable) - assert(ElementAt(array, Literal(4)).nullable) - assert(ElementAt(array, Literal(-4)).nullable) - - // GetArrayStructFields case - val f1 = StructField("a", IntegerType, nullable = false) - val f2 = StructField("b", IntegerType, nullable = true) - val structType = StructType(f1 :: f2 :: Nil) - val c = AttributeReference("c", structType, nullable = false)() - val inputArray1 = CreateArray(c :: Nil) - val inputArray1ContainsNull = c.nullable - val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) - assert(!ElementAt(stArray1, Literal(1)).nullable) - assert(!ElementAt(stArray1, Literal(-1)).nullable) - val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) - assert(ElementAt(stArray2, Literal(1)).nullable) - assert(ElementAt(stArray2, Literal(-1)).nullable) - - val d = AttributeReference("d", structType, nullable = true)() - val inputArray2 = CreateArray(c :: d :: Nil) - val inputArray2ContainsNull = c.nullable || d.nullable - val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) - assert(!ElementAt(stArray3, Literal(1)).nullable) - assert(!ElementAt(stArray3, Literal(-2)).nullable) - assert(ElementAt(stArray3, Literal(2)).nullable) - assert(ElementAt(stArray3, Literal(-1)).nullable) - val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) - assert(ElementAt(stArray4, Literal(1)).nullable) - assert(ElementAt(stArray4, Literal(-2)).nullable) - assert(ElementAt(stArray4, Literal(2)).nullable) - assert(ElementAt(stArray4, Literal(-1)).nullable) - - // GetArrayStructFields case invalid indices - assert(!ElementAt(stArray3, Literal(0)).nullable) - assert(ElementAt(stArray3, Literal(4)).nullable) - assert(ElementAt(stArray3, Literal(-4)).nullable) - - assert(ElementAt(stArray4, Literal(0)).nullable) - assert(ElementAt(stArray4, Literal(4)).nullable) - assert(ElementAt(stArray4, Literal(-4)).nullable) + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + // CreateArray case + val a = AttributeReference("a", IntegerType, nullable = false)() + val b = AttributeReference("b", IntegerType, nullable = true)() + val array = CreateArray(a :: b :: Nil) + assert(!ElementAt(array, Literal(1)).nullable) + assert(!ElementAt(array, Literal(-2)).nullable) + assert(ElementAt(array, Literal(2)).nullable) + assert(ElementAt(array, Literal(-1)).nullable) + assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable) + assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) + + // CreateArray case invalid indices + assert(!ElementAt(array, Literal(0)).nullable) + assert(ElementAt(array, Literal(4)).nullable == !ansiEnabled) + assert(ElementAt(array, Literal(-4)).nullable == !ansiEnabled) + + // GetArrayStructFields case + val f1 = StructField("a", IntegerType, nullable = false) + val f2 = StructField("b", IntegerType, nullable = true) + val structType = StructType(f1 :: f2 :: Nil) + val c = AttributeReference("c", structType, nullable = false)() + val inputArray1 = CreateArray(c :: Nil) + val inputArray1ContainsNull = c.nullable + val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) + assert(!ElementAt(stArray1, Literal(1)).nullable) + assert(!ElementAt(stArray1, Literal(-1)).nullable) + val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) + assert(ElementAt(stArray2, Literal(1)).nullable) + assert(ElementAt(stArray2, Literal(-1)).nullable) + + val d = AttributeReference("d", structType, nullable = true)() + val inputArray2 = CreateArray(c :: d :: Nil) + val inputArray2ContainsNull = c.nullable || d.nullable + val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) + assert(!ElementAt(stArray3, Literal(1)).nullable) + assert(!ElementAt(stArray3, Literal(-2)).nullable) + assert(ElementAt(stArray3, Literal(2)).nullable) + assert(ElementAt(stArray3, Literal(-1)).nullable) + val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) + assert(ElementAt(stArray4, Literal(1)).nullable) + assert(ElementAt(stArray4, Literal(-2)).nullable) + assert(ElementAt(stArray4, Literal(2)).nullable) + assert(ElementAt(stArray4, Literal(-1)).nullable) + + // GetArrayStructFields case invalid indices + assert(!ElementAt(stArray3, Literal(0)).nullable) + assert(ElementAt(stArray3, Literal(4)).nullable == !ansiEnabled) + assert(ElementAt(stArray3, Literal(-4)).nullable == !ansiEnabled) + + assert(ElementAt(stArray4, Literal(0)).nullable) + assert(ElementAt(stArray4, Literal(4)).nullable) + assert(ElementAt(stArray4, Literal(-4)).nullable) + } + } } test("Concat") { @@ -1883,4 +1887,32 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Literal(stringToInterval("interval 1 year"))), Seq(Date.valueOf("2018-01-01"))) } + + test("SPARK-33386: element_at ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + var expr: Expression = ElementAt(array, Literal(5)) + if (ansiEnabled) { + val errMsg = "Invalid index: 5, numElements: 3" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + + expr = ElementAt(array, Literal(-5)) + if (ansiEnabled) { + val errMsg = "Invalid index: -5, numElements: 3" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + + // SQL array indices start at 1 exception throws for both mode. + expr = ElementAt(array, Literal(0)) + val errMsg = "SQL array indices start at 1" + checkExceptionInExpression[Exception](expr, errMsg) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 38e32ff2518f7..67ab2071de037 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -62,6 +62,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } + test("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) + + if (ansiEnabled) { + checkExceptionInExpression[Exception]( + GetArrayItem(array, Literal(5)), + "Invalid index: 5, numElements: 2" + ) + + checkExceptionInExpression[Exception]( + GetArrayItem(array, Literal(-1)), + "Invalid index: -1, numElements: 2" + ) + } else { + checkEvaluation(GetArrayItem(array, Literal(5)), null) + checkEvaluation(GetArrayItem(array, Literal(-1)), null) + } + } + } + } + test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") { // CreateArray case val a = AttributeReference("a", IntegerType, nullable = false)() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 967ccc42c632d..a1b6cec24f23f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -968,4 +968,34 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateUnsafeProjection.generate( Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil) } + + test("SPARK-33386: elt ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + var expr: Expression = Elt(Seq(Literal(4), Literal("123"), Literal("456"))) + if (ansiEnabled) { + val errMsg = "Invalid index: 4, numElements: 2" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + + expr = Elt(Seq(Literal(0), Literal("123"), Literal("456"))) + if (ansiEnabled) { + val errMsg = "Invalid index: 0, numElements: 2" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + + expr = Elt(Seq(Literal(-1), Literal("123"), Literal("456"))) + if (ansiEnabled) { + val errMsg = "Invalid index: -1, numElements: 2" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + } + } + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/array.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/array.sql new file mode 100644 index 0000000000000..662756cbfb0b0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/array.sql @@ -0,0 +1 @@ +--IMPORT array.sql diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 984321ab795fc..f73b653659eb4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -90,3 +90,15 @@ select size(date_array), size(timestamp_array) from primitive_arrays; + +-- index out of range for array elements +select element_at(array(1, 2, 3), 5); +select element_at(array(1, 2, 3), -5); +select element_at(array(1, 2, 3), 0); + +select elt(4, '123', '456'); +select elt(0, '123', '456'); +select elt(-1, '123', '456'); + +select array(1, 2, 3)[5]; +select array(1, 2, 3)[-1]; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out new file mode 100644 index 0000000000000..12a77e36273fa --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -0,0 +1,234 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 20 + + +-- !query +create temporary view data as select * from values + ("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))), + ("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223))) + as data(a, b, c) +-- !query schema +struct<> +-- !query output + + + +-- !query +select * from data +-- !query schema +struct,c:array>> +-- !query output +one [11,12,13] [[111,112,113],[121,122,123]] +two [21,22,23] [[211,212,213],[221,222,223]] + + +-- !query +select a, b[0], b[0] + b[1] from data +-- !query schema +struct +-- !query output +one 11 23 +two 21 43 + + +-- !query +select a, c[0][0] + c[0][0 + 1] from data +-- !query schema +struct +-- !query output +one 223 +two 423 + + +-- !query +create temporary view primitive_arrays as select * from values ( + array(true), + array(2Y, 1Y), + array(2S, 1S), + array(2, 1), + array(2L, 1L), + array(9223372036854775809, 9223372036854775808), + array(2.0D, 1.0D), + array(float(2.0), float(1.0)), + array(date '2016-03-14', date '2016-03-13'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000') +) as primitive_arrays( + boolean_array, + tinyint_array, + smallint_array, + int_array, + bigint_array, + decimal_array, + double_array, + float_array, + date_array, + timestamp_array +) +-- !query schema +struct<> +-- !query output + + + +-- !query +select * from primitive_arrays +-- !query schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,date_array:array,timestamp_array:array> +-- !query output +[true] [2,1] [2,1] [2,1] [2,1] [9223372036854775809,9223372036854775808] [2.0,1.0] [2.0,1.0] [2016-03-14,2016-03-13] [2016-11-15 20:54:00,2016-11-12 20:54:00] + + +-- !query +select + array_contains(boolean_array, true), array_contains(boolean_array, false), + array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y), + array_contains(smallint_array, 2S), array_contains(smallint_array, 0S), + array_contains(int_array, 2), array_contains(int_array, 0), + array_contains(bigint_array, 2L), array_contains(bigint_array, 0L), + array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1), + array_contains(double_array, 2.0D), array_contains(double_array, 0.0D), + array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)), + array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'), + array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000') +from primitive_arrays +-- !query schema +struct +-- !query output +true false true false true false true false true false true false true false true false true false true false + + +-- !query +select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data +-- !query schema +struct +-- !query output +false false +true true + + +-- !query +select + sort_array(boolean_array), + sort_array(tinyint_array), + sort_array(smallint_array), + sort_array(int_array), + sort_array(bigint_array), + sort_array(decimal_array), + sort_array(double_array), + sort_array(float_array), + sort_array(date_array), + sort_array(timestamp_array) +from primitive_arrays +-- !query schema +struct,sort_array(tinyint_array, true):array,sort_array(smallint_array, true):array,sort_array(int_array, true):array,sort_array(bigint_array, true):array,sort_array(decimal_array, true):array,sort_array(double_array, true):array,sort_array(float_array, true):array,sort_array(date_array, true):array,sort_array(timestamp_array, true):array> +-- !query output +[true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00,2016-11-15 20:54:00] + + +-- !query +select sort_array(array('b', 'd'), '1') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'sort_array(array('b', 'd'), '1')' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + + +-- !query +select sort_array(array('b', 'd'), cast(NULL as boolean)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'sort_array(array('b', 'd'), CAST(NULL AS BOOLEAN))' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + + +-- !query +select + size(boolean_array), + size(tinyint_array), + size(smallint_array), + size(int_array), + size(bigint_array), + size(decimal_array), + size(double_array), + size(float_array), + size(date_array), + size(timestamp_array) +from primitive_arrays +-- !query schema +struct +-- !query output +1 2 2 2 2 2 2 2 2 2 + + +-- !query +select element_at(array(1, 2, 3), 5) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: 5, numElements: 3 + + +-- !query +select element_at(array(1, 2, 3), -5) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: -5, numElements: 3 + + +-- !query +select element_at(array(1, 2, 3), 0) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +SQL array indices start at 1 + + +-- !query +select elt(4, '123', '456') +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: 4, numElements: 2 + + +-- !query +select elt(0, '123', '456') +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: 0, numElements: 2 + + +-- !query +select elt(-1, '123', '456') +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: -1, numElements: 2 + + +-- !query +select array(1, 2, 3)[5] +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: 5, numElements: 3 + + +-- !query +select array(1, 2, 3)[-1] +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: -1, numElements: 3 diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 2c2b1a7856304..9bf0d89ed71fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 20 -- !query @@ -160,3 +160,68 @@ from primitive_arrays struct -- !query output 1 2 2 2 2 2 2 2 2 2 + + +-- !query +select element_at(array(1, 2, 3), 5) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select element_at(array(1, 2, 3), -5) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select element_at(array(1, 2, 3), 0) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +SQL array indices start at 1 + + +-- !query +select elt(4, '123', '456') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select elt(0, '123', '456') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select elt(-1, '123', '456') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select array(1, 2, 3)[5] +-- !query schema +struct +-- !query output +NULL + + +-- !query +select array(1, 2, 3)[-1] +-- !query schema +struct +-- !query output +NULL