Skip to content

Commit

Permalink
[SPARK-32167][SQL] Fix GetArrayStructFields to respect inner field's …
Browse files Browse the repository at this point in the history
…nullability together

### What changes were proposed in this pull request?

Fix nullability of `GetArrayStructFields`. It should consider both the original array's `containsNull` and the inner field's nullability.

### Why are the changes needed?

Fix a correctness issue.

### Does this PR introduce _any_ user-facing change?

Yes. See the added test.

### How was this patch tested?

a new UT and end-to-end test

Closes #28992 from cloud-fan/bug.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
cloud-fan authored and dongjoon-hyun committed Jul 7, 2020
1 parent 3fe3365 commit 5d296ed
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ object ExtractValue {
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
ordinal, fields.length, containsNull)
ordinal, fields.length, containsNull || fields(ordinal).nullable)

case (_: ArrayType, _) => GetArrayItem(child, extraction)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
Expand Down Expand Up @@ -159,6 +160,31 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
}

test("SPARK-32167: nullability of GetArrayStructFields") {
val resolver = SQLConf.get.resolver

val array1 = ArrayType(
new StructType().add("a", "int", nullable = true),
containsNull = false)
val data1 = Literal.create(Seq(Row(null)), array1)
val get1 = ExtractValue(data1, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
assert(get1.containsNull)

val array2 = ArrayType(
new StructType().add("a", "int", nullable = false),
containsNull = true)
val data2 = Literal.create(Seq(null), array2)
val get2 = ExtractValue(data2, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
assert(get2.containsNull)

val array3 = ArrayType(
new StructType().add("a", "int", nullable = false),
containsNull = false)
val data3 = Literal.create(Seq(Row(1)), array3)
val get3 = ExtractValue(data3, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
assert(!get3.containsNull)
}

test("CreateArray") {
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,13 @@ class SelectedFieldSuite extends AnalysisTest {
StructField("col3", ArrayType(StructType(
StructField("field1", StructType(
StructField("subfield1", IntegerType, nullable = false) :: Nil))
:: Nil), containsNull = false), nullable = false)
:: Nil), containsNull = true), nullable = false)
}

testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") {
StructField("col3", ArrayType(StructType(
StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false))
:: Nil), containsNull = false), nullable = false)
:: Nil), containsNull = true), nullable = false)
}

// |-- col1: string (nullable = false)
Expand Down Expand Up @@ -471,7 +471,7 @@ class SelectedFieldSuite extends AnalysisTest {
testSelect(mapWithArrayOfStructKey, "map_keys(col2)[0].field1 as foo") {
StructField("col2", MapType(
ArrayType(StructType(
StructField("field1", StringType) :: Nil), containsNull = false),
StructField("field1", StringType) :: Nil), containsNull = true),
ArrayType(StructType(
StructField("field3", StructType(
StructField("subfield3", IntegerType) ::
Expand All @@ -482,7 +482,7 @@ class SelectedFieldSuite extends AnalysisTest {
StructField("col2", MapType(
ArrayType(StructType(
StructField("field2", StructType(
StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false),
StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = true),
ArrayType(StructType(
StructField("field3", StructType(
StructField("subfield3", IntegerType) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

package org.apache.spark.sql

import scala.collection.JavaConverters._

import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{ArrayType, StructType}

class ComplexTypesSuite extends QueryTest with SharedSparkSession {
import testImplicits._

override def beforeAll(): Unit = {
super.beforeAll()
Expand Down Expand Up @@ -106,4 +110,11 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession {
checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil)
checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
}

test("SPARK-32167: get field from an array of struct") {
val innerStruct = new StructType().add("i", "int", nullable = true)
val schema = new StructType().add("arr", ArrayType(innerStruct, containsNull = false))
val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema)
checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null)))
}
}

0 comments on commit 5d296ed

Please sign in to comment.