Skip to content

Commit

Permalink
[SPARK-33460][SQL] Accessing map values should fail if key is not found
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Instead of returning NULL, throws runtime NoSuchElementException towards invalid key accessing in map-like functions, such as element_at, GetMapValue, when ANSI mode is on.

### Why are the changes needed?

For ANSI mode.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Added UT and Existing UT.

Closes #30386 from leanken/leanken-SPARK-33460.

Authored-by: xuewei.linxuewei <xuewei.linxuewei@alibaba-inc.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
leanken-zz authored and cloud-fan committed Nov 16, 2020
1 parent 6883f29 commit b5eca18
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 18 deletions.
2 changes: 2 additions & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ SELECT * FROM t;
The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
- `size`: This function returns null for null input.
- `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
- `element_at`: This function throws `NoSuchElementException` if key does not exist in map.
- `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.

### SQL Operators

The behavior of some SQL operators can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
- `array_col[index]`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices.
- `map_col[key]`: This operator throws `NoSuchElementException` if key does not exist in map.

### SQL Keywords

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ case class ProjectionOverSchema(schema: StructType) {
getProjection(child).map { projection => MapKeys(projection) }
case MapValues(child) =>
getProjection(child).map { projection => MapValues(projection) }
case GetMapValue(child, key) =>
getProjection(child).map { projection => GetMapValue(projection, key) }
case GetMapValue(child, key, failOnError) =>
getProjection(child).map { projection => GetMapValue(projection, key, failOnError) }
case GetStructFieldObject(child, field: StructField) =>
getProjection(child).map(p => (p, p.dataType)).map {
case (projection, projSchema: StructType) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ object SelectedField {
}
val newField = StructField(field.name, newFieldDataType, field.nullable)
selectField(child, Option(ArrayType(struct(newField), containsNull)))
case GetMapValue(child, _) =>
case GetMapValue(child, _, _) =>
// GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be
// the top-level extractor. However it can be part of an extractor chain.
val MapType(keyType, _, valueContainsNull) = child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1911,7 +1911,9 @@ case class ArrayPosition(left: Expression, right: Expression)
If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException
for invalid indices.
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
_FUNC_(map, key) - Returns value for given key. The function returns NULL
if the key is not contained in the map and `spark.sql.ansi.enabled` is set to false.
If `spark.sql.ansi.enabled` is set to true, it throws NoSuchElementException instead.
""",
examples = """
Examples:
Expand All @@ -1931,6 +1933,9 @@ case class ElementAt(

@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType

@transient private lazy val mapValueContainsNull =
left.dataType.asInstanceOf[MapType].valueContainsNull

@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull

@transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType)
Expand Down Expand Up @@ -1989,7 +1994,7 @@ case class ElementAt(
override def nullable: Boolean = left.dataType match {
case _: ArrayType =>
computeNullabilityFromArray(left, right, failOnError, nullability)
case _: MapType => true
case _: MapType => if (failOnError) mapValueContainsNull else true
}

override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal)
Expand Down Expand Up @@ -2022,7 +2027,7 @@ case class ElementAt(
}
}
case _: MapType =>
(value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering)
(value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering, failOnError)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -2069,7 +2074,7 @@ case class ElementAt(
""".stripMargin
})
case _: MapType =>
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType], failOnError)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,12 @@ trait GetArrayItemUtil {
trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {

// todo: current search is O(n), improve it.
def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
def getValueEval(
value: Any,
ordinal: Any,
keyType: DataType,
ordering: Ordering[Any],
failOnError: Boolean): Any = {
val map = value.asInstanceOf[MapData]
val length = map.numElements()
val keys = map.keyArray()
Expand All @@ -352,14 +357,24 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
}
}

if (!found || values.isNullAt(i)) {
if (!found) {
if (failOnError) {
throw new NoSuchElementException(s"Key $ordinal does not exist.")
} else {
null
}
} else if (values.isNullAt(i)) {
null
} else {
values.get(i, dataType)
}
}

def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = {
def doGetValueGenCode(
ctx: CodegenContext,
ev: ExprCode,
mapType: MapType,
failOnError: Boolean): ExprCode = {
val index = ctx.freshName("index")
val length = ctx.freshName("length")
val keys = ctx.freshName("keys")
Expand All @@ -368,12 +383,22 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
val values = ctx.freshName("values")
val keyType = mapType.keyType
val nullCheck = if (mapType.valueContainsNull) {
s" || $values.isNullAt($index)"
s"""else if ($values.isNullAt($index)) {
${ev.isNull} = true;
}
"""
} else {
""
}

val keyJavaType = CodeGenerator.javaType(keyType)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val keyNotFoundBranch = if (failOnError) {
s"""throw new NoSuchElementException("Key " + $eval2 + " does not exist.");"""
} else {
s"${ev.isNull} = true;"
}

s"""
final int $length = $eval1.numElements();
final ArrayData $keys = $eval1.keyArray();
Expand All @@ -390,9 +415,9 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
}
}

