Skip to content

Commit

Permalink
Added length check in is_variant_null
Browse files Browse the repository at this point in the history
  • Loading branch information
harshmotw-db committed Apr 30, 2024
1 parent ee98221 commit 8be9630
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant

import scala.util.control.NonFatal

import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, BadRecordException, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
Expand Down Expand Up @@ -50,9 +51,13 @@ object VariantExpressionEvalUtils {
false
} else {
val variantValue = input.getValue
// Variant NULL is denoted by basic_type == 0 and val_header == 0
variantValue(0) == 0
}
if(variantValue.isEmpty) {
throw new SparkRuntimeException("MALFORMED_VARIANT", Map.empty)
} else {
// Variant NULL is denoted by basic_type == 0 and val_header == 0
variantValue(0) == 0
}
}
}

/** Cast a Spark value from `dataType` into the variant type. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
check(expectedResult4, smallObject, smallMetadata)
}

test("is_variant_null invalid input") {
checkErrorInExpression[SparkRuntimeException](
IsVariantNull(Literal(new VariantVal(Array(), Array(1, 2, 3)))),
"MALFORMED_VARIANT"
)
}

private def parseJson(input: String): VariantVal =
VariantExpressionEvalUtils.parseJson(UTF8String.fromString(input))

Expand Down

0 comments on commit 8be9630

Please sign in to comment.