Skip to content

Commit

Permalink
[SPARK-33386][SQL] Accessing array elements in ElementAt/Elt/GetArray…
Browse files Browse the repository at this point in the history
…Item should failed if index is out of bound

### What changes were proposed in this pull request?

Instead of returning NULL, throws runtime ArrayIndexOutOfBoundsException when ansiMode is enable for `element_at`,`elt`, `GetArrayItem` functions.

### Why are the changes needed?

For ansiMode.

### Does this PR introduce any user-facing change?

When `spark.sql.ansi.enabled` = true, Spark will throw `ArrayIndexOutOfBoundsException` if out-of-range index when accessing array elements

### How was this patch tested?

Added UT and existing UT.

Closes #30297 from leanken/leanken-SPARK-33386.

Authored-by: xuewei.linxuewei <xuewei.linxuewei@alibaba-inc.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
leanken-zz authored and cloud-fan committed Nov 12, 2020
1 parent 22baf05 commit 6d31dae
Show file tree
Hide file tree
Showing 16 changed files with 584 additions and 104 deletions.
9 changes: 8 additions & 1 deletion docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,14 @@ SELECT * FROM t;
### SQL Functions

The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
- `size`: This function returns null for null input under ANSI mode.
- `size`: This function returns null for null input.
- `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
- `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.

### SQL Operators

The behavior of some SQL operators can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
- `array_col[index]`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices.

### SQL Keywords

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,8 @@ object TypeCoercion {
plan resolveOperators { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or not enough children
case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children) =>
case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children, _) =>
val index = children.head
val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
val newInputs = if (conf.eltOutputAsString ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ case class ProjectionOverSchema(schema: StructType) {
expr match {
case a: AttributeReference if fieldNames.contains(a.name) =>
Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
case GetArrayItem(child, arrayItemOrdinal) =>
getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) }
case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
getProjection(child).map {
projection => GetArrayItem(projection, arrayItemOrdinal, failOnError)
}
case a: GetArrayStructFields =>
getProjection(a.child).map(p => (p, p.dataType)).map {
case (projection, ArrayType(projSchema @ StructType(_), _)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ object SelectedField {
throw new AnalysisException(s"DataType '$x' is not supported by MapKeys.")
}
selectField(child, opt)
case GetArrayItem(child, _) =>
case GetArrayItem(child, _, _) =>
// GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be
// the top-level extractor. However it can be part of an extractor chain.
val ArrayType(_, containsNull) = child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1906,8 +1906,10 @@ case class ArrayPosition(left: Expression, right: Expression)
@ExpressionDescription(
usage = """
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
accesses elements from the last to the first. Returns NULL if the index exceeds the length
of the array.
accesses elements from the last to the first. The function returns NULL
if the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false.
If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException
for invalid indices.
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
""",
Expand All @@ -1919,9 +1921,14 @@ case class ArrayPosition(left: Expression, right: Expression)
b
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression)
case class ElementAt(
left: Expression,
right: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {

def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)

@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType

@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
Expand Down Expand Up @@ -1969,7 +1976,7 @@ case class ElementAt(left: Expression, right: Expression)
if (ordinal == 0) {
false
} else if (elements.length < math.abs(ordinal)) {
true
!failOnError
} else {
if (ordinal < 0) {
elements(elements.length + ordinal).nullable
Expand All @@ -1979,24 +1986,9 @@ case class ElementAt(left: Expression, right: Expression)
}
}

override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar, _) =>
nullability(ar, intOrdinal)
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
nullability(elements, intOrdinal) || field.nullable
case _ =>
true
}
} else {
true
}
}

override def nullable: Boolean = left.dataType match {
case _: ArrayType => computeNullabilityFromArray(left, right)
case _: ArrayType =>
computeNullabilityFromArray(left, right, failOnError, nullability)
case _: MapType => true
}

