Skip to content

Commit

Permalink
[SPARK-43742][SQL] Refactor default column value resolution
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR refactors the default column value resolution so that we don't need an extra DS v2 API for external v2 sources. The general idea is to split the default column value resolution into two parts:
1. resolve the column "DEFAULT" to the column default expression. This applies to `Project`/`UnresolvedInlineTable` under `InsertIntoStatement`, and assignment expressions in `UpdateTable`/`MergeIntoTable`.
2. fill missing columns with column default values for the input query. This does not apply to UPDATE and non-INSERT action of MERGE as they use the column from the target table as the default value.

The first part should be done for all the data sources, as it's part of column resolution. The second part should not be applied to v2 data sources with `ACCEPT_ANY_SCHEMA`, as they are free to define how to handle missing columns.

More concretely, this PR:
1. put the column "DEFAULT" resolution logic in the rule `ResolveReferences`, with two new virtual rules. This is to follow #38888
2. put the missing column handling in `TableOutputResolver`, which is shared by both the v1 and v2 insertion resolution rule. External v2 data sources can add custom catalyst rules to deal with missing columns for themselves.
3. Remove the old rule `ResolveDefaultColumns`. Note that, with the refactor, we no long need to manually look up the table. We will deal with column default values after the target table of INSERT/UPDATE/MERGE is resolved.
4. Remove the rule `ResolveUserSpecifiedColumns` and merge it to `PreprocessTableInsertion`. These two rules are both to resolve v1 insertion, and it's tricky to reason about their interactions. It's clearer to resolve the insertion with one pass.
### Why are the changes needed?

code cleanup and remove unneeded DS v2 API.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

updated tests

Closes #41262 from cloud-fan/def-val.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
cloud-fan authored and dongjoon-hyun committed May 28, 2023
1 parent 1e9b0f6 commit cc24978
Show file tree
Hide file tree
Showing 25 changed files with 855 additions and 1,405 deletions.
27 changes: 12 additions & 15 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,18 @@
},
"INSERT_COLUMN_ARITY_MISMATCH" : {
"message" : [
"<tableName> requires that the data to be inserted have the same number of columns as the target table: target table has <targetColumns> column(s) but the inserted data has <insertedColumns> column(s), including <staticPartCols> partition column(s) having constant value(s)."
"Cannot write to '<tableName>', <reason>:",
"Table columns: <tableColumns>.",
"Data columns: <dataColumns>."
],
"sqlState" : "21S01"
},
"INSERT_PARTITION_COLUMN_ARITY_MISMATCH" : {
"message" : [
"Cannot write to '<tableName>', <reason>:",
"Table columns: <tableColumns>.",
"Partition columns with static values: <staticPartCols>.",
"Data columns: <dataColumns>."
],
"sqlState" : "21S01"
},
Expand Down Expand Up @@ -3489,20 +3500,6 @@
"Cannot resolve column name \"<colName>\" among (<fieldNames>)."
]
},
"_LEGACY_ERROR_TEMP_1202" : {
"message" : [
"Cannot write to '<tableName>', too many data columns:",
"Table columns: <tableColumns>.",
"Data columns: <dataColumns>."
]
},
"_LEGACY_ERROR_TEMP_1203" : {
"message" : [
"Cannot write to '<tableName>', not enough data columns:",
"Table columns: <tableColumns>.",
"Data columns: <dataColumns>."
]
},
"_LEGACY_ERROR_TEMP_1204" : {
"message" : [
"Cannot write incompatible data to table '<tableName>':",
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils, StringUtils}
import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand All @@ -55,8 +55,7 @@ import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssig
import org.apache.spark.sql.internal.connector.V1Function
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.DAY
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.util.collection.{Utils => CUtils}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and
Expand Down Expand Up @@ -280,7 +279,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
KeepLegacyOutputs),
Batch("Resolution", fixedPoint,
new ResolveCatalogs(catalogManager) ::
ResolveUserSpecifiedColumns ::
ResolveInsertInto ::
ResolveRelations ::
ResolvePartitionSpec ::
Expand Down Expand Up @@ -313,7 +311,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
TimeWindowing ::
SessionWindowing ::
ResolveWindowTime ::
ResolveDefaultColumns(ResolveRelations.resolveRelationOrTempView) ::
ResolveInlineTables ::
ResolveLambdaVariables ::
ResolveTimeZone ::
Expand Down Expand Up @@ -1080,7 +1077,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor

def apply(plan: LogicalPlan)
: LogicalPlan = plan.resolveOperatorsUpWithPruning(AlwaysProcess.fn, ruleId) {
case i @ InsertIntoStatement(table, _, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(table, _, _, _, _, _) =>
val relation = table match {
case u: UnresolvedRelation if !u.isStreaming =>
resolveRelation(u).getOrElse(u)
Expand Down Expand Up @@ -1280,53 +1277,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}

/** Handle INSERT INTO for DSv2 */
object ResolveInsertInto extends Rule[LogicalPlan] {

/** Add a project to use the table column names for INSERT INTO BY NAME */
private def createProjectForByNameQuery(i: InsertIntoStatement): LogicalPlan = {
SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver)

if (i.userSpecifiedCols.size != i.query.output.size) {
throw QueryCompilationErrors.writeTableWithMismatchedColumnsError(
i.userSpecifiedCols.size, i.query.output.size, i.query)
}
val projectByName = i.userSpecifiedCols.zip(i.query.output)
.map { case (userSpecifiedCol, queryOutputCol) =>
val resolvedCol = i.table.resolve(Seq(userSpecifiedCol), resolver)
.getOrElse(
throw QueryCompilationErrors.unresolvedAttributeError(
"UNRESOLVED_COLUMN", userSpecifiedCol, i.table.output.map(_.name), i.origin))
(queryOutputCol.dataType, resolvedCol.dataType) match {
case (input: StructType, expected: StructType) =>
// Rename inner fields of the input column to pass the by-name INSERT analysis.
Alias(Cast(queryOutputCol, renameFieldsInStruct(input, expected)), resolvedCol.name)()
case _ =>
Alias(queryOutputCol, resolvedCol.name)()
}
}
Project(projectByName, i.query)
}

private def renameFieldsInStruct(input: StructType, expected: StructType): StructType = {
if (input.length == expected.length) {
val newFields = input.zip(expected).map { case (f1, f2) =>
(f1.dataType, f2.dataType) match {
case (s1: StructType, s2: StructType) =>
f1.copy(name = f2.name, dataType = renameFieldsInStruct(s1, s2))
case _ =>
f1.copy(name = f2.name)
}
}
StructType(newFields)
} else {
input
}
}

object ResolveInsertInto extends ResolveInsertionBase {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
AlwaysProcess.fn, ruleId) {
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _)
if i.query.resolved =>
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) if i.query.resolved =>
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name)
Expand Down Expand Up @@ -1529,6 +1483,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
// Don't wait other rules to resolve the child plans of `InsertIntoStatement` as we need
// to resolve column "DEFAULT" in the child plans so that they must be unresolved.
case i: InsertIntoStatement => ResolveColumnDefaultInInsert(i)

// Wait for other rules to resolve child plans first
case p: LogicalPlan if !p.childrenResolved => p

Expand Down Expand Up @@ -1648,6 +1606,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// implementation and should be resolved based on the table schema.
o.copy(deleteExpr = resolveExpressionByPlanOutput(o.deleteExpr, o.table))

case u: UpdateTable => ResolveReferencesInUpdate(u)

case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _)
if !m.resolved && targetTable.resolved && sourceTable.resolved =>

