diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index e4e66416dd9c7..f7649d85b7734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -358,7 +358,8 @@ object JavaTypeInference { } private[catalyst] def serializerFor( - inputObject: Expression, typeToken: TypeToken[_]): Expression = { + inputObject: Expression, + typeToken: TypeToken[_]): Expression = { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { val (dataType, nullable) = inferDataType(elementType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 904709bd846a3..9c7e76467d153 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1308,7 +1308,7 @@ case class ExternalMapToCatalyst private( val result = child.eval(input) if (result != null) { val (keys, values) = mapCatalystConverter(result) - new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) } else { null } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 59aa71cafead6..d1c99f5333256 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection @@ -513,6 +513,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") { + // Simple test val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() { { @@ -534,6 +535,30 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ScalaReflection.serializerFor[scala.collection.Map[Int, String]]( Literal.fromObject(scalaMap)), 0) checkEvaluation(serializer2, expected) + + // NULL key test + val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( + null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(null, "v0") + put(1, "v1") + } + } + + // Java Map + val serializer3 = GetStructField( + javaSerializerFor(javaMap.getClass)(Literal.fromObject(javaMapHasNullKey)), 0) + checkExceptionInExpression[RuntimeException]( + serializer3, EmptyRow, "Cannot use null as map key!") + + // Scala Map + val serializer4 = GetStructField( + ScalaReflection.serializerFor[scala.collection.Map[java.lang.Integer, String]]( + Literal.fromObject(scalaMapHasNullKey)), 0) + + checkExceptionInExpression[RuntimeException]( + serializer4, EmptyRow, "Cannot use null as map key!") } }