diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4a2f4e409079e..99f6062a0d243 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3957,13 +3957,15 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { if attr.dataType == StringType && list.forall(_.foldable) => CharVarcharUtils.getRawType(attr.metadata).flatMap { case CharType(length) => - val literalCharLengths = list.map(_.eval().asInstanceOf[UTF8String].numChars()) + val (nulls, literalChars) = + list.map(_.eval().asInstanceOf[UTF8String]).partition(_ == null) + val literalCharLengths = literalChars.map(_.numChars()) val targetLen = (length +: literalCharLengths).max Some(i.copy( value = addPadding(attr, length, targetLen), list = list.zip(literalCharLengths).map { case (lit, charLength) => addPadding(lit, charLength, targetLen) - })) + } ++ nulls.map(Literal.create(_, StringType)))) case _ => None }.getOrElse(i) @@ -3984,13 +3986,17 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { CharVarcharUtils.getRawType(attr.metadata).flatMap { case CharType(length) => val str = lit.eval().asInstanceOf[UTF8String] - val stringLitLen = str.numChars() - if (length < stringLitLen) { - Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit)) - } else if (length > stringLitLen) { - Some(Seq(attr, StringRPad(lit, Literal(length)))) - } else { + if (str == null) { None + } else { + val stringLitLen = str.numChars() + if (length < stringLitLen) { + Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit)) + } else if (length > stringLitLen) { + Some(Seq(attr, StringRPad(lit, Literal(length)))) + } else { + None + } } case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 586f41d0f03c0..1e561747b6157 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -146,6 +146,22 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { } } + test("SPARK-34233: char/varchar with null value for partitioned columns") { + Seq("CHAR(5)", "VARCHAR(5)").foreach { typ => + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c $typ) USING $format PARTITIONED BY (c)") + sql("INSERT INTO t VALUES ('1', null)") + checkPlainResult(spark.table("t"), typ, null) + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkPlainResult(spark.table("t"), typ, null) + sql("INSERT OVERWRITE t PARTITION (c=null) VALUES ('1')") + checkPlainResult(spark.table("t"), typ, null) + sql("ALTER TABLE t DROP PARTITION(c=null)") + checkAnswer(spark.table("t"), Nil) + } + } + } + test("char/varchar type values length check: partitioned columns of other types") { Seq("CHAR(5)", "VARCHAR(5)").foreach { typ => withTable("t") { @@ -427,7 +443,8 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { ("c1 IN ('a', 'b')", true), ("c1 = c2", true), ("c1 < c2", false), - ("c1 IN (c2)", true))) + ("c1 IN (c2)", true), + ("c1 <=> null", false))) } } @@ -443,7 +460,29 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { ("c1 IN ('a', 'b')", true), ("c1 = c2", true), ("c1 < c2", false), - ("c1 IN (c2)", true))) + ("c1 IN (c2)", true), + ("c1 <=> null", false))) + } + } + + private def testNullConditions(df: DataFrame, conditions: Seq[String]): Unit = { + conditions.foreach { cond => + checkAnswer(df.selectExpr(cond), Row(null)) + } + } + + test("SPARK-34233: char type comparison with null values") { + val conditions = Seq("c = null", "c IN ('e', null)", "c IN (null)") + withTable("t") { + sql(s"CREATE TABLE t(c CHAR(2)) USING $format") + sql("INSERT INTO t VALUES ('a')") + testNullConditions(spark.table("t"), conditions) + } + + withTable("t") { + sql(s"CREATE TABLE t(i INT, c CHAR(2)) USING $format PARTITIONED BY (c)") + sql("INSERT INTO t VALUES (1, 'a')") + testNullConditions(spark.table("t"), conditions) } }