diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 2765ec7d8a0eb..2a76df440ac37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -25,8 +25,7 @@ package org.apache.spark.sql.catalyst.expressions * return the same answer given any input (i.e. false negatives are possible). * * The following rules are applied: - * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. - * - Names for [[GetStructField]] are stripped. + * - Names for [[org.apache.spark.sql.types.DataType]]s and [[GetStructField]] are stripped. * - TimeZoneId for [[Cast]] and [[AnsiCast]] are stripped if `needsTimeZone` is false. * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered * by `hashCode`. @@ -39,10 +38,10 @@ object Canonicalize { expressionReorder(ignoreTimeZone(ignoreNamesTypes(e))) } - /** Remove names and nullability from types, and names from `GetStructField`. */ + /** Remove names from types and `GetStructField`. */ private[expressions] def ignoreNamesTypes(e: Expression): Expression = e match { case a: AttributeReference => - AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) + AttributeReference("none", a.dataType)(exprId = a.exprId) case GetStructField(child, ordinal, Some(_)) => GetStructField(child, ordinal, None) case _ => e } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 1805189b268db..83307c9022dd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} class CanonicalizeSuite extends SparkFunSuite { @@ -177,4 +177,17 @@ class CanonicalizeSuite extends SparkFunSuite { assert(expr.semanticEquals(attr)) assert(attr.semanticEquals(expr)) } + + test("SPARK-38030: Canonicalization should not remove nullability of AttributeReference" + + " dataType") { + val structType = StructType(Seq(StructField("name", StringType, nullable = false))) + val attr = AttributeReference("col", structType)() + // AttributeReference dataType should not be converted to nullable + assert(attr.canonicalized.dataType === structType) + + val cast = Cast(attr, structType) + assert(cast.resolved) + // canonicalization should not converted resolved cast to unresolved + assert(cast.canonicalized.resolved) + } }