Expand Down Expand Up @@ -1798,23 +1758,32 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case MergeResolvePolicy.SOURCE => Project(Nil, mergeInto.sourceTable)
case MergeResolvePolicy.TARGET => Project(Nil, mergeInto.targetTable)
}
resolveMergeExprOrFail(c, resolvePlan)
val resolvedExpr = resolveExprInAssignment(c, resolvePlan)
val withDefaultResolved = if (conf.enableDefaultColumns) {
resolveColumnDefaultInAssignmentValue(
resolvedKey,
resolvedExpr,
QueryCompilationErrors
.defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates())
} else {
resolvedExpr
}
checkResolvedMergeExpr(withDefaultResolved, resolvePlan)
withDefaultResolved
case o => o
}
Assignment(resolvedKey, resolvedValue)
}
}

private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = {
val resolved = resolveExpressionByPlanChildren(e, p)
resolved.references.filter { attribute: Attribute =>
!attribute.resolved &&
// We exclude attribute references named "DEFAULT" from consideration since they are
// handled exclusively by the ResolveDefaultColumns analysis rule. That rule checks the
// MERGE command for such references and either replaces each one with a corresponding
// value, or returns a custom error message.
normalizeFieldName(attribute.name) != normalizeFieldName(CURRENT_DEFAULT_COLUMN_NAME)
}.foreach { a =>
val resolved = resolveExprInAssignment(e, p)
checkResolvedMergeExpr(resolved, p)
resolved
}

private def checkResolvedMergeExpr(e: Expression, p: LogicalPlan): Unit = {
e.references.filter(!_.resolved).foreach { a =>
// Note: This will throw error only on unresolved attribute issues,
// not other resolution errors like mismatched data types.
val cols = p.inputSet.toSeq.map(_.sql).mkString(", ")
Expand All @@ -1824,10 +1793,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
"sqlExpr" -> a.sql,
"cols" -> cols))
}
resolved match {
case Alias(child: ExtractValue, _) => child
case other => other
}
}

