diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index c2b36033e318e..41921265b7c70 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -112,12 +112,14 @@ SELECT * FROM t; The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - `size`: This function returns null for null input. - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. + - `element_at`: This function throws `NoSuchElementException` if key does not exist in map. - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. ### SQL Operators The behavior of some SQL operators can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - `array_col[index]`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices. + - `map_col[key]`: This operator throws `NoSuchElementException` if key does not exist in map. ### SQL Keywords diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 6f1d9d065ab1a..241c761624b76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -55,8 +55,8 @@ case class ProjectionOverSchema(schema: StructType) { getProjection(child).map { projection => MapKeys(projection) } case MapValues(child) => getProjection(child).map { projection => MapValues(projection) } - case GetMapValue(child, key) => - getProjection(child).map { projection => GetMapValue(projection, key) } + case GetMapValue(child, key, failOnError) => + getProjection(child).map { projection => GetMapValue(projection, key, failOnError) } case GetStructFieldObject(child, field: StructField) => getProjection(child).map(p => (p, p.dataType)).map { case (projection, projSchema: StructType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala index adcc4be10687e..f2acb75ea6ac4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala @@ -91,7 +91,7 @@ object SelectedField { } val newField = StructField(field.name, newFieldDataType, field.nullable) selectField(child, Option(ArrayType(struct(newField), containsNull))) - case GetMapValue(child, _) => + case GetMapValue(child, _, _) => // GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be // the top-level extractor. However it can be part of an extractor chain. val MapType(keyType, _, valueContainsNull) = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ee98ebf5a8a50..0765bfdd78fa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1911,7 +1911,9 @@ case class ArrayPosition(left: Expression, right: Expression) If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException for invalid indices. - _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + _FUNC_(map, key) - Returns value for given key. The function returns NULL + if the key is not contained in the map and `spark.sql.ansi.enabled` is set to false. + If `spark.sql.ansi.enabled` is set to true, it throws NoSuchElementException instead. """, examples = """ Examples: @@ -1931,6 +1933,9 @@ case class ElementAt( @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType + @transient private lazy val mapValueContainsNull = + left.dataType.asInstanceOf[MapType].valueContainsNull + @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType) @@ -1989,7 +1994,7 @@ case class ElementAt( override def nullable: Boolean = left.dataType match { case _: ArrayType => computeNullabilityFromArray(left, right, failOnError, nullability) - case _: MapType => true + case _: MapType => if (failOnError) mapValueContainsNull else true } override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal) @@ -2022,7 +2027,7 @@ case class ElementAt( } } case _: MapType => - (value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering) + (value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering, failOnError) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -2069,7 +2074,7 @@ case class ElementAt( """.stripMargin }) case _: MapType => - doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType], failOnError) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 363d388692c9f..767650d022200 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -336,7 +336,12 @@ trait GetArrayItemUtil { trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { + def getValueEval( + value: Any, + ordinal: Any, + keyType: DataType, + ordering: Ordering[Any], + failOnError: Boolean): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -352,14 +357,24 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { } } - if (!found || values.isNullAt(i)) { + if (!found) { + if (failOnError) { + throw new NoSuchElementException(s"Key $ordinal does not exist.") + } else { + null + } + } else if (values.isNullAt(i)) { null } else { values.get(i, dataType) } } - def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { + def doGetValueGenCode( + ctx: CodegenContext, + ev: ExprCode, + mapType: MapType, + failOnError: Boolean): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") @@ -368,12 +383,22 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { val values = ctx.freshName("values") val keyType = mapType.keyType val nullCheck = if (mapType.valueContainsNull) { - s" || $values.isNullAt($index)" + s"""else if ($values.isNullAt($index)) { + ${ev.isNull} = true; + } + """ } else { "" } + val keyJavaType = CodeGenerator.javaType(keyType) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val keyNotFoundBranch = if (failOnError) { + s"""throw new NoSuchElementException("Key " + $eval2 + " does not exist.");""" + } else { + s"${ev.isNull} = true;" + } + s""" final int $length = $eval1.numElements(); final ArrayData $keys = $eval1.keyArray(); @@ -390,9 +415,9 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { } } - if (!$found$nullCheck) { - ${ev.isNull} = true; - } else { + if (!$found) { + $keyNotFoundBranch + } $nullCheck else { ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)}; } """ @@ -405,9 +430,14 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { * * We need to do type checking here as `key` expression maybe unresolved. */ -case class GetMapValue(child: Expression, key: Expression) +case class GetMapValue( + child: Expression, + key: Expression, + failOnError: Boolean = SQLConf.get.ansiEnabled) extends GetMapValueUtil with ExtractValue with NullIntolerant { + def this(child: Expression, key: Expression) = this(child, key, SQLConf.get.ansiEnabled) + @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(keyType) @@ -442,10 +472,10 @@ case class GetMapValue(child: Expression, key: Expression) // todo: current search is O(n), improve it. override def nullSafeEval(value: Any, ordinal: Any): Any = { - getValueEval(value, ordinal, keyType, ordering) + getValueEval(value, ordinal, keyType, ordering, failOnError) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType]) + doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType], failOnError) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 7a21ce254a235..3dd79d153c236 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -71,7 +71,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // out of bounds, mimic the runtime behavior and return null Literal(null, ga.dataType) } - case GetMapValue(CreateMap(elems, _), key) => CaseKeyWhen(key, elems) + case GetMapValue(CreateMap(elems, _), key, _) => CaseKeyWhen(key, elems) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6ee88c9eaef86..095894b9fffac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1915,4 +1915,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } } } + + test("SPARK-33460: element_at NoSuchElementException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + val map = Literal.create(Map(1 -> "a", 2 -> "b"), MapType(IntegerType, StringType)) + val expr: Expression = ElementAt(map, Literal(5)) + if (ansiEnabled) { + val errMsg = "Key 5 does not exist." + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 67ab2071de037..3d6f6937e780b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -85,6 +85,23 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-33460: GetMapValue NoSuchElementException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + val map = Literal.create(Map(1 -> "a", 2 -> "b"), MapType(IntegerType, StringType)) + + if (ansiEnabled) { + checkExceptionInExpression[Exception]( + GetMapValue(map, Literal(5)), + "Key 5 does not exist." + ) + } else { + checkEvaluation(GetMapValue(map, Literal(5)), null) + } + } + } + } + test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") { // CreateArray case val a = AttributeReference("a", IntegerType, nullable = false)() diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/map.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/map.sql new file mode 100644 index 0000000000000..23e5b9562973b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/map.sql @@ -0,0 +1 @@ +--IMPORT map.sql diff --git a/sql/core/src/test/resources/sql-tests/inputs/map.sql b/sql/core/src/test/resources/sql-tests/inputs/map.sql new file mode 100644 index 0000000000000..e2d855fba154e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/map.sql @@ -0,0 +1,5 @@ +-- test cases for map functions + +-- key does not exist +select element_at(map(1, 'a', 2, 'b'), 5); +select map(1, 'a', 2, 'b')[5]; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out new file mode 100644 index 0000000000000..12c599812cdee --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out @@ -0,0 +1,20 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 2 + + +-- !query +select element_at(map(1, 'a', 2, 'b'), 5) +-- !query schema +struct<> +-- !query output +java.util.NoSuchElementException +Key 5 does not exist. + + +-- !query +select map(1, 'a', 2, 'b')[5] +-- !query schema +struct<> +-- !query output +java.util.NoSuchElementException +Key 5 does not exist. diff --git a/sql/core/src/test/resources/sql-tests/results/map.sql.out b/sql/core/src/test/resources/sql-tests/results/map.sql.out new file mode 100644 index 0000000000000..7a0c0d776ca2b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/map.sql.out @@ -0,0 +1,18 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 2 + + +-- !query +select element_at(map(1, 'a', 2, 'b'), 5) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select map(1, 'a', 2, 'b')[5] +-- !query schema +struct +-- !query output +NULL