Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1670,7 +1670,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case u: UpdateTable => resolveReferencesInUpdate(u)

case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _)
if !m.resolved && targetTable.resolved && sourceTable.resolved && !m.needSchemaEvolution =>
if !m.resolved && targetTable.resolved && sourceTable.resolved =>

// Do not throw exception for schema evolution case.
// This allows unresolved assignment keys a chance to be resolved by a second pass
// by newly column/nested fields added by schema evolution.
val throws = !m.schemaEvolutionEnabled

EliminateSubqueryAliases(targetTable) match {
case r: NamedRelation if r.skipSchemaResolution =>
Expand All @@ -1680,6 +1685,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
m

case _ =>
def findAttrInTarget(name: String): Option[Attribute] = {
targetTable.output.find(targetAttr => conf.resolver(name, targetAttr.name))
}
val newMatchedActions = m.matchedActions.map {
case DeleteAction(deleteCondition) =>
val resolvedDeleteCondition = deleteCondition.map(
Expand All @@ -1691,18 +1699,30 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
UpdateAction(
resolvedUpdateCondition,
// The update value can access columns from both target and source tables.
resolveAssignments(assignments, m, MergeResolvePolicy.BOTH))
resolveAssignments(assignments, m, MergeResolvePolicy.BOTH,
throws = throws))
case UpdateStarAction(updateCondition) =>
// Use only source columns. Missing columns in target will be handled in
// ResolveRowLevelCommandAssignments.
val assignments = targetTable.output.flatMap{ targetAttr =>
sourceTable.output.find(
sourceCol => conf.resolver(sourceCol.name, targetAttr.name))
.map(Assignment(targetAttr, _))}
// Expand star to top level source columns. If source has less columns than target,
// assignments will be added by ResolveRowLevelCommandAssignments later.
val assignments = if (m.schemaEvolutionEnabled) {
// For schema evolution case, generate assignments for missing target columns.
// These columns will be added by ResolveMergeIntoTableSchemaEvolution later.
sourceTable.output.map { sourceAttr =>
val key = findAttrInTarget(sourceAttr.name).getOrElse(
UnresolvedAttribute(sourceAttr.name))
Assignment(key, sourceAttr)
}
} else {
sourceTable.output.flatMap { sourceAttr =>
findAttrInTarget(sourceAttr.name).map(
targetAttr => Assignment(targetAttr, sourceAttr))
}
}
UpdateAction(
updateCondition.map(resolveExpressionByPlanChildren(_, m)),
// For UPDATE *, the value must be from source table.
resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE))
resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE,
throws = throws))
case o => o
}
val newNotMatchedActions = m.notMatchedActions.map {
Expand All @@ -1713,21 +1733,33 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
resolveExpressionByPlanOutput(_, m.sourceTable))
InsertAction(
resolvedInsertCondition,
resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE))
resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE,
throws = throws))
case InsertStarAction(insertCondition) =>
// The insert action is used when not matched, so its condition and value can only
// access columns from the source table.
val resolvedInsertCondition = insertCondition.map(
resolveExpressionByPlanOutput(_, m.sourceTable))
// Use only source columns. Missing columns in target will be handled in
// ResolveRowLevelCommandAssignments.
val assignments = targetTable.output.flatMap{ targetAttr =>
sourceTable.output.find(
sourceCol => conf.resolver(sourceCol.name, targetAttr.name))
.map(Assignment(targetAttr, _))}
// Expand star to top level source columns. If source has less columns than target,
// assignments will be added by ResolveRowLevelCommandAssignments later.
val assignments = if (m.schemaEvolutionEnabled) {
// For schema evolution case, generate assignments for missing target columns.
// These columns will be added by ResolveMergeIntoTableSchemaEvolution later.
sourceTable.output.map { sourceAttr =>
val key = findAttrInTarget(sourceAttr.name).getOrElse(
UnresolvedAttribute(sourceAttr.name))
Assignment(key, sourceAttr)
}
} else {
sourceTable.output.flatMap { sourceAttr =>
findAttrInTarget(sourceAttr.name).map(
targetAttr => Assignment(targetAttr, sourceAttr))
}
}
InsertAction(
resolvedInsertCondition,
resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE))
resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE,
throws = throws))
case o => o
}
val newNotMatchedBySourceActions = m.notMatchedBySourceActions.map {
Expand All @@ -1741,7 +1773,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
UpdateAction(
resolvedUpdateCondition,
// The update value can access columns from the target table only.
resolveAssignments(assignments, m, MergeResolvePolicy.TARGET))
resolveAssignments(assignments, m, MergeResolvePolicy.TARGET,
throws = throws))
case o => o
}

