From 0fc7c4a29c4612518cbaae7cc6093b470d80cc1a Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 1 May 2024 17:52:02 -0700 Subject: [PATCH] [SPARK-45891][SQL][FOLLOW-UP] Added length check to the is_variant_null expression ### What changes were proposed in this pull request? Added a check in the `is_variant_null` expression where the length of the value field is verified to be greater than zero. If the length is zero, a `MALFORMED_VARIANT` exception is thrown. ### Why are the changes needed? Earlier, `is_variant_null` was simply checking if the first byte of a variant value was zero. However, if the value field is empty, the first byte logically doesn't exist and therefore, it could result in undefined behavior. Such a case should ideally never be seen but it could appear in the case of data corruption. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Additional unit test to check if the zero-length variant throws an exception. ### Was this patch authored or co-authored using generative AI tooling? No Closes #46311 from harshmotw-db/is_variant_null_fix. Authored-by: Harsh Motwani Signed-off-by: Dongjoon Hyun --- .../variant/VariantExpressionEvalUtils.scala | 10 +++++++--- .../apache/spark/sql/errors/QueryExecutionErrors.scala | 5 +++++ .../expressions/variant/VariantExpressionSuite.scala | 7 +++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index a8d629acbe2b9..f468e9745605b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -58,9 +58,13 @@ object VariantExpressionEvalUtils { false } else { val variantValue = input.getValue - // Variant NULL is denoted by basic_type == 0 and val_header == 0 - variantValue(0) == 0 - } + if (variantValue.isEmpty) { + throw QueryExecutionErrors.malformedVariant() + } else { + // Variant NULL is denoted by basic_type == 0 and val_header == 0 + variantValue(0) == 0 + } + } } /** Cast a Spark value from `dataType` into the variant type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index a9d0bfbddcc1b..53ac788956d17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2727,6 +2727,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("path" -> path, "functionName" -> toSQLId(functionName))) } + def malformedVariant(): Throwable = new SparkRuntimeException( + "MALFORMED_VARIANT", + Map.empty + ) + def invalidCharsetError(functionName: String, charset: String): RuntimeException = { new SparkIllegalArgumentException( errorClass = "INVALID_PARAMETER_VALUE.CHARSET", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index 1f9eec862bbeb..d001f0ec051eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -239,6 +239,13 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { check(expectedResult4, smallObject, smallMetadata) } + test("is_variant_null invalid input") { + checkErrorInExpression[SparkRuntimeException]( + IsVariantNull(Literal(new VariantVal(Array(), Array(1, 2, 3)))), + "MALFORMED_VARIANT" + ) + } + private def parseJson(input: String): VariantVal = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(input))