From 8be96303ba81546eb24abcb259afed5569435111 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Tue, 30 Apr 2024 15:25:12 -0700 Subject: [PATCH] Added length check in is_variant_null --- .../variant/VariantExpressionEvalUtils.scala | 11 ++++++++--- .../expressions/variant/VariantExpressionSuite.scala | 7 +++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index ea90bb88a9069..39451fc1ee811 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant import scala.util.control.NonFatal +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, BadRecordException, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors @@ -50,9 +51,13 @@ object VariantExpressionEvalUtils { false } else { val variantValue = input.getValue - // Variant NULL is denoted by basic_type == 0 and val_header == 0 - variantValue(0) == 0 - } + if(variantValue.isEmpty) { + throw new SparkRuntimeException("MALFORMED_VARIANT", Map.empty) + } else { + // Variant NULL is denoted by basic_type == 0 and val_header == 0 + variantValue(0) == 0 + } + } } /** Cast a Spark value from `dataType` into the variant type. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index 1f9eec862bbeb..d001f0ec051eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -239,6 +239,13 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { check(expectedResult4, smallObject, smallMetadata) } + test("is_variant_null invalid input") { + checkErrorInExpression[SparkRuntimeException]( + IsVariantNull(Literal(new VariantVal(Array(), Array(1, 2, 3)))), + "MALFORMED_VARIANT" + ) + } + private def parseJson(input: String): VariantVal = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(input))