Skip to content

Commit

Permalink
[SPARK-34302][SQL] Migrate ALTER TABLE ... CHANGE COLUMN command to u…
Browse files Browse the repository at this point in the history
…se UnresolvedTable to resolve the identifier

### What changes were proposed in this pull request?

This PR proposes to migrate the following `ALTER TABLE ... CHANGE COLUMN` command to use `UnresolvedTable` as a `child` to resolve the table identifier. This allows consistent resolution rules (temp view first, etc.) to be applied for both v1/v2 commands. More info about the consistent resolution rule proposal can be found in [JIRA](https://issues.apache.org/jira/browse/SPARK-29900) or [proposal doc](https://docs.google.com/document/d/1hvLjGA8y_W_hhilpngXVub1Ebv8RsMap986nENCFnrg/edit?usp=sharing).

### Why are the changes needed?

This is a part of effort to make the relation lookup behavior consistent: [SPARK-29900](https://issues.apache.org/jira/browse/SPARK-29900).

### Does this PR introduce _any_ user-facing change?

After this PR, the above `ALTER TABLE ... CHANGE COLUMN` commands will have a consistent resolution behavior.

### How was this patch tested?

Updated existing tests.

Closes #33113 from imback82/alter_change_column.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
imback82 authored and cloud-fan committed Jun 29, 2021
1 parent 622fc68 commit 620fde4
Show file tree
Hide file tree
Showing 13 changed files with 301 additions and 311 deletions.
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
Expand Down Expand Up @@ -299,7 +299,7 @@ class Analyzer(override val catalogManager: CatalogManager)
Batch("Post-Hoc Resolution", Once,
Seq(ResolveCommandsWithIfExists) ++
postHocResolutionRules: _*),
Batch("Normalize Alter Table Field Names", Once, ResolveFieldNames),
Batch("Normalize Alter Table Commands", Once, ResolveAlterTableCommands),
Batch("Normalize Alter Table", Once, ResolveAlterTableChanges),
Batch("Remove Unresolved Hints", Once,
new ResolveHints.RemoveAllHints),
Expand Down Expand Up @@ -3527,13 +3527,35 @@ class Analyzer(override val catalogManager: CatalogManager)
* Rule to mostly resolve, normalize and rewrite column names based on case sensitivity
* for alter table commands.
*/
object ResolveFieldNames extends Rule[LogicalPlan] {
object ResolveAlterTableCommands extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case a: AlterTableCommand if a.table.resolved =>
a.transformExpressions {
val table = a.table.asInstanceOf[ResolvedTable]
val transformed = a.transformExpressions {
case u: UnresolvedFieldName =>
val table = a.table.asInstanceOf[ResolvedTable]
resolveFieldNames(table.schema, u.name).map(ResolvedFieldName(_)).getOrElse(u)
resolveFieldNames(table.schema, u.name).getOrElse(u)
case u: UnresolvedFieldPosition => u.position match {
case after: After =>
resolveFieldNames(table.schema, u.fieldName.init :+ after.column())
.map { resolved =>
ResolvedFieldPosition(ColumnPosition.after(resolved.field.name))
}.getOrElse(u)
case _ => ResolvedFieldPosition(u.position)
}
}

transformed match {
case alter @ AlterTableAlterColumn(
_: ResolvedTable, ResolvedFieldName(_, field), Some(dataType), _, _, _) =>
// Hive style syntax provides the column type, even if it may not have changed.
val dt = CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType)
if (dt == dataType) {
// The user didn't want the field to change, so remove this change.
alter.copy(dataType = None)
} else {
alter
}
case other => other
}
}

Expand All @@ -3543,10 +3565,10 @@ class Analyzer(override val catalogManager: CatalogManager)
*/
private def resolveFieldNames(
schema: StructType,
fieldNames: Seq[String]): Option[Seq[String]] = {
fieldNames: Seq[String]): Option[ResolvedFieldName] = {
val fieldOpt = schema.findNestedField(
fieldNames, includeCollections = true, conf.resolver)
fieldOpt.map { case (path, field) => path :+ field.name }
fieldOpt.map { case (path, field) => ResolvedFieldName(path, field) }
}
}

Expand Down Expand Up @@ -3598,68 +3620,6 @@ class Analyzer(override val catalogManager: CatalogManager)
Some(addColumn(schema, "root", Nil))
}

case typeChange: UpdateColumnType =>
// Hive style syntax provides the column type, even if it may not have changed
val fieldOpt = schema.findNestedField(
typeChange.fieldNames(), includeCollections = true, conf.resolver)

