Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE, RECURSE}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, IsNull, Literal}
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
Expand Down Expand Up @@ -182,31 +181,11 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) {
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
} else if (exactAssignments.nonEmpty) {
if (updateStar) {
val value = exactAssignments.head.value
col.dataType match {
case structType: StructType =>
// Expand assignments to leaf fields
val structAssignment =
applyNestedFieldAssignments(col, colExpr, value, addError, colPath,
coerceNestedTypes)

// Wrap with null check for missing source fields
fixNullExpansion(col, value, structType, structAssignment,
colPath, addError)
case _ =>
// For non-struct types, resolve directly
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath,
coerceMode)
}
} else {
val value = exactAssignments.head.value
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError,
colPath, coerceMode)
resolvedValue
}
val value = exactAssignments.head.value
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError,
colPath, coerceMode)
resolvedValue
} else {
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath, coerceNestedTypes)
}
Expand Down Expand Up @@ -240,63 +219,6 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
}
}

private def applyNestedFieldAssignments(
col: Attribute,
colExpr: Expression,
value: Expression,
addError: String => Unit,
colPath: Seq[String],
coerceNestedTyptes: Boolean): Expression = {

col.dataType match {
case structType: StructType =>
val fieldAttrs = DataTypeUtils.toAttributes(structType)

val updatedFieldExprs = fieldAttrs.zipWithIndex.map { case (fieldAttr, ordinal) =>
val fieldPath = colPath :+ fieldAttr.name
val targetFieldExpr = GetStructField(colExpr, ordinal, Some(fieldAttr.name))

// Try to find a corresponding field in the source value by name
val sourceFieldValue: Expression = value.dataType match {
case valueStructType: StructType =>
valueStructType.fields.find(f => conf.resolver(f.name, fieldAttr.name)) match {
case Some(matchingField) =>
// Found matching field in source, extract it
val fieldIndex = valueStructType.fieldIndex(matchingField.name)
GetStructField(value, fieldIndex, Some(matchingField.name))
case None =>
// Field doesn't exist in source, use target's current value with null check
TableOutputResolver.checkNullability(targetFieldExpr, fieldAttr, conf, fieldPath)
}
case _ =>
// Value is not a struct, cannot extract field
addError(s"Cannot assign non-struct value to struct field '${fieldPath.quoted}'")
Literal(null, fieldAttr.dataType)
}

// Recurse or resolve based on field type
fieldAttr.dataType match {
case nestedStructType: StructType =>
// Field is a struct, recurse
applyNestedFieldAssignments(fieldAttr, targetFieldExpr, sourceFieldValue,
addError, fieldPath, coerceNestedTyptes)
case _ =>
// Field is not a struct, resolve with TableOutputResolver
val coerceMode = if (coerceNestedTyptes) RECURSE else NONE
TableOutputResolver.resolveUpdate("", sourceFieldValue, fieldAttr, conf, addError,
fieldPath, coerceMode)
}
}
toNamedStruct(structType, updatedFieldExprs)

case otherType =>
addError(
"Updating nested fields is only supported for StructType but " +
s"'${colPath.quoted}' is of type $otherType")
colExpr
}
}

private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = {
val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) =>
Seq(Literal(field.name), expr)
Expand Down Expand Up @@ -350,55 +272,6 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
IsNull(currentExpr)
}

/**
* As UPDATE SET * can assign struct fields individually (preserving existing fields),
* this will lead to null expansion, ie, a struct is created where all fields are null.
* Wraps a struct assignment with null checks for the source and missing source fields.
* Return null if all are null.
*
* @param col the target column attribute
* @param value the source value expression
* @param structType the target struct type
* @param structAssignment the struct assignment result to wrap
* @param colPath the column path for error reporting
* @param addError error reporting function
* @return the wrapped expression with null checks
*/
private def fixNullExpansion(
col: Attribute,
value: Expression,
structType: StructType,
structAssignment: Expression,
colPath: Seq[String],
addError: String => Unit): Expression = {
// As StoreAssignmentPolicy.LEGACY is not allowed in DSv2, always add null check for
// non-nullable column
if (!col.nullable) {
AssertNotNull(value)
} else {
// Check if source struct is null
val valueIsNull = IsNull(value)

// Check if missing source paths (paths in target but not in source) are not null
// These will be null for the case of UPDATE SET * and
val missingSourcePaths = getMissingSourcePaths(structType, value.dataType, colPath, addError)
val condition = if (missingSourcePaths.nonEmpty) {
// Check if all target attributes at missing source paths are null
val missingFieldNullChecks = missingSourcePaths.map { path =>
createNullCheckForFieldPath(col, path)
}
// Combine all null checks with AND
val allMissingFieldsNull = missingFieldNullChecks.reduce[Expression]((a, b) => And(a, b))
And(valueIsNull, allMissingFieldsNull)
} else {
valueIsNull
}

// Return: If (condition) THEN NULL ELSE structAssignment
If(condition, Literal(null, structAssignment.dataType), structAssignment)
}
}

/**
* Checks whether assignments are aligned and compatible with table columns.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6699,10 +6699,11 @@ object SQLConf {
buildConf("spark.sql.merge.nested.type.coercion.enabled")
.internal()
.doc("If enabled, allow MERGE INTO to coerce source nested types if they have less" +
"nested fields than the target table's nested types.")
"nested fields than the target table's nested types. This is experimental and" +
"the semantics may change.")
.version("4.1.0")
.booleanConf
.createWithDefault(true)
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
Expand Down
Loading