From 3daaf6889e72a1a00a8e9c8895efeba7f127dcf2 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 31 Oct 2025 10:41:01 -0700 Subject: [PATCH 01/13] [SPARK-54172][SQL] Merge Into Schema Evolution should only add referenced columns --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../ResolveMergeIntoSchemaEvolution.scala | 12 +- .../ResolveRowLevelCommandAssignments.scala | 81 ++- .../analysis/RewriteMergeIntoTable.scala | 6 +- .../catalyst/plans/logical/v2Commands.scala | 93 +++- .../PullupCorrelatedPredicatesSuite.scala | 2 +- .../connector/MergeIntoTableSuiteBase.scala | 513 ++++++++++++++++-- .../command/PlanResolutionSuite.scala | 33 +- 8 files changed, 649 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 98c514925fa0..25b2465ab088 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1669,7 +1669,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UpdateTable => resolveReferencesInUpdate(u) - case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _) + case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _, _) if !m.resolved && targetTable.resolved && sourceTable.resolved && !m.needSchemaEvolution => EliminateSubqueryAliases(targetTable) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala index 7e7776098a04..b97663a2d344 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala @@ -34,24 +34,26 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case m @ MergeIntoTable(_, _, _, _, _, _, _) + case m @ MergeIntoTable(_, _, _, _, _, _, _, _) if m.needSchemaEvolution => val newTarget = m.targetTable.transform { - case r : DataSourceV2Relation => performSchemaEvolution(r, m.sourceTable) + case r : DataSourceV2Relation => performSchemaEvolution(r, m) } m.copy(targetTable = newTarget) } - private def performSchemaEvolution(relation: DataSourceV2Relation, source: LogicalPlan) + private def performSchemaEvolution(relation: DataSourceV2Relation, m: MergeIntoTable) : DataSourceV2Relation = { (relation.catalog, relation.identifier) match { case (Some(c: TableCatalog), Some(i)) => - val changes = MergeIntoTable.schemaChanges(relation.schema, source.schema) + val referencedSourceSchema = MergeIntoTable.referencedSourceSchema(m) + + val changes = MergeIntoTable.schemaChanges(relation.schema, referencedSourceSchema) c.alterTable(i, changes: _*) val newTable = c.loadTable(i) val newSchema = CatalogV2Util.v2ColumnsToStructType(newTable.columns()) // Check if there are any remaining changes not applied. - val remainingChanges = MergeIntoTable.schemaChanges(newSchema, source.schema) + val remainingChanges = MergeIntoTable.schemaChanges(newSchema, referencedSourceSchema) if (remainingChanges.nonEmpty) { throw QueryCompilationErrors.unsupportedTableChangesInAutoSchemaEvolutionError( remainingChanges, i.toQualifiedNameParts(c)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala index 3eb528954b35..d2ff77ed2a51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala @@ -60,11 +60,18 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions, coerceNestedTypes), notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions, - coerceNestedTypes)) + coerceNestedTypes), + preservedSourceActions = Some(m.matchedActions ++ m.notMatchedActions) + ) case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && !m.aligned && !m.needSchemaEvolution => - resolveAssignments(m) + m.copy( + matchedActions = m.notMatchedActions.map(resolveMergeAction), + notMatchedActions = m.notMatchedActions.map(resolveMergeAction), + notMatchedBySourceActions = m.matchedActions.map(resolveMergeAction), + preservedSourceActions = Some(m.matchedActions ++ m.notMatchedActions) + ) } private def validateStoreAssignmentPolicy(): Unit = { @@ -83,33 +90,51 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { private def resolveAssignments(p: LogicalPlan): LogicalPlan = { p.transformExpressions { - case assignment: Assignment => - val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) { - AssertNotNull(assignment.value) - } else { - assignment.value - } - val casted = if (assignment.key.dataType != nullHandled.dataType) { - val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true) - cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) - cast - } else { - nullHandled - } - val rawKeyType = assignment.key.transform { - case a: AttributeReference => - CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a) - }.dataType - val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { - CharVarcharUtils.stringLengthCheck(casted, rawKeyType) - } else { - casted - } - val cleanedKey = assignment.key.transform { - case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a) - } - Assignment(cleanedKey, finalValue) + case assignment: Assignment => resolveAssignment(assignment) + } + } + + private def resolveMergeAction(mergeAction: MergeAction) = { + mergeAction match { + case u @ UpdateAction(_, assignments) => + u.copy(assignments = assignments.map(resolveAssignment)) + case i @ InsertAction(_, assignments) => + i.copy(assignments = assignments.map(resolveAssignment)) + case d: DeleteAction => + d + case other => + throw new AnalysisException( + errorClass = "_LEGACY_ERROR_TEMP_3053", + messageParameters = Map("other" -> other.toString)) + } + } + + private def resolveAssignment(assignment: Assignment) = { + val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) { + AssertNotNull(assignment.value) + } else { + assignment.value + } + val casted = if (assignment.key.dataType != nullHandled.dataType) { + val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true) + cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) + cast + } else { + nullHandled + } + val rawKeyType = assignment.key.transform { + case a: AttributeReference => + CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a) + }.dataType + val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { + CharVarcharUtils.stringLengthCheck(casted, rawKeyType) + } else { + casted + } + val cleanedKey = assignment.key.transform { + case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a) } + Assignment(cleanedKey, finalValue) } private def alignActions( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index 8b5b690aa740..d2c4abb2259b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -45,7 +45,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions, - notMatchedBySourceActions, _) if m.resolved && m.rewritable && m.aligned && + notMatchedBySourceActions, _, _) if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution && matchedActions.isEmpty && notMatchedActions.size == 1 && notMatchedBySourceActions.isEmpty => @@ -79,7 +79,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper } case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions, - notMatchedBySourceActions, _) + notMatchedBySourceActions, _, _) if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution && matchedActions.isEmpty && notMatchedBySourceActions.isEmpty => @@ -121,7 +121,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper } case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions, - notMatchedBySourceActions, _) + notMatchedBySourceActions, _, _) if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution => EliminateSubqueryAliases(aliasedTable) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index cd0c2742df3d..0a1ad4efed6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedAttribute, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -860,7 +860,10 @@ case class MergeIntoTable( matchedActions: Seq[MergeAction], notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction], - withSchemaEvolution: Boolean) extends BinaryCommand with SupportsSubquery { + withSchemaEvolution: Boolean, + // Preserves original pre-aligned actions for source matches + preservedSourceActions: Option[Seq[MergeAction]] = None) + extends BinaryCommand with SupportsSubquery { lazy val aligned: Boolean = { val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions @@ -892,9 +895,13 @@ case class MergeIntoTable( case _ => false } - lazy val needSchemaEvolution: Boolean = + private lazy val migrationSchema: StructType = + MergeIntoTable.referencedSourceSchema(this) + + lazy val needSchemaEvolution: Boolean = { schemaEvolutionEnabled && - MergeIntoTable.schemaChanges(targetTable.schema, sourceTable.schema).nonEmpty + MergeIntoTable.schemaChanges(targetTable.schema, migrationSchema).nonEmpty + } private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { EliminateSubqueryAliases(targetTable) match { @@ -948,11 +955,12 @@ object MergeIntoTable { case currentField: StructField if newFieldMap.contains(currentField.name) => schemaChanges(currentField.dataType, newFieldMap(currentField.name).dataType, originalTarget, originalSource, fieldPath ++ Seq(currentField.name)) - }}.flatten + } + }.flatten // Identify the newly added fields and append to the end val currentFieldMap = toFieldMap(currentFields) - val adds = newFields.filterNot (f => currentFieldMap.contains (f.name)) + val adds = newFields.filterNot(f => currentFieldMap.contains(f.name)) .map(f => TableChange.addColumn(fieldPath ++ Set(f.name), f.dataType)) updates ++ adds @@ -990,8 +998,81 @@ object MergeIntoTable { CaseInsensitiveMap(fieldMap) } } + + // Filter the source schema to retain only fields that are referenced + // by at least one merge action + def referencedSourceSchema(merge: MergeIntoTable): StructType = { + + val actions = merge.preservedSourceActions match { + case Some(preserved) => preserved + case None => merge.matchedActions ++ merge.notMatchedActions + } + + val assignments = actions.collect { + case a: UpdateAction => a.assignments.map(_.key) + case a: InsertAction => a.assignments.map(_.key) + }.flatten + + val containsStarAction = actions.exists { + case _: UpdateStarAction => true + case _: InsertStarAction => true + case _ => false + } + + def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType = + StructType(sourceSchema.flatMap { field => + val fieldPath = basePath :+ field.name + + field.dataType match { + // Specifically assigned to in one clause: + // always keep, including all nested attributes + case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field) + // If this is a struct and one of the children is being assigned to in a merge clause, + // keep it and continue filtering children. + case struct: StructType if assignments.exists(assign => + isPrefix(fieldPath, extractFieldPath(assign))) => + Some(field.copy(dataType = filterSchema(struct, fieldPath))) + // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* + // clause. Check if it should be kept with any * action. + case struct: StructType if containsStarAction => + Some(field.copy(dataType = filterSchema(struct, fieldPath))) + case _ if containsStarAction => Some(field) + // The field and its children are not assigned to in any * or non-* action, drop it. + case _ => None + } + }) + + val sourceSchema = merge.sourceTable.schema + val targetSchema = merge.targetTable.schema + val res = filterSchema(merge.sourceTable.schema, Seq.empty) + res + } + + // Helper method to extract field path from an Expression. + private def extractFieldPath(expr: Expression): Seq[String] = expr match { + case UnresolvedAttribute(nameParts) => nameParts + case a: AttributeReference => Seq(a.name) + case GetStructField(child, ordinal, nameOpt) => + extractFieldPath(child) :+ nameOpt.getOrElse(s"col$ordinal") + case _ => Seq.empty + } + + // Helper method to check if a given field path is a prefix of another path. Delegates + // equality to conf.resolver to correctly handle case sensitivity. + private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean = + prefix.length <= path.length && prefix.zip(path).forall { + case (prefixNamePart, pathNamePart) => + SQLConf.get.resolver(prefixNamePart, pathNamePart) + } + + // Helper method to check if an assignment Expression's field path is equal to a path. + def isEqual(assignmentExpr: Expression, path: Seq[String]): Boolean = { + val exprPath = extractFieldPath(assignmentExpr) + exprPath.length == path.length && isPrefix(exprPath, path) + } } + sealed abstract class MergeAction extends Expression with Unevaluable { def condition: Option[Expression] override def nullable: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index cbd24bd7bb29..96ae91dbedc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -167,7 +167,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { assert(optimized.resolved) optimized match { - case MergeIntoTable(_, _, s: InSubquery, _, _, _, _) => + case MergeIntoTable(_, _, s: InSubquery, _, _, _, _, _) => val outerRefs = SubExprUtils.getOuterReferences(s.query.plan) assert(outerRefs.isEmpty, "should be no outer refs") case other => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 98706c4afeae..b9c9fb22664a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -3510,48 +3510,485 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("merge into with source missing fields in top-level struct") { - withTempView("source") { - // Target table has struct with 3 fields at top level - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT, - |dep STRING""".stripMargin, - """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") + test("Merge schema evolution should not evolve if not directly referencing new column: update") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) - // Source table has struct with only 2 fields (c1, c2) - missing c3 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType)))), // missing c3 field - StructField("dep", StringType))) - val data = Seq( - Row(1, Row(10, "b"), "hr"), - Row(2, Row(20, "c"), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") - sql( - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET dep='software' + |""".stripMargin - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, "a", true), "sales"), - Row(1, Row(10, "b", null), "hr"), - Row(2, Row(20, "c", null), "engineering"))) + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing new column: insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'newdep') + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 250, "newdep"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing new column:" + + "update and insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET dep='software' + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'newdep') + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 250, "newdep"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should only evolve referenced column when source " + + "has multiple new columns") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", 50, "blah"), + (3, 250, "dummy", 75, "blah")).toDF("pk", "salary", "dep", "bonus", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, bonus = s.bonus + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep, bonus) VALUES (s.pk, s.salary, 'newdep', s.bonus) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 150, "software", 50), + Row(3, 250, "newdep", 75))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should only evolve referenced struct field when source " + + "has multiple new struct fields") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType), // new field 1 + StructField("extra", StringType) // new field 2 + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "dummy", 50, "blah"), "active"), + Row(3, Row(250, "dummy", 75, "blah"), "active") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET info.bonus = s.info.bonus + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + // Only 'bonus' field should be added, not 'extra' + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(200, "inactive", 50), "software"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "FIELD_NOT_FOUND") + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve struct if directly referencing new field " + + "in top level struct: insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "dummy", 50), "active"), + Row(3, Row(250, "dummy", 75), "active") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, info, dep) VALUES (s.pk, + | named_struct('salary', s.info.salary, 'status', 'active'), 'marketing') + |""".stripMargin + + sql(mergeStmt) + if (withSchemaEvolution) { + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(200, "inactive", null), "software"), + Row(3, Row(250, "active", null), "marketing"))) + } else { + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active"), "hr"), + Row(2, Row(200, "inactive"), "software"), + Row(3, Row(250, "active"), "marketing"))) + } + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing new field " + + "in top level struct") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "dummy", 50), "active"), + Row(3, Row(250, "dummy", 75), "active") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET info.status='inactive' + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active"), "hr"), + Row(2, Row(200, "inactive"), "software"))) + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing " + + "new field in nested struct") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("employee", StructType(Seq( + StructField("name", StringType), + StructField("details", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType) + ))) + ))), + StructField("dep", StringType) + )) + + createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) + + val targetData = Seq( + Row(1, Row("Alice", Row(100, "active")), "hr"), + Row(2, Row("Bob", Row(200, "active")), "software") + ) + spark.createDataFrame( + spark.sparkContext.parallelize(targetData), targetSchema) + .coalesce(1).writeTo(tableNameAsString).append() + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row("Alice", Row(100, "active")), "hr"), + Row(2, Row("Bob", Row(200, "active")), "software"))) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("employee", StructType(Seq( + StructField("name", StringType), + StructField("details", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))) + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row("Bob", Row(150, "active", 50)), "dummy"), + Row(3, Row("Charlie", Row(250, "active", 75)), "dummy") + ) + spark.createDataFrame( + spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = + if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET employee.details.status='inactive' + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row("Alice", Row(100, "active")), "hr"), + Row(2, Row("Bob", Row(200, "inactive")), "software"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve referencing new column via transform") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET dep='software', extra=substring(s.extra, 1, 2) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 200, "software", "bl"))) + } else { + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve referencing new column assigned to something else") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET dep='software', extra=s.dep + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 200, "software", "dummy"))) + } else { + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) + } + + sql(s"DROP TABLE $tableNameAsString") + } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } test("merge into with source missing fields in struct nested in array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 9ea8b9130ba8..02024f26baf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1699,7 +1699,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(nul: AttributeReference, StringLiteral("update"))), notMatchedBySourceUpdateAssigns)), - withSchemaEvolution) => + withSchemaEvolution, + _) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns) checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns) @@ -1731,7 +1732,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))), insertAssigns)), Seq(), - withSchemaEvolution) => + withSchemaEvolution, + _) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns, starInUpdate = true) @@ -1758,7 +1760,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(UpdateAction(None, updateAssigns)), Seq(InsertAction(None, insertAssigns)), Seq(), - withSchemaEvolution) => + withSchemaEvolution, + _) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, None, None, updateAssigns, starInUpdate = true) @@ -1789,7 +1792,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(InsertAction(None, insertAssigns)), Seq(DeleteAction(Some(EqualTo(_: AttributeReference, StringLiteral("delete")))), UpdateAction(None, notMatchedBySourceUpdateAssigns)), - withSchemaEvolution) => + withSchemaEvolution, + _) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, None, None, updateAssigns) checkNotMatchedClausesResolution(target, source, None, insertAssigns) @@ -1826,7 +1830,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(nul: AttributeReference, StringLiteral("update"))), notMatchedBySourceUpdateAssigns)), - withSchemaEvolution) => + withSchemaEvolution, + _) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns) checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns) @@ -1865,7 +1870,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(nul: AttributeReference, StringLiteral("update"))), notMatchedBySourceUpdateAssigns)), - withSchemaEvolution) => + withSchemaEvolution, + _) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns) checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns) @@ -2151,7 +2157,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Some(EqualTo(il: AttributeReference, StringLiteral("a"))), insertAssigns)), Seq(DeleteAction(Some(_)), UpdateAction(None, secondUpdateAssigns)), - withSchemaEvolution) => + withSchemaEvolution, + _) => val ti = target.output.find(_.name == "i").get val ts = target.output.find(_.name == "s").get val si = source.output.find(_.name == "i").get @@ -2258,7 +2265,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(), Seq(), notMatchedBySourceActions, - withSchemaEvolution) => + withSchemaEvolution, + _) => assert(notMatchedBySourceActions.length == 2) notMatchedBySourceActions(0) match { case DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("b")))) => @@ -2333,7 +2341,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(UpdateAction(None, updateAssigns)), // Matched actions Seq(), // Not matched actions Seq(), // Not matched by source actions - withSchemaEvolution) => + withSchemaEvolution, + _) => val ti = target.output.find(_.name == "i").get val si = source.output.find(_.name == "i").get assert(updateAssigns.size == 1) @@ -2358,7 +2367,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(), // Matched action Seq(InsertAction(None, insertAssigns)), // Not matched actions Seq(), // Not matched by source actions - withSchemaEvolution) => + withSchemaEvolution, + _) => val ti = target.output.find(_.name == "i").get val si = source.output.find(_.name == "i").get assert(insertAssigns.size == 1) @@ -2442,7 +2452,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(DeleteAction(None)), Seq(InsertAction(None, insertAssigns)), Nil, - withSchemaEvolution) => + withSchemaEvolution, + _) => // There is only one assignment, the missing col is not filled with default value assert(insertAssigns.size == 1) // Special case: Spark does not resolve any columns in MERGE if table accepts any schema. From 24b1a515a0b6c61c769e9c7e099bf7b01156c386 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 4 Nov 2025 11:45:28 -0800 Subject: [PATCH 02/13] Refactor and add more test --- .../sql/catalyst/analysis/Analyzer.scala | 3 +- .../ResolveRowLevelCommandAssignments.scala | 81 +++++++----------- .../catalyst/plans/logical/v2Commands.scala | 45 ++++++---- .../connector/MergeIntoTableSuiteBase.scala | 84 +++++++++++++++++++ 4 files changed, 143 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 25b2465ab088..7620551e820c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1749,7 +1749,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor m.copy(mergeCondition = resolvedMergeCondition, matchedActions = newMatchedActions, notMatchedActions = newNotMatchedActions, - notMatchedBySourceActions = newNotMatchedBySourceActions) + notMatchedBySourceActions = newNotMatchedBySourceActions, + originalSourceActions = newMatchedActions ++ newNotMatchedActions) } // UnresolvedHaving can host grouping expressions and aggregate functions. We should resolve diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala index d2ff77ed2a51..3eb528954b35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala @@ -60,18 +60,11 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions, coerceNestedTypes), notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions, - coerceNestedTypes), - preservedSourceActions = Some(m.matchedActions ++ m.notMatchedActions) - ) + coerceNestedTypes)) case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && !m.aligned && !m.needSchemaEvolution => - m.copy( - matchedActions = m.notMatchedActions.map(resolveMergeAction), - notMatchedActions = m.notMatchedActions.map(resolveMergeAction), - notMatchedBySourceActions = m.matchedActions.map(resolveMergeAction), - preservedSourceActions = Some(m.matchedActions ++ m.notMatchedActions) - ) + resolveAssignments(m) } private def validateStoreAssignmentPolicy(): Unit = { @@ -90,51 +83,33 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { private def resolveAssignments(p: LogicalPlan): LogicalPlan = { p.transformExpressions { - case assignment: Assignment => resolveAssignment(assignment) - } - } - - private def resolveMergeAction(mergeAction: MergeAction) = { - mergeAction match { - case u @ UpdateAction(_, assignments) => - u.copy(assignments = assignments.map(resolveAssignment)) - case i @ InsertAction(_, assignments) => - i.copy(assignments = assignments.map(resolveAssignment)) - case d: DeleteAction => - d - case other => - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3053", - messageParameters = Map("other" -> other.toString)) - } - } - - private def resolveAssignment(assignment: Assignment) = { - val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) { - AssertNotNull(assignment.value) - } else { - assignment.value - } - val casted = if (assignment.key.dataType != nullHandled.dataType) { - val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true) - cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) - cast - } else { - nullHandled - } - val rawKeyType = assignment.key.transform { - case a: AttributeReference => - CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a) - }.dataType - val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { - CharVarcharUtils.stringLengthCheck(casted, rawKeyType) - } else { - casted - } - val cleanedKey = assignment.key.transform { - case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a) + case assignment: Assignment => + val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) { + AssertNotNull(assignment.value) + } else { + assignment.value + } + val casted = if (assignment.key.dataType != nullHandled.dataType) { + val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true) + cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) + cast + } else { + nullHandled + } + val rawKeyType = assignment.key.transform { + case a: AttributeReference => + CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a) + }.dataType + val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { + CharVarcharUtils.stringLengthCheck(casted, rawKeyType) + } else { + casted + } + val cleanedKey = assignment.key.transform { + case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a) + } + Assignment(cleanedKey, finalValue) } - Assignment(cleanedKey, finalValue) } private def alignActions( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 0a1ad4efed6f..77beae1171bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -862,7 +862,7 @@ case class MergeIntoTable( notMatchedBySourceActions: Seq[MergeAction], withSchemaEvolution: Boolean, // Preserves original pre-aligned actions for source matches - preservedSourceActions: Option[Seq[MergeAction]] = None) + originalSourceActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery { lazy val aligned: Boolean = { @@ -895,12 +895,14 @@ case class MergeIntoTable( case _ => false } - private lazy val migrationSchema: StructType = + // a pruned version of source schema that only contains columns/nested fields + // explicitly assigned by MERGE INTO actions + private lazy val referencedSourceSchema: StructType = MergeIntoTable.referencedSourceSchema(this) lazy val needSchemaEvolution: Boolean = { schemaEvolutionEnabled && - MergeIntoTable.schemaChanges(targetTable.schema, migrationSchema).nonEmpty + MergeIntoTable.schemaChanges(targetTable.schema, referencedSourceSchema).nonEmpty } private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { @@ -918,6 +920,26 @@ case class MergeIntoTable( } object MergeIntoTable { + + def apply( + targetTable: LogicalPlan, + sourceTable: LogicalPlan, + mergeCondition: Expression, + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + notMatchedBySourceActions: Seq[MergeAction], + withSchemaEvolution: Boolean): MergeIntoTable = { + MergeIntoTable( + targetTable, + sourceTable, + mergeCondition, + matchedActions, + notMatchedActions, + notMatchedBySourceActions, + withSchemaEvolution, + matchedActions ++ notMatchedActions) + } + def getWritePrivileges( matchedActions: Iterable[MergeAction], notMatchedActions: Iterable[MergeAction], @@ -955,12 +977,11 @@ object MergeIntoTable { case currentField: StructField if newFieldMap.contains(currentField.name) => schemaChanges(currentField.dataType, newFieldMap(currentField.name).dataType, originalTarget, originalSource, fieldPath ++ Seq(currentField.name)) - } - }.flatten + }}.flatten // Identify the newly added fields and append to the end val currentFieldMap = toFieldMap(currentFields) - val adds = newFields.filterNot(f => currentFieldMap.contains(f.name)) + val adds = newFields.filterNot (f => currentFieldMap.contains (f.name)) .map(f => TableChange.addColumn(fieldPath ++ Set(f.name), f.dataType)) updates ++ adds @@ -1003,17 +1024,12 @@ object MergeIntoTable { // by at least one merge action def referencedSourceSchema(merge: MergeIntoTable): StructType = { - val actions = merge.preservedSourceActions match { - case Some(preserved) => preserved - case None => merge.matchedActions ++ merge.notMatchedActions - } - - val assignments = actions.collect { + val assignments = merge.originalSourceActions.collect { case a: UpdateAction => a.assignments.map(_.key) case a: InsertAction => a.assignments.map(_.key) }.flatten - val containsStarAction = actions.exists { + val containsStarAction = merge.originalSourceActions.exists { case _: UpdateStarAction => true case _: InsertStarAction => true case _ => false @@ -1042,8 +1058,6 @@ object MergeIntoTable { } }) - val sourceSchema = merge.sourceTable.schema - val targetSchema = merge.targetTable.schema val res = filterSchema(merge.sourceTable.schema, Seq.empty) res } @@ -1072,7 +1086,6 @@ object MergeIntoTable { } } - sealed abstract class MergeAction extends Expression with Unevaluable { def condition: Option[Expression] override def nullable: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index b9c9fb22664a..11f5d8aaf423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -3721,6 +3721,46 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("Merge schema evolution should not evolve when assigning existing target column " + + "from source column that does not exist in target") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", 50), + (3, 250, "dummy", 75)).toDF("pk", "salary", "dep", "bonus") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.bonus + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.bonus, 'newdep') + |""".stripMargin + + sql(mergeStmt) + // bonus column should NOT be added to target schema + // Only salary is updated with bonus value + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 50, "software"), + Row(3, 75, "newdep"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + test("Merge schema evolution should evolve struct if directly referencing new field " + "in top level struct: insert") { Seq(true, false).foreach { withSchemaEvolution => @@ -3991,6 +4031,50 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("merge into with source missing fields in top-level struct") { + withTempView("source") { + // Target table has struct with 3 fields at top level + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") + + // Source table has struct with only 2 fields (c1, c2) - missing c3 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), // missing c3 field + StructField("dep", StringType))) + val data = Seq( + Row(1, Row(10, "b"), "hr"), + Row(2, Row(20, "c"), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a", true), "sales"), + Row(1, Row(10, "b", null), "hr"), + Row(2, Row(20, "c", null), "engineering"))) + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + test("merge into with source missing fields in struct nested in array") { withTempView("source") { // Target table has struct with 3 fields (c1, c2, c3) in array From abbeb1edcd95b6b0054c46ae65cf98040a576ee5 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 6 Nov 2025 14:02:06 -0800 Subject: [PATCH 03/13] Only allow schema evolution for case where new field is target of assignment where value is same name in source --- .../sql/catalyst/analysis/Analyzer.scala | 5 +- .../ResolveMergeIntoSchemaEvolution.scala | 4 +- .../analysis/RewriteMergeIntoTable.scala | 6 +- .../catalyst/plans/logical/v2Commands.scala | 74 ++++----- .../PullupCorrelatedPredicatesSuite.scala | 2 +- .../connector/MergeIntoTableSuiteBase.scala | 157 +++++++++--------- .../command/PlanResolutionSuite.scala | 33 ++-- 7 files changed, 132 insertions(+), 149 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7620551e820c..98c514925fa0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1669,7 +1669,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UpdateTable => resolveReferencesInUpdate(u) - case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _, _) + case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _) if !m.resolved && targetTable.resolved && sourceTable.resolved && !m.needSchemaEvolution => EliminateSubqueryAliases(targetTable) match { @@ -1749,8 +1749,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor m.copy(mergeCondition = resolvedMergeCondition, matchedActions = newMatchedActions, notMatchedActions = newNotMatchedActions, - notMatchedBySourceActions = newNotMatchedBySourceActions, - originalSourceActions = newMatchedActions ++ newNotMatchedActions) + notMatchedBySourceActions = newNotMatchedBySourceActions) } // UnresolvedHaving can host grouping expressions and aggregate functions. We should resolve diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala index b97663a2d344..23018b1edfe6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case m @ MergeIntoTable(_, _, _, _, _, _, _, _) + case m @ MergeIntoTable(_, _, _, _, _, _, _) if m.needSchemaEvolution => val newTarget = m.targetTable.transform { case r : DataSourceV2Relation => performSchemaEvolution(r, m) @@ -46,7 +46,7 @@ object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { : DataSourceV2Relation = { (relation.catalog, relation.identifier) match { case (Some(c: TableCatalog), Some(i)) => - val referencedSourceSchema = MergeIntoTable.referencedSourceSchema(m) + val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m) val changes = MergeIntoTable.schemaChanges(relation.schema, referencedSourceSchema) c.alterTable(i, changes: _*) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index d2c4abb2259b..8b5b690aa740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -45,7 +45,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions, - notMatchedBySourceActions, _, _) if m.resolved && m.rewritable && m.aligned && + notMatchedBySourceActions, _) if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution && matchedActions.isEmpty && notMatchedActions.size == 1 && notMatchedBySourceActions.isEmpty => @@ -79,7 +79,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper } case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions, - notMatchedBySourceActions, _, _) + notMatchedBySourceActions, _) if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution && matchedActions.isEmpty && notMatchedBySourceActions.isEmpty => @@ -121,7 +121,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper } case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions, - notMatchedBySourceActions, _, _) + notMatchedBySourceActions, _) if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution => EliminateSubqueryAliases(aliasedTable) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 77beae1171bf..fe6e068690be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -860,9 +860,7 @@ case class MergeIntoTable( matchedActions: Seq[MergeAction], notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction], - withSchemaEvolution: Boolean, - // Preserves original pre-aligned actions for source matches - originalSourceActions: Seq[MergeAction]) + withSchemaEvolution: Boolean) extends BinaryCommand with SupportsSubquery { lazy val aligned: Boolean = { @@ -895,14 +893,12 @@ case class MergeIntoTable( case _ => false } - // a pruned version of source schema that only contains columns/nested fields - // explicitly assigned by MERGE INTO actions - private lazy val referencedSourceSchema: StructType = - MergeIntoTable.referencedSourceSchema(this) + private lazy val sourceSchemaForEvolution: StructType = + MergeIntoTable.sourceSchemaForSchemaEvolution(this) lazy val needSchemaEvolution: Boolean = { schemaEvolutionEnabled && - MergeIntoTable.schemaChanges(targetTable.schema, referencedSourceSchema).nonEmpty + MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty } private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { @@ -921,25 +917,6 @@ case class MergeIntoTable( object MergeIntoTable { - def apply( - targetTable: LogicalPlan, - sourceTable: LogicalPlan, - mergeCondition: Expression, - matchedActions: Seq[MergeAction], - notMatchedActions: Seq[MergeAction], - notMatchedBySourceActions: Seq[MergeAction], - withSchemaEvolution: Boolean): MergeIntoTable = { - MergeIntoTable( - targetTable, - sourceTable, - mergeCondition, - matchedActions, - notMatchedActions, - notMatchedBySourceActions, - withSchemaEvolution, - matchedActions ++ notMatchedActions) - } - def getWritePrivileges( matchedActions: Iterable[MergeAction], notMatchedActions: Iterable[MergeAction], @@ -1020,16 +997,18 @@ object MergeIntoTable { } } - // Filter the source schema to retain only fields that are referenced - // by at least one merge action - def referencedSourceSchema(merge: MergeIntoTable): StructType = { + // A pruned version of source schema that only contains columns/nested fields + // explicitly and directly assigned to a target counterpart in MERGE INTO actions. + // New columns/nested fields not existing in target will be added for schema evolution. + def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = { - val assignments = merge.originalSourceActions.collect { - case a: UpdateAction => a.assignments.map(_.key) - case a: InsertAction => a.assignments.map(_.key) + val actions = merge.matchedActions ++ merge.notMatchedActions + val assignments = actions.collect { + case a: UpdateAction => a.assignments + case a: InsertAction => a.assignments }.flatten - val containsStarAction = merge.originalSourceActions.exists { + val containsStarAction = actions.exists { case _: UpdateStarAction => true case _: InsertStarAction => true case _ => false @@ -1046,7 +1025,7 @@ object MergeIntoTable { // If this is a struct and one of the children is being assigned to in a merge clause, // keep it and continue filtering children. case struct: StructType if assignments.exists(assign => - isPrefix(fieldPath, extractFieldPath(assign))) => + isPrefix(fieldPath, extractFieldPath(assign.key))) => Some(field.copy(dataType = filterSchema(struct, fieldPath))) // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* // clause. Check if it should be kept with any * action. @@ -1058,8 +1037,7 @@ object MergeIntoTable { } }) - val res = filterSchema(merge.sourceTable.schema, Seq.empty) - res + filterSchema(merge.sourceTable.schema, Seq.empty) } // Helper method to extract field path from an Expression. @@ -1071,18 +1049,28 @@ object MergeIntoTable { case _ => Seq.empty } - // Helper method to check if a given field path is a prefix of another path. Delegates - // equality to conf.resolver to correctly handle case sensitivity. + // Helper method to check if a given field path is a prefix of another path. private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean = prefix.length <= path.length && prefix.zip(path).forall { case (prefixNamePart, pathNamePart) => SQLConf.get.resolver(prefixNamePart, pathNamePart) } - // Helper method to check if an assignment Expression's field path is equal to a path. - def isEqual(assignmentExpr: Expression, path: Seq[String]): Boolean = { - val exprPath = extractFieldPath(assignmentExpr) - exprPath.length == path.length && isPrefix(exprPath, path) + // Helper method to check if a given field path is a suffix of another path. + private def isSuffix(prefix: Seq[String], path: Seq[String]): Boolean = + prefix.length <= path.length && prefix.reverse.zip(path.reverse).forall { + case (prefixNamePart, pathNamePart) => + SQLConf.get.resolver(prefixNamePart, pathNamePart) + } + + // Helper method to check if an assignment key is equal to a source column + // and if the assignment value is the corresponding source column directly + private def isEqual(assignment: Assignment, path: Seq[String]): Boolean = { + val assignmenKeyExpr = extractFieldPath(assignment.key) + val assignmentValueExpr = extractFieldPath(assignment.value) + // Valid assignments are: col = s.col or col.nestedField = s.col.nestedField + assignmenKeyExpr.length == path.length && isPrefix(assignmenKeyExpr, path) && + isSuffix(path, assignmentValueExpr) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 96ae91dbedc1..cbd24bd7bb29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -167,7 +167,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { assert(optimized.resolved) optimized match { - case MergeIntoTable(_, _, s: InSubquery, _, _, _, _, _) => + case MergeIntoTable(_, _, s: InSubquery, _, _, _, _) => val outerRefs = SubExprUtils.getOuterReferences(s.query.plan) assert(outerRefs.isEmpty, "should be no outer refs") case other => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 11f5d8aaf423..823cc02cc456 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -3510,6 +3510,41 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("Merge schema evolution should not evolve referencing new column via transform") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET extra=substring(s.extra, 1, 2) + |""".stripMargin + + + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + test("Merge schema evolution should not evolve if not directly referencing new column: update") { Seq(true, false).foreach { withSchemaEvolution => withTempView("source") { @@ -3617,6 +3652,40 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("Merge schema evolution should not evolve if not having just column name: update") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.extra = s.extra + |""".stripMargin + + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(exception.message.contains(" A column, variable, or function parameter with name " + + "`t`.`extra` cannot be resolved")) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + test("Merge schema evolution should only evolve referenced column when source " + "has multiple new columns") { Seq(true, false).foreach { withSchemaEvolution => @@ -3761,7 +3830,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("Merge schema evolution should evolve struct if directly referencing new field " + + test("Merge schema evolution should not evolve struct if not directly referencing new field " + "in top level struct: insert") { Seq(true, false).foreach { withSchemaEvolution => withTempView("source") { @@ -3801,28 +3870,19 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase |""".stripMargin sql(mergeStmt) - if (withSchemaEvolution) { - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(100, "active", null), "hr"), - Row(2, Row(200, "inactive", null), "software"), - Row(3, Row(250, "active", null), "marketing"))) - } else { - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(100, "active"), "hr"), - Row(2, Row(200, "inactive"), "software"), - Row(3, Row(250, "active"), "marketing"))) - } + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active"), "hr"), + Row(2, Row(200, "inactive"), "software"), + Row(3, Row(250, "active"), "marketing"))) sql(s"DROP TABLE $tableNameAsString") } } } test("Merge schema evolution should not evolve if not directly referencing new field " + - "in top level struct") { + "in top level struct: UPDATE") { Seq(true, false).foreach { withSchemaEvolution => withTempView("source") { createAndInitTable( @@ -3945,49 +4005,6 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("Merge schema evolution should evolve referencing new column via transform") { - Seq(true, false).foreach { withSchemaEvolution => - withTempView("source") { - createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", - """{ "pk": 1, "salary": 100, "dep": "hr" } - |{ "pk": 2, "salary": 200, "dep": "software" } - |""".stripMargin) - - val sourceDF = Seq((2, 150, "dummy", "blah"), - (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") - sourceDF.createOrReplaceTempView("source") - - val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - val mergeStmt = - s"""MERGE $schemaEvolutionClause - |INTO $tableNameAsString t - |USING source s - |ON t.pk = s.pk - |WHEN MATCHED THEN - | UPDATE SET dep='software', extra=substring(s.extra, 1, 2) - |""".stripMargin - - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, 100, "hr", null), - Row(2, 200, "software", "bl"))) - } else { - val e = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) - } - assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") - assert(e.getMessage.contains("A column, variable, or function parameter with name " + - "`extra` cannot be resolved")) - } - - sql(s"DROP TABLE $tableNameAsString") - } - } - } - test("Merge schema evolution should evolve referencing new column assigned to something else") { Seq(true, false).foreach { withSchemaEvolution => withTempView("source") { @@ -4007,25 +4024,15 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase |USING source s |ON t.pk = s.pk |WHEN MATCHED THEN - | UPDATE SET dep='software', extra=s.dep + | UPDATE SET extra=s.dep |""".stripMargin - if (withSchemaEvolution) { + val e = intercept[org.apache.spark.sql.AnalysisException] { sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, 100, "hr", null), - Row(2, 200, "software", "dummy"))) - } else { - val e = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) - } - assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") - assert(e.getMessage.contains("A column, variable, or function parameter with name " + - "`extra` cannot be resolved")) } - + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) sql(s"DROP TABLE $tableNameAsString") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 02024f26baf0..9ea8b9130ba8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1699,8 +1699,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(nul: AttributeReference, StringLiteral("update"))), notMatchedBySourceUpdateAssigns)), - withSchemaEvolution, - _) => + withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns) checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns) @@ -1732,8 +1731,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))), insertAssigns)), Seq(), - withSchemaEvolution, - _) => + withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns, starInUpdate = true) @@ -1760,8 +1758,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(UpdateAction(None, updateAssigns)), Seq(InsertAction(None, insertAssigns)), Seq(), - withSchemaEvolution, - _) => + withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, None, None, updateAssigns, starInUpdate = true) @@ -1792,8 +1789,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(InsertAction(None, insertAssigns)), Seq(DeleteAction(Some(EqualTo(_: AttributeReference, StringLiteral("delete")))), UpdateAction(None, notMatchedBySourceUpdateAssigns)), - withSchemaEvolution, - _) => + withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, None, None, updateAssigns) checkNotMatchedClausesResolution(target, source, None, insertAssigns) @@ -1830,8 +1826,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(nul: AttributeReference, StringLiteral("update"))), notMatchedBySourceUpdateAssigns)), - withSchemaEvolution, - _) => + withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns) checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns) @@ -1870,8 +1865,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(nul: AttributeReference, StringLiteral("update"))), notMatchedBySourceUpdateAssigns)), - withSchemaEvolution, - _) => + withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns) checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns) @@ -2157,8 +2151,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Some(EqualTo(il: AttributeReference, StringLiteral("a"))), insertAssigns)), Seq(DeleteAction(Some(_)), UpdateAction(None, secondUpdateAssigns)), - withSchemaEvolution, - _) => + withSchemaEvolution) => val ti = target.output.find(_.name == "i").get val ts = target.output.find(_.name == "s").get val si = source.output.find(_.name == "i").get @@ -2265,8 +2258,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(), Seq(), notMatchedBySourceActions, - withSchemaEvolution, - _) => + withSchemaEvolution) => assert(notMatchedBySourceActions.length == 2) notMatchedBySourceActions(0) match { case DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("b")))) => @@ -2341,8 +2333,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(UpdateAction(None, updateAssigns)), // Matched actions Seq(), // Not matched actions Seq(), // Not matched by source actions - withSchemaEvolution, - _) => + withSchemaEvolution) => val ti = target.output.find(_.name == "i").get val si = source.output.find(_.name == "i").get assert(updateAssigns.size == 1) @@ -2367,8 +2358,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(), // Matched action Seq(InsertAction(None, insertAssigns)), // Not matched actions Seq(), // Not matched by source actions - withSchemaEvolution, - _) => + withSchemaEvolution) => val ti = target.output.find(_.name == "i").get val si = source.output.find(_.name == "i").get assert(insertAssigns.size == 1) @@ -2452,8 +2442,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { Seq(DeleteAction(None)), Seq(InsertAction(None, insertAssigns)), Nil, - withSchemaEvolution, - _) => + withSchemaEvolution) => // There is only one assignment, the missing col is not filled with default value assert(insertAssigns.size == 1) // Special case: Spark does not resolve any columns in MERGE if table accepts any schema. From 14612656c56c712a10288ab4fbb3f260e962040c Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 6 Nov 2025 16:14:51 -0800 Subject: [PATCH 04/13] Add more tests --- .../connector/MergeIntoTableSuiteBase.scala | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 823cc02cc456..a3d6af981074 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -3930,6 +3930,123 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("Merge schema evolution should evolve when directly assigning struct with new field:" + + "UPDATE") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "updated", 50), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET info = s.info + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + // Schema should evolve - bonus field should be added + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(150, "updated", 50), "software"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.getMessage.contains("Cannot safely cast") || + exception.getMessage.contains("incompatible")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve when directly assigning struct with new field: " + + "INSERT") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(3, Row(150, "new", 50), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, info, dep) VALUES (s.pk, s.info, s.dep) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + // Schema should evolve - bonus field should be added + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(200, "inactive", null), "software"), + Row(3, Row(150, "new", 50), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.getMessage.contains("Cannot safely cast") || + exception.getMessage.contains("incompatible")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + test("Merge schema evolution should not evolve if not directly referencing " + "new field in nested struct") { Seq(true, false).foreach { withSchemaEvolution => From 451da3e535220ad7e37bcc1fb463c06c35eb2c8d Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 6 Nov 2025 19:28:36 -0800 Subject: [PATCH 05/13] minor cleanup --- .../spark/sql/catalyst/plans/logical/v2Commands.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index fe6e068690be..480a208dbe90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1057,11 +1057,8 @@ object MergeIntoTable { } // Helper method to check if a given field path is a suffix of another path. - private def isSuffix(prefix: Seq[String], path: Seq[String]): Boolean = - prefix.length <= path.length && prefix.reverse.zip(path.reverse).forall { - case (prefixNamePart, pathNamePart) => - SQLConf.get.resolver(prefixNamePart, pathNamePart) - } + private def isSuffix(suffix: Seq[String], path: Seq[String]): Boolean = + isPrefix(suffix.reverse, path.reverse) // Helper method to check if an assignment key is equal to a source column // and if the assignment value is the corresponding source column directly From cee88a24faded8d00b0e995a0dafa89c45c023a3 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 7 Nov 2025 09:34:05 -0800 Subject: [PATCH 06/13] review comment --- .../spark/sql/catalyst/plans/logical/v2Commands.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 480a208dbe90..ccbb9b75ac8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -860,8 +860,7 @@ case class MergeIntoTable( matchedActions: Seq[MergeAction], notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction], - withSchemaEvolution: Boolean) - extends BinaryCommand with SupportsSubquery { + withSchemaEvolution: Boolean) extends BinaryCommand with SupportsSubquery { lazy val aligned: Boolean = { val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions @@ -998,8 +997,10 @@ object MergeIntoTable { } // A pruned version of source schema that only contains columns/nested fields - // explicitly and directly assigned to a target counterpart in MERGE INTO actions. - // New columns/nested fields not existing in target will be added for schema evolution. + // explicitly and directly assigned to a target counterpart in MERGE INTO actions, + // which are relevant for schema evolution. + // New columns/nested fields in this schema that are not existing in target schema + // will be added for schema evolution. def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = { val actions = merge.matchedActions ++ merge.notMatchedActions From f23a985f78954bdd76233bf248149bdb84cd8937 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Sat, 8 Nov 2025 17:53:45 -0800 Subject: [PATCH 07/13] Fix attempt --- .../sql/catalyst/analysis/Analyzer.scala | 81 ++++++++++++++----- .../analysis/ColumnResolutionHelper.scala | 15 +++- .../ResolveMergeIntoSchemaEvolution.scala | 80 +++++++++++++++++- .../catalyst/plans/logical/v2Commands.scala | 65 ++++++++++++++- .../connector/MergeIntoTableSuiteBase.scala | 2 +- 5 files changed, 212 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 98c514925fa0..8c6b4f670201 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1670,7 +1670,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UpdateTable => resolveReferencesInUpdate(u) case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _) - if !m.resolved && targetTable.resolved && sourceTable.resolved && !m.needSchemaEvolution => + if !m.resolved && targetTable.resolved && sourceTable.resolved => + + // This rule is run again after schema evolution to re-resolve based on evolved schema + // Schema evolution requires all assignments with keys being non candidate columns + // to be resolved. + // The final run will throw exceptions if not all expressions are resolved + val finalResolution = m.allAssignmentsResolvedOrEvolutionCandidate EliminateSubqueryAliases(targetTable) match { case r: NamedRelation if r.skipSchemaResolution => @@ -1680,6 +1686,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor m case _ => + def findAttrInTarget(name: String): Option[Attribute] = { + targetTable.output.find(targetAttr => conf.resolver(name, targetAttr.name)) + } val newMatchedActions = m.matchedActions.map { case DeleteAction(deleteCondition) => val resolvedDeleteCondition = deleteCondition.map( @@ -1691,18 +1700,33 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor UpdateAction( resolvedUpdateCondition, // The update value can access columns from both target and source tables. - resolveAssignments(assignments, m, MergeResolvePolicy.BOTH)) + resolveAssignments(assignments, m, MergeResolvePolicy.BOTH, + throws = finalResolution)) case UpdateStarAction(updateCondition) => // Use only source columns. Missing columns in target will be handled in // ResolveRowLevelCommandAssignments. - val assignments = targetTable.output.flatMap{ targetAttr => - sourceTable.output.find( - sourceCol => conf.resolver(sourceCol.name, targetAttr.name)) - .map(Assignment(targetAttr, _))} + val assignments = if (m.schemaEvolutionEnabled) { + sourceTable.output.map(sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + .getOrElse(Assignment( + UnresolvedAttribute(sourceAttr.name), + sourceAttr))) + } else { + sourceTable.output.flatMap { sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + } + } + + // sourceTable.output.find( +// sourceCol => conf.resolver(sourceCol.name, targetAttr.name)) +// .map(Assignment(targetAttr, _))} UpdateAction( updateCondition.map(resolveExpressionByPlanChildren(_, m)), // For UPDATE *, the value must be from source table. - resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) + resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, + throws = finalResolution)) case o => o } val newNotMatchedActions = m.notMatchedActions.map { @@ -1713,7 +1737,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolveExpressionByPlanOutput(_, m.sourceTable)) InsertAction( resolvedInsertCondition, - resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) + resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, + throws = finalResolution)) case InsertStarAction(insertCondition) => // The insert action is used when not matched, so its condition and value can only // access columns from the source table. @@ -1721,13 +1746,23 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolveExpressionByPlanOutput(_, m.sourceTable)) // Use only source columns. Missing columns in target will be handled in // ResolveRowLevelCommandAssignments. - val assignments = targetTable.output.flatMap{ targetAttr => - sourceTable.output.find( - sourceCol => conf.resolver(sourceCol.name, targetAttr.name)) - .map(Assignment(targetAttr, _))} + val assignments = if (m.schemaEvolutionEnabled) { + sourceTable.output.map(sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + .getOrElse(Assignment( + UnresolvedAttribute(sourceAttr.name), + sourceAttr))) + } else { + sourceTable.output.flatMap { sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + } + } InsertAction( resolvedInsertCondition, - resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) + resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, + throws = finalResolution)) case o => o } val newNotMatchedBySourceActions = m.notMatchedBySourceActions.map { @@ -1741,7 +1776,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor UpdateAction( resolvedUpdateCondition, // The update value can access columns from the target table only. - resolveAssignments(assignments, m, MergeResolvePolicy.TARGET)) + resolveAssignments(assignments, m, MergeResolvePolicy.TARGET, + throws = finalResolution)) case o => o } @@ -1818,11 +1854,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor def resolveAssignments( assignments: Seq[Assignment], mergeInto: MergeIntoTable, - resolvePolicy: MergeResolvePolicy.Value): Seq[Assignment] = { + resolvePolicy: MergeResolvePolicy.Value, + throws: Boolean): Seq[Assignment] = { assignments.map { assign => val resolvedKey = assign.key match { case c if !c.resolved => - resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable)) + resolveMergeExpr(c, Project(Nil, mergeInto.targetTable), throws) case o => o } val resolvedValue = assign.value match { @@ -1842,7 +1879,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } else { resolvedExpr } - checkResolvedMergeExpr(withDefaultResolved, resolvePlan) + if (throws) { + checkResolvedMergeExpr(withDefaultResolved, resolvePlan) + } withDefaultResolved case o => o } @@ -1850,9 +1889,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } - private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = { - val resolved = resolveExprInAssignment(e, p) - checkResolvedMergeExpr(resolved, p) + private def resolveMergeExpr(e: Expression, p: LogicalPlan, throws: Boolean): Expression = { + val resolved = resolveExprInAssignment(e, p, throws) + if (throws) { + checkResolvedMergeExpr(resolved, p) + } resolved } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 53c92ca5425d..34541a8840cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -425,7 +425,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { def resolveExpressionByPlanChildren( e: Expression, q: LogicalPlan, - includeLastResort: Boolean = false): Expression = { + includeLastResort: Boolean = false, + throws: Boolean = true): Expression = { resolveExpression( tryResolveDataFrameColumns(e, q.children), resolveColumnByName = nameParts => { @@ -435,7 +436,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { assert(q.children.length == 1) q.children.head.output }, - throws = true, + throws, includeLastResort = includeLastResort) } @@ -475,8 +476,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { resolveVariables(resolveOuterRef(e)) } - def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = { - resolveExpressionByPlanChildren(expr, hostPlan) match { + def resolveExprInAssignment( + expr: Expression, + hostPlan: LogicalPlan, + throws: Boolean = true): Expression = { + resolveExpressionByPlanChildren(expr, + hostPlan, + includeLastResort = false, + throws = throws) match { // Assignment key and value does not need the alias when resolving nested columns. case Alias(child: ExtractValue, _) => child case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala index 23018b1edfe6..1e872c0d4b2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GetStructField} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -34,12 +35,85 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case m @ MergeIntoTable(_, _, _, _, _, _, _) - if m.needSchemaEvolution => + // This rule should run only if all assignments are resolved, except those + // that will be satisfied by schema evolution + case m @ MergeIntoTable(_, _, _, _, _, _, _) if m.needSchemaEvolution => val newTarget = m.targetTable.transform { case r : DataSourceV2Relation => performSchemaEvolution(r, m) } - m.copy(targetTable = newTarget) + + // Unresolve the merge condition and all assignments + val unresolvedMergeCondition = unresolveCondition(m.mergeCondition) + val unresolvedMatchedActions = unresolveActions(m.matchedActions) + val unresolvedNotMatchedActions = unresolveActions(m.notMatchedActions) + val unresolvedNotMatchedBySourceActions = + unresolveActions(m.notMatchedBySourceActions) + + m.copy( + targetTable = newTarget, + mergeCondition = unresolvedMergeCondition, + matchedActions = unresolvedMatchedActions, + notMatchedActions = unresolvedNotMatchedActions, + notMatchedBySourceActions = unresolvedNotMatchedBySourceActions) + } + + private def unresolveActions(actions: Seq[MergeAction]): Seq[MergeAction] = { + actions.map { + case UpdateAction(condition, assignments) => + UpdateAction(condition.map(unresolveCondition), unresolveAssignmentKeys(assignments)) + case InsertAction(condition, assignments) => + InsertAction(condition.map(unresolveCondition), unresolveAssignmentKeys(assignments)) + case DeleteAction(condition) => + DeleteAction(condition.map(unresolveCondition)) + case other => other + } + } + + private def unresolveCondition(expr: Expression): Expression = { + expr.transform { + case attr: AttributeReference => + val nameParts = if (attr.qualifier.nonEmpty) { + attr.qualifier ++ Seq(attr.name) + } else { + Seq(attr.name) + } + UnresolvedAttribute(nameParts) + } + } + + private def unresolveAssignmentKeys(assignments: Seq[Assignment]): Seq[Assignment] = { + assignments.map { assignment => + val unresolvedKey = assignment.key match { + case _: UnresolvedAttribute => assignment.key + case gsf: GetStructField => + // Recursively collect all nested GetStructField names and the base AttributeReference + val nameParts = collectStructFieldNames(gsf) + nameParts match { + case Some(names) => UnresolvedAttribute(names) + case None => assignment.key + } + case attr: AttributeReference => + UnresolvedAttribute(Seq(attr.name)) + case attr: Attribute => + UnresolvedAttribute(Seq(attr.name)) + case other => other + } + Assignment(unresolvedKey, assignment.value) + } + } + + private def collectStructFieldNames(expr: Expression): Option[Seq[String]] = { + expr match { + case GetStructField(child, _, Some(fieldName)) => + collectStructFieldNames(child) match { + case Some(childNames) => Some(childNames :+ fieldName) + case None => None + } + case attr: AttributeReference => + Some(Seq(attr.name)) + case _ => + None + } } private def performSchemaEvolution(relation: DataSourceV2Relation, m: MergeIntoTable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index ccbb9b75ac8d..0edea67f0465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -895,12 +895,31 @@ case class MergeIntoTable( private lazy val sourceSchemaForEvolution: StructType = MergeIntoTable.sourceSchemaForSchemaEvolution(this) - lazy val needSchemaEvolution: Boolean = { + lazy val needSchemaEvolution: Boolean = schemaEvolutionEnabled && - MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty + allAssignmentsResolvedOrEvolutionCandidate && + (MergeIntoTable.assignmentForEvolutionCandidate(this).nonEmpty || + MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty) + + lazy val allAssignmentsResolvedOrEvolutionCandidate: Boolean = { + if ((!targetTable.resolved) || (!sourceTable.resolved)) { + false + } else { + val actions = matchedActions ++ notMatchedActions + val assignments = actions.collect { + case a: UpdateAction => a.assignments + case a: InsertAction => a.assignments + }.flatten + + val matchingAssignments = MergeIntoTable.assignmentForEvolutionCandidate(this).toSet + + assignments.forall { assignment => + assignment.resolved || matchingAssignments.contains(assignment) + } + } } - private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { + def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { EliminateSubqueryAliases(targetTable) match { case r: DataSourceV2Relation if r.autoSchemaEvolution() => true case _ => false @@ -1041,6 +1060,46 @@ object MergeIntoTable { filterSchema(merge.sourceTable.schema, Seq.empty) } + /** + * Returns all assignments with keys that match exactly a source field path from + * sourceTable's schema. + */ + def assignmentForEvolutionCandidate(merge: MergeIntoTable): Seq[Assignment] = { + // Collect all assignments from merge actions + val actions = merge.matchedActions ++ merge.notMatchedActions + val assignments = actions.collect { + case a: UpdateAction => a.assignments + case a: InsertAction => a.assignments + }.flatten + + // Extract all field paths from source schema + def extractAllFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty): + Seq[Seq[String]] = { + schema.flatMap { field => + val fieldPath = basePath :+ field.name + field.dataType match { + case struct: StructType => + fieldPath +: extractAllFieldPaths(struct, fieldPath) + case _ => + Seq(fieldPath) + } + } + } + + val sourceFieldPaths = extractAllFieldPaths(merge.sourceTable.schema) + val targetFieldPaths = extractAllFieldPaths(merge.targetTable.schema) + val addedSourceFieldPaths = sourceFieldPaths.diff(targetFieldPaths) + + // Filter assignments whose key matches exactly a source field path + assignments.filter { assignment => + val keyPath = extractFieldPath(assignment.key) + addedSourceFieldPaths.exists { sourcePath => + keyPath.length == sourcePath.length && + isPrefix(keyPath, sourcePath) + } + } + } + // Helper method to extract field path from an Expression. private def extractFieldPath(expr: Expression): Seq[String] = expr match { case UnresolvedAttribute(nameParts) => nameParts diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index a3d6af981074..baa1a41fe43a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -2149,7 +2149,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("Merge schema evolution new column with set explicit column") { - Seq((true, true), (false, true), (true, false)).foreach { + Seq((true, true)).foreach { case (withSchemaEvolution, schemaEvolutionEnabled) => withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", From ea00e3c22bdd5f4140b9a51e5d73322596b7f664 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Sun, 9 Nov 2025 11:58:09 -0800 Subject: [PATCH 08/13] Simplify logic --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++--- .../catalyst/plans/logical/v2Commands.scala | 67 +++++++------------ .../connector/MergeIntoTableSuiteBase.scala | 58 +++++++++++++++- 3 files changed, 91 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8c6b4f670201..63eb62faac65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1672,11 +1672,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _) if !m.resolved && targetTable.resolved && sourceTable.resolved => - // This rule is run again after schema evolution to re-resolve based on evolved schema - // Schema evolution requires all assignments with keys being non candidate columns - // to be resolved. - // The final run will throw exceptions if not all expressions are resolved - val finalResolution = m.allAssignmentsResolvedOrEvolutionCandidate + // Do not throw exception for schema evolution case if it has not had a chance to run. + // This allows unresolved assignment keys a chance to be resolved by a second pass + // by newly column/nested fields added by schema evolution. + // If schema evolution has already had a chance to run, this will be the final pass + val throws = !m.schemaEvolutionEnabled || + (m.canEvaluateSchemaEvolution && !m.schemaChangesNonEmpty) EliminateSubqueryAliases(targetTable) match { case r: NamedRelation if r.skipSchemaResolution => @@ -1701,7 +1702,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolvedUpdateCondition, // The update value can access columns from both target and source tables. resolveAssignments(assignments, m, MergeResolvePolicy.BOTH, - throws = finalResolution)) + throws = throws)) case UpdateStarAction(updateCondition) => // Use only source columns. Missing columns in target will be handled in // ResolveRowLevelCommandAssignments. @@ -1726,7 +1727,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor updateCondition.map(resolveExpressionByPlanChildren(_, m)), // For UPDATE *, the value must be from source table. resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, - throws = finalResolution)) + throws = throws)) case o => o } val newNotMatchedActions = m.notMatchedActions.map { @@ -1738,7 +1739,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor InsertAction( resolvedInsertCondition, resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, - throws = finalResolution)) + throws = throws)) case InsertStarAction(insertCondition) => // The insert action is used when not matched, so its condition and value can only // access columns from the source table. @@ -1762,7 +1763,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor InsertAction( resolvedInsertCondition, resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, - throws = finalResolution)) + throws = throws)) case o => o } val newNotMatchedBySourceActions = m.notMatchedBySourceActions.map { @@ -1777,7 +1778,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolvedUpdateCondition, // The update value can access columns from the target table only. resolveAssignments(assignments, m, MergeResolvePolicy.TARGET, - throws = finalResolution)) + throws = throws)) case o => o } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 0edea67f0465..069ab8ead1a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -895,13 +895,18 @@ case class MergeIntoTable( private lazy val sourceSchemaForEvolution: StructType = MergeIntoTable.sourceSchemaForSchemaEvolution(this) + lazy val schemaChangesNonEmpty = + MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty + lazy val needSchemaEvolution: Boolean = schemaEvolutionEnabled && - allAssignmentsResolvedOrEvolutionCandidate && - (MergeIntoTable.assignmentForEvolutionCandidate(this).nonEmpty || - MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty) + canEvaluateSchemaEvolution && + schemaChangesNonEmpty - lazy val allAssignmentsResolvedOrEvolutionCandidate: Boolean = { + // Guard that assignments are resolved or candidates for evolution before evaluating schema + // evolution. We need to use resolved assignment values to check candidates, see + // MergeIntoTable.assignmentForEvolutionCandidate. + lazy val canEvaluateSchemaEvolution: Boolean = { if ((!targetTable.resolved) || (!sourceTable.resolved)) { false } else { @@ -911,13 +916,14 @@ case class MergeIntoTable( case a: InsertAction => a.assignments }.flatten - val matchingAssignments = MergeIntoTable.assignmentForEvolutionCandidate(this).toSet - + val evolutionPaths = MergeIntoTable.extractAllFieldPaths(sourceSchemaForEvolution) assignments.forall { assignment => - assignment.resolved || matchingAssignments.contains(assignment) + assignment.resolved || + evolutionPaths.exists { path => MergeIntoTable.isEqual(assignment, path) } + } } } - } + def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { EliminateSubqueryAliases(targetTable) match { @@ -1060,42 +1066,15 @@ object MergeIntoTable { filterSchema(merge.sourceTable.schema, Seq.empty) } - /** - * Returns all assignments with keys that match exactly a source field path from - * sourceTable's schema. - */ - def assignmentForEvolutionCandidate(merge: MergeIntoTable): Seq[Assignment] = { - // Collect all assignments from merge actions - val actions = merge.matchedActions ++ merge.notMatchedActions - val assignments = actions.collect { - case a: UpdateAction => a.assignments - case a: InsertAction => a.assignments - }.flatten - - // Extract all field paths from source schema - def extractAllFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty): - Seq[Seq[String]] = { - schema.flatMap { field => - val fieldPath = basePath :+ field.name - field.dataType match { - case struct: StructType => - fieldPath +: extractAllFieldPaths(struct, fieldPath) - case _ => - Seq(fieldPath) - } - } - } - - val sourceFieldPaths = extractAllFieldPaths(merge.sourceTable.schema) - val targetFieldPaths = extractAllFieldPaths(merge.targetTable.schema) - val addedSourceFieldPaths = sourceFieldPaths.diff(targetFieldPaths) - - // Filter assignments whose key matches exactly a source field path - assignments.filter { assignment => - val keyPath = extractFieldPath(assignment.key) - addedSourceFieldPaths.exists { sourcePath => - keyPath.length == sourcePath.length && - isPrefix(keyPath, sourcePath) + private def extractAllFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty): + Seq[Seq[String]] = { + schema.flatMap { field => + val fieldPath = basePath :+ field.name + field.dataType match { + case struct: StructType => + fieldPath +: extractAllFieldPaths(struct, fieldPath) + case _ => + Seq(fieldPath) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index baa1a41fe43a..d7af0d7cf06f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -2149,7 +2149,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("Merge schema evolution new column with set explicit column") { - Seq((true, true)).foreach { + Seq((true, true), (false, true), (true, false)).foreach { case (withSchemaEvolution, schemaEvolutionEnabled) => withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -4465,6 +4465,62 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(s"DROP TABLE IF EXISTS $tableNameAsString") } + test("Merge schema evolution should error on non-existent column in UPDATE and INSERT") { + withTable(tableNameAsString) { + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |salary INT, + |dep STRING""".stripMargin, + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + + val data = Seq( + Row(2, 250, "engineering"), + Row(3, 300, "finance") + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val updateException = intercept[AnalysisException] { + sql( + s"""MERGE WITH SCHEMA EVOLUTION + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET non_existent = s.nonexistent_column + |""".stripMargin) + } + assert(updateException.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(updateException.message.contains("A column, variable, or function parameter " + + "with name `non_existent` cannot be resolved")) + + val insertException = intercept[AnalysisException] { + sql( + s"""MERGE WITH SCHEMA EVOLUTION + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep, non_existent) VALUES (s.pk, s.salary, s.dep, s.dep) + |""".stripMargin) + } + assert(insertException.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(insertException.message.contains("A column, variable, or function parameter " + + "with name `non_existent` cannot be resolved")) + } + } + } + private def findMergeExec(query: String): MergeRowsExec = { val plan = executeAndKeepPlan { sql(query) From 4050f949e9e4466eb9a8da6f27852b737a62a977 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Sun, 9 Nov 2025 21:57:37 -0800 Subject: [PATCH 09/13] More cleanup --- .../sql/catalyst/analysis/Analyzer.scala | 16 ++--- .../catalyst/plans/logical/v2Commands.scala | 58 ++++++++++--------- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 63eb62faac65..866a099ed02b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1704,9 +1704,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolveAssignments(assignments, m, MergeResolvePolicy.BOTH, throws = throws)) case UpdateStarAction(updateCondition) => - // Use only source columns. Missing columns in target will be handled in - // ResolveRowLevelCommandAssignments. + // Expand star to top level source columns. If source has less columns than target, + // assignments will be added by ResolveRowLevelCommandAssignments later. val assignments = if (m.schemaEvolutionEnabled) { + // For schema evolution case, generate assignments for missing target columns. + // These columns will be added by ResolveMergeIntoTableSchemaEvolution later. sourceTable.output.map(sourceAttr => findAttrInTarget(sourceAttr.name).map( targetAttr => Assignment(targetAttr, sourceAttr)) @@ -1719,10 +1721,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor targetAttr => Assignment(targetAttr, sourceAttr)) } } - - // sourceTable.output.find( -// sourceCol => conf.resolver(sourceCol.name, targetAttr.name)) -// .map(Assignment(targetAttr, _))} UpdateAction( updateCondition.map(resolveExpressionByPlanChildren(_, m)), // For UPDATE *, the value must be from source table. @@ -1745,9 +1743,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // access columns from the source table. val resolvedInsertCondition = insertCondition.map( resolveExpressionByPlanOutput(_, m.sourceTable)) - // Use only source columns. Missing columns in target will be handled in - // ResolveRowLevelCommandAssignments. + // Expand star to top level source columns. If source has less columns than target, + // assignments will be added by ResolveRowLevelCommandAssignments later. val assignments = if (m.schemaEvolutionEnabled) { + // For schema evolution case, generate assignments for missing target columns. + // These columns will be added by ResolveMergeIntoTableSchemaEvolution later. sourceTable.output.map(sourceAttr => findAttrInTarget(sourceAttr.name).map( targetAttr => Assignment(targetAttr, sourceAttr)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 069ab8ead1a5..e48048452644 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -895,7 +895,7 @@ case class MergeIntoTable( private lazy val sourceSchemaForEvolution: StructType = MergeIntoTable.sourceSchemaForSchemaEvolution(this) - lazy val schemaChangesNonEmpty = + lazy val schemaChangesNonEmpty: Boolean = MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty lazy val needSchemaEvolution: Boolean = @@ -903,9 +903,9 @@ case class MergeIntoTable( canEvaluateSchemaEvolution && schemaChangesNonEmpty - // Guard that assignments are resolved or candidates for evolution before evaluating schema - // evolution. We need to use resolved assignment values to check candidates, see - // MergeIntoTable.assignmentForEvolutionCandidate. + // Guard that assignments are either resolved or candidates for evolution before + // evaluating schema evolution. We need to use resolved assignment values to check + // candidates, see MergeIntoTable.sourceSchemaForSchemaEvolution for details. lazy val canEvaluateSchemaEvolution: Boolean = { if ((!targetTable.resolved) || (!sourceTable.resolved)) { false @@ -916,10 +916,10 @@ case class MergeIntoTable( case a: InsertAction => a.assignments }.flatten - val evolutionPaths = MergeIntoTable.extractAllFieldPaths(sourceSchemaForEvolution) + val sourcePaths = MergeIntoTable.extractAllFieldPaths(sourceTable.schema) assignments.forall { assignment => assignment.resolved || - evolutionPaths.exists { path => MergeIntoTable.isEqual(assignment, path) } + sourcePaths.exists { path => MergeIntoTable.isEqual(assignment, path) } } } } @@ -978,11 +978,12 @@ object MergeIntoTable { case currentField: StructField if newFieldMap.contains(currentField.name) => schemaChanges(currentField.dataType, newFieldMap(currentField.name).dataType, originalTarget, originalSource, fieldPath ++ Seq(currentField.name)) - }}.flatten + } + }.flatten // Identify the newly added fields and append to the end val currentFieldMap = toFieldMap(currentFields) - val adds = newFields.filterNot (f => currentFieldMap.contains (f.name)) + val adds = newFields.filterNot(f => currentFieldMap.contains(f.name)) .map(f => TableChange.addColumn(fieldPath ++ Set(f.name), f.dataType)) updates ++ adds @@ -1024,10 +1025,13 @@ object MergeIntoTable { // A pruned version of source schema that only contains columns/nested fields // explicitly and directly assigned to a target counterpart in MERGE INTO actions, // which are relevant for schema evolution. + // Examples: + // * UPDATE SET target.a = source.a + // * UPDATE SET nested.a = source.nested.a + // * INSERT (a, nested.b) VALUES (source.a, source.nested.b) // New columns/nested fields in this schema that are not existing in target schema // will be added for schema evolution. def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = { - val actions = merge.matchedActions ++ merge.notMatchedActions val assignments = actions.collect { case a: UpdateAction => a.assignments @@ -1051,7 +1055,7 @@ object MergeIntoTable { // If this is a struct and one of the children is being assigned to in a merge clause, // keep it and continue filtering children. case struct: StructType if assignments.exists(assign => - isPrefix(fieldPath, extractFieldPath(assign.key))) => + isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) => Some(field.copy(dataType = filterSchema(struct, fieldPath))) // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* // clause. Check if it should be kept with any * action. @@ -1080,12 +1084,14 @@ object MergeIntoTable { } // Helper method to extract field path from an Expression. - private def extractFieldPath(expr: Expression): Seq[String] = expr match { - case UnresolvedAttribute(nameParts) => nameParts - case a: AttributeReference => Seq(a.name) - case GetStructField(child, ordinal, nameOpt) => - extractFieldPath(child) :+ nameOpt.getOrElse(s"col$ordinal") - case _ => Seq.empty + private def extractFieldPath(expr: Expression, allowUnresolved: Boolean): Seq[String] = { + expr match { + case UnresolvedAttribute(nameParts) if allowUnresolved => nameParts + case a: AttributeReference => Seq(a.name) + case GetStructField(child, ordinal, nameOpt) => + extractFieldPath(child, allowUnresolved) :+ nameOpt.getOrElse(s"col$ordinal") + case _ => Seq.empty + } } // Helper method to check if a given field path is a prefix of another path. @@ -1095,18 +1101,16 @@ object MergeIntoTable { SQLConf.get.resolver(prefixNamePart, pathNamePart) } - // Helper method to check if a given field path is a suffix of another path. - private def isSuffix(suffix: Seq[String], path: Seq[String]): Boolean = - isPrefix(suffix.reverse, path.reverse) - // Helper method to check if an assignment key is equal to a source column - // and if the assignment value is the corresponding source column directly - private def isEqual(assignment: Assignment, path: Seq[String]): Boolean = { - val assignmenKeyExpr = extractFieldPath(assignment.key) - val assignmentValueExpr = extractFieldPath(assignment.value) - // Valid assignments are: col = s.col or col.nestedField = s.col.nestedField - assignmenKeyExpr.length == path.length && isPrefix(assignmenKeyExpr, path) && - isSuffix(path, assignmentValueExpr) + // and if the assignment value is that same source column. + // Example: UPDATE SET target.a = source.a + private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = { + // key must be a non-qualified field path that may be added to target schema via evolution + val assignmenKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true) + // value should always be resolved (from source) + val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false) + assignmenKeyExpr == assignmentValueExpr && + assignmenKeyExpr == sourceFieldPath } } From 4ac5b0a1578da77170df0572319cd0747e8087e1 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 10 Nov 2025 00:37:04 -0800 Subject: [PATCH 10/13] More simplification --- .../ResolveMergeIntoSchemaEvolution.scala | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala index 1e872c0d4b2c..a0dd3160e723 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GetStructField} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, GetStructField} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -42,12 +42,13 @@ object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { case r : DataSourceV2Relation => performSchemaEvolution(r, m) } - // Unresolve the merge condition and all assignments - val unresolvedMergeCondition = unresolveCondition(m.mergeCondition) - val unresolvedMatchedActions = unresolveActions(m.matchedActions) - val unresolvedNotMatchedActions = unresolveActions(m.notMatchedActions) + // Unresolve all references based on old target output + val targetOutput = m.targetTable.output + val unresolvedMergeCondition = unresolveCondition(m.mergeCondition, targetOutput) + val unresolvedMatchedActions = unresolveActions(m.matchedActions, targetOutput) + val unresolvedNotMatchedActions = unresolveActions(m.notMatchedActions, targetOutput) val unresolvedNotMatchedBySourceActions = - unresolveActions(m.notMatchedBySourceActions) + unresolveActions(m.notMatchedBySourceActions, targetOutput) m.copy( targetTable = newTarget, @@ -57,21 +58,25 @@ object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { notMatchedBySourceActions = unresolvedNotMatchedBySourceActions) } - private def unresolveActions(actions: Seq[MergeAction]): Seq[MergeAction] = { + private def unresolveActions(actions: Seq[MergeAction], output: Seq[Attribute]): + Seq[MergeAction] = { actions.map { case UpdateAction(condition, assignments) => - UpdateAction(condition.map(unresolveCondition), unresolveAssignmentKeys(assignments)) + UpdateAction(condition.map(unresolveCondition(_, output)), + unresolveAssignmentKeys(assignments)) case InsertAction(condition, assignments) => - InsertAction(condition.map(unresolveCondition), unresolveAssignmentKeys(assignments)) + InsertAction(condition.map(unresolveCondition(_, output)), + unresolveAssignmentKeys(assignments)) case DeleteAction(condition) => - DeleteAction(condition.map(unresolveCondition)) + DeleteAction(condition.map(unresolveCondition(_, output))) case other => other } } - private def unresolveCondition(expr: Expression): Expression = { + private def unresolveCondition(expr: Expression, output: Seq[Attribute]): Expression = { + val outputSet = AttributeSet(output) expr.transform { - case attr: AttributeReference => + case attr: AttributeReference if outputSet.contains(attr) => val nameParts = if (attr.qualifier.nonEmpty) { attr.qualifier ++ Seq(attr.name) } else { From ebe84acdd0d638101b094542c7be6601ed888f66 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 10 Nov 2025 16:03:39 -0800 Subject: [PATCH 11/13] Review comments --- .../sql/catalyst/analysis/Analyzer.scala | 28 ++--- .../ResolveMergeIntoSchemaEvolution.scala | 108 ++++-------------- .../catalyst/plans/logical/v2Commands.scala | 29 ++--- 3 files changed, 48 insertions(+), 117 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 866a099ed02b..88cb2ded59c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1672,12 +1672,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _) if !m.resolved && targetTable.resolved && sourceTable.resolved => - // Do not throw exception for schema evolution case if it has not had a chance to run. + // Do not throw exception for schema evolution case. // This allows unresolved assignment keys a chance to be resolved by a second pass // by newly column/nested fields added by schema evolution. - // If schema evolution has already had a chance to run, this will be the final pass - val throws = !m.schemaEvolutionEnabled || - (m.canEvaluateSchemaEvolution && !m.schemaChangesNonEmpty) + val throws = !m.schemaEvolutionEnabled EliminateSubqueryAliases(targetTable) match { case r: NamedRelation if r.skipSchemaResolution => @@ -1709,12 +1707,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val assignments = if (m.schemaEvolutionEnabled) { // For schema evolution case, generate assignments for missing target columns. // These columns will be added by ResolveMergeIntoTableSchemaEvolution later. - sourceTable.output.map(sourceAttr => - findAttrInTarget(sourceAttr.name).map( - targetAttr => Assignment(targetAttr, sourceAttr)) - .getOrElse(Assignment( - UnresolvedAttribute(sourceAttr.name), - sourceAttr))) + sourceTable.output.map { sourceAttr => + val key = findAttrInTarget(sourceAttr.name).getOrElse( + UnresolvedAttribute(sourceAttr.name)) + Assignment(key, sourceAttr) + } } else { sourceTable.output.flatMap { sourceAttr => findAttrInTarget(sourceAttr.name).map( @@ -1748,12 +1745,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val assignments = if (m.schemaEvolutionEnabled) { // For schema evolution case, generate assignments for missing target columns. // These columns will be added by ResolveMergeIntoTableSchemaEvolution later. - sourceTable.output.map(sourceAttr => - findAttrInTarget(sourceAttr.name).map( - targetAttr => Assignment(targetAttr, sourceAttr)) - .getOrElse(Assignment( - UnresolvedAttribute(sourceAttr.name), - sourceAttr))) + sourceTable.output.map { sourceAttr => + val key = findAttrInTarget(sourceAttr.name).getOrElse( + UnresolvedAttribute(sourceAttr.name)) + Assignment(key, sourceAttr) + } } else { sourceTable.output.flatMap { sourceAttr => findAttrInTarget(sourceAttr.name).map( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala index a0dd3160e723..f317a2efddbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, GetStructField} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsRowLevelOperations, TableCatalog, TableChange} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.StructType /** @@ -37,97 +37,31 @@ object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // This rule should run only if all assignments are resolved, except those // that will be satisfied by schema evolution - case m @ MergeIntoTable(_, _, _, _, _, _, _) if m.needSchemaEvolution => - val newTarget = m.targetTable.transform { - case r : DataSourceV2Relation => performSchemaEvolution(r, m) + case m@MergeIntoTable(_, _, _, _, _, _, _) if m.evaluateSchemaEvolution => + val changes = m.changesForSchemaEvolution + if (changes.isEmpty) { + m + } else { + m transformUpWithNewOutput { + case r @ DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _, _) => + val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m) + val newTarget = performSchemaEvolution(r, referencedSourceSchema, changes) + val oldTargetOutput = m.targetTable.output + val newTargetOutput = newTarget.output + val attributeMapping = oldTargetOutput.map( + oldAttr => (oldAttr, newTargetOutput.find(_.name == oldAttr.name).getOrElse(oldAttr)) + ) + newTarget -> attributeMapping } - - // Unresolve all references based on old target output - val targetOutput = m.targetTable.output - val unresolvedMergeCondition = unresolveCondition(m.mergeCondition, targetOutput) - val unresolvedMatchedActions = unresolveActions(m.matchedActions, targetOutput) - val unresolvedNotMatchedActions = unresolveActions(m.notMatchedActions, targetOutput) - val unresolvedNotMatchedBySourceActions = - unresolveActions(m.notMatchedBySourceActions, targetOutput) - - m.copy( - targetTable = newTarget, - mergeCondition = unresolvedMergeCondition, - matchedActions = unresolvedMatchedActions, - notMatchedActions = unresolvedNotMatchedActions, - notMatchedBySourceActions = unresolvedNotMatchedBySourceActions) - } - - private def unresolveActions(actions: Seq[MergeAction], output: Seq[Attribute]): - Seq[MergeAction] = { - actions.map { - case UpdateAction(condition, assignments) => - UpdateAction(condition.map(unresolveCondition(_, output)), - unresolveAssignmentKeys(assignments)) - case InsertAction(condition, assignments) => - InsertAction(condition.map(unresolveCondition(_, output)), - unresolveAssignmentKeys(assignments)) - case DeleteAction(condition) => - DeleteAction(condition.map(unresolveCondition(_, output))) - case other => other - } - } - - private def unresolveCondition(expr: Expression, output: Seq[Attribute]): Expression = { - val outputSet = AttributeSet(output) - expr.transform { - case attr: AttributeReference if outputSet.contains(attr) => - val nameParts = if (attr.qualifier.nonEmpty) { - attr.qualifier ++ Seq(attr.name) - } else { - Seq(attr.name) - } - UnresolvedAttribute(nameParts) - } - } - - private def unresolveAssignmentKeys(assignments: Seq[Assignment]): Seq[Assignment] = { - assignments.map { assignment => - val unresolvedKey = assignment.key match { - case _: UnresolvedAttribute => assignment.key - case gsf: GetStructField => - // Recursively collect all nested GetStructField names and the base AttributeReference - val nameParts = collectStructFieldNames(gsf) - nameParts match { - case Some(names) => UnresolvedAttribute(names) - case None => assignment.key - } - case attr: AttributeReference => - UnresolvedAttribute(Seq(attr.name)) - case attr: Attribute => - UnresolvedAttribute(Seq(attr.name)) - case other => other } - Assignment(unresolvedKey, assignment.value) - } - } - - private def collectStructFieldNames(expr: Expression): Option[Seq[String]] = { - expr match { - case GetStructField(child, _, Some(fieldName)) => - collectStructFieldNames(child) match { - case Some(childNames) => Some(childNames :+ fieldName) - case None => None - } - case attr: AttributeReference => - Some(Seq(attr.name)) - case _ => - None - } } - private def performSchemaEvolution(relation: DataSourceV2Relation, m: MergeIntoTable) - : DataSourceV2Relation = { + private def performSchemaEvolution( + relation: DataSourceV2Relation, + referencedSourceSchema: StructType, + changes: Array[TableChange]): DataSourceV2Relation = { (relation.catalog, relation.identifier) match { case (Some(c: TableCatalog), Some(i)) => - val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m) - - val changes = MergeIntoTable.schemaChanges(relation.schema, referencedSourceSchema) c.alterTable(i, changes: _*) val newTable = c.loadTable(i) val newSchema = CatalogV2Util.v2ColumnsToStructType(newTable.columns()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index e48048452644..db5dacdcef38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -892,16 +892,19 @@ case class MergeIntoTable( case _ => false } - private lazy val sourceSchemaForEvolution: StructType = - MergeIntoTable.sourceSchemaForSchemaEvolution(this) - - lazy val schemaChangesNonEmpty: Boolean = - MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty - lazy val needSchemaEvolution: Boolean = + evaluateSchemaEvolution && changesForSchemaEvolution.nonEmpty + + lazy val evaluateSchemaEvolution: Boolean = schemaEvolutionEnabled && - canEvaluateSchemaEvolution && - schemaChangesNonEmpty + canEvaluateSchemaEvolution + + lazy val schemaEvolutionEnabled: Boolean = withSchemaEvolution && { + EliminateSubqueryAliases(targetTable) match { + case r: DataSourceV2Relation if r.autoSchemaEvolution() => true + case _ => false + } + } // Guard that assignments are either resolved or candidates for evolution before // evaluating schema evolution. We need to use resolved assignment values to check @@ -924,13 +927,11 @@ case class MergeIntoTable( } } + private lazy val sourceSchemaForEvolution: StructType = + MergeIntoTable.sourceSchemaForSchemaEvolution(this) - def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { - EliminateSubqueryAliases(targetTable) match { - case r: DataSourceV2Relation if r.autoSchemaEvolution() => true - case _ => false - } - } + lazy val changesForSchemaEvolution: Array[TableChange] = + MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution) override def left: LogicalPlan = targetTable override def right: LogicalPlan = sourceTable From 6c08b54cbe5edfc26e3d42a3d8e42812e6964226 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 10 Nov 2025 21:54:20 -0800 Subject: [PATCH 12/13] Fix test --- .../command/PlanResolutionSuite.scala | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 9ea8b9130ba8..dfd24a1ebe97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1633,10 +1633,10 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { if (starInUpdate) { assert(updateAssigns.size == 2) - assert(updateAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti)) - assert(updateAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si)) - assert(updateAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts)) - assert(updateAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss)) + assert(updateAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ts)) + assert(updateAssigns(0).value.asInstanceOf[AttributeReference].sameRef(ss)) + assert(updateAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ti)) + assert(updateAssigns(1).value.asInstanceOf[AttributeReference].sameRef(si)) } else { assert(updateAssigns.size == 1) assert(updateAssigns.head.key.asInstanceOf[AttributeReference].sameRef(ts)) @@ -1648,15 +1648,25 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { target: LogicalPlan, source: LogicalPlan, insertCondAttr: Option[AttributeReference], - insertAssigns: Seq[Assignment]): Unit = { + insertAssigns: Seq[Assignment], + starInInsert: Boolean = false): Unit = { val (si, ss) = getAttributes(source) val (ti, ts) = getAttributes(target) insertCondAttr.foreach(a => assert(a.sameRef(ss))) - assert(insertAssigns.size == 2) - assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti)) - assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si)) - assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts)) - assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss)) + + if (starInInsert) { + assert(insertAssigns.size == 2) + assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ts)) + assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(ss)) + assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ti)) + assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(si)) + } else { + assert(insertAssigns.size == 2) + assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti)) + assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si)) + assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts)) + assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss)) + } } def checkNotMatchedBySourceClausesResolution( @@ -1735,7 +1745,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns, starInUpdate = true) - checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns) + checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns, + starInInsert = true) assert(withSchemaEvolution === false) case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) @@ -1762,7 +1773,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, None, None, updateAssigns, starInUpdate = true) - checkNotMatchedClausesResolution(target, source, None, insertAssigns) + checkNotMatchedClausesResolution(target, source, None, insertAssigns, + starInInsert = true) assert(withSchemaEvolution === false) case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) From 02d8dc9967d8c90e41ddb206b276cd9785e32ddd Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 11 Nov 2025 13:37:27 -0800 Subject: [PATCH 13/13] Add a unit test to check that action conditions are properly resolve in schema evolutiond --- .../connector/MergeIntoTableSuiteBase.scala | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index d7af0d7cf06f..b73c8d2458c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -2206,6 +2206,118 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("Merge schema evolution new column with conditions on update and insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "marketing" } + |{ "pk": 5, "salary": 500, "dep": "executive" } + |""".stripMargin) + + // Two rows that could be updated (pk 4 and 5), but only one has salary > 450 + // Two rows that could be inserted (pk 6 and 7), but only one has active = true + val sourceDF = Seq((4, 450, "finance", false), + (5, 550, "finance", true), + (6, 350, "sales", true), + (7, 250, "sales", false)).toDF("pk", "salary", "dep", "active") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND s.salary > 450 THEN + | UPDATE SET dep='updated', active=s.active + |WHEN NOT MATCHED AND s.active = true THEN + | INSERT (pk, salary, dep, active) VALUES (s.pk, s.salary, s.dep, + | s.active) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 200, "software", null), + Row(3, 300, "hr", null), + Row(4, 400, "marketing", null), // pk=4 not updated (salary 450 is not > 450) + Row(5, 500, "updated", true), // pk=5 updated (salary 550 > 450) + Row(6, 350, "sales", true))) // pk=6 inserted (active = true) + // pk=7 not inserted (active = false) + } else { + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`active` cannot be resolved")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution with condition on new column from target") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "marketing" } + |{ "pk": 5, "salary": 500, "dep": "executive" } + |""".stripMargin) + + // Source has new 'active' column that doesn't exist in target + val sourceDF = Seq((4, 450, "finance", true), + (5, 550, "finance", false), + (6, 350, "sales", true)).toDF("pk", "salary", "dep", "active") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + // Condition references t.active which doesn't exist yet in target + val mergeStmt = s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND t.active IS NULL THEN + | UPDATE SET salary=s.salary, dep=s.dep, active=s.active + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep, active) + | VALUES (s.pk, s.salary, s.dep, s.active) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 200, "software", null), + Row(3, 300, "hr", null), + Row(4, 450, "finance", true), // Updated (t.active was NULL) + Row(5, 550, "finance", false), // Updated (t.active was NULL) + Row(6, 350, "sales", true))) // Inserted + } else { + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`active` cannot be resolved")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + test("Merge schema evolution new column with set all columns") { Seq((true, true), (false, true), (true, false)).foreach { case (withSchemaEvolution, schemaEvolutionEnabled) =>