Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,21 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] {
targetTable: LogicalPlan,
originalSource: StructType,
isByName: Boolean): Array[TableChange] = {
val onError: () => Nothing = () =>
throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError(
targetTable.schema, originalSource, null)
val candidateChanges = computeSchemaChanges(
targetTable.schema,
originalSource,
targetTable.schema,
originalSource,
fieldPath = Nil,
isByName)
isByName,
onError)
filterSupportedChanges(targetTable, candidateChanges)
}

def filterSupportedChanges(
targetTable: LogicalPlan,
candidateChanges: Array[TableChange]): Array[TableChange] = {
targetTable match {
case ExtractV2Table(t: SupportsSchemaEvolution) =>
candidateChanges.filter {
Expand All @@ -131,133 +139,86 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] {
}
}

private def computeSchemaChanges(
private[catalyst] def computeSchemaChanges(
currentType: DataType,
newType: DataType,
originalTarget: StructType,
originalSource: StructType,
fieldPath: List[String],
isByName: Boolean): Array[TableChange] = {
isByName: Boolean,
onError: () => Nothing): Array[TableChange] = {
(currentType, newType) match {
case (StructType(currentFields), StructType(newFields)) =>
if (isByName) {
computeSchemaChangesByName(
currentFields, newFields, originalTarget, originalSource, fieldPath)
computeSchemaChangesByName(currentFields, newFields, fieldPath, onError)
} else {
computeSchemaChangesByPosition(
currentFields, newFields, originalTarget, originalSource, fieldPath)
computeSchemaChangesByPosition(currentFields, newFields, fieldPath, onError)
}

case (ArrayType(currentElementType, _), ArrayType(newElementType, _)) =>
computeSchemaChanges(
currentElementType,
newElementType,
originalTarget,
originalSource,
fieldPath :+ "element",
isByName)
currentElementType, newElementType, fieldPath :+ "element", isByName, onError)

case (MapType(currentKeyType, currentValueType, _),
MapType(newKeyType, newValueType, _)) =>
val keyChanges = computeSchemaChanges(
currentKeyType,
newKeyType,
originalTarget,
originalSource,
fieldPath :+ "key",
isByName)
val valueChanges = computeSchemaChanges(
currentValueType,
newValueType,
originalTarget,
originalSource,
fieldPath :+ "value",
isByName)
keyChanges ++ valueChanges
computeSchemaChanges(
currentKeyType, newKeyType, fieldPath :+ "key", isByName, onError) ++
computeSchemaChanges(
currentValueType, newValueType, fieldPath :+ "value", isByName, onError)

case (currentType: AtomicType, newType: AtomicType) if currentType != newType =>
Array(TableChange.updateColumnType(fieldPath.toArray, newType))

case (currentType, newType) if currentType == newType =>
// No change needed
Array.empty[TableChange]

case (_, NullType) =>
// Don't try to change to NullType.
Array.empty[TableChange]

case (_: AtomicType | NullType, newType: AtomicType) =>
Array(TableChange.updateColumnType(fieldPath.toArray, newType))

case _ =>
// Do not support change between atomic and complex types for now
throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError(
originalTarget, originalSource, null)
onError()
}
}

/**
* Match fields by name: look up each target field in the source by name to collect schema
* differences. Nested struct fields are also matched by name.
*/
private def computeSchemaChangesByName(
currentFields: Array[StructField],
newFields: Array[StructField],
originalTarget: StructType,
originalSource: StructType,
fieldPath: List[String]): Array[TableChange] = {
fieldPath: List[String],
onError: () => Nothing): Array[TableChange] = {
val currentFieldMap = toFieldMap(currentFields)
val newFieldMap = toFieldMap(newFields)

// Collect field updates
val updates = currentFields
.filter(f => newFieldMap.contains(f.name))
.flatMap { f =>
computeSchemaChanges(
f.dataType,
newFieldMap(f.name).dataType,
originalTarget,
originalSource,
fieldPath :+ f.name,
isByName = true)
f.dataType, newFieldMap(f.name).dataType, fieldPath :+ f.name,
isByName = true, onError)
}

// Collect newly added fields
val adds = newFields
.filterNot(f => currentFieldMap.contains(f.name))
.map { f =>
// Make the type nullable, since existing rows in the table will have NULLs for this column.
TableChange.addColumn((fieldPath :+ f.name).toArray, f.dataType.asNullable)
}

updates ++ adds
}

