diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6b0665c1b7f3..fb8a84a85fc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1709,14 +1709,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val resolvedDeleteCondition = deleteCondition.map( resolveExpressionByPlanChildren(_, m)) DeleteAction(resolvedDeleteCondition) - case UpdateAction(updateCondition, assignments) => + case UpdateAction(updateCondition, assignments, fromStar) => val resolvedUpdateCondition = updateCondition.map( resolveExpressionByPlanChildren(_, m)) UpdateAction( resolvedUpdateCondition, // The update value can access columns from both target and source tables. resolveAssignments(assignments, m, MergeResolvePolicy.BOTH, - throws = throws)) + throws = throws), + fromStar) case UpdateStarAction(updateCondition) => // Expand star to top level source columns. If source has less columns than target, // assignments will be added by ResolveRowLevelCommandAssignments later. @@ -1738,7 +1739,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor updateCondition.map(resolveExpressionByPlanChildren(_, m)), // For UPDATE *, the value must be from source table. resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, - throws = throws)) + throws = throws), + fromStar = true) case o => o } val newNotMatchedActions = m.notMatchedActions.map { @@ -1783,14 +1785,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val resolvedDeleteCondition = deleteCondition.map( resolveExpressionByPlanOutput(_, targetTable)) DeleteAction(resolvedDeleteCondition) - case UpdateAction(updateCondition, assignments) => + case UpdateAction(updateCondition, assignments, fromStar) => val resolvedUpdateCondition = updateCondition.map( resolveExpressionByPlanOutput(_, targetTable)) UpdateAction( resolvedUpdateCondition, // The update value can access columns from the target table only. resolveAssignments(assignments, m, MergeResolvePolicy.TARGET, - throws = throws)) + throws = throws), + fromStar) case o => o } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala index 145c9077a4c2..6cbc17c67381 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala @@ -21,13 +21,15 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE, RECURSE} -import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.Assignment import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -50,13 +52,18 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { * * @param attrs table attributes * @param assignments assignments to align + * @param fromStar whether the assignments were resolved from an UPDATE SET * clause. + * These updates may assign struct fields individually + * (preserving existing fields). * @param coerceNestedTypes whether to coerce nested types to match the target type * for complex types + * @param missingSourcePaths paths that exist in target but not in source * @return aligned update assignments that match table attributes */ def alignUpdateAssignments( attrs: Seq[Attribute], assignments: Seq[Assignment], + fromStar: Boolean, coerceNestedTypes: Boolean): Seq[Assignment] = { val errors = new mutable.ArrayBuffer[String]() @@ -68,7 +75,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { assignments, addError = err => errors += err, colPath = Seq(attr.name), - coerceNestedTypes) + coerceNestedTypes, + fromStar) } if (errors.nonEmpty) { @@ -152,7 +160,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { assignments: Seq[Assignment], addError: String => Unit, colPath: Seq[String], - coerceNestedTypes: Boolean = false): Expression = { + coerceNestedTypes: Boolean = false, + updateStar: Boolean = false): Expression = { val (exactAssignments, otherAssignments) = assignments.partition { assignment => assignment.key.semanticEquals(colExpr) @@ -174,9 +183,31 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { } else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) { TableOutputResolver.checkNullability(colExpr, col, conf, colPath) } else if (exactAssignments.nonEmpty) { - val value = exactAssignments.head.value - val coerceMode = if (coerceNestedTypes) RECURSE else NONE - TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath, coerceMode) + if (SQLConf.get.mergeUpdateStructsByField && updateStar) { + val value = exactAssignments.head.value + col.dataType match { + case structType: StructType => + // Expand assignments to leaf fields + val structAssignment = + applyNestedFieldAssignments(col, colExpr, value, addError, colPath, + coerceNestedTypes) + + // Wrap with null check for missing source fields + fixNullExpansion(col, value, structType, structAssignment, + colPath, addError) + case _ => + // For non-struct types, resolve directly + val coerceMode = if (coerceNestedTypes) RECURSE else NONE + TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath, + coerceMode) + } + } else { + val value = exactAssignments.head.value + val coerceMode = if (coerceNestedTypes) RECURSE else NONE + val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError, + colPath, coerceMode) + resolvedValue + } } else { applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath, coerceNestedTypes) } @@ -210,6 +241,63 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { } } + private def applyNestedFieldAssignments( + col: Attribute, + colExpr: Expression, + value: Expression, + addError: String => Unit, + colPath: Seq[String], + coerceNestedTyptes: Boolean): Expression = { + + col.dataType match { + case structType: StructType => + val fieldAttrs = DataTypeUtils.toAttributes(structType) + + val updatedFieldExprs = fieldAttrs.zipWithIndex.map { case (fieldAttr, ordinal) => + val fieldPath = colPath :+ fieldAttr.name + val targetFieldExpr = GetStructField(colExpr, ordinal, Some(fieldAttr.name)) + + // Try to find a corresponding field in the source value by name + val sourceFieldValue: Expression = value.dataType match { + case valueStructType: StructType => + valueStructType.fields.find(f => conf.resolver(f.name, fieldAttr.name)) match { + case Some(matchingField) => + // Found matching field in source, extract it + val fieldIndex = valueStructType.fieldIndex(matchingField.name) + GetStructField(value, fieldIndex, Some(matchingField.name)) + case None => + // Field doesn't exist in source, use target's current value with null check + TableOutputResolver.checkNullability(targetFieldExpr, fieldAttr, conf, fieldPath) + } + case _ => + // Value is not a struct, cannot extract field + addError(s"Cannot assign non-struct value to struct field '${fieldPath.quoted}'") + Literal(null, fieldAttr.dataType) + } + + // Recurse or resolve based on field type + fieldAttr.dataType match { + case nestedStructType: StructType => + // Field is a struct, recurse + applyNestedFieldAssignments(fieldAttr, targetFieldExpr, sourceFieldValue, + addError, fieldPath, coerceNestedTyptes) + case _ => + // Field is not a struct, resolve with TableOutputResolver + val coerceMode = if (coerceNestedTyptes) RECURSE else NONE + TableOutputResolver.resolveUpdate("", sourceFieldValue, fieldAttr, conf, addError, + fieldPath, coerceMode) + } + } + toNamedStruct(structType, updatedFieldExprs) + + case otherType => + addError( + "Updating nested fields is only supported for StructType but " + + s"'${colPath.quoted}' is of type $otherType") + colExpr + } + } + private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = { val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) => Seq(Literal(field.name), expr) @@ -217,6 +305,101 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { CreateNamedStruct(namedStructExprs) } + private def getMissingSourcePaths(targetType: StructType, + sourceType: DataType, + colPath: Seq[String], + addError: String => Unit): Seq[Seq[String]] = { + val nestedTargetPaths = DataTypeUtils.extractLeafFieldPaths(targetType, Seq.empty) + val nestedSourcePaths = sourceType match { + case sourceStructType: StructType => + DataTypeUtils.extractLeafFieldPaths(sourceStructType, Seq.empty) + case _ => + addError(s"Value for struct type: " + + s"${colPath.quoted} must be a struct but was ${sourceType.simpleString}") + Seq() + } + nestedSourcePaths.diff(nestedTargetPaths) + } + + /** + * Creates a null check for a field at the given path within a struct expression. + * Navigates through the struct hierarchy following the path and returns an IsNull check + * for the final field. + * + * @param rootExpr the root expression to navigate from + * @param path the field path to navigate (sequence of field names) + * @return an IsNull expression checking if the field at the path is null + */ + private def createNullCheckForFieldPath( + rootExpr: Expression, + path: Seq[String]): Expression = { + var currentExpr: Expression = rootExpr + path.foreach { fieldName => + currentExpr.dataType match { + case st: StructType => + st.fields.find(f => conf.resolver(f.name, fieldName)) match { + case Some(field) => + val fieldIndex = st.fieldIndex(field.name) + currentExpr = GetStructField(currentExpr, fieldIndex, Some(field.name)) + case None => + // Field not found, shouldn't happen + } + case _ => + // Not a struct, shouldn't happen + } + } + IsNull(currentExpr) + } + + /** + * As UPDATE SET * can assign struct fields individually (preserving existing fields), + * this will lead to null expansion, ie, a struct is created where all fields are null. + * Wraps a struct assignment with null checks for the source and missing source fields. + * Return null if all are null. + * + * @param col the target column attribute + * @param value the source value expression + * @param structType the target struct type + * @param structAssignment the struct assignment result to wrap + * @param colPath the column path for error reporting + * @param addError error reporting function + * @return the wrapped expression with null checks + */ + private def fixNullExpansion( + col: Attribute, + value: Expression, + structType: StructType, + structAssignment: Expression, + colPath: Seq[String], + addError: String => Unit): Expression = { + // As StoreAssignmentPolicy.LEGACY is not allowed in DSv2, always add null check for + // non-nullable column + if (!col.nullable) { + AssertNotNull(value) + } else { + // Check if source struct is null + val valueIsNull = IsNull(value) + + // Check if missing source paths (paths in target but not in source) are not null + // These will be null for the case of UPDATE SET * and + val missingSourcePaths = getMissingSourcePaths(structType, value.dataType, colPath, addError) + val condition = if (missingSourcePaths.nonEmpty) { + // Check if all target attributes at missing source paths are null + val missingFieldNullChecks = missingSourcePaths.map { path => + createNullCheckForFieldPath(col, path) + } + // Combine all null checks with AND + val allMissingFieldsNull = missingFieldNullChecks.reduce[Expression]((a, b) => And(a, b)) + And(valueIsNull, allMissingFieldsNull) + } else { + valueIsNull + } + + // Return: If (condition) THEN NULL ELSE structAssignment + If(condition, Literal(null, structAssignment.dataType), structAssignment) + } + } + /** * Checks whether assignments are aligned and compatible with table columns. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala index 3eb528954b35..93ef98e3183a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala @@ -44,7 +44,7 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { validateStoreAssignmentPolicy() val newTable = cleanAttrMetadata(u.table) val newAssignments = AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments, - coerceNestedTypes = false) + fromStar = false, coerceNestedTypes = false) u.copy(table = newTable, assignments = newAssignments) case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned => @@ -53,10 +53,11 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && m.rewritable && !m.aligned && !m.needSchemaEvolution => validateStoreAssignmentPolicy() - val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes + val coerceNestedTypes = SQLConf.get.mergeCoerceNestedTypes m.copy( targetTable = cleanAttrMetadata(m.targetTable), - matchedActions = alignActions(m.targetTable.output, m.matchedActions, coerceNestedTypes), + matchedActions = alignActions(m.targetTable.output, m.matchedActions, + coerceNestedTypes), notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions, coerceNestedTypes), notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions, @@ -117,9 +118,9 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { actions: Seq[MergeAction], coerceNestedTypes: Boolean): Seq[MergeAction] = { actions.map { - case u @ UpdateAction(_, assignments) => + case u @ UpdateAction(_, assignments, fromStar) => u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs, assignments, - coerceNestedTypes)) + fromStar, coerceNestedTypes)) case d: DeleteAction => d case i @ InsertAction(_, assignments) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index 8b5b690aa740..1d2e2fef2096 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -334,7 +334,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper // original row ID values must be preserved and passed back to the table to encode updates // if there are any assignments to row ID attributes, add extra columns for original values val updateAssignments = (matchedActions ++ notMatchedBySourceActions).flatMap { - case UpdateAction(_, assignments) => assignments + case UpdateAction(_, assignments, _) => assignments case _ => Nil } buildOriginalRowIdValues(rowIdAttrs, updateAssignments) @@ -434,7 +434,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper // converts a MERGE action into an instruction on top of the joined plan for group-based plans private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = { action match { - case UpdateAction(cond, assignments) => + case UpdateAction(cond, assignments, _) => val rowValues = assignments.map(_.value) val metadataValues = nullifyMetadataOnUpdate(metadataAttrs) val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues @@ -466,12 +466,12 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper splitUpdates: Boolean): Instruction = { action match { - case UpdateAction(cond, assignments) if splitUpdates => + case UpdateAction(cond, assignments, _) if splitUpdates => val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues) val otherOutput = deltaReinsertOutput(assignments, metadataAttrs, originalRowIdValues) Split(cond.getOrElse(TrueLiteral), output, otherOutput) - case UpdateAction(cond, assignments) => + case UpdateAction(cond, assignments, _) => val output = deltaUpdateOutput(assignments, metadataAttrs, originalRowIdValues) Keep(Update, cond.getOrElse(TrueLiteral), output) @@ -495,7 +495,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper val actions = merge.matchedActions ++ merge.notMatchedActions ++ merge.notMatchedBySourceActions actions.foreach { case DeleteAction(Some(cond)) => checkMergeIntoCondition("DELETE", cond) - case UpdateAction(Some(cond), _) => checkMergeIntoCondition("UPDATE", cond) + case UpdateAction(Some(cond), _, _) => checkMergeIntoCondition("UPDATE", cond) case InsertAction(Some(cond), _) => checkMergeIntoCondition("INSERT", cond) case _ => // OK } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 7134c3daf3ba..9a676571d107 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -149,7 +149,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = { mergeActions.map { - case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond))) + case u @ UpdateAction(Some(cond), _, _) => + u.copy(condition = Some(replaceNullWithFalse(cond))) case u @ UpdateStarAction(Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index dcce22040244..26ce138523e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -873,7 +873,7 @@ case class MergeIntoTable( lazy val aligned: Boolean = { val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions actions.forall { - case UpdateAction(_, assignments) => + case UpdateAction(_, assignments, _) => AssignmentUtils.aligned(targetTable.output, assignments) case _: DeleteAction => true @@ -926,10 +926,7 @@ case class MergeIntoTable( case a: UpdateAction => a.assignments case a: InsertAction => a.assignments }.flatten - - val sourcePaths = MergeIntoTable.extractAllFieldPaths(sourceTable.schema) - // Only allow unresolved assignment keys to be candidates for schema evolution - // if they are directly assigned from source fields, ie UPDATE SET new = source.new + val sourcePaths = DataTypeUtils.extractAllFieldPaths(sourceTable.schema) assignments.forall { assignment => assignment.resolved || (assignment.value.resolved && sourcePaths.exists { @@ -1083,19 +1080,6 @@ object MergeIntoTable { filterSchema(merge.sourceTable.schema, Seq.empty) } - private def extractAllFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty): - Seq[Seq[String]] = { - schema.flatMap { field => - val fieldPath = basePath :+ field.name - field.dataType match { - case struct: StructType => - fieldPath +: extractAllFieldPaths(struct, fieldPath) - case _ => - Seq(fieldPath) - } - } - } - // Helper method to extract field path from an Expression. private def extractFieldPath(expr: Expression, allowUnresolved: Boolean): Seq[String] = { expr match { @@ -1142,7 +1126,8 @@ case class DeleteAction(condition: Option[Expression]) extends MergeAction { case class UpdateAction( condition: Option[Expression], - assignments: Seq[Assignment]) extends MergeAction { + assignments: Seq[Assignment], + fromStar: Boolean = false) extends MergeAction { override def children: Seq[Expression] = condition.toSeq ++ assignments override protected def withNewChildrenInternal( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala index c6e51aab4584..e7bd5bd1aa2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala @@ -249,5 +249,38 @@ object DataTypeUtils { case v: Long => fromDecimal(Decimal(BigDecimal(v))) case _ => forType(literal.dataType) } + + /** + * Extracts all struct field paths from a nested StructType. + */ + def extractAllFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty): + Seq[Seq[String]] = { + schema.flatMap { field => + val fieldPath = basePath :+ field.name + field.dataType match { + case struct: StructType => + fieldPath +: extractAllFieldPaths(struct, fieldPath) + case _ => + Seq(fieldPath) + } + } + } + + /** + * Extracts only leaf-level field paths from a nested StructType. + * Unlike extractAllFieldPaths, this method does not include intermediate struct paths. + */ + def extractLeafFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty): + Seq[Seq[String]] = { + schema.flatMap { field => + val fieldPath = basePath :+ field.name + field.dataType match { + case struct: StructType => + extractLeafFieldPaths(struct, fieldPath) + case _ => + Seq(fieldPath) + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index efd34a84cfdc..b9623566f648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -6686,8 +6686,8 @@ object SQLConf { .booleanConf .createWithDefault(true) - val MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED = - buildConf("spark.sql.merge.source.nested.type.coercion.enabled") + val MERGE_INTO_NESTED_TYPE_COERCION_ENABLED = + buildConf("spark.sql.merge.nested.type.coercion.enabled") .internal() .doc("If enabled, allow MERGE INTO to coerce source nested types if they have less" + "nested fields than the target table's nested types.") @@ -6695,6 +6695,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD = + buildConf("spark.sql.merge.nested.type.assign.by.field") + .internal() + .doc("If enabled and spark.sql.merge.source.nested.type.coercion.enabled is true," + + "allow MERGE INTO with UPDATE SET * action to set nested structs field by field. " + + "In updated rows, target structs will preserve the original value for fields missing " + + "in the the source struct. If disabled, the entire target struct will be replaced, " + + "and fields missing in the source struct will be null.") + .version("4.1.0") + .booleanConf + .createWithDefault(true) + /** * Holds information about keys that have been deprecated. * @@ -7892,8 +7904,11 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyXMLParserEnabled: Boolean = getConf(SQLConf.LEGACY_XML_PARSER_ENABLED) - def coerceMergeNestedTypes: Boolean = - getConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED) + def mergeCoerceNestedTypes: Boolean = + getConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED) + + def mergeUpdateStructsByField: Boolean = + getConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 85b0faed4c38..7051a0b455e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -3231,60 +3231,180 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase test("merge into schema evolution replace column with nested struct and set all columns") { Seq(true, false).foreach { withSchemaEvolution => - withTempView("source") { - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING""".stripMargin, - """{ "pk": 1, "s": { "c1": 2, "c2": { "a": [1,2], "m": { "a": "b" } } }, "dep": "hr" }""") + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString) { + withTempView("source") { + // Create table using Spark SQL + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - // missing column 'a' - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) // new column - ))) - ))), - StructField("dep", StringType) - )) - val data = Seq( - Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + // Insert data using DataFrame API with objects + val tableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // missing column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + StructField("dep", StringType) + )) + val sourceData = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - val mergeStmt = - s"""MERGE $schemaEvolutionClause - |INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + if (withSchemaEvolution) { + sql(mergeStmt) + if (updateByFields) { + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } + } + } + + test("merge into schema evolution replace column with nested struct and update " + + "top level struct") { + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString) { + withTempView("source") { + // Create table using Spark SQL + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) + + // Insert data using DataFrame API with objects + val tableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // missing column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + StructField("dep", StringType) + )) + val sourceData = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET s = src.s + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } @@ -4343,50 +4463,6 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("merge into with source missing fields in top-level struct") { - withTempView("source") { - // Target table has struct with 3 fields at top level - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT, - |dep STRING""".stripMargin, - """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") - - // Source table has struct with only 2 fields (c1, c2) - missing c3 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType)))), // missing c3 field - StructField("dep", StringType))) - val data = Seq( - Row(1, Row(10, "b"), "hr"), - Row(2, Row(20, "c"), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") - - sql( - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) - - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, "a", true), "sales"), - Row(1, Row(10, "b", null), "hr"), - Row(2, Row(20, "c", null), "engineering"))) - } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") - } - test("merge into with source missing fields in struct nested in array") { withTempView("source") { // Target table has struct with 3 fields (c1, c2, c3) in array @@ -4540,22 +4616,459 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - test("merge into with source missing fields in nested struct") { - Seq(true, false).foreach { nestedTypeCoercion => - withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key - -> nestedTypeCoercion.toString) { + test("merge into with source missing fields in top-level struct") { + withTempView("source") { + // Target table has struct with 3 fields at top level + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") + + // Source table has struct with only 2 fields (c1, c2) - missing c3 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), // missing c3 field + StructField("dep", StringType))) + val data = Seq( + Row(1, Row(10, "b"), "hr"), + Row(2, Row(20, "c"), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a", true), "sales"), + Row(1, Row(10, "b", null), "hr"), + Row(2, Row(20, "c", null), "engineering"))) + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + + test("merge with null struct") { + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString) { + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": "a" }, "dep": "sales" } + |{ "pk": 1, "s": { "c1": 2, "c2": "b" }, "dep": "hr" }""" + .stripMargin) + + // Source table matches target table schema + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType) + ))), + StructField("dep", StringType) + )) + + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a"), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } + + test("merge with null struct - update field") { + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString) { + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": "a" }, "dep": "sales" } + |{ "pk": 1, "s": { "c1": 2, "c2": "b" }, "dep": "hr" }""" + .stripMargin) + + // Source table matches target table schema + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType) + ))), + StructField("dep", StringType) + )) + + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET s = source.s + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a"), "sales"), + Row(1, null, "hr"), + Row(2, null, "finance"))) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } + + test("merge with null struct into non-nullable struct column") { + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString) { withTempView("source") { - // Target table has nested struct: s.c1, s.c2.a, s.c2.b createAndInitTable( s"""pk INT NOT NULL, - |s STRUCT>, + |s STRUCT NOT NULL, |dep STRING""".stripMargin, - """{ "pk": 1, "s": { "c1": 2, "c2": { "a": 10, "b": true } } } - |{ "pk": 2, "s": { "c1": 2, "c2": { "a": 30, "b": false } } }""".stripMargin) + """{ "pk": 0, "s": { "c1": 1, "c2": "a" }, "dep": "sales" } + |{ "pk": 1, "s": { "c1": 2, "c2": "b" }, "dep": "hr" }""" + .stripMargin) - // Source table is missing field 'b' in nested struct s.c2 + // Source table has null for the struct column val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType) + ))), + StructField("dep", StringType) + )) + + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + // Should throw an exception when trying to insert/update null into NOT NULL column + val exception = intercept[Exception] { + sql( + s"""MERGE INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + } + assert(exception.getMessage.contains( + "NULL value appeared in non-nullable field")) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } + + test("merge with with null struct with missing nested field") { + Seq(true, false).foreach { updateByFields => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf( + SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString, + SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has nested struct with fields c1 and c2 + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT>, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep": "sales" } + |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep": "hr" }""" + .stripMargin) + + // Source table has null for the nested struct + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + // missing field 'b' + ))) + ))), + StructField("dep", StringType) + )) + + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val mergeStmt = + s"""MERGE INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + + if (coerceNestedTypes) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } else { + // Without coercion, the merge should fail due to missing field + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `s`.`c2`.`b`.")) + } + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } + } + + test("merge null struct with schema evolution - source with missing and extra nested fields") { + Seq(true, false).foreach { updateByFields => + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf( + SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString, + SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has nested struct with fields c1 and c2 + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT>, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep": "sales" } + |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep": "hr" }""" + .stripMargin) + + // Source table has missing field 'b' and extra field 'c' in nested struct + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + // missing field 'b' + StructField("c", StringType) // extra field 'c' + ))) + ))), + StructField("dep", StringType) + )) + + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + + if (coerceNestedTypes) { + if (withSchemaEvolution) { + // extra nested field is added + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x", null)), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } else { + // extra nested field is not added + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot write extra fields `c` to the struct `s`.`c2`")) + } + } else { + // Without source struct coercion, the merge should fail + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `s`.`c2`.`b`.")) + } + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } + } + } + + test("merge null struct with non-nullable nested field - source with missing " + + "and extra nested fields") { + + withSQLConf( + SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> "true", + SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> "true") { + withTempView("source") { + // Target table has nested struct with NON-NULLABLE field b + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT>, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep": "sales" } + |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep": "hr" }""" + .stripMargin) + + // Source table has missing field 'b' and extra field 'c' in nested struct + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + // missing field 'b' (which is non-nullable in target) + StructField("c", StringType) // extra field 'c' + ))) + ))), + StructField("dep", StringType) + )) + + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val mergeStmt = + s"""MERGE WITH SCHEMA EVOLUTION + |INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + + // All cases should fail due to non-nullable constraint violation + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains("Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `s`.`c2`.`b`.")) + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } + + test("merge with null struct using default value") { + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString) { + withTempView("source") { + // Target table has nested struct with a default value + sql( + s"""CREATE TABLE $tableNameAsString ( + | pk INT NOT NULL, + | s STRUCT> DEFAULT + | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b', 'default')), + | dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) + + // Insert initial data using DataFrame API + val initialSchema = StructType(Seq( StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType) + ))) + ))), + StructField("dep", StringType) + )) + val initialData = Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, Row(2, Row(20, "y")), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(initialData), initialSchema) + .writeTo(tableNameAsString).append() + + // Source table has null for the nested struct + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), StructField("s", StructType(Seq( StructField("c1", IntegerType), StructField("c2", StructType(Seq( @@ -4565,45 +5078,179 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase ))), StructField("dep", StringType) )) + val data = Seq( - Row(1, Row(10, Row(20)), "sales"), - Row(2, Row(20, Row(30)), "engineering") + Row(1, null, "engineering"), + Row(2, null, "finance") ) spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) .createOrReplaceTempView("source") - // Missing field b should be filled with NULL - val mergeStmt = s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + sql( + s"""MERGE INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } - if (nestedTypeCoercion) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(20, null)), "sales"), - Row(2, Row(20, Row(30, null)), "engineering"))) - } else { - val exception = intercept[Exception] { + test("merge with source missing struct column with default value") { + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString) { + withTempView("source") { + // Target table has nested struct with a default value + sql( + s"""CREATE TABLE $tableNameAsString ( + | pk INT NOT NULL, + | s STRUCT> DEFAULT + | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b', 'default')), + | dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) + + // Insert initial data using DataFrame API + val initialSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType) + ))) + ))), + StructField("dep", StringType) + )) + val initialData = Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, Row(2, Row(20, "y")), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(initialData), initialSchema) + .writeTo(tableNameAsString).append() + + // Source table is completely missing the struct column 's' + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("dep", StringType) + )) + + val data = Seq( + Row(1, "engineering"), + Row(2, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + // When inserting without specifying the struct column, default should be used + sql( + s"""MERGE INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET dep = source.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, dep) VALUES (source.pk, source.dep) + |""".stripMargin) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, Row(2, Row(20, "y")), "engineering"), + Row(2, Row(999, Row(999, "default")), "finance"))) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } + + test("merge into with source missing fields in nested struct") { + Seq(true, false).foreach { nestedTypeCoercion => + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key + -> updateByFields.toString, + SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key + -> nestedTypeCoercion.toString) { + withTempView("source") { + // Target table has nested struct: s.c1, s.c2.a, s.c2.b + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT>, + |dep STRING""".stripMargin, + """{ "pk": 1, "s": { "c1": 2, "c2": { "a": 10, "b": true } } } + |{ "pk": 2, "s": { "c1": 2, "c2": { "a": 30, "b": false } } }""".stripMargin) + + // Source table is missing field 'b' in nested struct s.c2 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + // missing field 'b' + ))) + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(1, Row(10, Row(20)), "sales"), + Row(2, Row(20, Row(30)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + // Missing field b should be filled with NULL + val mergeStmt = s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + + if (nestedTypeCoercion) { sql(mergeStmt) + if (updateByFields) { + // When updating by fields, only non-null fields are updated + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(20, true)), "sales"), + Row(2, Row(20, Row(30, false)), "engineering"))) + } else { + // When updating by top level column, the missing field is set to NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(20, null)), "sales"), + Row(2, Row(20, Row(30, null)), "engineering"))) + } + } else { + val exception = intercept[Exception] { + sql(mergeStmt) + } + assert(exception.getMessage.contains( + """Cannot write incompatible data for the table ``""".stripMargin)) } - assert(exception.getMessage.contains( - """Cannot write incompatible data for the table ``""".stripMargin)) } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } } - - test("merge with named_struct missing non-nullable field backup") { + test("merge with named_struct missing non-nullable field") { withTempView("source") { createAndInitTable( s"""pk INT NOT NULL, @@ -4631,7 +5278,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase .createOrReplaceTempView("source") Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key -> + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> coerceNestedTypes.toString) { // Test UPDATE with named_struct missing non-nullable field c2 val e = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala index 8420e5e4d880..f635131dc3f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala @@ -40,7 +40,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { | UPDATE SET t.txt = "error", t.i = CAST(null AS INT)""".stripMargin) matchedActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(i: AttributeReference, AssertNotNull(iValue: AttributeReference, _)), @@ -80,7 +80,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } notMatchedBySourceActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(i: AttributeReference, AssertNotNull(_: Cast, _)), @@ -138,7 +138,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } matchedActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(i: AttributeReference, iValue: AttributeReference), @@ -184,7 +184,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } notMatchedBySourceActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(i: AttributeReference, iValue: AttributeReference), @@ -217,7 +217,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { |""".stripMargin) matchedActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(i: AttributeReference, iValue: AttributeReference), @@ -280,7 +280,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } matchedActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(i: AttributeReference, iValue: AttributeReference), @@ -342,7 +342,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } notMatchedBySourceActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(i: AttributeReference, iValue: AttributeReference), @@ -463,7 +463,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } matchedActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(c: AttributeReference, cValue: StaticInvoke), @@ -531,7 +531,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } notMatchedBySourceActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(c: AttributeReference, cValue: StaticInvoke), @@ -691,7 +691,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i")) Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key -> + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> coerceNestedTypes.toString) { val mergeStmt = s"""MERGE INTO nested_struct_table t USING nested_struct_table src @@ -745,7 +745,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { val actions = if (matchedActions.nonEmpty) matchedActions else notMatchedBySourceActions actions match { - case Seq(UpdateAction(_, assignments)) => + case Seq(UpdateAction(_, assignments, _)) => assignments match { case Seq( Assignment( @@ -858,7 +858,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i")) Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key -> + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> coerceNestedTypes.toString) { val mergeStmt = s"""MERGE INTO nested_struct_table t USING nested_struct_table src @@ -947,7 +947,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { | UPDATE SET t.i = DEFAULT""".stripMargin) matchedActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(b: AttributeReference, bValue: AttributeReference), @@ -1001,7 +1001,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } notMatchedBySourceActions match { - case Seq(UpdateAction(None, assignments)) => + case Seq(UpdateAction(None, assignments, _)) => assignments match { case Seq( Assignment(b: AttributeReference, bValue: AttributeReference), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index dfd24a1ebe97..fb56cecb05ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1703,12 +1703,12 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("update"))), - updateAssigns)), + updateAssigns, _)), Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))), insertAssigns)), Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(nul: AttributeReference, StringLiteral("update"))), - notMatchedBySourceUpdateAssigns)), + notMatchedBySourceUpdateAssigns, _)), withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns) @@ -1737,7 +1737,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, - StringLiteral("update"))), updateAssigns)), + StringLiteral("update"))), updateAssigns, _)), Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))), insertAssigns)), Seq(), @@ -1766,7 +1766,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), mergeCondition, - Seq(UpdateAction(None, updateAssigns)), + Seq(UpdateAction(None, updateAssigns, _)), Seq(InsertAction(None, insertAssigns)), Seq(), withSchemaEvolution) => @@ -1797,10 +1797,10 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), mergeCondition, - Seq(DeleteAction(Some(_)), UpdateAction(None, updateAssigns)), + Seq(DeleteAction(Some(_)), UpdateAction(None, updateAssigns, _)), Seq(InsertAction(None, insertAssigns)), Seq(DeleteAction(Some(EqualTo(_: AttributeReference, StringLiteral("delete")))), - UpdateAction(None, notMatchedBySourceUpdateAssigns)), + UpdateAction(None, notMatchedBySourceUpdateAssigns, _)), withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, None, None, updateAssigns) @@ -1832,12 +1832,12 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("update"))), - updateAssigns)), + updateAssigns, _)), Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))), insertAssigns)), Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(nul: AttributeReference, StringLiteral("update"))), - notMatchedBySourceUpdateAssigns)), + notMatchedBySourceUpdateAssigns, _)), withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns) @@ -1871,12 +1871,12 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("update"))), - updateAssigns)), + updateAssigns, _)), Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))), insertAssigns)), Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(nul: AttributeReference, StringLiteral("update"))), - notMatchedBySourceUpdateAssigns)), + notMatchedBySourceUpdateAssigns, _)), withSchemaEvolution) => checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns) @@ -1927,7 +1927,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { case UpdateAction(Some(EqualTo(_: AttributeReference, StringLiteral("update"))), Seq( Assignment(_: AttributeReference, Literal(null, StringType)), - Assignment(_: AttributeReference, _: AttributeReference))) => + Assignment(_: AttributeReference, _: AttributeReference)), _) => case other => fail("unexpected second matched action " + other) } assert(m.notMatchedActions.length == 1) @@ -1947,7 +1947,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { } m.notMatchedBySourceActions(1) match { case UpdateAction(Some(EqualTo(_: AttributeReference, StringLiteral("update"))), - Seq(Assignment(_: AttributeReference, Literal(null, StringType)))) => + Seq(Assignment(_: AttributeReference, Literal(null, StringType))), _) => case other => fail("unexpected second not matched by source action " + other) } @@ -1999,7 +1999,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { val second = m.matchedActions(1) second match { case UpdateAction(Some(EqualTo(_: AttributeReference, Literal(31, IntegerType))), - Seq(Assignment(_: AttributeReference, Literal(42, IntegerType)))) => + Seq(Assignment(_: AttributeReference, Literal(42, IntegerType))), _) => case other => fail("unexpected second matched action " + other) } assert(m.notMatchedActions.length == 1) @@ -2017,7 +2017,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { } m.notMatchedBySourceActions(1) match { case UpdateAction(Some(EqualTo(_: AttributeReference, Literal(31, IntegerType))), - Seq(Assignment(_: AttributeReference, Literal(42, IntegerType)))) => + Seq(Assignment(_: AttributeReference, Literal(42, IntegerType))), _) => case other => fail("unexpected second not matched by source action " + other) } assert(m.withSchemaEvolution === false) @@ -2158,11 +2158,11 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { AsDataSourceV2Relation(target), AsDataSourceV2Relation(source), _, - Seq(DeleteAction(Some(_)), UpdateAction(None, firstUpdateAssigns)), + Seq(DeleteAction(Some(_)), UpdateAction(None, firstUpdateAssigns, _)), Seq(InsertAction( Some(EqualTo(il: AttributeReference, StringLiteral("a"))), insertAssigns)), - Seq(DeleteAction(Some(_)), UpdateAction(None, secondUpdateAssigns)), + Seq(DeleteAction(Some(_)), UpdateAction(None, secondUpdateAssigns, _)), withSchemaEvolution) => val ti = target.output.find(_.name == "i").get val ts = target.output.find(_.name == "s").get @@ -2282,7 +2282,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { } notMatchedBySourceActions(1) match { case UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("a"))), - Seq(Assignment(us: AttributeReference, IntegerLiteral(1)))) => + Seq(Assignment(us: AttributeReference, IntegerLiteral(1))), _) => // UPDATE condition and assignment are resolved with target table only, so column `s` // and `i` are not ambiguous. val ts = target.output.find(_.name == "s").get @@ -2342,7 +2342,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { AsDataSourceV2Relation(target), AsDataSourceV2Relation(source), EqualTo(IntegerLiteral(1), IntegerLiteral(1)), - Seq(UpdateAction(None, updateAssigns)), // Matched actions + Seq(UpdateAction(None, updateAssigns, _)), // Matched actions Seq(), // Not matched actions Seq(), // Not matched by source actions withSchemaEvolution) => @@ -2395,7 +2395,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { assert(m.matchedActions.length == 1) m.matchedActions.head match { case UpdateAction(_, Seq( - Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) => + Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke)), _) => assert(s1.arguments.length == 2) assert(s1.functionName == "charTypeWriteSideCheck") assert(s2.arguments.length == 2) @@ -2421,7 +2421,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { assert(m.notMatchedBySourceActions.length == 1) m.notMatchedBySourceActions.head match { case UpdateAction(_, Seq( - Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) => + Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke)), _) => assert(s1.arguments.length == 2) assert(s1.functionName == "charTypeWriteSideCheck") assert(s2.arguments.length == 2)