diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 4bdaecffde4b8..e19f1eacfd8b5 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3404,8 +3404,10 @@ class PlanGenerationTestSuite // Handle parameterized scala types e.g.: List, Seq and Map. fn.typedLit(Some(1)), fn.typedLit(Array(1, 2, 3)), + fn.typedLit[Array[Integer]](Array(null, null)), fn.typedLit(Seq(1, 2, 3)), - fn.typedLit(Map("a" -> 1, "b" -> 2)), + fn.typedLit(mutable.LinkedHashMap("a" -> 1, "b" -> 2)), + fn.typedLit(mutable.LinkedHashMap[String, Integer]("a" -> null, "b" -> null)), fn.typedLit(("a", 2, 1.0)), fn.typedLit[Option[Int]](None), fn.typedLit[Array[Option[Int]]](Array(Some(1))), diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala index 3b6dd090caf6e..afc2b1db023e7 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala @@ -1681,6 +1681,13 @@ class ClientE2ETestSuite assert(df.count() == 100) } } + + test("SPARK-53553: null value handling in literals") { + val df = spark.sql("select 1").select(typedlit(Array[Integer](1, null)).as("arr_col")) + val result = df.collect() + assert(result.length === 1) + assert(result(0).getAs[Array[Integer]]("arr_col") === Array(1, null)) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index e8522d7118c27..d64f5d7cdf2df 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -163,6 +163,14 @@ object LiteralValueProtoConverter { } (literal, dataType) match { + case (v: Option[_], _: DataType) => + if (v.isDefined) { + toLiteralProtoBuilder(v.get) + } else { + builder.setNull(toConnectProtoType(dataType)) + } + case (null, _) => + builder.setNull(toConnectProtoType(dataType)) case (v: mutable.ArraySeq[_], ArrayType(_, _)) => toLiteralProtoBuilder(v.array, dataType) case (v: immutable.ArraySeq[_], ArrayType(_, _)) => @@ -175,12 +183,6 @@ object LiteralValueProtoConverter { builder.setMap(mapBuilder(v, keyType, valueType)) case (v, structType: StructType) => builder.setStruct(structBuilder(v, structType)) - case (v: Option[_], _: DataType) => - if (v.isDefined) { - toLiteralProtoBuilder(v.get) - } else { - builder.setNull(toConnectProtoType(dataType)) - } case _ => toLiteralProtoBuilder(literal) } } @@ -297,7 +299,7 @@ object LiteralValueProtoConverter { } private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { - if (dataType.hasShort) { v => + val converter: proto.Expression.Literal => Any = if (dataType.hasShort) { v => v.getShort.toShort } else if (dataType.hasInteger) { v => v.getInteger @@ -339,6 +341,7 @@ object LiteralValueProtoConverter { } else { throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)") } + v => if (v.hasNull) null else converter(v) } def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = { diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain index 508128bec26d0..a566430136f2a 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain @@ -1,2 +1,2 @@ -Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 19 more fields] +Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 21 more fields] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index 80b95e4664c14..456033244a945 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -77,7 +77,8 @@ }, { "literal": { "null": { - "null": { + "string": { + "collation": "UTF8_BINARY" } } }, @@ -814,6 +815,43 @@ } } } + }, { + "literal": { + "array": { + "elementType": { + "integer": { + } + }, + "elements": [{ + "null": { + "integer": { + } + } + }, { + "null": { + "integer": { + } + } + }] + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } }, { "literal": { "array": { @@ -888,6 +926,53 @@ } } } + }, { + "literal": { + "map": { + "keyType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "valueType": { + "integer": { + } + }, + "keys": [{ + "string": "a" + }, { + "string": "b" + }], + "values": [{ + "null": { + "integer": { + } + } + }, { + "null": { + "integer": { + } + } + }] + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } }, { "literal": { "struct": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index 6aa367df32276..749da55007dc8 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