From 5d296ed39e3dd79ddb10c68657e773adba40a5e0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 Jul 2020 20:07:33 -0700 Subject: [PATCH] [SPARK-32167][SQL] Fix GetArrayStructFields to respect inner field's nullability together ### What changes were proposed in this pull request? Fix nullability of `GetArrayStructFields`. It should consider both the original array's `containsNull` and the inner field's nullability. ### Why are the changes needed? Fix a correctness issue. ### Does this PR introduce _any_ user-facing change? Yes. See the added test. ### How was this patch tested? a new UT and end-to-end test Closes #28992 from cloud-fan/bug. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../expressions/complexTypeExtractors.scala | 2 +- .../expressions/ComplexTypeSuite.scala | 26 +++++++++++++++++++ .../expressions/SelectedFieldSuite.scala | 8 +++--- .../apache/spark/sql/ComplexTypesSuite.scala | 11 ++++++++ 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9c600c9d39cf7..89ff4facd25a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -57,7 +57,7 @@ object ExtractValue { val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), - ordinal, fields.length, containsNull) + ordinal, fields.length, containsNull || fields(ordinal).nullable) case (_: ArrayType, _) => GetArrayItem(child, extraction) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 3df7d02fb6604..dbe43709d1d35 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext @@ -159,6 +160,31 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) } + test("SPARK-32167: nullability of GetArrayStructFields") { + val resolver = SQLConf.get.resolver + + val array1 = ArrayType( + new StructType().add("a", "int", nullable = true), + containsNull = false) + val data1 = Literal.create(Seq(Row(null)), array1) + val get1 = ExtractValue(data1, Literal("a"), resolver).asInstanceOf[GetArrayStructFields] + assert(get1.containsNull) + + val array2 = ArrayType( + new StructType().add("a", "int", nullable = false), + containsNull = true) + val data2 = Literal.create(Seq(null), array2) + val get2 = ExtractValue(data2, Literal("a"), resolver).asInstanceOf[GetArrayStructFields] + assert(get2.containsNull) + + val array3 = ArrayType( + new StructType().add("a", "int", nullable = false), + containsNull = false) + val data3 = Literal.create(Seq(Row(1)), array3) + val get3 = ExtractValue(data3, Literal("a"), resolver).asInstanceOf[GetArrayStructFields] + assert(!get3.containsNull) + } + test("CreateArray") { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala index 3c826e812b5cc..76d6890cc8f6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala @@ -254,13 +254,13 @@ class SelectedFieldSuite extends AnalysisTest { StructField("col3", ArrayType(StructType( StructField("field1", StructType( StructField("subfield1", IntegerType, nullable = false) :: Nil)) - :: Nil), containsNull = false), nullable = false) + :: Nil), containsNull = true), nullable = false) } testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") { StructField("col3", ArrayType(StructType( StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) - :: Nil), containsNull = false), nullable = false) + :: Nil), containsNull = true), nullable = false) } // |-- col1: string (nullable = false) @@ -471,7 +471,7 @@ class SelectedFieldSuite extends AnalysisTest { testSelect(mapWithArrayOfStructKey, "map_keys(col2)[0].field1 as foo") { StructField("col2", MapType( ArrayType(StructType( - StructField("field1", StringType) :: Nil), containsNull = false), + StructField("field1", StringType) :: Nil), containsNull = true), ArrayType(StructType( StructField("field3", StructType( StructField("subfield3", IntegerType) :: @@ -482,7 +482,7 @@ class SelectedFieldSuite extends AnalysisTest { StructField("col2", MapType( ArrayType(StructType( StructField("field2", StructType( - StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false), + StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = true), ArrayType(StructType( StructField("field3", StructType( StructField("subfield3", IntegerType) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index 6b503334f9f23..bdcf7230e3211 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{ArrayType, StructType} class ComplexTypesSuite extends QueryTest with SharedSparkSession { + import testImplicits._ override def beforeAll(): Unit = { super.beforeAll() @@ -106,4 +110,11 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession { checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil) checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) } + + test("SPARK-32167: get field from an array of struct") { + val innerStruct = new StructType().add("i", "int", nullable = true) + val schema = new StructType().add("arr", ArrayType(innerStruct, containsNull = false)) + val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema) + checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null))) + } }