Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30615][SQL] Introduce Analyzer rule for V2 AlterTable column change resolution #27350

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -240,6 +241,7 @@ class Analyzer(
TypeCoercion.typeCoercionRules(conf) ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Batch("Normalize Alter Table", Once, ResolveAlterTableChanges),
Batch("Remove Unresolved Hints", Once,
new ResolveHints.RemoveAllHints(conf)),
Batch("Nondeterministic", Once,
Expand Down Expand Up @@ -3002,6 +3004,160 @@ class Analyzer(
}
}
}

/** Rule to mostly resolve, normalize and rewrite column names based on case sensitivity. */
object ResolveAlterTableChanges extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case a @ AlterTable(_, _, t: NamedRelation, changes) if t.resolved =>
val schema = t.schema
val normalizedChanges = changes.flatMap {
case add: AddColumn =>
val parent = add.fieldNames().init
if (parent.nonEmpty) {
// Adding a nested field, need to normalize the parent column and position
val target = schema.findNestedField(parent, includeCollections = true, conf.resolver)
if (target.isEmpty) {
// Leave unresolved. Throws error in CheckAnalysis
Some(add)
} else {
val (normalizedName, sf) = target.get
sf.dataType match {
case struct: StructType =>
val pos = findColumnPosition(add.position(), parent.quoted, struct)
Some(TableChange.addColumn(
(normalizedName ++ Seq(sf.name, add.fieldNames().last)).toArray,
add.dataType(),
add.isNullable,
add.comment,
pos))

case other =>
Some(add)
}
}
} else {
// Adding to the root. Just need to normalize position
val pos = findColumnPosition(add.position(), "root", schema)
Some(TableChange.addColumn(
add.fieldNames(),
add.dataType(),
add.isNullable,
add.comment,
pos))
}

case typeChange: UpdateColumnType =>
// Hive style syntax provides the column type, even if it may not have changed
val fieldOpt = schema.findNestedField(
typeChange.fieldNames(), includeCollections = true, conf.resolver)

if (fieldOpt.isEmpty) {
// We couldn't resolve the field. Leave it to CheckAnalysis
Some(typeChange)
} else {
val (fieldNames, field) = fieldOpt.get
if (field.dataType == typeChange.newDataType()) {
// The user didn't want the field to change, so remove this change
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove other column changes if they are noop, e.g. UpdateColumnNullability without changing nullability. We can address it in a followup.

None
} else {
Some(TableChange.updateColumnType(
(fieldNames :+ field.name).toArray, typeChange.newDataType()))
}
}
case n: UpdateColumnNullability =>
// Need to resolve column
resolveFieldNames(
schema,
n.fieldNames(),
TableChange.updateColumnNullability(_, n.nullable())).orElse(Some(n))

case position: UpdateColumnPosition =>
position.position() match {
case after: After =>
// Need to resolve column as well as position reference
val fieldOpt = schema.findNestedField(
position.fieldNames(), includeCollections = true, conf.resolver)

if (fieldOpt.isEmpty) {
Some(position)
} else {
val (normalizedPath, field) = fieldOpt.get
val targetCol = schema.findNestedField(
normalizedPath :+ after.column(), includeCollections = true, conf.resolver)
if (targetCol.isEmpty) {
// Leave unchanged to CheckAnalysis
Some(position)
} else {
Some(TableChange.updateColumnPosition(
(normalizedPath :+ field.name).toArray,
ColumnPosition.after(targetCol.get._2.name)))
}
}
case _ =>
// Need to resolve column
resolveFieldNames(
schema,
position.fieldNames(),
TableChange.updateColumnPosition(_, position.position())).orElse(Some(position))
}

case comment: UpdateColumnComment =>
resolveFieldNames(
schema,
comment.fieldNames(),
TableChange.updateColumnComment(_, comment.newComment())).orElse(Some(comment))

case rename: RenameColumn =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we check that the new name doesn't conflict with the existing field names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left to CheckAnalysis

resolveFieldNames(
schema,
rename.fieldNames(),
TableChange.renameColumn(_, rename.newName())).orElse(Some(rename))

case delete: DeleteColumn =>
resolveFieldNames(schema, delete.fieldNames(), TableChange.deleteColumn)
.orElse(Some(delete))

case column: ColumnChange =>
// This is informational for future developers
throw new UnsupportedOperationException(
"Please add an implementation for a column change here")
case other => Some(other)
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
}

a.copy(changes = normalizedChanges)
}