if (fieldOpt.isEmpty) {
// We couldn't resolve the field. Leave it to CheckAnalysis
Some(typeChange)
} else {
val (fieldNames, field) = fieldOpt.get
val dt = CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType)
if (dt == typeChange.newDataType()) {
// The user didn't want the field to change, so remove this change
None
} else {
Some(TableChange.updateColumnType(
(fieldNames :+ field.name).toArray, typeChange.newDataType()))
}
}
case n: UpdateColumnNullability =>
// Need to resolve column
resolveFieldNames(
schema,
n.fieldNames(),
TableChange.updateColumnNullability(_, n.nullable())).orElse(Some(n))

case position: UpdateColumnPosition =>
position.position() match {
case after: After =>
// Need to resolve column as well as position reference
val fieldOpt = schema.findNestedField(
position.fieldNames(), includeCollections = true, conf.resolver)

if (fieldOpt.isEmpty) {
Some(position)
} else {
val (normalizedPath, field) = fieldOpt.get
val targetCol = schema.findNestedField(
normalizedPath :+ after.column(), includeCollections = true, conf.resolver)
if (targetCol.isEmpty) {
// Leave unchanged to CheckAnalysis
Some(position)
} else {
Some(TableChange.updateColumnPosition(
(normalizedPath :+ field.name).toArray,
ColumnPosition.after(targetCol.get._2.name)))
}
}
case _ =>
// Need to resolve column
resolveFieldNames(
schema,
position.fieldNames(),
TableChange.updateColumnPosition(_, position.position())).orElse(Some(position))
}

case comment: UpdateColumnComment =>
resolveFieldNames(
schema,
comment.fieldNames(),
TableChange.updateColumnComment(_, comment.newComment())).orElse(Some(comment))

case delete: DeleteColumn =>
resolveFieldNames(schema, delete.fieldNames(), TableChange.deleteColumn)
.orElse(Some(delete))
Expand Down
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils}
import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement}
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -444,12 +444,42 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
write.query.schema.foreach(f => TypeUtils.failWithIntervalType(f.dataType))

case alter: AlterTableCommand if alter.table.resolved =>
val table = alter.table.asInstanceOf[ResolvedTable]
def findField(fieldName: Seq[String]): StructField = {
// Include collections because structs nested in maps and arrays may be altered.
val field = table.schema.findNestedField(fieldName, includeCollections = true)
if (field.isEmpty) {
alter.failAnalysis(s"Cannot ${alter.operation} missing field ${fieldName.quoted} " +
s"in ${table.name} schema: ${table.schema.treeString}")
}
field.get._2
}
def findParentStruct(fieldNames: Seq[String]): StructType = {
val parent = fieldNames.init
val field = if (parent.nonEmpty) {
findField(parent).dataType
} else {
table.schema
}
field match {
case s: StructType => s
case o => alter.failAnalysis(s"Cannot ${alter.operation} ${fieldNames.quoted}, " +
s"because its parent is not a StructType. Found $o")
}
}
alter.transformExpressions {
case u: UnresolvedFieldName =>
val table = alter.table.asInstanceOf[ResolvedTable]
alter.failAnalysis(
s"Cannot ${alter.operation} missing field ${u.name.quoted} in ${table.name} " +
s"schema: ${table.schema.treeString}")
case UnresolvedFieldName(name) =>
alter.failAnalysis(s"Cannot ${alter.operation} missing field ${name.quoted} in " +
s"${table.name} schema: ${table.schema.treeString}")
case UnresolvedFieldPosition(fieldName, position: After) =>
val parent = findParentStruct(fieldName)
val allFields = parent match {
case s: StructType => s.fieldNames
case o => alter.failAnalysis(s"Cannot ${alter.operation} ${fieldName.quoted}, " +
s"because its parent is not a StructType. Found $o")
}
alter.failAnalysis(s"Couldn't resolve positional argument $position amongst " +
s"${allFields.mkString("[", ", ", "]")}")
}
checkAlterTableCommand(alter)

Expand Down Expand Up @@ -522,66 +552,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
positionArgumentExists(add.position(), parent, fieldsAdded)
TypeUtils.failWithIntervalType(add.dataType())
colsToAdd(parentName) = fieldsAdded :+ add.fieldNames().last
case update: UpdateColumnType =>
val field = {
val f = findField("update", update.fieldNames)
CharVarcharUtils.getRawType(f.metadata)
.map(dt => f.copy(dataType = dt))
.getOrElse(f)
}
val fieldName = update.fieldNames.quoted
update.newDataType match {
case _: StructType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
s"update a struct by updating its fields")
case _: MapType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
s"update a map by updating $fieldName.key or $fieldName.value")
case _: ArrayType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
s"update the element by updating $fieldName.element")
case u: UserDefinedType[_] =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
s"update a UserDefinedType[${u.sql}] by updating its fields")
case _: CalendarIntervalType | _: YearMonthIntervalType |
_: DayTimeIntervalType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName to " +
s"interval type")
case _ =>
// update is okay
}

