Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2308,9 +2308,16 @@ case class ArrayJoin(
}
}
} else {
// When array and delimiter are both non-nullable, neither nullSafeExec wrapper above runs,
// so reset ev.isNull here. doGenCode initializes ev.isNull to true whenever the expression
// is nullable (e.g. a nullable nullReplacement), and without this reset the computed result
// would be discarded as NULL. When the expression is non-nullable, ev.isNull is a literal
// false and must not be assigned.
val resetIsNull = if (nullable) s"${ev.isNull} = false;" else ""
s"""
|${arrayGen.code}
|${delimiterGen.code}
|$resetIsNull
|$resultCode""".stripMargin
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,33 @@ class CollectionExpressionsSuite
Some(Literal.create(null, StringType))), null)
}

test("ArrayJoin codegen with non-nullable array/delimiter and nullable " +
"nullReplacement") {
// When an upstream IsNotNull filter tightens the array and delimiter to
// non-nullable but the nullReplacement is a nullable column, ArrayJoin.nullable is true so
// doGenCode initializes ev.isNull = true. The non-nullable branch of
// genCodeForArrayAndDelimiter must still reset ev.isNull = false, otherwise codegen builds the
// joined string but discards it as NULL while interpreted eval() returns the correct result.
val arr = BoundReference(0, ArrayType(StringType, containsNull = true), nullable = false)
val delimiter = BoundReference(1, StringType, nullable = false)
val nullReplacement = BoundReference(2, StringType, nullable = true)
val arrayJoin = ArrayJoin(arr, delimiter, Some(nullReplacement))
// ArrayJoin is nullable only because nullReplacement is nullable.
assert(arrayJoin.nullable)

// Non-null replacement: NULL array elements are replaced and a joined string is produced.
checkEvaluation(
arrayJoin,
"a,NR,b",
create_row(Seq[String]("a", null, "b"), ",", "NR"))

// Null replacement value: the whole result is NULL, matching eval().
checkEvaluation(
arrayJoin,
null,
create_row(Seq[String]("a", null, "b"), ",", null))
}

test("ArraysZip") {
val literals = Seq(
Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -1993,6 +1994,40 @@ class DataFrameFunctionsSuite extends SharedSparkSession {
)
}

test("array_join with nullable nullReplacement under whole-stage codegen") {
// With a nullable nullReplacement column and an upstream IsNotNull
// filter that tightens the array (and delimiter) to non-nullable, whole-stage codegen used to
// build the joined string but leave ev.isNull = true, discarding every row as NULL. The result
// must match interpreted eval(). The source is materialized via a cached temp view (an
// InMemoryRelation), so the plan is not folded to interpreted eval by ConvertToLocalRelation.
withTempView("array_join_codegen") {
Seq(
(Seq[String]("a", null, "b"), ",", "NR"),
(Seq[String]("a", null, "b"), ",", null),
(Seq[String]("x", "y"), "-", "NR")
).toDF("arr", "delim_col", "repl_col").createOrReplaceTempView("array_join_codegen")
spark.catalog.cacheTable("array_join_codegen")

val query =
"SELECT array_join(arr, delim_col, repl_col) FROM array_join_codegen " +
"WHERE arr IS NOT NULL AND delim_col IS NOT NULL"

withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY") {
val df = sql(query)
assert(
df.queryExecution.executedPlan.exists(_.isInstanceOf[WholeStageCodegenExec]),
"expected the array_join query to run inside whole-stage codegen")
checkAnswer(df, Seq(Row("a,NR,b"), Row(null), Row("x-y")))
}

withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN") {
checkAnswer(sql(query), Seq(Row("a,NR,b"), Row(null), Row("x-y")))
}
}
}

test("array_min function") {
val df = Seq(
Seq[Option[Int]](Some(1), Some(3), Some(2)),
Expand Down