Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1709,14 +1709,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
val resolvedDeleteCondition = deleteCondition.map(
resolveExpressionByPlanChildren(_, m))
DeleteAction(resolvedDeleteCondition)
case UpdateAction(updateCondition, assignments) =>
case UpdateAction(updateCondition, assignments, fromStar) =>
val resolvedUpdateCondition = updateCondition.map(
resolveExpressionByPlanChildren(_, m))
UpdateAction(
resolvedUpdateCondition,
// The update value can access columns from both target and source tables.
resolveAssignments(assignments, m, MergeResolvePolicy.BOTH,
throws = throws))
throws = throws),
fromStar)
case UpdateStarAction(updateCondition) =>
// Expand star to top level source columns. If source has less columns than target,
// assignments will be added by ResolveRowLevelCommandAssignments later.
Expand All @@ -1738,7 +1739,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
updateCondition.map(resolveExpressionByPlanChildren(_, m)),
// For UPDATE *, the value must be from source table.
resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE,
throws = throws))
throws = throws),
fromStar = true)
case o => o
}
val newNotMatchedActions = m.notMatchedActions.map {
Expand Down Expand Up @@ -1783,14 +1785,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
val resolvedDeleteCondition = deleteCondition.map(
resolveExpressionByPlanOutput(_, targetTable))
DeleteAction(resolvedDeleteCondition)
case UpdateAction(updateCondition, assignments) =>
case UpdateAction(updateCondition, assignments, fromStar) =>
val resolvedUpdateCondition = updateCondition.map(
resolveExpressionByPlanOutput(_, targetTable))
UpdateAction(
resolvedUpdateCondition,
// The update value can access columns from the target table only.
resolveAssignments(assignments, m, MergeResolvePolicy.TARGET,
throws = throws))
throws = throws),
fromStar)
case o => o
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE, RECURSE}
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._

Expand All @@ -50,13 +52,18 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
*
* @param attrs table attributes
* @param assignments assignments to align
* @param fromStar whether the assignments were resolved from an UPDATE SET * clause.
* These updates may assign struct fields individually
* (preserving existing fields).
* @param coerceNestedTypes whether to coerce nested types to match the target type
* for complex types
* @param missingSourcePaths paths that exist in target but not in source
* @return aligned update assignments that match table attributes
*/
def alignUpdateAssignments(
attrs: Seq[Attribute],
assignments: Seq[Assignment],
fromStar: Boolean,
coerceNestedTypes: Boolean): Seq[Assignment] = {

val errors = new mutable.ArrayBuffer[String]()
Expand All @@ -68,7 +75,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
assignments,
addError = err => errors += err,
colPath = Seq(attr.name),
coerceNestedTypes)
coerceNestedTypes,
fromStar)
}

if (errors.nonEmpty) {
Expand Down Expand Up @@ -152,7 +160,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
assignments: Seq[Assignment],
addError: String => Unit,
colPath: Seq[String],
coerceNestedTypes: Boolean = false): Expression = {
coerceNestedTypes: Boolean = false,
updateStar: Boolean = false): Expression = {

val (exactAssignments, otherAssignments) = assignments.partition { assignment =>
assignment.key.semanticEquals(colExpr)
Expand All @@ -174,9 +183,31 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) {
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
} else if (exactAssignments.nonEmpty) {
val value = exactAssignments.head.value
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath, coerceMode)
if (SQLConf.get.mergeUpdateStructsByField && updateStar) {
val value = exactAssignments.head.value
col.dataType match {
case structType: StructType =>
// Expand assignments to leaf fields
val structAssignment =
applyNestedFieldAssignments(col, colExpr, value, addError, colPath,
coerceNestedTypes)

// Wrap with null check for missing source fields
fixNullExpansion(col, value, structType, structAssignment,
colPath, addError)
case _ =>
// For non-struct types, resolve directly
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath,
coerceMode)
}
} else {
val value = exactAssignments.head.value
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError,
colPath, coerceMode)
resolvedValue
}
} else {
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath, coerceNestedTypes)
}
Expand Down Expand Up @@ -210,13 +241,165 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
}
}

