Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33386][SQL] Accessing array elements in ElementAt/Elt/GetArrayItem should failed if index is out of bound #30297

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ SELECT * FROM t;

The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
- `size`: This function returns null for null input under ANSI mode.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This is not related to this PR though) could you remove under ANSI mode in this statement, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

- `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode.
- `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its better to describe the behaviour change of GetArrayItem, too, so how about creating a new subsection for it like Other SQL Operations?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about removing "under ANSI mode" in each entry? They look redundant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


### SQL Keywords

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,8 @@ object TypeCoercion {
plan resolveOperators { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or not enough children
case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children) =>
case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children, _) =>
val index = children.head
val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
val newInputs = if (conf.eltOutputAsString ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ case class ProjectionOverSchema(schema: StructType) {
expr match {
case a: AttributeReference if fieldNames.contains(a.name) =>
Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
case GetArrayItem(child, arrayItemOrdinal) =>
getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) }
case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
getProjection(child).map {
projection => GetArrayItem(projection, arrayItemOrdinal, failOnError)
}
case a: GetArrayStructFields =>
getProjection(a.child).map(p => (p, p.dataType)).map {
case (projection, ArrayType(projSchema @ StructType(_), _)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ object SelectedField {
throw new AnalysisException(s"DataType '$x' is not supported by MapKeys.")
}
selectField(child, opt)
case GetArrayItem(child, _) =>
case GetArrayItem(child, _, _) =>
// GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be
// the top-level extractor. However it can be part of an extractor chain.
val ArrayType(_, containsNull) = child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1906,8 +1906,8 @@ case class ArrayPosition(left: Expression, right: Expression)
@ExpressionDescription(
usage = """
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
accesses elements from the last to the first. Returns NULL if the index exceeds the length
of the array.
accesses elements from the last to the first. If the index exceeds the length of the array,
Returns NULL if Ansi mode is off; Throws ArrayIndexOutOfBoundsException when Ansi mode is on.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ansi -> ANSI


_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to make ElementAtbehavior consistent on map type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's about to support ANSI mode for map type in next PR.

""",
Expand All @@ -1919,9 +1919,14 @@ case class ArrayPosition(left: Expression, right: Expression)
b
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression)
case class ElementAt(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the usage above of ExpressionDescription , too?

left: Expression,
right: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
viirya marked this conversation as resolved.
Show resolved Hide resolved
extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {

def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)

@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType

@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
Expand Down Expand Up @@ -1969,7 +1974,7 @@ case class ElementAt(left: Expression, right: Expression)
if (ordinal == 0) {
false
} else if (elements.length < math.abs(ordinal)) {
true
if (failOnError) false else true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: !failOnFalse

} else {
if (ordinal < 0) {
elements(elements.length + ordinal).nullable
Expand All @@ -1991,7 +1996,7 @@ case class ElementAt(left: Expression, right: Expression)
true
}
} else {
true
if (failOnError) arrayContainsNull else true
}
}

Expand All @@ -2008,7 +2013,11 @@ case class ElementAt(left: Expression, right: Expression)
val array = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Int]
if (array.numElements() < math.abs(index)) {
null
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we include the total number of elements too in the error message? sometimes that is helpful for debugging.

} else {
null
}
} else {
val idx = if (index == 0) {
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
Expand Down Expand Up @@ -2042,10 +2051,17 @@ case class ElementAt(left: Expression, right: Expression)
} else {
""
}

val failOnErrorBranch = if (failOnError) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it should be indexOutOfBoundBranch

s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we remove .stripMargin because this is a single line?

} else {
s"${ev.isNull} = true;"
}

s"""
|int $index = (int) $eval2;
|if ($eval1.numElements() < Math.abs($index)) {
| ${ev.isNull} = true;
| $failOnErrorBranch
|} else {
| if ($index == 0) {
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -222,10 +223,15 @@ case class GetArrayStructFields(
*
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
case class GetArrayItem(
child: Expression,
ordinal: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
with NullIntolerant {

def this(child: Expression, ordinal: Expression) = this(child, ordinal, SQLConf.get.ansiEnabled)

// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)

Expand All @@ -240,7 +246,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) {
if (index >= baseValue.numElements() || index < 0) {
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
} else {
null
}
} else if (baseValue.isNullAt(index)) {
null
} else {
baseValue.get(index, dataType)
Expand All @@ -255,9 +267,18 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
} else {
""
}

val failOnErrorBranch = if (failOnError) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it should be indexOutOfBoundBranch

s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we remove .stripMargin because this is a single line?

} else {
s"${ev.isNull} = true;"
}

s"""
final int $index = (int) $eval2;
if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
if ($index >= $eval1.numElements() || $index < 0) {
$failOnErrorBranch
} else if (false$nullCheck) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

false$nullCheck -> $nullCheck

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, we can get rid of the entire else if (false$nullCheck) if containsNull == false.

${ev.isNull} = true;
} else {
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
Expand Down Expand Up @@ -231,15 +232,23 @@ case class ConcatWs(children: Seq[Expression])
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.",
usage = """
_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.
If the index exceeds the length of the array, Returns NULL if Ansi mode is off;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, Ansi -> ANSI

Throws ArrayIndexOutOfBoundsException when Ansi mode is on.
""",
examples = """
Examples:
> SELECT _FUNC_(1, 'scala', 'java');
scala
""",
since = "2.0.0")
// scalastyle:on line.size.limit
case class Elt(children: Seq[Expression]) extends Expression {
case class Elt(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

children: Seq[Expression],
failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression {

def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled)

private lazy val indexExpr = children.head
private lazy val inputExprs = children.tail.toArray
Expand Down Expand Up @@ -275,7 +284,11 @@ case class Elt(children: Seq[Expression]) extends Expression {
} else {
val index = indexObj.asInstanceOf[Int]
if (index <= 0 || index > inputExprs.length) {
null
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
} else {
null
}
} else {
inputExprs(index - 1).eval(input)
}
Expand Down Expand Up @@ -323,6 +336,16 @@ case class Elt(children: Seq[Expression]) extends Expression {
""".stripMargin
}.mkString)

val failOnErrorBranch = if (failOnError) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it should be indexOutOfBoundBranch

s"""
|if (!$indexMatched) {
| throw new ArrayIndexOutOfBoundsException("Invalid index: " + ${index.value});
|}
""".stripMargin
} else {
""
}

ev.copy(
code"""
|${index.code}
Expand All @@ -332,6 +355,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
|do {
| $codes
|} while (false);
|$failOnErrorBranch
|final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal;
|final boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))), useStringTypeWhenEmpty)