Expand Down Expand Up @@ -1818,11 +1851,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
def resolveAssignments(
assignments: Seq[Assignment],
mergeInto: MergeIntoTable,
resolvePolicy: MergeResolvePolicy.Value): Seq[Assignment] = {
resolvePolicy: MergeResolvePolicy.Value,
throws: Boolean): Seq[Assignment] = {
assignments.map { assign =>
val resolvedKey = assign.key match {
case c if !c.resolved =>
resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable))
resolveMergeExpr(c, Project(Nil, mergeInto.targetTable), throws)
case o => o
}
val resolvedValue = assign.value match {
Expand All @@ -1842,17 +1876,21 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
} else {
resolvedExpr
}
checkResolvedMergeExpr(withDefaultResolved, resolvePlan)
if (throws) {
checkResolvedMergeExpr(withDefaultResolved, resolvePlan)
}
withDefaultResolved
case o => o
}
Assignment(resolvedKey, resolvedValue)
}
}

private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = {
val resolved = resolveExprInAssignment(e, p)
checkResolvedMergeExpr(resolved, p)
private def resolveMergeExpr(e: Expression, p: LogicalPlan, throws: Boolean): Expression = {
val resolved = resolveExprInAssignment(e, p, throws)
if (throws) {
checkResolvedMergeExpr(resolved, p)
}
resolved
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
def resolveExpressionByPlanChildren(
e: Expression,
q: LogicalPlan,
includeLastResort: Boolean = false): Expression = {
includeLastResort: Boolean = false,
throws: Boolean = true): Expression = {
resolveExpression(
tryResolveDataFrameColumns(e, q.children),
resolveColumnByName = nameParts => {
Expand All @@ -435,7 +436,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
assert(q.children.length == 1)
q.children.head.output
},
throws = true,
throws,
includeLastResort = includeLastResort)
}

Expand Down Expand Up @@ -475,8 +476,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
resolveVariables(resolveOuterRef(e))
}

def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = {
resolveExpressionByPlanChildren(expr, hostPlan) match {
def resolveExprInAssignment(
expr: Expression,
hostPlan: LogicalPlan,
throws: Boolean = true): Expression = {
resolveExpressionByPlanChildren(expr,
hostPlan,
includeLastResort = false,
throws = throws) match {
// Assignment key and value does not need the alias when resolving nested columns.
case Alias(child: ExtractValue, _) => child
case other => other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsRowLevelOperations, TableCatalog, TableChange}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.StructType


/**
Expand All @@ -34,24 +35,38 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case m @ MergeIntoTable(_, _, _, _, _, _, _)
if m.needSchemaEvolution =>
val newTarget = m.targetTable.transform {
case r : DataSourceV2Relation => performSchemaEvolution(r, m.sourceTable)
// This rule should run only if all assignments are resolved, except those
// that will be satisfied by schema evolution
case m@MergeIntoTable(_, _, _, _, _, _, _) if m.evaluateSchemaEvolution =>
val changes = m.changesForSchemaEvolution
if (changes.isEmpty) {
m
} else {
m transformUpWithNewOutput {
case r @ DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _, _) =>
val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m)
val newTarget = performSchemaEvolution(r, referencedSourceSchema, changes)
val oldTargetOutput = m.targetTable.output
val newTargetOutput = newTarget.output
val attributeMapping = oldTargetOutput.map(
oldAttr => (oldAttr, newTargetOutput.find(_.name == oldAttr.name).getOrElse(oldAttr))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given we can only append new columns, this can be simply oldTargetOutput.zip(newTargetOutput)

)
newTarget -> attributeMapping
}
m.copy(targetTable = newTarget)
}
}

private def performSchemaEvolution(relation: DataSourceV2Relation, source: LogicalPlan)
: DataSourceV2Relation = {
private def performSchemaEvolution(
relation: DataSourceV2Relation,
referencedSourceSchema: StructType,
changes: Array[TableChange]): DataSourceV2Relation = {
(relation.catalog, relation.identifier) match {
case (Some(c: TableCatalog), Some(i)) =>
val changes = MergeIntoTable.schemaChanges(relation.schema, source.schema)
c.alterTable(i, changes: _*)
val newTable = c.loadTable(i)
val newSchema = CatalogV2Util.v2ColumnsToStructType(newTable.columns())
// Check if there are any remaining changes not applied.
val remainingChanges = MergeIntoTable.schemaChanges(newSchema, source.schema)
val remainingChanges = MergeIntoTable.schemaChanges(newSchema, referencedSourceSchema)
if (remainingChanges.nonEmpty) {
throw QueryCompilationErrors.unsupportedTableChangesInAutoSchemaEvolutionError(
remainingChanges, i.toQualifiedNameParts(c))
Expand Down
Loading