/**
* Match fields by position: pair target and source fields in order to collect schema
* differences. Nested struct fields are also matched by position.
*/
private def computeSchemaChangesByPosition(
currentFields: Array[StructField],
newFields: Array[StructField],
originalTarget: StructType,
originalSource: StructType,
fieldPath: List[String]): Array[TableChange] = {
// Update existing field types by pairing fields at the same position.
fieldPath: List[String],
onError: () => Nothing): Array[TableChange] = {
val updates = currentFields.zip(newFields).flatMap { case (currentField, newField) =>
computeSchemaChanges(
currentField.dataType,
newField.dataType,
originalTarget,
originalSource,
fieldPath :+ currentField.name,
isByName = false)
currentField.dataType, newField.dataType, fieldPath :+ currentField.name,
isByName = false, onError)
}

// Extra source fields beyond the target's field count are new additions.
val adds = newFields.drop(currentFields.length)
.map { f =>
// Make the type nullable, since existing rows in the table will have NULLs for this column.
TableChange.addColumn((fieldPath :+ f.name).toArray, f.dataType.asNullable)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructType}
Expand Down Expand Up @@ -1041,9 +1042,8 @@ case class MergeIntoTable(

override lazy val pendingSchemaChanges: Seq[TableChange] = {
if (schemaEvolutionEnabled && schemaEvolutionReady) {
val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(this)
ResolveSchemaEvolution.computeSupportedSchemaChanges(
table, referencedSourceSchema, isByName = true).toSeq
val candidateChanges = MergeIntoTable.pendingSchemaChanges(this)
ResolveSchemaEvolution.filterSupportedChanges(table, candidateChanges.toArray).toSeq
} else {
Seq.empty
}
Expand Down Expand Up @@ -1097,112 +1097,101 @@ object MergeIntoTable {
.toSet
}

// A pruned version of source schema that only contains columns/nested fields
// explicitly and directly assigned to a target counterpart in MERGE INTO actions,
// which are relevant for schema evolution.
// Examples:
// * UPDATE SET target.a = source.a
// * UPDATE SET nested.a = source.nested.a
// * INSERT (a, nested.b) VALUES (source.a, source.nested.b)
// New columns/nested fields in this schema that are not existing in target schema
// will be added for schema evolution.
def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = {
private def pendingSchemaChanges(merge: MergeIntoTable): Seq[TableChange] = {
val onError: () => Nothing = () =>
throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError(
merge.targetTable.schema, merge.sourceTable.schema, null)

schemaEvolutionTriggeringAssignments(merge).flatMap {

// New column: the key didn't resolve against the target, so the column is missing.
// Applies to top-level fields (e.g. SET new_col = source.new_col) and nested fields
// where the leaf is missing (e.g. SET addr.zip = source.addr.zip).
case a @ Assignment(UnresolvedAttribute(fieldPath), _)
if !containsColumn(merge.targetTable, fieldPath) =>
Seq(TableChange.addColumn(fieldPath.toArray, a.value.dataType.asNullable))

// Type mismatch on an existing column: the key is resolved but the source type differs.
// For atomic types this produces an updateColumnType; for structs this recurses to
// find nested additions or type changes (e.g. SET addr = source.addr where source.addr
// has an extra field or a widened child type).
case a if a.key.dataType != a.value.dataType =>
ResolveSchemaEvolution.computeSchemaChanges(
a.key.dataType,
a.value.dataType,
fieldPath = extractFieldPath(a.key).toList,
isByName = true,
onError).toSeq

// Types already match - no schema change needed.
case _ => Seq.empty
}.distinct
}

// Schema evolution affects only fields referenced in MERGE INTO assignments.
// Candidate assignments are those in which the key is a direct assignment from the value,
// where the key is a (potentially missing) field in the target and the value is the
// same-named field in the source.
//
// Explicit assignment examples:
// UPDATE SET target.a = source.a
// UPDATE SET target.nested.a = source.nested.a
// INSERT (a, nested.b) VALUES (source.a, source.nested.b)
//
// Star actions (UPDATE SET * / INSERT *) also qualify because they will be resolved to
// per-column assignments, e.g.:
// UPDATE SET * => UPDATE SET target.a = source.a, target.b = source.b, ...
// INSERT * => INSERT (a, b, ...) VALUES (source.a, source.b, ...)
private def schemaEvolutionTriggeringAssignments(
merge: MergeIntoTable): Seq[Assignment] = {
val actions = merge.matchedActions ++ merge.notMatchedActions
val assignments = actions.collect {
val assignments = actions.flatMap {
case a: UpdateAction => a.assignments
case a: InsertAction => a.assignments
}.flatten

val containsStarAction = actions.exists {
case _: UpdateStarAction => true
case _: InsertStarAction => true
case _ => false
case _ => Seq.empty
}
assignments.filter(isSchemaEvolutionTrigger(_, merge.sourceTable))
}

def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType =
StructType(sourceSchema.flatMap { field =>
val fieldPath = basePath :+ field.name

field.dataType match {
// Specifically assigned to in one clause:
// always keep, including all nested attributes
case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field)
// If this is a struct and one of the children is being assigned to in a merge clause,
// keep it and continue filtering children.
case struct: StructType if assignments.exists(assign =>
isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) =>
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
// The field isn't assigned to directly or indirectly (i.e. its children) in any non-*
// clause. Check if it should be kept with any * action.
case struct: StructType if containsStarAction =>
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
case _ if containsStarAction => Some(field)
// The field and its children are not assigned to in any * or non-* action, drop it.
case _ => None
}
})

filterSchema(merge.sourceTable.schema, Seq.empty)
}

// Helper method to extract field path from an Expression.
private def extractFieldPath(expr: Expression, allowUnresolved: Boolean): Seq[String] = {
private def extractFieldPath(expr: Expression): Seq[String] = {
expr match {
case UnresolvedAttribute(nameParts) if allowUnresolved => nameParts
case UnresolvedAttribute(nameParts) => nameParts
case a: AttributeReference => Seq(a.name)
case Alias(child, _) => extractFieldPath(child, allowUnresolved)
case Alias(child, _) => extractFieldPath(child)
case GetStructField(child, ordinal, nameOpt) =>
extractFieldPath(child, allowUnresolved) :+ nameOpt.getOrElse(s"col$ordinal")
extractFieldPath(child) :+ nameOpt.getOrElse(s"col$ordinal")
case _ => Seq.empty
}
}

// Helper method to check if a given field path is a prefix of another path.
private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean =
prefix.length <= path.length && prefix.zip(path).forall {
case (prefixNamePart, pathNamePart) =>
SQLConf.get.resolver(prefixNamePart, pathNamePart)
}

// Helper method to check if an assignment key is equal to a source column
// and if the assignment value is that same source column.
// Example: UPDATE SET target.a = source.a
private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = {
// key must be a non-qualified field path that may be added to target schema via evolution
val assignmentKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true)
// value should always be resolved (from source)
val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false)
assignmentKeyExpr == assignmentValueExpr && assignmentKeyExpr == sourceFieldPath
private def containsColumn(table: LogicalPlan, fieldPath: Seq[String]): Boolean = {
table.schema.findNestedField(fieldPath, resolver = SQLConf.get.resolver).isDefined
}

private def areSchemaEvolutionReady(
assignments: Seq[Assignment],
source: LogicalPlan): Boolean = {
assignments.forall(assign => assign.resolved || isSchemaEvolutionCandidate(assign, source))
assignments.forall(assign => assign.resolved || isSchemaEvolutionTrigger(assign, source))
}

private def isSchemaEvolutionCandidate(assignment: Assignment, source: LogicalPlan): Boolean = {
assignment.value.resolved && isSameColumnAssignment(assignment, source)
}

// Helper method to check if an assignment key is equal to a source column
// and if the assignment value is that same source column.
// Checks if an assignment key maps to the same-named source column, meaning the
// assignment is a direct copy from source to target that may trigger schema evolution.
//
// Top-level example: UPDATE SET target.a = source.a
// key: AttributeReference("a", ...) -> path Seq("a")
// value: AttributeReference("a", ...) from source
// key: AttributeReference("a") or UnresolvedAttribute("a")
// value: AttributeReference("a") from source
//
// Nested example: UPDATE SET addr.city = source.addr.city
// key: GetStructField(GetStructField(AttributeReference("addr", ...), 0), 1)
// value: GetStructField(GetStructField(AttributeReference("addr", ...), 0), 1) from source
// Nested example: UPDATE SET target.addr.city = source.addr.city
// key: GetStructField(AttributeReference("addr"), ..., "city")
// value: GetStructField(AttributeReference("addr"), ..., "city") from source
//
// references contains only root attributes, so subsetOf(source.outputSet) works for both.
private def isSameColumnAssignment(assignment: Assignment, source: LogicalPlan): Boolean = {
// key must be a non-qualified field path that may be added to target schema via evolution
val keyPath = extractFieldPath(assignment.key, allowUnresolved = true)
// value should always be resolved (from source)
val valuePath = extractFieldPath(assignment.value, allowUnresolved = false)
keyPath == valuePath && assignment.value.references.subsetOf(source.outputSet)
// `references` contains only root attributes, so subsetOf(source.outputSet) works for both.
private def isSchemaEvolutionTrigger(assignment: Assignment, source: LogicalPlan): Boolean = {
assignment.value.resolved && {
val keyPath = extractFieldPath(assignment.key)
val valuePath = extractFieldPath(assignment.value)
keyPath == valuePath && assignment.value.references.subsetOf(source.outputSet)
}
}
}

Expand Down