// Expand the star expression using the input plan first. If failed, try resolve
Expand Down Expand Up @@ -3359,53 +3324,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

/**
* A special rule to reorder columns for DSv1 when users specify a column list in INSERT INTO.
* DSv2 is handled by [[ResolveInsertInto]] separately.
*/
object ResolveUserSpecifiedColumns extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
AlwaysProcess.fn, ruleId) {
case i: InsertIntoStatement if !i.table.isInstanceOf[DataSourceV2Relation] &&
i.table.resolved && i.query.resolved && i.userSpecifiedCols.nonEmpty =>
val resolved = resolveUserSpecifiedColumns(i)
val projection = addColumnListOnQuery(i.table.output, resolved, i.query)
i.copy(userSpecifiedCols = Nil, query = projection)
}

private def resolveUserSpecifiedColumns(i: InsertIntoStatement): Seq[NamedExpression] = {
SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver)

i.userSpecifiedCols.map { col =>
i.table.resolve(Seq(col), resolver).getOrElse {
val candidates = i.table.output.map(_.qualifiedName)
val orderedCandidates = StringUtils.orderSuggestedIdentifiersBySimilarity(col, candidates)
throw QueryCompilationErrors
.unresolvedAttributeError("UNRESOLVED_COLUMN", col, orderedCandidates, i.origin)
}
}
}

private def addColumnListOnQuery(
tableOutput: Seq[Attribute],
cols: Seq[NamedExpression],
query: LogicalPlan): LogicalPlan = {
if (cols.size != query.output.size) {
throw QueryCompilationErrors.writeTableWithMismatchedColumnsError(
cols.size, query.output.size, query)
}
val nameToQueryExpr = CUtils.toMap(cols, query.output)
// Static partition columns in the table output should not appear in the column list
// they will be handled in another rule ResolveInsertInto
val reordered = tableOutput.flatMap { nameToQueryExpr.get(_).orElse(None) }
if (reordered == query.output) {
query
} else {
Project(reordered, query)
}
}
}

private def validateStoreAssignmentPolicy(): Unit = {
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2.
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal}
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -103,8 +104,11 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
case assignment if assignment.key.semanticEquals(attr) => assignment
}
val resolvedValue = if (matchingAssignments.isEmpty) {
errors += s"No assignment for '${attr.name}'"
attr
val defaultExpr = getDefaultValueExprOrNullLit(attr, conf)
if (defaultExpr.isEmpty) {
errors += s"No assignment for '${attr.name}'"
}
defaultExpr.getOrElse(attr)
} else if (matchingAssignments.length > 1) {
val conflictingValuesStr = matchingAssignments.map(_.value.sql).mkString(", ")
errors += s"Multiple assignments for '${attr.name}': $conflictingValuesStr"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,21 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
}

def checkAnalysis0(plan: LogicalPlan): Unit = {
// The target table is not a child plan of the insert command. We should report errors for table
// not found first, instead of errors in the input query of the insert command, by doing a
// top-down traversal.
plan.foreach {
case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) =>
u.tableNotFound(u.multipartIdentifier)

// TODO (SPARK-27484): handle streaming write commands when we have them.
case write: V2WriteCommand if write.table.isInstanceOf[UnresolvedRelation] =>
val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier
write.table.tableNotFound(tblName)

case _ =>
}

// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
Expand Down Expand Up @@ -197,14 +212,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
errorClass = "_LEGACY_ERROR_TEMP_2313",
messageParameters = Map("name" -> u.name))

case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) =>
u.tableNotFound(u.multipartIdentifier)

// TODO (SPARK-27484): handle streaming write commands when we have them.
case write: V2WriteCommand if write.table.isInstanceOf[UnresolvedRelation] =>
val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier
write.table.tableNotFound(tblName)

case command: V2PartitionCommand =>
command.table match {
case r @ ResolvedTable(_, _, table, _) => table match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ trait ColumnResolutionHelper extends Logging {
allowOuter = allowOuter)
}

def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = {
resolveExpressionByPlanChildren(expr, hostPlan) match {
// Assignment key and value does not need the alias when resolving nested columns.
case Alias(child: ExtractValue, _) => child
case other => other
}
}

private def resolveExpressionByPlanId(
e: Expression,
q: LogicalPlan): Expression = {
Expand Down
Loading

0 comments on commit cc24978

Please sign in to comment.