// Remove redundant map lookup.
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) =>
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx), _) =>
// Instead of creating the array and then selecting one row, remove array creation
// altogether.
if (idx >= 0 && idx < elems.size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,9 @@ object SQLConf {
.doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " +
"throw a runtime exception if an overflow occurs in any operation on integral/decimal " +
"field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
"the SQL parser.")
"the SQL parser. 3. Spark will returns null for null input for function `size`. " +
"4. Spark will throw ArrayIndexOutOfBoundsException if invalid indices " +
Copy link
Contributor

@cloud-fan cloud-fan Nov 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can merge 1 and 4: 1. Spark will throw an exception at runtime if the inputs to a SQL operator/function are invalid, e.g. overflow in arithmetic operations, out-of-range index when accessing array elements.

"used on function `element_at`/`elt`.")
.version("3.0.0")
.booleanConf
.createWithDefault(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1118,58 +1118,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
}

test("correctly handles ElementAt nullability for arrays") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Literal(-2)).nullable)
assert(ElementAt(array, Literal(2)).nullable)
assert(ElementAt(array, Literal(-1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)

// CreateArray case invalid indices
assert(!ElementAt(array, Literal(0)).nullable)
assert(ElementAt(array, Literal(4)).nullable)
assert(ElementAt(array, Literal(-4)).nullable)

// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(1)).nullable)
assert(!ElementAt(stArray1, Literal(-1)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(1)).nullable)
assert(ElementAt(stArray2, Literal(-1)).nullable)

val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(1)).nullable)
assert(!ElementAt(stArray3, Literal(-2)).nullable)
assert(ElementAt(stArray3, Literal(2)).nullable)
assert(ElementAt(stArray3, Literal(-1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(1)).nullable)
assert(ElementAt(stArray4, Literal(-2)).nullable)
assert(ElementAt(stArray4, Literal(2)).nullable)
assert(ElementAt(stArray4, Literal(-1)).nullable)

// GetArrayStructFields case invalid indices
assert(!ElementAt(stArray3, Literal(0)).nullable)
assert(ElementAt(stArray3, Literal(4)).nullable)
assert(ElementAt(stArray3, Literal(-4)).nullable)

assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(4)).nullable)
assert(ElementAt(stArray4, Literal(-4)).nullable)
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Literal(-2)).nullable)
assert(ElementAt(array, Literal(2)).nullable)
assert(ElementAt(array, Literal(-1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)

// CreateArray case invalid indices
assert(!ElementAt(array, Literal(0)).nullable)
if (ansiEnabled) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assert(ElementAt(array, Literal(4)).nullable == !ansiEnabled)

assert(!ElementAt(array, Literal(4)).nullable)
assert(!ElementAt(array, Literal(-4)).nullable)
} else {
assert(ElementAt(array, Literal(4)).nullable)
assert(ElementAt(array, Literal(-4)).nullable)
}

// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(1)).nullable)
assert(!ElementAt(stArray1, Literal(-1)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(1)).nullable)
assert(ElementAt(stArray2, Literal(-1)).nullable)

val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(1)).nullable)
assert(!ElementAt(stArray3, Literal(-2)).nullable)
assert(ElementAt(stArray3, Literal(2)).nullable)
assert(ElementAt(stArray3, Literal(-1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(1)).nullable)
assert(ElementAt(stArray4, Literal(-2)).nullable)
assert(ElementAt(stArray4, Literal(2)).nullable)
assert(ElementAt(stArray4, Literal(-1)).nullable)

// GetArrayStructFields case invalid indices
assert(!ElementAt(stArray3, Literal(0)).nullable)
if (ansiEnabled) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

assert(!ElementAt(stArray3, Literal(4)).nullable)
assert(!ElementAt(stArray3, Literal(-4)).nullable)
} else {
assert(ElementAt(stArray3, Literal(4)).nullable)
assert(ElementAt(stArray3, Literal(-4)).nullable)
}

assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(4)).nullable)
assert(ElementAt(stArray4, Literal(-4)).nullable)
}
}
}

test("Concat") {
Expand Down
Loading