private def applyNestedFieldAssignments(
Copy link
Member Author

@szehon-ho szehon-ho Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is like applyFieldAssignments above, but recurses into nested structs

col: Attribute,
colExpr: Expression,
value: Expression,
addError: String => Unit,
colPath: Seq[String],
coerceNestedTyptes: Boolean): Expression = {

col.dataType match {
case structType: StructType =>
val fieldAttrs = DataTypeUtils.toAttributes(structType)

val updatedFieldExprs = fieldAttrs.zipWithIndex.map { case (fieldAttr, ordinal) =>
val fieldPath = colPath :+ fieldAttr.name
val targetFieldExpr = GetStructField(colExpr, ordinal, Some(fieldAttr.name))

// Try to find a corresponding field in the source value by name
val sourceFieldValue: Expression = value.dataType match {
case valueStructType: StructType =>
valueStructType.fields.find(f => conf.resolver(f.name, fieldAttr.name)) match {
case Some(matchingField) =>
// Found matching field in source, extract it
val fieldIndex = valueStructType.fieldIndex(matchingField.name)
GetStructField(value, fieldIndex, Some(matchingField.name))
case None =>
// Field doesn't exist in source, use target's current value with null check
TableOutputResolver.checkNullability(targetFieldExpr, fieldAttr, conf, fieldPath)
}
case _ =>
// Value is not a struct, cannot extract field
addError(s"Cannot assign non-struct value to struct field '${fieldPath.quoted}'")
Literal(null, fieldAttr.dataType)
}

// Recurse or resolve based on field type
fieldAttr.dataType match {
case nestedStructType: StructType =>
// Field is a struct, recurse
applyNestedFieldAssignments(fieldAttr, targetFieldExpr, sourceFieldValue,
addError, fieldPath, coerceNestedTyptes)
case _ =>
// Field is not a struct, resolve with TableOutputResolver
val coerceMode = if (coerceNestedTyptes) RECURSE else NONE
TableOutputResolver.resolveUpdate("", sourceFieldValue, fieldAttr, conf, addError,
fieldPath, coerceMode)
}
}
toNamedStruct(structType, updatedFieldExprs)

case otherType =>
addError(
"Updating nested fields is only supported for StructType but " +
s"'${colPath.quoted}' is of type $otherType")
colExpr
}
}

private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = {
val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) =>
Seq(Literal(field.name), expr)
}.toImmutableArraySeq
CreateNamedStruct(namedStructExprs)
}

private def getMissingSourcePaths(targetType: StructType,
sourceType: DataType,
colPath: Seq[String],
addError: String => Unit): Seq[Seq[String]] = {
val nestedTargetPaths = DataTypeUtils.extractLeafFieldPaths(targetType, Seq.empty)
val nestedSourcePaths = sourceType match {
case sourceStructType: StructType =>
DataTypeUtils.extractLeafFieldPaths(sourceStructType, Seq.empty)
case _ =>
addError(s"Value for struct type: " +
s"${colPath.quoted} must be a struct but was ${sourceType.simpleString}")
Seq()
}
nestedSourcePaths.diff(nestedTargetPaths)
}

/**
* Creates a null check for a field at the given path within a struct expression.
* Navigates through the struct hierarchy following the path and returns an IsNull check
* for the final field.
*
* @param rootExpr the root expression to navigate from
* @param path the field path to navigate (sequence of field names)
* @return an IsNull expression checking if the field at the path is null
*/
private def createNullCheckForFieldPath(
rootExpr: Expression,
path: Seq[String]): Expression = {
var currentExpr: Expression = rootExpr
path.foreach { fieldName =>
currentExpr.dataType match {
case st: StructType =>
st.fields.find(f => conf.resolver(f.name, fieldName)) match {
case Some(field) =>
val fieldIndex = st.fieldIndex(field.name)
currentExpr = GetStructField(currentExpr, fieldIndex, Some(field.name))
case None =>
// Field not found, shouldn't happen
}
case _ =>
// Not a struct, shouldn't happen
}
}
IsNull(currentExpr)
}

/**
* As UPDATE SET * can assign struct fields individually (preserving existing fields),
* this will lead to null expansion, ie, a struct is created where all fields are null.
* Wraps a struct assignment with null checks for the source and missing source fields.
* Return null if all are null.
*
* @param col the target column attribute
* @param value the source value expression
* @param structType the target struct type
* @param structAssignment the struct assignment result to wrap
* @param colPath the column path for error reporting
* @param addError error reporting function
* @return the wrapped expression with null checks
*/
private def fixNullExpansion(
col: Attribute,
value: Expression,
structType: StructType,
structAssignment: Expression,
colPath: Seq[String],
addError: String => Unit): Expression = {
// As StoreAssignmentPolicy.LEGACY is not allowed in DSv2, always add null check for
// non-nullable column
if (!col.nullable) {
AssertNotNull(value)
} else {
// Check if source struct is null
val valueIsNull = IsNull(value)

// Check if missing source paths (paths in target but not in source) are not null
// These will be null for the case of UPDATE SET * and
val missingSourcePaths = getMissingSourcePaths(structType, value.dataType, colPath, addError)
val condition = if (missingSourcePaths.nonEmpty) {
// Check if all target attributes at missing source paths are null
val missingFieldNullChecks = missingSourcePaths.map { path =>
createNullCheckForFieldPath(col, path)
}
// Combine all null checks with AND
val allMissingFieldsNull = missingFieldNullChecks.reduce[Expression]((a, b) => And(a, b))
And(valueIsNull, allMissingFieldsNull)
} else {
valueIsNull
}

// Return: If (condition) THEN NULL ELSE structAssignment
If(condition, Literal(null, structAssignment.dataType), structAssignment)
}
}