// We don't need to handle nested types here which shall fail before
def canAlterColumnType(from: DataType, to: DataType): Boolean = (from, to) match {
case (CharType(l1), CharType(l2)) => l1 == l2
case (CharType(l1), VarcharType(l2)) => l1 <= l2
case (VarcharType(l1), VarcharType(l2)) => l1 <= l2
case _ => Cast.canUpCast(from, to)
}

if (!canAlterColumnType(field.dataType, update.newDataType)) {
alter.failAnalysis(
s"Cannot update ${table.name} field $fieldName: " +
s"${field.dataType.simpleString} cannot be cast to " +
s"${update.newDataType.simpleString}")
}
case update: UpdateColumnNullability =>
val field = findField("update", update.fieldNames)
val fieldName = update.fieldNames.quoted
if (!update.nullable && field.nullable) {
alter.failAnalysis(
s"Cannot change nullable column to non-nullable: $fieldName")
}
case updatePos: UpdateColumnPosition =>
findField("update", updatePos.fieldNames)
val parent = findParentStruct("update", updatePos.fieldNames())
val parentName = updatePos.fieldNames().init
positionArgumentExists(
updatePos.position(),
parent,
colsToAdd.getOrElse(parentName, Nil))
case update: UpdateColumnComment =>
findField("update", update.fieldNames)
case delete: DeleteColumn =>
findField("delete", delete.fieldNames)
// REPLACE COLUMNS has deletes followed by adds. Remember the deleted columns
Expand Down Expand Up @@ -1088,8 +1058,51 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
}

alter match {
case AlterTableRenameColumn(table: ResolvedTable, ResolvedFieldName(name), newName) =>
checkColumnNotExists(name.init :+ newName, table.schema)
case AlterTableRenameColumn(table: ResolvedTable, col: ResolvedFieldName, newName) =>
checkColumnNotExists(col.path :+ newName, table.schema)
case a @ AlterTableAlterColumn(table: ResolvedTable, col: ResolvedFieldName, _, _, _, _) =>
val fieldName = col.name.quoted
if (a.dataType.isDefined) {
val field = CharVarcharUtils.getRawType(col.field.metadata)
.map(dt => col.field.copy(dataType = dt))
.getOrElse(col.field)
val newDataType = a.dataType.get
newDataType match {
case _: StructType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
"update a struct by updating its fields")
case _: MapType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
s"update a map by updating $fieldName.key or $fieldName.value")
case _: ArrayType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
s"update the element by updating $fieldName.element")
case u: UserDefinedType[_] =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
s"update a UserDefinedType[${u.sql}] by updating its fields")
case _: CalendarIntervalType | _: YearMonthIntervalType | _: DayTimeIntervalType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName to interval type")
case _ => // update is okay
}

// We don't need to handle nested types here which shall fail before.
def canAlterColumnType(from: DataType, to: DataType): Boolean = (from, to) match {
case (CharType(l1), CharType(l2)) => l1 == l2
case (CharType(l1), VarcharType(l2)) => l1 <= l2
case (VarcharType(l1), VarcharType(l2)) => l1 <= l2
case _ => Cast.canUpCast(from, to)
}

if (!canAlterColumnType(field.dataType, newDataType)) {
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName: " +
s"${field.dataType.simpleString} cannot be cast to ${newDataType.simpleString}")
}
}
if (a.nullable.isDefined) {
if (!a.nullable.get && col.field.nullable) {
alter.failAnalysis(s"Cannot change nullable column to non-nullable: $fieldName")
}
}
case _ =>
}
}
Expand Down
Expand Up @@ -66,28 +66,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
}
createAlterTable(nameParts, catalog, tbl, changes)

case a @ AlterTableAlterColumnStatement(
nameParts @ NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _) =>
a.dataType.foreach(failNullType)
val colName = a.column.toArray
val typeChange = a.dataType.map { newDataType =>
TableChange.updateColumnType(colName, newDataType)
}
val nullabilityChange = a.nullable.map { nullable =>
TableChange.updateColumnNullability(colName, nullable)
}
val commentChange = a.comment.map { newComment =>
TableChange.updateColumnComment(colName, newComment)
}
val positionChange = a.position.map { newPosition =>
TableChange.updateColumnPosition(colName, newPosition)
}
createAlterTable(
nameParts,
catalog,
tbl,
typeChange.toSeq ++ nullabilityChange ++ commentChange ++ positionChange)

case c @ CreateTableStatement(
NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) =>
assertNoNullTypeInSchema(c.tableSchema)
Expand Down

0 comments on commit 620fde4

Please sign in to comment.