/**
* Returns the table change if the field can be resolved, returns None if the column is not
* found. An error will be thrown in CheckAnalysis for columns that can't be resolved.
*/
private def resolveFieldNames(
schema: StructType,
fieldNames: Array[String],
copy: Array[String] => TableChange): Option[TableChange] = {
val fieldOpt = schema.findNestedField(
fieldNames, includeCollections = true, conf.resolver)
fieldOpt.map { case (path, field) => copy((path :+ field.name).toArray) }
}

private def findColumnPosition(
position: ColumnPosition,
field: String,
struct: StructType): ColumnPosition = {
position match {
case null => null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when will ColumnPosition be null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you're adding a column without a position (so to the end)

case after: After =>
struct.fieldNames.find(n => conf.resolver(n, after.column())) match {
case Some(colName) =>
ColumnPosition.after(colName)
case None =>
throw new AnalysisException("Couldn't find the reference column for " +
s"$after at $field")
}
case other => other
}
}
}
}

/**
Expand Down
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnType}
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -425,24 +425,56 @@ trait CheckAnalysis extends PredicateHelper {
case _ =>
}

case alter: AlterTable if alter.childrenResolved =>
case alter: AlterTable if alter.table.resolved =>
val table = alter.table
def findField(operation: String, fieldName: Array[String]): StructField = {
// include collections because structs nested in maps and arrays may be altered
val field = table.schema.findNestedField(fieldName, includeCollections = true)
if (field.isEmpty) {
throw new AnalysisException(
s"Cannot $operation missing field in ${table.name} schema: ${fieldName.quoted}")
alter.failAnalysis(
s"Cannot $operation missing field ${fieldName.quoted} in ${table.name} schema: " +
table.schema.treeString)
}
field.get._2
}
def positionArgumentExists(position: ColumnPosition, struct: StructType): Unit = {
position match {
case after: After =>
if (!struct.fieldNames.contains(after.column())) {
alter.failAnalysis(s"Couldn't resolve positional argument $position amongst " +
s"${struct.fieldNames.mkString("[", ", ", "]")}")
}
case _ =>
}
}
def findParentStruct(operation: String, fieldNames: Array[String]): StructType = {
val parent = fieldNames.init
val field = if (parent.nonEmpty) {
findField(operation, parent).dataType
} else {
table.schema
}
field match {
case s: StructType => s
case o => alter.failAnalysis(s"Cannot $operation ${fieldNames.quoted}, because " +
s"its parent is not a StructType. Found $o")
}
}
def checkColumnNotExists(
operation: String,
fieldNames: Array[String],
struct: StructType): Unit = {
if (struct.findNestedField(fieldNames, includeCollections = true).isDefined) {
alter.failAnalysis(s"Cannot $operation column, because ${fieldNames.quoted} " +
s"already exists in ${struct.treeString}")
}
field.get
}

alter.changes.foreach {
case add: AddColumn =>
val parent = add.fieldNames.init
if (parent.nonEmpty) {
findField("add to", parent)
}
checkColumnNotExists("add", add.fieldNames(), table.schema)
val parent = findParentStruct("add", add.fieldNames())
positionArgumentExists(add.position(), parent)
TypeUtils.failWithIntervalType(add.dataType())
case update: UpdateColumnType =>
val field = findField("update", update.fieldNames)
Expand All @@ -467,7 +499,7 @@ trait CheckAnalysis extends PredicateHelper {
// update is okay
}
if (!Cast.canUpCast(field.dataType, update.newDataType)) {
throw new AnalysisException(
alter.failAnalysis(
s"Cannot update ${table.name} field $fieldName: " +
s"${field.dataType.simpleString} cannot be cast to " +
s"${update.newDataType.simpleString}")
Expand All @@ -476,11 +508,17 @@ trait CheckAnalysis extends PredicateHelper {
val field = findField("update", update.fieldNames)
val fieldName = update.fieldNames.quoted
if (!update.nullable && field.nullable) {
throw new AnalysisException(
alter.failAnalysis(
s"Cannot change nullable column to non-nullable: $fieldName")
}
case updatePos: UpdateColumnPosition =>
findField("update", updatePos.fieldNames)
val parent = findParentStruct("update", updatePos.fieldNames())
positionArgumentExists(updatePos.position(), parent)
case rename: RenameColumn =>
findField("rename", rename.fieldNames)
checkColumnNotExists(
"rename", rename.fieldNames().init :+ rename.newName(), table.schema)
case update: UpdateColumnComment =>
findField("update", update.fieldNames)
case delete: DeleteColumn =>
Expand Down
Expand Up @@ -25,6 +25,8 @@ import org.json4s.JsonDSL._

import org.apache.spark.SparkException
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString, StringUtils}
Expand Down Expand Up @@ -308,52 +310,75 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
}

/**
* Returns a field in this struct and its child structs.
* Returns the normalized path to a field and the field in this struct and its child structs.
*
* If includeCollections is true, this will return fields that are nested in maps and arrays.
*/
private[sql] def findNestedField(
fieldNames: Seq[String],
includeCollections: Boolean = false): Option[StructField] = {
fieldNames.headOption.flatMap(nameToField.get) match {
case Some(field) =>
(fieldNames.tail, field.dataType, includeCollections) match {
case (Seq(), _, _) =>
Some(field)
includeCollections: Boolean = false,
resolver: Resolver = _ == _): Option[(Seq[String], StructField)] = {
def prettyFieldName(nameParts: Seq[String]): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse CatalogV2Implicits.MultipartIdentifierHelper.quoted ?

import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
nameParts.quoted
}

def findField(
struct: StructType,
searchPath: Seq[String],
normalizedPath: Seq[String]): Option[(Seq[String], StructField)] = {
searchPath.headOption.flatMap { searchName =>
val found = struct.fields.filter(f => resolver(searchName, f.name))
if (found.length > 1) {
val names = found.map(f => prettyFieldName(normalizedPath :+ f.name))
.mkString("[", ", ", " ]")
throw new AnalysisException(
s"Ambiguous field name: ${prettyFieldName(normalizedPath :+ searchName)}. Found " +
s"multiple columns that can match: $names")
} else if (found.isEmpty) {
None
} else {
val field = found.head
(searchPath.tail, field.dataType, includeCollections) match {
case (Seq(), _, _) =>
Some(normalizedPath -> field)

case (names, struct: StructType, _) =>
struct.findNestedField(names, includeCollections)
case (names, struct: StructType, _) =>
findField(struct, names, normalizedPath :+ field.name)

case (_, _, false) =>
None // types nested in maps and arrays are not used
case (_, _, false) =>
None // types nested in maps and arrays are not used

case (Seq("key"), MapType(keyType, _, _), true) =>
// return the key type as a struct field to include nullability
Some(StructField("key", keyType, nullable = false))
case (Seq("key"), MapType(keyType, _, _), true) =>
// return the key type as a struct field to include nullability
Some((normalizedPath :+ field.name) -> StructField("key", keyType, nullable = false))

case (Seq("key", names @ _*), MapType(struct: StructType, _, _), true) =>
struct.findNestedField(names, includeCollections)
case (Seq("key", names @ _*), MapType(struct: StructType, _, _), true) =>
findField(struct, names, normalizedPath ++ Seq(field.name, "key"))

case (Seq("value"), MapType(_, valueType, isNullable), true) =>
// return the value type as a struct field to include nullability
Some(StructField("value", valueType, nullable = isNullable))
case (Seq("value"), MapType(_, valueType, isNullable), true) =>
// return the value type as a struct field to include nullability
Some((normalizedPath :+ field.name) ->
StructField("value", valueType, nullable = isNullable))

case (Seq("value", names @ _*), MapType(_, struct: StructType, _), true) =>
struct.findNestedField(names, includeCollections)
case (Seq("value", names @ _*), MapType(_, struct: StructType, _), true) =>
findField(struct, names, normalizedPath ++ Seq(field.name, "value"))

case (Seq("element"), ArrayType(elementType, isNullable), true) =>
// return the element type as a struct field to include nullability
Some(StructField("element", elementType, nullable = isNullable))
case (Seq("element"), ArrayType(elementType, isNullable), true) =>
// return the element type as a struct field to include nullability
Some((normalizedPath :+ field.name) ->
StructField("element", elementType, nullable = isNullable))

case (Seq("element", names @ _*), ArrayType(struct: StructType, _), true) =>
struct.findNestedField(names, includeCollections)
case (Seq("element", names @ _*), ArrayType(struct: StructType, _), true) =>
findField(struct, names, normalizedPath ++ Seq(field.name, "element"))

case _ =>
None
case _ =>
None
}
}
case _ =>
None
}
}
findField(this, fieldNames, Nil)
}

protected[sql] def toAttributes: Seq[AttributeReference] =
Expand Down