Expand All @@ -2008,7 +2000,12 @@ case class ElementAt(left: Expression, right: Expression)
val array = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Int]
if (array.numElements() < math.abs(index)) {
null
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $index, numElements: ${array.numElements()}")
} else {
null
}
} else {
val idx = if (index == 0) {
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
Expand Down Expand Up @@ -2042,10 +2039,20 @@ case class ElementAt(left: Expression, right: Expression)
} else {
""
}

val indexOutOfBoundBranch = if (failOnError) {
s"""throw new ArrayIndexOutOfBoundsException(
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
|);
""".stripMargin
} else {
s"${ev.isNull} = true;"
}

s"""
|int $index = (int) $eval2;
|if ($eval1.numElements() < Math.abs($index)) {
| ${ev.isNull} = true;
| $indexOutOfBoundBranch
|} else {
| if ($index == 0) {
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -222,10 +223,15 @@ case class GetArrayStructFields(
*
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
case class GetArrayItem(
child: Expression,
ordinal: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
with NullIntolerant {

def this(child: Expression, ordinal: Expression) = this(child, ordinal, SQLConf.get.ansiEnabled)

// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)

Expand All @@ -234,13 +240,29 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def left: Expression = child
override def right: Expression = ordinal
override def nullable: Boolean = computeNullabilityFromArray(left, right)
override def nullable: Boolean =
computeNullabilityFromArray(left, right, failOnError, nullability)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = {
if (ordinal >= 0 && ordinal < elements.length) {
elements(ordinal).nullable
} else {
!failOnError
}
}

protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) {
if (index >= baseValue.numElements() || index < 0) {
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $index, numElements: ${baseValue.numElements()}")
} else {
null
}
} else if (baseValue.isNullAt(index)) {
null
} else {
baseValue.get(index, dataType)
Expand All @@ -251,15 +273,28 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) {
s" || $eval1.isNullAt($index)"
s"""else if ($eval1.isNullAt($index)) {
${ev.isNull} = true;
}
"""
} else {
""
}

val indexOutOfBoundBranch = if (failOnError) {
s"""throw new ArrayIndexOutOfBoundsException(
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
|);
""".stripMargin
} else {
s"${ev.isNull} = true;"
}

s"""
final int $index = (int) $eval2;
if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
${ev.isNull} = true;
} else {
if ($index >= $eval1.numElements() || $index < 0) {
$indexOutOfBoundBranch
} $nullCheck else {
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
}
"""
Expand All @@ -273,20 +308,24 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
trait GetArrayItemUtil {

/** `Null` is returned for invalid ordinals. */
protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
protected def computeNullabilityFromArray(
child: Expression,
ordinal: Expression,
failOnError: Boolean,
nullability: (Seq[Expression], Int) => Boolean): Boolean = {
val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar, _) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case CreateArray(ar, _) =>
nullability(ar, intOrdinal)
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
nullability(elements, intOrdinal) || field.nullable
case _ =>
true
}
} else {
true
if (failOnError) arrayContainsNull else true
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
Expand Down Expand Up @@ -231,15 +232,24 @@ case class ConcatWs(children: Seq[Expression])
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.",
usage = """
_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.
The function returns NULL if the index exceeds the length of the array
and `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true,
it throws ArrayIndexOutOfBoundsException for invalid indices.
""",
examples = """
Examples:
> SELECT _FUNC_(1, 'scala', 'java');
scala
""",
since = "2.0.0")
// scalastyle:on line.size.limit
case class Elt(children: Seq[Expression]) extends Expression {
case class Elt(
children: Seq[Expression],
failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression {

def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled)

private lazy val indexExpr = children.head
private lazy val inputExprs = children.tail.toArray
Expand Down Expand Up @@ -275,7 +285,12 @@ case class Elt(children: Seq[Expression]) extends Expression {
} else {
val index = indexObj.asInstanceOf[Int]
if (index <= 0 || index > inputExprs.length) {
null
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $index, numElements: ${inputExprs.length}")
} else {
null
}
} else {
inputExprs(index - 1).eval(input)
}
Expand Down Expand Up @@ -323,6 +338,17 @@ case class Elt(children: Seq[Expression]) extends Expression {
""".stripMargin
}.mkString)

val indexOutOfBoundBranch = if (failOnError) {
s"""
|if (!$indexMatched) {
| throw new ArrayIndexOutOfBoundsException(
| "Invalid index: " + ${index.value} + ", numElements: " + ${inputExprs.length});
|}
""".stripMargin
} else {
""
}

ev.copy(
code"""
|${index.code}
Expand All @@ -332,6 +358,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
|do {
| $codes
|} while (false);
|$indexOutOfBoundBranch
|final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal;
|final boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))), useStringTypeWhenEmpty)

// Remove redundant map lookup.
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) =>
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx), _) =>
// Instead of creating the array and then selecting one row, remove array creation
// altogether.
if (idx >= 0 && idx < elems.size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2144,9 +2144,10 @@ object SQLConf {

val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled")
.doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " +
"throw a runtime exception if an overflow occurs in any operation on integral/decimal " +
"field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
"the SQL parser.")
"throw an exception at runtime if the inputs to a SQL operator/function are invalid, " +
"e.g. overflow in arithmetic operations, out-of-range index when accessing array elements. " +
"2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
"the SQL parser. 3. Spark will return NULL for null input for function `size`.")
.version("3.0.0")
.booleanConf
.createWithDefault(false)
Expand Down
Loading

0 comments on commit 6d31dae

Please sign in to comment.