/**
* Checks whether assignments are aligned and compatible with table columns.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
validateStoreAssignmentPolicy()
val newTable = cleanAttrMetadata(u.table)
val newAssignments = AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments,
coerceNestedTypes = false)
fromStar = false, coerceNestedTypes = false)
u.copy(table = newTable, assignments = newAssignments)

case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned =>
Expand All @@ -53,10 +53,11 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && m.rewritable && !m.aligned &&
!m.needSchemaEvolution =>
validateStoreAssignmentPolicy()
val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes
val coerceNestedTypes = SQLConf.get.mergeCoerceNestedTypes
m.copy(
targetTable = cleanAttrMetadata(m.targetTable),
matchedActions = alignActions(m.targetTable.output, m.matchedActions, coerceNestedTypes),
matchedActions = alignActions(m.targetTable.output, m.matchedActions,
coerceNestedTypes),
notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions,
coerceNestedTypes),
notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions,
Expand Down Expand Up @@ -117,9 +118,9 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
actions: Seq[MergeAction],
coerceNestedTypes: Boolean): Seq[MergeAction] = {
actions.map {
case u @ UpdateAction(_, assignments) =>
case u @ UpdateAction(_, assignments, fromStar) =>
u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs, assignments,
coerceNestedTypes))
fromStar, coerceNestedTypes))
case d: DeleteAction =>
d
case i @ InsertAction(_, assignments) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// original row ID values must be preserved and passed back to the table to encode updates
// if there are any assignments to row ID attributes, add extra columns for original values
val updateAssignments = (matchedActions ++ notMatchedBySourceActions).flatMap {
case UpdateAction(_, assignments) => assignments
case UpdateAction(_, assignments, _) => assignments
case _ => Nil
}
buildOriginalRowIdValues(rowIdAttrs, updateAssignments)
Expand Down Expand Up @@ -434,7 +434,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// converts a MERGE action into an instruction on top of the joined plan for group-based plans
private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = {
action match {
case UpdateAction(cond, assignments) =>
case UpdateAction(cond, assignments, _) =>
val rowValues = assignments.map(_.value)
val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues
Expand Down Expand Up @@ -466,12 +466,12 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
splitUpdates: Boolean): Instruction = {

action match {
case UpdateAction(cond, assignments) if splitUpdates =>
case UpdateAction(cond, assignments, _) if splitUpdates =>
val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues)
val otherOutput = deltaReinsertOutput(assignments, metadataAttrs, originalRowIdValues)
Split(cond.getOrElse(TrueLiteral), output, otherOutput)

case UpdateAction(cond, assignments) =>
case UpdateAction(cond, assignments, _) =>
val output = deltaUpdateOutput(assignments, metadataAttrs, originalRowIdValues)
Keep(Update, cond.getOrElse(TrueLiteral), output)

Expand All @@ -495,7 +495,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
val actions = merge.matchedActions ++ merge.notMatchedActions ++ merge.notMatchedBySourceActions
actions.foreach {
case DeleteAction(Some(cond)) => checkMergeIntoCondition("DELETE", cond)
case UpdateAction(Some(cond), _) => checkMergeIntoCondition("UPDATE", cond)
case UpdateAction(Some(cond), _, _) => checkMergeIntoCondition("UPDATE", cond)
case InsertAction(Some(cond), _) => checkMergeIntoCondition("INSERT", cond)
case _ => // OK
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {

private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = {
mergeActions.map {
case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond)))
case u @ UpdateAction(Some(cond), _, _) =>
u.copy(condition = Some(replaceNullWithFalse(cond)))
case u @ UpdateStarAction(Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond)))
Expand Down
Loading