From b62338ce85511e28b06914eb6401c2ca5971dc41 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 10 Apr 2026 16:36:06 -0700 Subject: [PATCH] [SPARK-56125][SQL] Refactor MERGE INTO schema evolution to use assignment-based approach --- .../analysis/ResolveSchemaEvolution.scala | 99 ++++------- .../catalyst/plans/logical/v2Commands.scala | 161 ++++++++---------- 2 files changed, 105 insertions(+), 155 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSchemaEvolution.scala index e312c1fc00b8..925c0afed464 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSchemaEvolution.scala @@ -108,13 +108,21 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] { targetTable: LogicalPlan, originalSource: StructType, isByName: Boolean): Array[TableChange] = { + val onError: () => Nothing = () => + throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError( + targetTable.schema, originalSource, null) val candidateChanges = computeSchemaChanges( - targetTable.schema, - originalSource, targetTable.schema, originalSource, fieldPath = Nil, - isByName) + isByName, + onError) + filterSupportedChanges(targetTable, candidateChanges) + } + + def filterSupportedChanges( + targetTable: LogicalPlan, + candidateChanges: Array[TableChange]): Array[TableChange] = { targetTable match { case ExtractV2Table(t: SupportsSchemaEvolution) => candidateChanges.filter { @@ -131,133 +139,86 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] { } } - private def computeSchemaChanges( + private[catalyst] def computeSchemaChanges( currentType: DataType, newType: DataType, - originalTarget: StructType, - originalSource: StructType, fieldPath: List[String], - isByName: Boolean): Array[TableChange] = { + isByName: Boolean, + onError: () => Nothing): Array[TableChange] = { (currentType, newType) match { case (StructType(currentFields), StructType(newFields)) => if (isByName) { - computeSchemaChangesByName( - currentFields, newFields, originalTarget, originalSource, fieldPath) + computeSchemaChangesByName(currentFields, newFields, fieldPath, onError) } else { - computeSchemaChangesByPosition( - currentFields, newFields, originalTarget, originalSource, fieldPath) + computeSchemaChangesByPosition(currentFields, newFields, fieldPath, onError) } case (ArrayType(currentElementType, _), ArrayType(newElementType, _)) => computeSchemaChanges( - currentElementType, - newElementType, - originalTarget, - originalSource, - fieldPath :+ "element", - isByName) + currentElementType, newElementType, fieldPath :+ "element", isByName, onError) case (MapType(currentKeyType, currentValueType, _), MapType(newKeyType, newValueType, _)) => - val keyChanges = computeSchemaChanges( - currentKeyType, - newKeyType, - originalTarget, - originalSource, - fieldPath :+ "key", - isByName) - val valueChanges = computeSchemaChanges( - currentValueType, - newValueType, - originalTarget, - originalSource, - fieldPath :+ "value", - isByName) - keyChanges ++ valueChanges + computeSchemaChanges( + currentKeyType, newKeyType, fieldPath :+ "key", isByName, onError) ++ + computeSchemaChanges( + currentValueType, newValueType, fieldPath :+ "value", isByName, onError) case (currentType: AtomicType, newType: AtomicType) if currentType != newType => Array(TableChange.updateColumnType(fieldPath.toArray, newType)) case (currentType, newType) if currentType == newType => - // No change needed Array.empty[TableChange] case (_, NullType) => - // Don't try to change to NullType. Array.empty[TableChange] case (_: AtomicType | NullType, newType: AtomicType) => Array(TableChange.updateColumnType(fieldPath.toArray, newType)) case _ => - // Do not support change between atomic and complex types for now - throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError( - originalTarget, originalSource, null) + onError() } } - /** - * Match fields by name: look up each target field in the source by name to collect schema - * differences. Nested struct fields are also matched by name. - */ private def computeSchemaChangesByName( currentFields: Array[StructField], newFields: Array[StructField], - originalTarget: StructType, - originalSource: StructType, - fieldPath: List[String]): Array[TableChange] = { + fieldPath: List[String], + onError: () => Nothing): Array[TableChange] = { val currentFieldMap = toFieldMap(currentFields) val newFieldMap = toFieldMap(newFields) - // Collect field updates val updates = currentFields .filter(f => newFieldMap.contains(f.name)) .flatMap { f => computeSchemaChanges( - f.dataType, - newFieldMap(f.name).dataType, - originalTarget, - originalSource, - fieldPath :+ f.name, - isByName = true) + f.dataType, newFieldMap(f.name).dataType, fieldPath :+ f.name, + isByName = true, onError) } - // Collect newly added fields val adds = newFields .filterNot(f => currentFieldMap.contains(f.name)) .map { f => - // Make the type nullable, since existing rows in the table will have NULLs for this column. TableChange.addColumn((fieldPath :+ f.name).toArray, f.dataType.asNullable) } updates ++ adds } - /** - * Match fields by position: pair target and source fields in order to collect schema - * differences. Nested struct fields are also matched by position. - */ private def computeSchemaChangesByPosition( currentFields: Array[StructField], newFields: Array[StructField], - originalTarget: StructType, - originalSource: StructType, - fieldPath: List[String]): Array[TableChange] = { - // Update existing field types by pairing fields at the same position. + fieldPath: List[String], + onError: () => Nothing): Array[TableChange] = { val updates = currentFields.zip(newFields).flatMap { case (currentField, newField) => computeSchemaChanges( - currentField.dataType, - newField.dataType, - originalTarget, - originalSource, - fieldPath :+ currentField.name, - isByName = false) + currentField.dataType, newField.dataType, fieldPath :+ currentField.name, + isByName = false, onError) } - // Extra source fields beyond the target's field count are new additions. val adds = newFields.drop(currentFields.length) .map { f => - // Make the type nullable, since existing rows in the table will have NULLs for this column. TableChange.addColumn((fieldPath :+ f.name).toArray, f.dataType.asNullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 500d648d23ac..0299f0b6fc72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write} import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE} import org.apache.spark.sql.errors.DataTypeErrors.toSQLType +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructType} @@ -1041,9 +1042,8 @@ case class MergeIntoTable( override lazy val pendingSchemaChanges: Seq[TableChange] = { if (schemaEvolutionEnabled && schemaEvolutionReady) { - val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(this) - ResolveSchemaEvolution.computeSupportedSchemaChanges( - table, referencedSourceSchema, isByName = true).toSeq + val candidateChanges = MergeIntoTable.pendingSchemaChanges(this) + ResolveSchemaEvolution.filterSupportedChanges(table, candidateChanges.toArray).toSeq } else { Seq.empty } @@ -1097,112 +1097,101 @@ object MergeIntoTable { .toSet } - // A pruned version of source schema that only contains columns/nested fields - // explicitly and directly assigned to a target counterpart in MERGE INTO actions, - // which are relevant for schema evolution. - // Examples: - // * UPDATE SET target.a = source.a - // * UPDATE SET nested.a = source.nested.a - // * INSERT (a, nested.b) VALUES (source.a, source.nested.b) - // New columns/nested fields in this schema that are not existing in target schema - // will be added for schema evolution. - def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = { + private def pendingSchemaChanges(merge: MergeIntoTable): Seq[TableChange] = { + val onError: () => Nothing = () => + throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError( + merge.targetTable.schema, merge.sourceTable.schema, null) + + schemaEvolutionTriggeringAssignments(merge).flatMap { + + // New column: the key didn't resolve against the target, so the column is missing. + // Applies to top-level fields (e.g. SET new_col = source.new_col) and nested fields + // where the leaf is missing (e.g. SET addr.zip = source.addr.zip). + case a @ Assignment(UnresolvedAttribute(fieldPath), _) + if !containsColumn(merge.targetTable, fieldPath) => + Seq(TableChange.addColumn(fieldPath.toArray, a.value.dataType.asNullable)) + + // Type mismatch on an existing column: the key is resolved but the source type differs. + // For atomic types this produces an updateColumnType; for structs this recurses to + // find nested additions or type changes (e.g. SET addr = source.addr where source.addr + // has an extra field or a widened child type). + case a if a.key.dataType != a.value.dataType => + ResolveSchemaEvolution.computeSchemaChanges( + a.key.dataType, + a.value.dataType, + fieldPath = extractFieldPath(a.key).toList, + isByName = true, + onError).toSeq + + // Types already match - no schema change needed. + case _ => Seq.empty + }.distinct + } + + // Schema evolution affects only fields referenced in MERGE INTO assignments. + // Candidate assignments are those in which the key is a direct assignment from the value, + // where the key is a (potentially missing) field in the target and the value is the + // same-named field in the source. + // + // Explicit assignment examples: + // UPDATE SET target.a = source.a + // UPDATE SET target.nested.a = source.nested.a + // INSERT (a, nested.b) VALUES (source.a, source.nested.b) + // + // Star actions (UPDATE SET * / INSERT *) also qualify because they will be resolved to + // per-column assignments, e.g.: + // UPDATE SET * => UPDATE SET target.a = source.a, target.b = source.b, ... + // INSERT * => INSERT (a, b, ...) VALUES (source.a, source.b, ...) + private def schemaEvolutionTriggeringAssignments( + merge: MergeIntoTable): Seq[Assignment] = { val actions = merge.matchedActions ++ merge.notMatchedActions - val assignments = actions.collect { + val assignments = actions.flatMap { case a: UpdateAction => a.assignments case a: InsertAction => a.assignments - }.flatten - - val containsStarAction = actions.exists { - case _: UpdateStarAction => true - case _: InsertStarAction => true - case _ => false + case _ => Seq.empty } + assignments.filter(isSchemaEvolutionTrigger(_, merge.sourceTable)) + } - def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType = - StructType(sourceSchema.flatMap { field => - val fieldPath = basePath :+ field.name - - field.dataType match { - // Specifically assigned to in one clause: - // always keep, including all nested attributes - case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field) - // If this is a struct and one of the children is being assigned to in a merge clause, - // keep it and continue filtering children. - case struct: StructType if assignments.exists(assign => - isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) => - Some(field.copy(dataType = filterSchema(struct, fieldPath))) - // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* - // clause. Check if it should be kept with any * action. - case struct: StructType if containsStarAction => - Some(field.copy(dataType = filterSchema(struct, fieldPath))) - case _ if containsStarAction => Some(field) - // The field and its children are not assigned to in any * or non-* action, drop it. - case _ => None - } - }) - - filterSchema(merge.sourceTable.schema, Seq.empty) - } - - // Helper method to extract field path from an Expression. - private def extractFieldPath(expr: Expression, allowUnresolved: Boolean): Seq[String] = { + private def extractFieldPath(expr: Expression): Seq[String] = { expr match { - case UnresolvedAttribute(nameParts) if allowUnresolved => nameParts + case UnresolvedAttribute(nameParts) => nameParts case a: AttributeReference => Seq(a.name) - case Alias(child, _) => extractFieldPath(child, allowUnresolved) + case Alias(child, _) => extractFieldPath(child) case GetStructField(child, ordinal, nameOpt) => - extractFieldPath(child, allowUnresolved) :+ nameOpt.getOrElse(s"col$ordinal") + extractFieldPath(child) :+ nameOpt.getOrElse(s"col$ordinal") case _ => Seq.empty } } - // Helper method to check if a given field path is a prefix of another path. - private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean = - prefix.length <= path.length && prefix.zip(path).forall { - case (prefixNamePart, pathNamePart) => - SQLConf.get.resolver(prefixNamePart, pathNamePart) - } - - // Helper method to check if an assignment key is equal to a source column - // and if the assignment value is that same source column. - // Example: UPDATE SET target.a = source.a - private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = { - // key must be a non-qualified field path that may be added to target schema via evolution - val assignmentKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true) - // value should always be resolved (from source) - val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false) - assignmentKeyExpr == assignmentValueExpr && assignmentKeyExpr == sourceFieldPath + private def containsColumn(table: LogicalPlan, fieldPath: Seq[String]): Boolean = { + table.schema.findNestedField(fieldPath, resolver = SQLConf.get.resolver).isDefined } private def areSchemaEvolutionReady( assignments: Seq[Assignment], source: LogicalPlan): Boolean = { - assignments.forall(assign => assign.resolved || isSchemaEvolutionCandidate(assign, source)) + assignments.forall(assign => assign.resolved || isSchemaEvolutionTrigger(assign, source)) } - private def isSchemaEvolutionCandidate(assignment: Assignment, source: LogicalPlan): Boolean = { - assignment.value.resolved && isSameColumnAssignment(assignment, source) - } - - // Helper method to check if an assignment key is equal to a source column - // and if the assignment value is that same source column. + // Checks if an assignment key maps to the same-named source column, meaning the + // assignment is a direct copy from source to target that may trigger schema evolution. // // Top-level example: UPDATE SET target.a = source.a - // key: AttributeReference("a", ...) -> path Seq("a") - // value: AttributeReference("a", ...) from source + // key: AttributeReference("a") or UnresolvedAttribute("a") + // value: AttributeReference("a") from source // - // Nested example: UPDATE SET addr.city = source.addr.city - // key: GetStructField(GetStructField(AttributeReference("addr", ...), 0), 1) - // value: GetStructField(GetStructField(AttributeReference("addr", ...), 0), 1) from source + // Nested example: UPDATE SET target.addr.city = source.addr.city + // key: GetStructField(AttributeReference("addr"), ..., "city") + // value: GetStructField(AttributeReference("addr"), ..., "city") from source // - // references contains only root attributes, so subsetOf(source.outputSet) works for both. - private def isSameColumnAssignment(assignment: Assignment, source: LogicalPlan): Boolean = { - // key must be a non-qualified field path that may be added to target schema via evolution - val keyPath = extractFieldPath(assignment.key, allowUnresolved = true) - // value should always be resolved (from source) - val valuePath = extractFieldPath(assignment.value, allowUnresolved = false) - keyPath == valuePath && assignment.value.references.subsetOf(source.outputSet) + // `references` contains only root attributes, so subsetOf(source.outputSet) works for both. + private def isSchemaEvolutionTrigger(assignment: Assignment, source: LogicalPlan): Boolean = { + assignment.value.resolved && { + val keyPath = extractFieldPath(assignment.key) + val valuePath = extractFieldPath(assignment.value) + keyPath == valuePath && assignment.value.references.subsetOf(source.outputSet) + } } }