if (!$found$nullCheck) {
${ev.isNull} = true;
} else {
if (!$found) {
$keyNotFoundBranch
} $nullCheck else {
${ev.value} = ${CodeGenerator.getValue(values, dataType, index)};
}
"""
Expand All @@ -405,9 +430,14 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
*
* We need to do type checking here as `key` expression maybe unresolved.
*/
case class GetMapValue(child: Expression, key: Expression)
case class GetMapValue(
child: Expression,
key: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends GetMapValueUtil with ExtractValue with NullIntolerant {

def this(child: Expression, key: Expression) = this(child, key, SQLConf.get.ansiEnabled)

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(keyType)

Expand Down Expand Up @@ -442,10 +472,10 @@ case class GetMapValue(child: Expression, key: Expression)

// todo: current search is O(n), improve it.
override def nullSafeEval(value: Any, ordinal: Any): Any = {
getValueEval(value, ordinal, keyType, ordering)
getValueEval(value, ordinal, keyType, ordering, failOnError)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType])
doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType], failOnError)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// out of bounds, mimic the runtime behavior and return null
Literal(null, ga.dataType)
}
case GetMapValue(CreateMap(elems, _), key) => CaseKeyWhen(key, elems)
case GetMapValue(CreateMap(elems, _), key, _) => CaseKeyWhen(key, elems)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1915,4 +1915,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
}
}
}

test("SPARK-33460: element_at NoSuchElementException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val map = Literal.create(Map(1 -> "a", 2 -> "b"), MapType(IntegerType, StringType))
val expr: Expression = ElementAt(map, Literal(5))
if (ansiEnabled) {
val errMsg = "Key 5 does not exist."
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("SPARK-33460: GetMapValue NoSuchElementException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val map = Literal.create(Map(1 -> "a", 2 -> "b"), MapType(IntegerType, StringType))

if (ansiEnabled) {
checkExceptionInExpression[Exception](
GetMapValue(map, Literal(5)),
"Key 5 does not exist."
)
} else {
checkEvaluation(GetMapValue(map, Literal(5)), null)
}
}
}
}

test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
Expand Down
1 change: 1 addition & 0 deletions sql/core/src/test/resources/sql-tests/inputs/ansi/map.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--IMPORT map.sql
5 changes: 5 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/map.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- test cases for map functions

-- key does not exist
select element_at(map(1, 'a', 2, 'b'), 5);
select map(1, 'a', 2, 'b')[5];
20 changes: 20 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 2


-- !query
select element_at(map(1, 'a', 2, 'b'), 5)
-- !query schema
struct<>
-- !query output
java.util.NoSuchElementException
Key 5 does not exist.


-- !query
select map(1, 'a', 2, 'b')[5]
-- !query schema
struct<>
-- !query output
java.util.NoSuchElementException
Key 5 does not exist.
18 changes: 18 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/map.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 2


-- !query
select element_at(map(1, 'a', 2, 'b'), 5)
-- !query schema
struct<element_at(map(1, a, 2, b), 5):string>
-- !query output
NULL


-- !query
select map(1, 'a', 2, 'b')[5]
-- !query schema
struct<map(1, a, 2, b)[5]:string>
-- !query output
NULL

0 comments on commit b5eca18

Please sign in to comment.