diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index c087fdf5f962b..1b9432047d9d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -56,8 +56,9 @@ case class ExpandExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val groups = projections.map(projection).toArray + groups.foreach(_.initialize(index)) new Iterator[InternalRow] { private[this] var result: InternalRow = _ private[this] var idx = -1 // -1 means the initial state diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index f659ca6329e2f..cb7e395212bf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.Row import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue} import org.apache.spark.sql.connector.expressions.LiteralValue +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType} abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { @@ -528,6 +529,25 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(2) :: Nil) } + test("SPARK-53538: update with nondeterministic assignments and no wholestage codegen") { + val extraColCount = SQLConf.get.wholeStageMaxNumFields - 4 + val schema = "pk INT NOT NULL, id INT, value DOUBLE, dep STRING, " + + ((1 to extraColCount).map(i => s"col$i INT").mkString(", ")) + val data = (1 to 3).map { i => + s"""{ "pk": $i, "id": $i, "value": 2.0, "dep": "hr", """ + + ((1 to extraColCount).map(j => s""""col$j": $i""").mkString(", ")) + + "}" + }.mkString("\n") + createAndInitTable(schema, data) + + // rand() always generates values in [0, 1) range + sql(s"UPDATE $tableNameAsString SET value = rand() WHERE id <= 2") + + checkAnswer( + sql(s"SELECT count(*) FROM $tableNameAsString WHERE value < 2.0"), + Row(2) :: Nil) + } + test("update with default values") { val idDefault = new ColumnDefaultValue("42", LiteralValue(42, IntegerType)) val columns = Array(