diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 98c514925fa0..88cb2ded59c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1670,7 +1670,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UpdateTable => resolveReferencesInUpdate(u) case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _) - if !m.resolved && targetTable.resolved && sourceTable.resolved && !m.needSchemaEvolution => + if !m.resolved && targetTable.resolved && sourceTable.resolved => + + // Do not throw exception for schema evolution case. + // This allows unresolved assignment keys a chance to be resolved by a second pass + // by newly column/nested fields added by schema evolution. + val throws = !m.schemaEvolutionEnabled EliminateSubqueryAliases(targetTable) match { case r: NamedRelation if r.skipSchemaResolution => @@ -1680,6 +1685,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor m case _ => + def findAttrInTarget(name: String): Option[Attribute] = { + targetTable.output.find(targetAttr => conf.resolver(name, targetAttr.name)) + } val newMatchedActions = m.matchedActions.map { case DeleteAction(deleteCondition) => val resolvedDeleteCondition = deleteCondition.map( @@ -1691,18 +1699,30 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor UpdateAction( resolvedUpdateCondition, // The update value can access columns from both target and source tables. - resolveAssignments(assignments, m, MergeResolvePolicy.BOTH)) + resolveAssignments(assignments, m, MergeResolvePolicy.BOTH, + throws = throws)) case UpdateStarAction(updateCondition) => - // Use only source columns. Missing columns in target will be handled in - // ResolveRowLevelCommandAssignments. - val assignments = targetTable.output.flatMap{ targetAttr => - sourceTable.output.find( - sourceCol => conf.resolver(sourceCol.name, targetAttr.name)) - .map(Assignment(targetAttr, _))} + // Expand star to top level source columns. If source has less columns than target, + // assignments will be added by ResolveRowLevelCommandAssignments later. + val assignments = if (m.schemaEvolutionEnabled) { + // For schema evolution case, generate assignments for missing target columns. + // These columns will be added by ResolveMergeIntoTableSchemaEvolution later. + sourceTable.output.map { sourceAttr => + val key = findAttrInTarget(sourceAttr.name).getOrElse( + UnresolvedAttribute(sourceAttr.name)) + Assignment(key, sourceAttr) + } + } else { + sourceTable.output.flatMap { sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + } + } UpdateAction( updateCondition.map(resolveExpressionByPlanChildren(_, m)), // For UPDATE *, the value must be from source table. - resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) + resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, + throws = throws)) case o => o } val newNotMatchedActions = m.notMatchedActions.map { @@ -1713,21 +1733,33 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolveExpressionByPlanOutput(_, m.sourceTable)) InsertAction( resolvedInsertCondition, - resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) + resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, + throws = throws)) case InsertStarAction(insertCondition) => // The insert action is used when not matched, so its condition and value can only // access columns from the source table. val resolvedInsertCondition = insertCondition.map( resolveExpressionByPlanOutput(_, m.sourceTable)) - // Use only source columns. Missing columns in target will be handled in - // ResolveRowLevelCommandAssignments. - val assignments = targetTable.output.flatMap{ targetAttr => - sourceTable.output.find( - sourceCol => conf.resolver(sourceCol.name, targetAttr.name)) - .map(Assignment(targetAttr, _))} + // Expand star to top level source columns. If source has less columns than target, + // assignments will be added by ResolveRowLevelCommandAssignments later. + val assignments = if (m.schemaEvolutionEnabled) { + // For schema evolution case, generate assignments for missing target columns. + // These columns will be added by ResolveMergeIntoTableSchemaEvolution later. + sourceTable.output.map { sourceAttr => + val key = findAttrInTarget(sourceAttr.name).getOrElse( + UnresolvedAttribute(sourceAttr.name)) + Assignment(key, sourceAttr) + } + } else { + sourceTable.output.flatMap { sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + } + } InsertAction( resolvedInsertCondition, - resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) + resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, + throws = throws)) case o => o } val newNotMatchedBySourceActions = m.notMatchedBySourceActions.map { @@ -1741,7 +1773,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor UpdateAction( resolvedUpdateCondition, // The update value can access columns from the target table only. - resolveAssignments(assignments, m, MergeResolvePolicy.TARGET)) + resolveAssignments(assignments, m, MergeResolvePolicy.TARGET, + throws = throws)) case o => o } @@ -1818,11 +1851,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor def resolveAssignments( assignments: Seq[Assignment], mergeInto: MergeIntoTable, - resolvePolicy: MergeResolvePolicy.Value): Seq[Assignment] = { + resolvePolicy: MergeResolvePolicy.Value, + throws: Boolean): Seq[Assignment] = { assignments.map { assign => val resolvedKey = assign.key match { case c if !c.resolved => - resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable)) + resolveMergeExpr(c, Project(Nil, mergeInto.targetTable), throws) case o => o } val resolvedValue = assign.value match { @@ -1842,7 +1876,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } else { resolvedExpr } - checkResolvedMergeExpr(withDefaultResolved, resolvePlan) + if (throws) { + checkResolvedMergeExpr(withDefaultResolved, resolvePlan) + } withDefaultResolved case o => o } @@ -1850,9 +1886,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } - private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = { - val resolved = resolveExprInAssignment(e, p) - checkResolvedMergeExpr(resolved, p) + private def resolveMergeExpr(e: Expression, p: LogicalPlan, throws: Boolean): Expression = { + val resolved = resolveExprInAssignment(e, p, throws) + if (throws) { + checkResolvedMergeExpr(resolved, p) + } resolved } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 53c92ca5425d..34541a8840cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -425,7 +425,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { def resolveExpressionByPlanChildren( e: Expression, q: LogicalPlan, - includeLastResort: Boolean = false): Expression = { + includeLastResort: Boolean = false, + throws: Boolean = true): Expression = { resolveExpression( tryResolveDataFrameColumns(e, q.children), resolveColumnByName = nameParts => { @@ -435,7 +436,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { assert(q.children.length == 1) q.children.head.output }, - throws = true, + throws, includeLastResort = includeLastResort) } @@ -475,8 +476,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { resolveVariables(resolveOuterRef(e)) } - def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = { - resolveExpressionByPlanChildren(expr, hostPlan) match { + def resolveExprInAssignment( + expr: Expression, + hostPlan: LogicalPlan, + throws: Boolean = true): Expression = { + resolveExpressionByPlanChildren(expr, + hostPlan, + includeLastResort = false, + throws = throws) match { // Assignment key and value does not need the alias when resolving nested columns. case Alias(child: ExtractValue, _) => child case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala index 7e7776098a04..f317a2efddbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsRowLevelOperations, TableCatalog, TableChange} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.StructType /** @@ -34,24 +35,38 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case m @ MergeIntoTable(_, _, _, _, _, _, _) - if m.needSchemaEvolution => - val newTarget = m.targetTable.transform { - case r : DataSourceV2Relation => performSchemaEvolution(r, m.sourceTable) + // This rule should run only if all assignments are resolved, except those + // that will be satisfied by schema evolution + case m@MergeIntoTable(_, _, _, _, _, _, _) if m.evaluateSchemaEvolution => + val changes = m.changesForSchemaEvolution + if (changes.isEmpty) { + m + } else { + m transformUpWithNewOutput { + case r @ DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _, _) => + val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m) + val newTarget = performSchemaEvolution(r, referencedSourceSchema, changes) + val oldTargetOutput = m.targetTable.output + val newTargetOutput = newTarget.output + val attributeMapping = oldTargetOutput.map( + oldAttr => (oldAttr, newTargetOutput.find(_.name == oldAttr.name).getOrElse(oldAttr)) + ) + newTarget -> attributeMapping } - m.copy(targetTable = newTarget) + } } - private def performSchemaEvolution(relation: DataSourceV2Relation, source: LogicalPlan) - : DataSourceV2Relation = { + private def performSchemaEvolution( + relation: DataSourceV2Relation, + referencedSourceSchema: StructType, + changes: Array[TableChange]): DataSourceV2Relation = { (relation.catalog, relation.identifier) match { case (Some(c: TableCatalog), Some(i)) => - val changes = MergeIntoTable.schemaChanges(relation.schema, source.schema) c.alterTable(i, changes: _*) val newTable = c.loadTable(i) val newSchema = CatalogV2Util.v2ColumnsToStructType(newTable.columns()) // Check if there are any remaining changes not applied. - val remainingChanges = MergeIntoTable.schemaChanges(newSchema, source.schema) + val remainingChanges = MergeIntoTable.schemaChanges(newSchema, referencedSourceSchema) if (remainingChanges.nonEmpty) { throw QueryCompilationErrors.unsupportedTableChangesInAutoSchemaEvolutionError( remainingChanges, i.toQualifiedNameParts(c)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index cd0c2742df3d..db5dacdcef38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedAttribute, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -893,16 +893,46 @@ case class MergeIntoTable( } lazy val needSchemaEvolution: Boolean = + evaluateSchemaEvolution && changesForSchemaEvolution.nonEmpty + + lazy val evaluateSchemaEvolution: Boolean = schemaEvolutionEnabled && - MergeIntoTable.schemaChanges(targetTable.schema, sourceTable.schema).nonEmpty + canEvaluateSchemaEvolution - private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { + lazy val schemaEvolutionEnabled: Boolean = withSchemaEvolution && { EliminateSubqueryAliases(targetTable) match { case r: DataSourceV2Relation if r.autoSchemaEvolution() => true case _ => false } } + // Guard that assignments are either resolved or candidates for evolution before + // evaluating schema evolution. We need to use resolved assignment values to check + // candidates, see MergeIntoTable.sourceSchemaForSchemaEvolution for details. + lazy val canEvaluateSchemaEvolution: Boolean = { + if ((!targetTable.resolved) || (!sourceTable.resolved)) { + false + } else { + val actions = matchedActions ++ notMatchedActions + val assignments = actions.collect { + case a: UpdateAction => a.assignments + case a: InsertAction => a.assignments + }.flatten + + val sourcePaths = MergeIntoTable.extractAllFieldPaths(sourceTable.schema) + assignments.forall { assignment => + assignment.resolved || + sourcePaths.exists { path => MergeIntoTable.isEqual(assignment, path) } + } + } + } + + private lazy val sourceSchemaForEvolution: StructType = + MergeIntoTable.sourceSchemaForSchemaEvolution(this) + + lazy val changesForSchemaEvolution: Array[TableChange] = + MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution) + override def left: LogicalPlan = targetTable override def right: LogicalPlan = sourceTable override protected def withNewChildrenInternal( @@ -911,6 +941,7 @@ case class MergeIntoTable( } object MergeIntoTable { + def getWritePrivileges( matchedActions: Iterable[MergeAction], notMatchedActions: Iterable[MergeAction], @@ -948,11 +979,12 @@ object MergeIntoTable { case currentField: StructField if newFieldMap.contains(currentField.name) => schemaChanges(currentField.dataType, newFieldMap(currentField.name).dataType, originalTarget, originalSource, fieldPath ++ Seq(currentField.name)) - }}.flatten + } + }.flatten // Identify the newly added fields and append to the end val currentFieldMap = toFieldMap(currentFields) - val adds = newFields.filterNot (f => currentFieldMap.contains (f.name)) + val adds = newFields.filterNot(f => currentFieldMap.contains(f.name)) .map(f => TableChange.addColumn(fieldPath ++ Set(f.name), f.dataType)) updates ++ adds @@ -990,6 +1022,97 @@ object MergeIntoTable { CaseInsensitiveMap(fieldMap) } } + + // A pruned version of source schema that only contains columns/nested fields + // explicitly and directly assigned to a target counterpart in MERGE INTO actions, + // which are relevant for schema evolution. + // Examples: + // * UPDATE SET target.a = source.a + // * UPDATE SET nested.a = source.nested.a + // * INSERT (a, nested.b) VALUES (source.a, source.nested.b) + // New columns/nested fields in this schema that are not existing in target schema + // will be added for schema evolution. + def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = { + val actions = merge.matchedActions ++ merge.notMatchedActions + val assignments = actions.collect { + case a: UpdateAction => a.assignments + case a: InsertAction => a.assignments + }.flatten + + val containsStarAction = actions.exists { + case _: UpdateStarAction => true + case _: InsertStarAction => true + case _ => false + } + + def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType = + StructType(sourceSchema.flatMap { field => + val fieldPath = basePath :+ field.name + + field.dataType match { + // Specifically assigned to in one clause: + // always keep, including all nested attributes + case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field) + // If this is a struct and one of the children is being assigned to in a merge clause, + // keep it and continue filtering children. + case struct: StructType if assignments.exists(assign => + isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) => + Some(field.copy(dataType = filterSchema(struct, fieldPath))) + // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* + // clause. Check if it should be kept with any * action. + case struct: StructType if containsStarAction => + Some(field.copy(dataType = filterSchema(struct, fieldPath))) + case _ if containsStarAction => Some(field) + // The field and its children are not assigned to in any * or non-* action, drop it. + case _ => None + } + }) + + filterSchema(merge.sourceTable.schema, Seq.empty) + } + + private def extractAllFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty): + Seq[Seq[String]] = { + schema.flatMap { field => + val fieldPath = basePath :+ field.name + field.dataType match { + case struct: StructType => + fieldPath +: extractAllFieldPaths(struct, fieldPath) + case _ => + Seq(fieldPath) + } + } + } + + // Helper method to extract field path from an Expression. + private def extractFieldPath(expr: Expression, allowUnresolved: Boolean): Seq[String] = { + expr match { + case UnresolvedAttribute(nameParts) if allowUnresolved => nameParts + case a: AttributeReference => Seq(a.name) + case GetStructField(child, ordinal, nameOpt) => + extractFieldPath(child, allowUnresolved) :+ nameOpt.getOrElse(s"col$ordinal") + case _ => Seq.empty + } + } + + // Helper method to check if a given field path is a prefix of another path. + private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean = + prefix.length <= path.length && prefix.zip(path).forall { + case (prefixNamePart, pathNamePart) => + SQLConf.get.resolver(prefixNamePart, pathNamePart) + } + + // Helper method to check if an assignment key is equal to a source column + // and if the assignment value is that same source column. + // Example: UPDATE SET target.a = source.a + private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = { + // key must be a non-qualified field path that may be added to target schema via evolution + val assignmenKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true) + // value should always be resolved (from source) + val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false) + assignmenKeyExpr == assignmentValueExpr && + assignmenKeyExpr == sourceFieldPath + } } sealed abstract class MergeAction extends Expression with Unevaluable { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 98706c4afeae..b73c8d2458c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -2206,6 +2206,118 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("Merge schema evolution new column with conditions on update and insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "marketing" } + |{ "pk": 5, "salary": 500, "dep": "executive" } + |""".stripMargin) + + // Two rows that could be updated (pk 4 and 5), but only one has salary > 450 + // Two rows that could be inserted (pk 6 and 7), but only one has active = true + val sourceDF = Seq((4, 450, "finance", false), + (5, 550, "finance", true), + (6, 350, "sales", true), + (7, 250, "sales", false)).toDF("pk", "salary", "dep", "active") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND s.salary > 450 THEN + | UPDATE SET dep='updated', active=s.active + |WHEN NOT MATCHED AND s.active = true THEN + | INSERT (pk, salary, dep, active) VALUES (s.pk, s.salary, s.dep, + | s.active) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 200, "software", null), + Row(3, 300, "hr", null), + Row(4, 400, "marketing", null), // pk=4 not updated (salary 450 is not > 450) + Row(5, 500, "updated", true), // pk=5 updated (salary 550 > 450) + Row(6, 350, "sales", true))) // pk=6 inserted (active = true) + // pk=7 not inserted (active = false) + } else { + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`active` cannot be resolved")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution with condition on new column from target") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "marketing" } + |{ "pk": 5, "salary": 500, "dep": "executive" } + |""".stripMargin) + + // Source has new 'active' column that doesn't exist in target + val sourceDF = Seq((4, 450, "finance", true), + (5, 550, "finance", false), + (6, 350, "sales", true)).toDF("pk", "salary", "dep", "active") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + // Condition references t.active which doesn't exist yet in target + val mergeStmt = s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND t.active IS NULL THEN + | UPDATE SET salary=s.salary, dep=s.dep, active=s.active + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep, active) + | VALUES (s.pk, s.salary, s.dep, s.active) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 200, "software", null), + Row(3, 300, "hr", null), + Row(4, 450, "finance", true), // Updated (t.active was NULL) + Row(5, 550, "finance", false), // Updated (t.active was NULL) + Row(6, 350, "sales", true))) // Inserted + } else { + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`active` cannot be resolved")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + test("Merge schema evolution new column with set all columns") { Seq((true, true), (false, true), (true, false)).foreach { case (withSchemaEvolution, schemaEvolutionEnabled) => @@ -3510,155 +3622,800 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("merge into with source missing fields in top-level struct") { - withTempView("source") { - // Target table has struct with 3 fields at top level - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT, - |dep STRING""".stripMargin, - """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") - - // Source table has struct with only 2 fields (c1, c2) - missing c3 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType)))), // missing c3 field - StructField("dep", StringType))) - val data = Seq( - Row(1, Row(10, "b"), "hr"), - Row(2, Row(20, "c"), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") - - sql( - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) + test("Merge schema evolution should not evolve referencing new column via transform") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, "a", true), "sales"), - Row(1, Row(10, "b", null), "hr"), - Row(2, Row(20, "c", null), "engineering"))) - } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") - } + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") - test("merge into with source missing fields in struct nested in array") { - withTempView("source") { - // Target table has struct with 3 fields (c1, c2, c3) in array - createAndInitTable( - s"""pk INT NOT NULL, - |a ARRAY>, - |dep STRING""".stripMargin, - """{ "pk": 0, "a": [ { "c1": 1, "c2": "a", "c3": true } ], "dep": "sales" } - |{ "pk": 1, "a": [ { "c1": 2, "c2": "b", "c3": false } ], "dep": "sales" }""" - .stripMargin) + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET extra=substring(s.extra, 1, 2) + |""".stripMargin - // Source table has struct with only 2 fields (c1, c2) - missing c3 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("a", ArrayType( - StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType))))), // missing c3 field - StructField("dep", StringType))) - val data = Seq( - Row(1, Array(Row(10, "c")), "hr"), - Row(2, Array(Row(30, "e")), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") - sql( - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Array(Row(1, "a", true)), "sales"), - Row(1, Array(Row(10, "c", null)), "hr"), - Row(2, Array(Row(30, "e", null)), "engineering"))) + sql(s"DROP TABLE $tableNameAsString") + } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - test("merge into with source missing fields in struct nested in map key") { - withTempView("source") { - // Target table has struct with 2 fields in map key - val targetSchema = - StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType), StructField("c2", BooleanType))), - StructType(Seq(StructField("c3", StringType))))), - StructField("dep", StringType))) - createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) + test("Merge schema evolution should not evolve if not directly referencing new column: update") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) - val targetData = Seq( - Row(0, Map(Row(10, true) -> Row("x")), "hr"), - Row(1, Map(Row(20, false) -> Row("y")), "sales")) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") - // Source table has struct with only 1 field (c1) in map key - missing c2 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType))), // missing c2 - StructType(Seq(StructField("c3", StringType))))), - StructField("dep", StringType))) - val sourceData = Seq( - Row(1, Map(Row(10) -> Row("z")), "sales"), - Row(2, Map(Row(20) -> Row("w")), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET dep='software' + |""".stripMargin - sql( - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) - // Missing field c2 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Map(Row(10, true) -> Row("x")), "hr"), - Row(1, Map(Row(10, null) -> Row("z")), "sales"), - Row(2, Map(Row(20, null) -> Row("w")), "engineering"))) + sql(s"DROP TABLE $tableNameAsString") + } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - test("merge into with source missing fields in struct nested in map value") { - withTempView("source") { - // Target table has struct with 2 fields in map value - val targetSchema = - StructType(Seq( + test("Merge schema evolution should not evolve if not directly referencing new column: insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'newdep') + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 250, "newdep"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing new column:" + + "update and insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET dep='software' + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'newdep') + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 250, "newdep"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not having just column name: update") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.extra = s.extra + |""".stripMargin + + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(exception.message.contains(" A column, variable, or function parameter with name " + + "`t`.`extra` cannot be resolved")) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should only evolve referenced column when source " + + "has multiple new columns") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", 50, "blah"), + (3, 250, "dummy", 75, "blah")).toDF("pk", "salary", "dep", "bonus", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, bonus = s.bonus + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep, bonus) VALUES (s.pk, s.salary, 'newdep', s.bonus) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 150, "software", 50), + Row(3, 250, "newdep", 75))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should only evolve referenced struct field when source " + + "has multiple new struct fields") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType), // new field 1 + StructField("extra", StringType) // new field 2 + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "dummy", 50, "blah"), "active"), + Row(3, Row(250, "dummy", 75, "blah"), "active") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET info.bonus = s.info.bonus + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + // Only 'bonus' field should be added, not 'extra' + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(200, "inactive", 50), "software"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "FIELD_NOT_FOUND") + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve when assigning existing target column " + + "from source column that does not exist in target") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", 50), + (3, 250, "dummy", 75)).toDF("pk", "salary", "dep", "bonus") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.bonus + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.bonus, 'newdep') + |""".stripMargin + + sql(mergeStmt) + // bonus column should NOT be added to target schema + // Only salary is updated with bonus value + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 50, "software"), + Row(3, 75, "newdep"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve struct if not directly referencing new field " + + "in top level struct: insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "dummy", 50), "active"), + Row(3, Row(250, "dummy", 75), "active") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, info, dep) VALUES (s.pk, + | named_struct('salary', s.info.salary, 'status', 'active'), 'marketing') + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active"), "hr"), + Row(2, Row(200, "inactive"), "software"), + Row(3, Row(250, "active"), "marketing"))) + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing new field " + + "in top level struct: UPDATE") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "dummy", 50), "active"), + Row(3, Row(250, "dummy", 75), "active") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET info.status='inactive' + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active"), "hr"), + Row(2, Row(200, "inactive"), "software"))) + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve when directly assigning struct with new field:" + + "UPDATE") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "updated", 50), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET info = s.info + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + // Schema should evolve - bonus field should be added + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(150, "updated", 50), "software"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.getMessage.contains("Cannot safely cast") || + exception.getMessage.contains("incompatible")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve when directly assigning struct with new field: " + + "INSERT") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(3, Row(150, "new", 50), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, info, dep) VALUES (s.pk, s.info, s.dep) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + // Schema should evolve - bonus field should be added + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(200, "inactive", null), "software"), + Row(3, Row(150, "new", 50), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.getMessage.contains("Cannot safely cast") || + exception.getMessage.contains("incompatible")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing " + + "new field in nested struct") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("employee", StructType(Seq( + StructField("name", StringType), + StructField("details", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType) + ))) + ))), + StructField("dep", StringType) + )) + + createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) + + val targetData = Seq( + Row(1, Row("Alice", Row(100, "active")), "hr"), + Row(2, Row("Bob", Row(200, "active")), "software") + ) + spark.createDataFrame( + spark.sparkContext.parallelize(targetData), targetSchema) + .coalesce(1).writeTo(tableNameAsString).append() + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row("Alice", Row(100, "active")), "hr"), + Row(2, Row("Bob", Row(200, "active")), "software"))) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("employee", StructType(Seq( + StructField("name", StringType), + StructField("details", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))) + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row("Bob", Row(150, "active", 50)), "dummy"), + Row(3, Row("Charlie", Row(250, "active", 75)), "dummy") + ) + spark.createDataFrame( + spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = + if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET employee.details.status='inactive' + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row("Alice", Row(100, "active")), "hr"), + Row(2, Row("Bob", Row(200, "inactive")), "software"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve referencing new column assigned to something else") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET extra=s.dep + |""".stripMargin + + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("merge into with source missing fields in top-level struct") { + withTempView("source") { + // Target table has struct with 3 fields at top level + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") + + // Source table has struct with only 2 fields (c1, c2) - missing c3 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), // missing c3 field + StructField("dep", StringType))) + val data = Seq( + Row(1, Row(10, "b"), "hr"), + Row(2, Row(20, "c"), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a", true), "sales"), + Row(1, Row(10, "b", null), "hr"), + Row(2, Row(20, "c", null), "engineering"))) + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + + test("merge into with source missing fields in struct nested in array") { + withTempView("source") { + // Target table has struct with 3 fields (c1, c2, c3) in array + createAndInitTable( + s"""pk INT NOT NULL, + |a ARRAY>, + |dep STRING""".stripMargin, + """{ "pk": 0, "a": [ { "c1": 1, "c2": "a", "c3": true } ], "dep": "sales" } + |{ "pk": 1, "a": [ { "c1": 2, "c2": "b", "c3": false } ], "dep": "sales" }""" + .stripMargin) + + // Source table has struct with only 2 fields (c1, c2) - missing c3 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("a", ArrayType( + StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType))))), // missing c3 field + StructField("dep", StringType))) + val data = Seq( + Row(1, Array(Row(10, "c")), "hr"), + Row(2, Array(Row(30, "e")), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Array(Row(1, "a", true)), "sales"), + Row(1, Array(Row(10, "c", null)), "hr"), + Row(2, Array(Row(30, "e", null)), "engineering"))) + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + + test("merge into with source missing fields in struct nested in map key") { + withTempView("source") { + // Target table has struct with 2 fields in map key + val targetSchema = + StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType), StructField("c2", BooleanType))), + StructType(Seq(StructField("c3", StringType))))), + StructField("dep", StringType))) + createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) + + val targetData = Seq( + Row(0, Map(Row(10, true) -> Row("x")), "hr"), + Row(1, Map(Row(20, false) -> Row("y")), "sales")) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + // Source table has struct with only 1 field (c1) in map key - missing c2 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType))), // missing c2 + StructType(Seq(StructField("c3", StringType))))), + StructField("dep", StringType))) + val sourceData = Seq( + Row(1, Map(Row(10) -> Row("z")), "sales"), + Row(2, Map(Row(20) -> Row("w")), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + // Missing field c2 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Map(Row(10, true) -> Row("x")), "hr"), + Row(1, Map(Row(10, null) -> Row("z")), "sales"), + Row(2, Map(Row(20, null) -> Row("w")), "engineering"))) + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + + test("merge into with source missing fields in struct nested in map value") { + withTempView("source") { + // Target table has struct with 2 fields in map value + val targetSchema = + StructType(Seq( StructField("pk", IntegerType, nullable = false), StructField("m", MapType( StructType(Seq(StructField("c1", IntegerType))), @@ -3820,6 +4577,62 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(s"DROP TABLE IF EXISTS $tableNameAsString") } + test("Merge schema evolution should error on non-existent column in UPDATE and INSERT") { + withTable(tableNameAsString) { + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |salary INT, + |dep STRING""".stripMargin, + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + + val data = Seq( + Row(2, 250, "engineering"), + Row(3, 300, "finance") + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val updateException = intercept[AnalysisException] { + sql( + s"""MERGE WITH SCHEMA EVOLUTION + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET non_existent = s.nonexistent_column + |""".stripMargin) + } + assert(updateException.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(updateException.message.contains("A column, variable, or function parameter " + + "with name `non_existent` cannot be resolved")) + + val insertException = intercept[AnalysisException] { + sql( + s"""MERGE WITH SCHEMA EVOLUTION + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep, non_existent) VALUES (s.pk, s.salary, s.dep, s.dep) + |""".stripMargin) + } + assert(insertException.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(insertException.message.contains("A column, variable, or function parameter " + + "with name `non_existent` cannot be resolved")) + } + } + } + private def findMergeExec(query: String): MergeRowsExec = { val plan = executeAndKeepPlan { sql(query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 9ea8b9130ba8..dfd24a1ebe97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1633,10 +1633,10 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { if (starInUpdate) { assert(updateAssigns.size == 2) - assert(updateAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti)) - assert(updateAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si)) - assert(updateAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts)) - assert(updateAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss)) + assert(updateAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ts)) + assert(updateAssigns(0).value.asInstanceOf[AttributeReference].sameRef(ss)) + assert(updateAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ti)) + assert(updateAssigns(1).value.asInstanceOf[AttributeReference].sameRef(si)) } else { assert(updateAssigns.size == 1) assert(updateAssigns.head.key.asInstanceOf[AttributeReference].sameRef(ts)) @@ -1648,15 +1648,25 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { target: LogicalPlan, source: LogicalPlan, insertCondAttr: Option[AttributeReference], - insertAssigns: Seq[Assignment]): Unit = { + insertAssigns: Seq[Assignment], + starInInsert: Boolean = false): Unit = { val (si, ss) = getAttributes(source) val (ti, ts) = getAttributes(target) insertCondAttr.foreach(a => assert(a.sameRef(ss))) - assert(insertAssigns.size == 2) - assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti)) - assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si)) - assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts)) - assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss)) + + if (starInInsert) { + assert(insertAssigns.size == 2) + assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ts)) + assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(ss)) + assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ti)) + assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(si)) + } else { + assert(insertAssigns.size == 2) + assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti)) + assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si)) + assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts)) + assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss)) + } } def checkNotMatchedBySourceClausesResolution( @@ -1735,7 +1745,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns, starInUpdate = true) - checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns) + checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns, + starInInsert = true) assert(withSchemaEvolution === false) case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) @@ -1762,7 +1773,8 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { checkMergeConditionResolution(target, source, mergeCondition) checkMatchedClausesResolution(target, source, None, None, updateAssigns, starInUpdate = true) - checkNotMatchedClausesResolution(target, source, None, insertAssigns) + checkNotMatchedClausesResolution(target, source, None, insertAssigns, + starInInsert = true) assert(withSchemaEvolution === false) case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)