Skip to content

Commit

Permalink
1. fix GetArrayItem nullablity issue.
Browse files Browse the repository at this point in the history
2. move DataFrameSuite code into array.sql and ansi/array.sql
3. add numElements to exception message.
4. other code refine.

Change-Id: Ieb322ed7b036fc3322fd3b814c8508bfef266378
  • Loading branch information
leanken-zz committed Nov 12, 2020
1 parent 6729d7b commit 36d235a
Show file tree
Hide file tree
Showing 12 changed files with 364 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1986,24 +1986,9 @@ case class ElementAt(
}
}

override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar, _) =>
nullability(ar, intOrdinal)
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
nullability(elements, intOrdinal) || field.nullable
case _ =>
true
}
} else {
if (failOnError) arrayContainsNull else true
}
}

override def nullable: Boolean = left.dataType match {
case _: ArrayType => computeNullabilityFromArray(left, right)
case _: ArrayType =>
computeNullabilityFromArray(left, right, failOnError, nullability)
case _: MapType => true
}

Expand All @@ -2016,7 +2001,8 @@ case class ElementAt(
val index = ordinal.asInstanceOf[Int]
if (array.numElements() < math.abs(index)) {
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $index, numElements: ${array.numElements()}")
} else {
null
}
Expand Down Expand Up @@ -2055,7 +2041,10 @@ case class ElementAt(
}

val failOnErrorBranch = if (failOnError) {
s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin
s"""throw new ArrayIndexOutOfBoundsException(
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
|);
""".stripMargin
} else {
s"${ev.isNull} = true;"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,25 @@ case class GetArrayItem(

override def left: Expression = child
override def right: Expression = ordinal
override def nullable: Boolean = computeNullabilityFromArray(left, right)
override def nullable: Boolean =
computeNullabilityFromArray(left, right, failOnError, nullability)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = {
if (ordinal >= 0 && ordinal < elements.length) {
elements(ordinal).nullable
} else {
!failOnError
}
}

protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
if (index >= baseValue.numElements() || index < 0) {
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $index, numElements: ${baseValue.numElements()}")
} else {
null
}
Expand All @@ -272,7 +282,10 @@ case class GetArrayItem(
}

val failOnErrorBranch = if (failOnError) {
s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin
s"""throw new ArrayIndexOutOfBoundsException(
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
|);
""".stripMargin
} else {
s"${ev.isNull} = true;"
}
Expand All @@ -295,20 +308,24 @@ case class GetArrayItem(
trait GetArrayItemUtil {

/** `Null` is returned for invalid ordinals. */
protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
protected def computeNullabilityFromArray(
child: Expression,
ordinal: Expression,
failOnError: Boolean,
nullability: (Seq[Expression], Int) => Boolean): Boolean = {
val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar, _) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case CreateArray(ar, _) =>
nullability(ar, intOrdinal)
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
nullability(elements, intOrdinal) || field.nullable
case _ =>
true
}
} else {
true
if (failOnError) arrayContainsNull else true
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ case class Elt(
val index = indexObj.asInstanceOf[Int]
if (index <= 0 || index > inputExprs.length) {
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $index, numElements: ${inputExprs.length}")
} else {
null
}
Expand Down Expand Up @@ -340,7 +341,8 @@ case class Elt(
val failOnErrorBranch = if (failOnError) {
s"""
|if (!$indexMatched) {
| throw new ArrayIndexOutOfBoundsException("Invalid index: " + ${index.value});
| throw new ArrayIndexOutOfBoundsException(
| "Invalid index: " + ${index.value} + ", numElements: " + ${inputExprs.length});
|}
""".stripMargin
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2147,7 +2147,7 @@ object SQLConf {
"throw an exception at runtime if the inputs to a SQL operator/function are invalid, " +
"e.g. overflow in arithmetic operations, out-of-range index when accessing array elements. " +
"2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
"the SQL parser. 3. Spark will returns null for null input for function `size`.")
"the SQL parser. 3. Spark will return NULL for null input for function `size`.")
.version("3.0.0")
.booleanConf
.createWithDefault(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1888,21 +1888,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(Date.valueOf("2018-01-01")))
}

test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") {
test("SPARK-33386: element_at ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
var expr: Expression = ElementAt(array, Literal(5))
if (ansiEnabled) {
val errMsg = "Invalid index: 5"
val errMsg = "Invalid index: 5, numElements: 3"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
}

expr = ElementAt(array, Literal(-5))
if (ansiEnabled) {
val errMsg = "Invalid index: -5"
val errMsg = "Invalid index: -5, numElements: 3"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
}

test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") {
test("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val array = Literal.create(Seq("a", "b"), ArrayType(StringType))

if (ansiEnabled) {
checkExceptionInExpression[Exception](
GetArrayItem(array, Literal(5)),
"Invalid index: 5"
"Invalid index: 5, numElements: 2"
)

checkExceptionInExpression[Exception](
GetArrayItem(array, Literal(-1)),
"Invalid index: -1"
"Invalid index: -1, numElements: 2"
)
} else {
checkEvaluation(GetArrayItem(array, Literal(5)), null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,28 +969,28 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil)
}

test("SPARK-33391: elt ArrayIndexOutOfBoundsException") {
test("SPARK-33386: elt ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
var expr: Expression = Elt(Seq(Literal(4), Literal("123"), Literal("456")))
if (ansiEnabled) {
val errMsg = "Invalid index: 4"
val errMsg = "Invalid index: 4, numElements: 2"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
}

expr = Elt(Seq(Literal(0), Literal("123"), Literal("456")))
if (ansiEnabled) {
val errMsg = "Invalid index: 0"
val errMsg = "Invalid index: 0, numElements: 2"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
}

expr = Elt(Seq(Literal(-1), Literal("123"), Literal("456")))
if (ansiEnabled) {
val errMsg = "Invalid index: -1"
val errMsg = "Invalid index: -1, numElements: 2"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--IMPORT array.sql
12 changes: 12 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/array.sql
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,15 @@ select
size(date_array),
size(timestamp_array)
from primitive_arrays;

-- index out of range for array elements
select element_at(array(1, 2, 3), 5);
select element_at(array(1, 2, 3), -5);
select element_at(array(1, 2, 3), 0);

select elt(4, '123', '456');
select elt(0, '123', '456');
select elt(-1, '123', '456');

select array(1, 2, 3)[5];
select array(1, 2, 3)[-1];
Loading

0 comments on commit 36d235a

Please sign in to comment.