From 357232c8b51315a4bf2adcc83531ccebd209ce4f Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 23 Jan 2020 19:14:00 -0800 Subject: [PATCH 01/10] Added checks to normalize columns --- .../sql/catalyst/analysis/Analyzer.scala | 122 ++++++++++++++++++ .../apache/spark/sql/types/StructType.scala | 82 +++++++----- 2 files changed, 174 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 45547bff8a9d6..c50efa02a1389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf @@ -3001,6 +3002,127 @@ class Analyzer( } } } + + object ResolveAlterTableChanges extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case a @ AlterTable(_, _, t: DataSourceV2Relation, changes) => + val schema = t.schema + val normalizedChanges = changes.flatMap { + case add: AddColumn if add.position() != null && add.position().isInstanceOf[After] => + val parent = add.fieldNames().init + val target = schema.findNestedField(parent, includeCollections = true) + if (target.isEmpty) { + Some(add) + } else { + val sf = target.get._2 + sf.dataType match { + case struct: StructType => + val after = add.position().asInstanceOf[After] + struct.fieldNames.find(n => conf.resolver(n, after.column())) match { + case Some(colName) => + Some(TableChange.addColumn( + add.fieldNames(), + add.dataType(), + add.isNullable, + add.comment, + ColumnPosition.after(colName))) + case None => + throw new AnalysisException("Couldn't find the reference column for " + + s"AFTER ${after.column()} at ${UnresolvedAttribute(parent).name}") + } + case other => + throw new AnalysisException( + s"Columns can only be added to struct types. Found ${other.simpleString}.") + } + } + + case typeChange: UpdateColumnType => + // Hive style syntax provides the column type, even if it may not have changed + val fieldOpt = schema.findNestedField( + typeChange.fieldNames(), includeCollections = true, conf.resolver) + + if (fieldOpt.isEmpty) { + // We couldn't resolve the field. Leave it to CheckAnalysis + Some(typeChange) + } else { + val (fieldNames, field) = fieldOpt.get + if (field.dataType == typeChange.newDataType()) { + // The user didn't want the field to change, so remove this change + None + } else { + Some(TableChange.updateColumnType( + (fieldNames :+ field.name).toArray, typeChange.newDataType())) + } + } + case n: UpdateColumnNullability => + resolveFieldNames( + schema, + n.fieldNames(), + TableChange.updateColumnNullability(_, n.nullable())).orElse(Some(n)) + + case position: UpdateColumnPosition => + position.position() match { + case after: After => // resolve this column as well + val fieldOpt = schema.findNestedField( + position.fieldNames(), includeCollections = true, conf.resolver) + + if (fieldOpt.isEmpty) { + Some(position) + } else { + val (normalizedPath, field) = fieldOpt.get + val targetCol = schema.findNestedField( + normalizedPath :+ after.column(), includeCollections = true, conf.resolver) + if (targetCol.isEmpty) { + throw new AnalysisException("Couldn't find the reference column for " + + s"AFTER ${after.column()} at ${UnresolvedAttribute(normalizedPath).name}") + } else { + Some(TableChange.updateColumnPosition( + (normalizedPath :+ field.name).toArray, + ColumnPosition.after(targetCol.get._2.name))) + } + } + case _ => + resolveFieldNames( + schema, + position.fieldNames(), + TableChange.updateColumnPosition(_, position.position())).orElse(Some(position)) + } + + case comment: UpdateColumnComment => + resolveFieldNames( + schema, + comment.fieldNames(), + TableChange.updateColumnComment(_, comment.newComment())).orElse(Some(comment)) + + case rename: RenameColumn => + resolveFieldNames( + schema, + rename.fieldNames(), + TableChange.renameColumn(_, rename.newName())).orElse(Some(rename)) + + case delete: DeleteColumn => + resolveFieldNames(schema, delete.fieldNames(), TableChange.deleteColumn) + .orElse(Some(delete)) + + case column: ColumnChange => + // This is informational for future developers + throw new UnsupportedOperationException( + "Please add an implementation for a column change here") + case other => Some(other) + } + + a.copy(changes = normalizedChanges) + } + + private def resolveFieldNames( + schema: StructType, + fieldNames: Array[String], + copy: Array[String] => TableChange): Option[TableChange] = { + val fieldOpt = schema.findNestedField( + fieldNames, includeCollections = true, conf.resolver) + fieldOpt.map { case (path, field) => copy(path +: field.name) } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 476d47a2942b2..db7bca06d661a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -25,6 +25,8 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.Stable +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString, StringUtils} @@ -308,52 +310,72 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } /** - * Returns a field in this struct and its child structs. + * Returns the normalized path to a field and the field in this struct and its child structs. * * If includeCollections is true, this will return fields that are nested in maps and arrays. */ private[sql] def findNestedField( fieldNames: Seq[String], - includeCollections: Boolean = false): Option[StructField] = { - fieldNames.headOption.flatMap(nameToField.get) match { - case Some(field) => - (fieldNames.tail, field.dataType, includeCollections) match { - case (Seq(), _, _) => - Some(field) + includeCollections: Boolean = false, + resolver: Resolver = _ == _): Option[(Seq[String], StructField)] = { + def prettyFieldName(nameParts: Seq[String]): String = { + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + } + + def findField( + struct: StructType, + searchPath: Seq[String], + normalizedPath: Seq[String]): Option[(Seq[String], StructField)] = { + searchPath.headOption.flatMap { searchName => + val found = this.fields.filter(f => resolver(searchName, f.name)) + if (found.length > 1) { + val names = found.map(f => prettyFieldName(normalizedPath :+ f.name)) + .mkString("[", ", ", " ]") + throw new AnalysisException( + s"Ambiguous field name: ${prettyFieldName(normalizedPath :+ searchName)}. Found " + + s"multiple columns that can match: $names") + } else if (found.isEmpty) { + None + } else { + val field = found.head + (fieldNames.tail, field.dataType, includeCollections) match { + case (Seq(), _, _) => + Some(normalizedPath -> field) - case (names, struct: StructType, _) => - struct.findNestedField(names, includeCollections) + case (names, struct: StructType, _) => + findField(struct, names, normalizedPath :+ field.name) - case (_, _, false) => - None // types nested in maps and arrays are not used + case (_, _, false) => + None // types nested in maps and arrays are not used - case (Seq("key"), MapType(keyType, _, _), true) => - // return the key type as a struct field to include nullability - Some(StructField("key", keyType, nullable = false)) + case (Seq("key"), MapType(keyType, _, _), true) => + // return the key type as a struct field to include nullability + Some(normalizedPath -> StructField("key", keyType, nullable = false)) - case (Seq("key", names @ _*), MapType(struct: StructType, _, _), true) => - struct.findNestedField(names, includeCollections) + case (Seq("key", names @ _*), MapType(struct: StructType, _, _), true) => + findField(struct, names, normalizedPath ++ Seq(field.name, "key")) - case (Seq("value"), MapType(_, valueType, isNullable), true) => - // return the value type as a struct field to include nullability - Some(StructField("value", valueType, nullable = isNullable)) + case (Seq("value"), MapType(_, valueType, isNullable), true) => + // return the value type as a struct field to include nullability + Some(normalizedPath -> StructField("value", valueType, nullable = isNullable)) - case (Seq("value", names @ _*), MapType(_, struct: StructType, _), true) => - struct.findNestedField(names, includeCollections) + case (Seq("value", names @ _*), MapType(_, struct: StructType, _), true) => + findField(struct, names, normalizedPath ++ Seq(field.name, "value")) - case (Seq("element"), ArrayType(elementType, isNullable), true) => - // return the element type as a struct field to include nullability - Some(StructField("element", elementType, nullable = isNullable)) + case (Seq("element"), ArrayType(elementType, isNullable), true) => + // return the element type as a struct field to include nullability + Some(normalizedPath -> StructField("element", elementType, nullable = isNullable)) - case (Seq("element", names @ _*), ArrayType(struct: StructType, _), true) => - struct.findNestedField(names, includeCollections) + case (Seq("element", names @ _*), ArrayType(struct: StructType, _), true) => + findField(struct, names, normalizedPath ++ Seq(field.name, "element")) - case _ => - None + case _ => + None + } } - case _ => - None + } } + findField(this, fieldNames, Nil) } protected[sql] def toAttributes: Seq[AttributeReference] = From 54266d925c0850ede2aa6466a8a46407abdd9570 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 23 Jan 2020 19:26:11 -0800 Subject: [PATCH 02/10] add rule --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c50efa02a1389..6e13b0dec7d53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -241,6 +241,8 @@ class Analyzer( TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), + Batch("Normalize Alter Table", Once, + ResolveAlterTableChanges :: Nil), Batch("Remove Unresolved Hints", Once, new ResolveHints.RemoveAllHints(conf)), Batch("Nondeterministic", Once, From f4db1b264995bf78c4ad0fabb11df6c115ac5b76 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 24 Jan 2020 09:36:55 -0800 Subject: [PATCH 03/10] fix --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6e13b0dec7d53..92acad7f0f9cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -241,8 +241,7 @@ class Analyzer( TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), - Batch("Normalize Alter Table", Once, - ResolveAlterTableChanges :: Nil), + Batch("Normalize Alter Table", Once, ResolveAlterTableChanges), Batch("Remove Unresolved Hints", Once, new ResolveHints.RemoveAllHints(conf)), Batch("Nondeterministic", Once, From 5c369afd06cd470d02d3e8271c5689d395baffd5 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 24 Jan 2020 11:01:13 -0800 Subject: [PATCH 04/10] builds --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 92acad7f0f9cc..3ac3ef0a2b455 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3121,7 +3121,7 @@ class Analyzer( copy: Array[String] => TableChange): Option[TableChange] = { val fieldOpt = schema.findNestedField( fieldNames, includeCollections = true, conf.resolver) - fieldOpt.map { case (path, field) => copy(path +: field.name) } + fieldOpt.map { case (path, field) => copy((path :+ field.name).toArray) } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d6fc1dc6ddc3d..63b893d809db8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -434,7 +434,7 @@ trait CheckAnalysis extends PredicateHelper { throw new AnalysisException( s"Cannot $operation missing field in ${table.name} schema: ${fieldName.quoted}") } - field.get + field.get._2 } alter.changes.foreach { From 9c3022dc156d1f02b3dff56f62d535785d56c000 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Jan 2020 20:44:59 -0800 Subject: [PATCH 05/10] Tests ready too --- .../sql/catalyst/analysis/Analyzer.scala | 73 ++++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 40 ++- .../apache/spark/sql/types/StructType.scala | 4 +- .../sql/catalyst/analysis/AnalysisTest.scala | 9 +- ...eateTablePartitioningValidationSuite.scala | 4 +- .../V2CommandsCaseSensitivitySuite.scala | 227 ++++++++++++++++++ .../command/PlanResolutionSuite.scala | 26 +- 7 files changed, 344 insertions(+), 39 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3ac3ef0a2b455..4d7f374a7a2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3006,35 +3006,40 @@ class Analyzer( object ResolveAlterTableChanges extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case a @ AlterTable(_, _, t: DataSourceV2Relation, changes) => + case a @ AlterTable(_, _, t: NamedRelation, changes) if t.resolved => val schema = t.schema val normalizedChanges = changes.flatMap { - case add: AddColumn if add.position() != null && add.position().isInstanceOf[After] => + case add: AddColumn => val parent = add.fieldNames().init - val target = schema.findNestedField(parent, includeCollections = true) - if (target.isEmpty) { - Some(add) - } else { - val sf = target.get._2 - sf.dataType match { - case struct: StructType => - val after = add.position().asInstanceOf[After] - struct.fieldNames.find(n => conf.resolver(n, after.column())) match { - case Some(colName) => - Some(TableChange.addColumn( - add.fieldNames(), - add.dataType(), - add.isNullable, - add.comment, - ColumnPosition.after(colName))) - case None => - throw new AnalysisException("Couldn't find the reference column for " + - s"AFTER ${after.column()} at ${UnresolvedAttribute(parent).name}") - } - case other => - throw new AnalysisException( - s"Columns can only be added to struct types. Found ${other.simpleString}.") + if (parent.nonEmpty) { + val target = schema.findNestedField(parent, includeCollections = true, conf.resolver) + if (target.isEmpty) { + Some(add) + } else { + val (normalizedName, sf) = target.get + sf.dataType match { + case struct: StructType => + val pos = findColumnPosition(add.position(), parent.quoted, struct) + Some(TableChange.addColumn( + (normalizedName ++ Seq(sf.name, add.fieldNames().last)).toArray, + add.dataType(), + add.isNullable, + add.comment, + pos)) + + case other => + throw new AnalysisException( + s"Columns can only be added to struct types. Found ${other.simpleString}.") + } } + } else { + val pos = findColumnPosition(add.position(), "root", schema) + Some(TableChange.addColumn( + add.fieldNames(), + add.dataType(), + add.isNullable, + add.comment, + pos)) } case typeChange: UpdateColumnType => @@ -3123,6 +3128,24 @@ class Analyzer( fieldNames, includeCollections = true, conf.resolver) fieldOpt.map { case (path, field) => copy((path :+ field.name).toArray) } } + + private def findColumnPosition( + position: ColumnPosition, + field: String, + struct: StructType): ColumnPosition = { + position match { + case null => null + case after: After => + struct.fieldNames.find(n => conf.resolver(n, after.column())) match { + case Some(colName) => + ColumnPosition.after(colName) + case None => + throw new AnalysisException("Couldn't find the reference column for " + + s"$after at $field") + } + case other => other + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 63b893d809db8..32ccfe008ae6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnType} +import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -431,17 +431,37 @@ trait CheckAnalysis extends PredicateHelper { // include collections because structs nested in maps and arrays may be altered val field = table.schema.findNestedField(fieldName, includeCollections = true) if (field.isEmpty) { - throw new AnalysisException( - s"Cannot $operation missing field in ${table.name} schema: ${fieldName.quoted}") + alter.failAnalysis( + s"Cannot $operation missing field ${fieldName.quoted} in ${table.name} schema: " + + table.schema.treeString) } field.get._2 } + def positionArgumentExists(position: ColumnPosition, dataType: DataType): Boolean = { + (position, dataType) match { + case (after: After, struct: StructType) => + struct.fieldNames.contains(after.column()) + case (after: After, _) => false + case _ => true + } + } alter.changes.foreach { case add: AddColumn => val parent = add.fieldNames.init if (parent.nonEmpty) { - findField("add to", parent) + val parentField = findField("add to", parent) + if (!positionArgumentExists(add.position(), parentField.dataType)) { + alter.failAnalysis( + s"Couldn't resolve positional argument ${add.position()} amongst " + + s"${parent.quoted}") + } + } else { + if (!positionArgumentExists(add.position(), table.schema)) { + alter.failAnalysis( + s"Couldn't resolve positional argument ${add.position()} amongst " + + s"${table.schema.treeString}") + } } TypeUtils.failWithIntervalType(add.dataType()) case update: UpdateColumnType => @@ -467,7 +487,7 @@ trait CheckAnalysis extends PredicateHelper { // update is okay } if (!Cast.canUpCast(field.dataType, update.newDataType)) { - throw new AnalysisException( + alter.failAnalysis( s"Cannot update ${table.name} field $fieldName: " + s"${field.dataType.simpleString} cannot be cast to " + s"${update.newDataType.simpleString}") @@ -476,9 +496,17 @@ trait CheckAnalysis extends PredicateHelper { val field = findField("update", update.fieldNames) val fieldName = update.fieldNames.quoted if (!update.nullable && field.nullable) { - throw new AnalysisException( + alter.failAnalysis( s"Cannot change nullable column to non-nullable: $fieldName") } + case updatePos: UpdateColumnPosition => + findField("update", updatePos.fieldNames) + val parent = findField("update", updatePos.fieldNames.init) + if (!positionArgumentExists(updatePos.position(), parent.dataType)) { + alter.failAnalysis( + s"Couldn't resolve positional argument ${updatePos.position()} amongst " + + s"${updatePos.fieldNames.init.quoted}") + } case rename: RenameColumn => findField("rename", rename.fieldNames) case update: UpdateColumnComment => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index db7bca06d661a..2e314978413c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -327,7 +327,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru searchPath: Seq[String], normalizedPath: Seq[String]): Option[(Seq[String], StructField)] = { searchPath.headOption.flatMap { searchName => - val found = this.fields.filter(f => resolver(searchName, f.name)) + val found = struct.fields.filter(f => resolver(searchName, f.name)) if (found.length > 1) { val names = found.map(f => prettyFieldName(normalizedPath :+ f.name)) .mkString("[", ", ", " ]") @@ -338,7 +338,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru None } else { val field = found.head - (fieldNames.tail, field.dataType, includeCollections) match { + (searchPath.tail, field.dataType, includeCollections) match { case (Seq(), _, _) => Some(normalizedPath -> field) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 7d196f8b8edd2..3f8d409992381 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -26,12 +26,15 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf trait AnalysisTest extends PlanTest { - protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true) - protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false) + protected lazy val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true) + protected lazy val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false) + + protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) @@ -43,7 +46,7 @@ trait AnalysisTest extends PlanTest { catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true) new Analyzer(catalog, conf) { - override val extendedResolutionRules = EliminateSubqueryAliases :: Nil + override val extendedResolutionRules = EliminateSubqueryAliases +: extendedAnalysisRules } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index 9dd43ea70eb4b..f433229595e9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -133,7 +133,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { } } -private object CreateTablePartitioningValidationSuite { +private[sql] object CreateTablePartitioningValidationSuite { val catalog: TableCatalog = { val cat = new InMemoryTableCatalog() cat.initialize("test", CaseInsensitiveStringMap.empty()) @@ -146,7 +146,7 @@ private object CreateTablePartitioningValidationSuite { .add("point", new StructType().add("x", DoubleType).add("y", DoubleType)) } -private case object TestRelation2 extends LeafNode with NamedRelation { +private[sql] case object TestRelation2 extends LeafNode with NamedRelation { override def name: String = "source_relation" override def output: Seq[AttributeReference] = CreateTablePartitioningValidationSuite.schema.toAttributes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala new file mode 100644 index 0000000000000..e980131090648 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, TestRelation2} +import org.apache.spark.sql.catalyst.analysis.CreateTablePartitioningValidationSuite +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition +import org.apache.spark.sql.connector.expressions.Expressions +import org.apache.spark.sql.execution.datasources.PreprocessTableCreation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{LongType, StringType} + +class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTest { + import CreateTablePartitioningValidationSuite._ + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = { + Seq(PreprocessTableCreation(spark)) + } + + test("CreateTableAsSelect: using top level field for partitioning") { + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + Seq("ID", "iD").foreach { ref => + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + Expressions.identity(ref) :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + if (caseSensitive) { + assertAnalysisError(plan, Seq("Couldn't find column", ref), caseSensitive) + } else { + assertAnalysisSuccess(plan, caseSensitive) + } + } + } + } + } + + test("CreateTableAsSelect: using nested column for partitioning") { + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + Expressions.bucket(4, ref) :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + if (caseSensitive) { + val field = ref.split("\\.") + assertAnalysisError(plan, Seq("Couldn't find column", field.head), caseSensitive) + } else { + assertAnalysisSuccess(plan, caseSensitive) + } + } + } + } + } + + test("ReplaceTableAsSelect: using top level field for partitioning") { + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + Seq("ID", "iD").foreach { ref => + val plan = ReplaceTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + Expressions.identity(ref) :: Nil, + TestRelation2, + Map.empty, + Map.empty, + orCreate = true) + + if (caseSensitive) { + assertAnalysisError(plan, Seq("Couldn't find column", ref), caseSensitive) + } else { + assertAnalysisSuccess(plan, caseSensitive) + } + } + } + } + } + + test("ReplaceTableAsSelect: using nested column for partitioning") { + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => + val plan = ReplaceTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + Expressions.bucket(4, ref) :: Nil, + TestRelation2, + Map.empty, + Map.empty, + orCreate = true) + + if (caseSensitive) { + val field = ref.split("\\.") + assertAnalysisError(plan, Seq("Couldn't find column", field.head), caseSensitive) + } else { + assertAnalysisSuccess(plan, caseSensitive) + } + } + } + } + } + + test("AlterTable: add column - nested") { + Seq("POINT.Z", "poInt.z", "poInt.Z").foreach { ref => + val field = ref.split("\\.") + alterTableTest( + TableChange.addColumn(field, LongType), + Seq("add to", field.head) + ) + } + } + + test("AlterTable: add column resolution - positional") { + Seq("ID", "iD").foreach { ref => + alterTableTest( + TableChange.addColumn( + Array("f"), LongType, true, null, ColumnPosition.after(ref)), + Seq("reference column", ref) + ) + } + } + + test("AlterTable: add column resolution - nested positional") { + Seq("X", "Y").foreach { ref => + alterTableTest( + TableChange.addColumn( + Array("point", "z"), LongType, true, null, ColumnPosition.after(ref)), + Seq("reference column", ref) + ) + } + } + + test("AlterTable: drop column resolution") { + Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => + alterTableTest( + TableChange.deleteColumn(ref), + Seq("Cannot delete missing field", ref.quoted) + ) + } + } + + test("AlterTable: rename column resolution") { + Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => + alterTableTest( + TableChange.renameColumn(ref, "newName"), + Seq("Cannot rename missing field", ref.quoted) + ) + } + } + + test("AlterTable: drop column nullability resolution") { + Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => + alterTableTest( + TableChange.updateColumnNullability(ref, true), + Seq("Cannot update missing field", ref.quoted) + ) + } + } + + test("AlterTable: change column type resolution") { + Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => + alterTableTest( + TableChange.updateColumnType(ref, StringType), + Seq("Cannot update missing field", ref.quoted) + ) + } + } + + test("AlterTable: change column comment resolution") { + Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => + alterTableTest( + TableChange.updateColumnComment(ref, "Here's a comment for ya"), + Seq("Cannot update missing field", ref.quoted) + ) + } + } + + private def alterTableTest(change: TableChange, error: Seq[String]): Unit = { + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val plan = AlterTable( + catalog, + Identifier.of(Array(), "table_name"), + TestRelation2, + Seq(change) + ) + + if (caseSensitive) { + assertAnalysisError(plan, error, caseSensitive) + } else { + assertAnalysisSuccess(plan, caseSensitive) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 8f17ce7f32c82..4b45f02c63430 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, Assignment, CreateTableAsSelect, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SubqueryAlias, UpdateAction, UpdateTable} import org.apache.spark.sql.connector.InMemoryTableProvider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog, TableChange, V1Table} +import org.apache.spark.sql.connector.catalog.TableChange.{UpdateColumnComment, UpdateColumnType} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf @@ -145,7 +146,8 @@ class PlanResolutionSuite extends AnalysisTest { analyzer.ResolveTables, analyzer.ResolveReferences, analyzer.ResolveSubqueryColumnAliases, - analyzer.ResolveReferences) + analyzer.ResolveReferences, + analyzer.ResolveAlterTableChanges) rules.foldLeft(parsePlan(query)) { case (plan, rule) => rule.apply(plan) } @@ -1072,6 +1074,28 @@ class PlanResolutionSuite extends AnalysisTest { } } + test("alter table: hive style change column") { + Seq("v2Table", "testcat.tab").foreach { tblName => + parseAndResolve(s"ALTER TABLE $tblName CHANGE COLUMN i i int COMMENT 'an index'") match { + case AlterTable(_, _, _: DataSourceV2Relation, changes) => + assert(changes.length == 1, "Should only have a comment change") + assert(changes.head.isInstanceOf[UpdateColumnComment], + s"Expected only a UpdateColumnComment change but got: ${changes.head}") + case _ => fail("expect AlterTable") + } + + parseAndResolve(s"ALTER TABLE $tblName CHANGE COLUMN i i long COMMENT 'an index'") match { + case AlterTable(_, _, _: DataSourceV2Relation, changes) => + assert(changes.length == 2, "Should have a comment change and type change") + assert(changes.exists(_.isInstanceOf[UpdateColumnComment]), + s"Expected UpdateColumnComment change but got: ${changes}") + assert(changes.exists(_.isInstanceOf[UpdateColumnType]), + s"Expected UpdateColumnType change but got: ${changes}") + case _ => fail("expect AlterTable") + } + } + } + test("MERGE INTO TABLE") { def checkResolution( target: LogicalPlan, From 2a9c1a01758043123e140c8297d99a4938a4acac Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 28 Jan 2020 09:35:03 -0800 Subject: [PATCH 06/10] add comments --- .../spark/sql/catalyst/analysis/Analyzer.scala | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ca8dce19f1cfd..d21391f358ce7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3005,6 +3005,7 @@ class Analyzer( } } + /** Rule to mostly resolve, normalize and rewrite column names based on case sensitivity. */ object ResolveAlterTableChanges extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case a @ AlterTable(_, _, t: NamedRelation, changes) if t.resolved => @@ -3013,8 +3014,10 @@ class Analyzer( case add: AddColumn => val parent = add.fieldNames().init if (parent.nonEmpty) { + // Adding a nested field, need to normalize the parent column and position val target = schema.findNestedField(parent, includeCollections = true, conf.resolver) if (target.isEmpty) { + // Leave unresolved. Throws error in CheckAnalysis Some(add) } else { val (normalizedName, sf) = target.get @@ -3034,6 +3037,7 @@ class Analyzer( } } } else { + // Adding to the root. Just need to normalize position val pos = findColumnPosition(add.position(), "root", schema) Some(TableChange.addColumn( add.fieldNames(), @@ -3062,6 +3066,7 @@ class Analyzer( } } case n: UpdateColumnNullability => + // Need to resolve column resolveFieldNames( schema, n.fieldNames(), @@ -3069,7 +3074,8 @@ class Analyzer( case position: UpdateColumnPosition => position.position() match { - case after: After => // resolve this column as well + case after: After => + // Need to resolve column as well as position reference val fieldOpt = schema.findNestedField( position.fieldNames(), includeCollections = true, conf.resolver) @@ -3081,7 +3087,7 @@ class Analyzer( normalizedPath :+ after.column(), includeCollections = true, conf.resolver) if (targetCol.isEmpty) { throw new AnalysisException("Couldn't find the reference column for " + - s"AFTER ${after.column()} at ${UnresolvedAttribute(normalizedPath).name}") + s"$after at ${normalizedPath.quoted}") } else { Some(TableChange.updateColumnPosition( (normalizedPath :+ field.name).toArray, @@ -3089,6 +3095,7 @@ class Analyzer( } } case _ => + // Need to resolve column resolveFieldNames( schema, position.fieldNames(), @@ -3121,6 +3128,10 @@ class Analyzer( a.copy(changes = normalizedChanges) } + /** + * Returns the table change if the field can be resolved, returns None if the column is not + * found. An error will be thrown in CheckAnalysis for columns that can't be resolved. + */ private def resolveFieldNames( schema: StructType, fieldNames: Array[String], From c035012e8adacab2b247cdd06e9fcbf742c771ea Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 29 Jan 2020 10:25:49 -0800 Subject: [PATCH 07/10] fix array and map types --- .../scala/org/apache/spark/sql/types/StructType.scala | 8 +++++--- .../org/apache/spark/sql/connector/AlterTableTests.scala | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 2e314978413c9..f65a67dfdf14c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -350,21 +350,23 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru case (Seq("key"), MapType(keyType, _, _), true) => // return the key type as a struct field to include nullability - Some(normalizedPath -> StructField("key", keyType, nullable = false)) + Some((normalizedPath :+ field.name) -> StructField("key", keyType, nullable = false)) case (Seq("key", names @ _*), MapType(struct: StructType, _, _), true) => findField(struct, names, normalizedPath ++ Seq(field.name, "key")) case (Seq("value"), MapType(_, valueType, isNullable), true) => // return the value type as a struct field to include nullability - Some(normalizedPath -> StructField("value", valueType, nullable = isNullable)) + Some((normalizedPath :+ field.name) -> + StructField("value", valueType, nullable = isNullable)) case (Seq("value", names @ _*), MapType(_, struct: StructType, _), true) => findField(struct, names, normalizedPath ++ Seq(field.name, "value")) case (Seq("element"), ArrayType(elementType, isNullable), true) => // return the element type as a struct field to include nullability - Some(normalizedPath -> StructField("element", elementType, nullable = isNullable)) + Some((normalizedPath :+ field.name) -> + StructField("element", elementType, nullable = isNullable)) case (Seq("element", names @ _*), ArrayType(struct: StructType, _), true) => findField(struct, names, normalizedPath ++ Seq(field.name, "element")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index ee7f205b3fa52..57365f25768c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -146,9 +146,9 @@ trait AlterTableTests extends SharedSparkSession { .add("point", new StructType().add("x", IntegerType)) .add("b", StringType)) - val e1 = intercept[SparkException]( + val e1 = intercept[AnalysisException]( sql(s"ALTER TABLE $t ADD COLUMN c string AFTER non_exist")) - assert(e1.getMessage().contains("AFTER column not found")) + assert(e1.getMessage().contains("Couldn't find the reference column")) sql(s"ALTER TABLE $t ADD COLUMN point.y int FIRST") assert(getTableMetadata(t).schema == new StructType() @@ -169,7 +169,7 @@ trait AlterTableTests extends SharedSparkSession { val e2 = intercept[SparkException]( sql(s"ALTER TABLE $t ADD COLUMN point.x2 int AFTER non_exist")) - assert(e2.getMessage().contains("AFTER column not found")) + assert(e2.getMessage().contains("Couldn't find the reference column")) } } From c824d15cc31542c8ce9b60aed2fe2cabef6935da Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 29 Jan 2020 12:50:44 -0800 Subject: [PATCH 08/10] fix last two tests --- .../apache/spark/sql/connector/AlterTableTests.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 57365f25768c6..6fa30e5743f9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -167,7 +167,7 @@ trait AlterTableTests extends SharedSparkSession { .add("z", IntegerType)) .add("b", StringType)) - val e2 = intercept[SparkException]( + val e2 = intercept[AnalysisException]( sql(s"ALTER TABLE $t ADD COLUMN point.x2 int AFTER non_exist")) assert(e2.getMessage().contains("Couldn't find the reference column")) } @@ -595,9 +595,9 @@ trait AlterTableTests extends SharedSparkSession { .add("z", IntegerType)) .add("b", IntegerType)) - val e1 = intercept[SparkException]( + val e1 = intercept[AnalysisException]( sql(s"ALTER TABLE $t ALTER COLUMN b AFTER non_exist")) - assert(e1.getMessage.contains("AFTER column not found")) + assert(e1.getMessage.contains("Couldn't find the reference column")) sql(s"ALTER TABLE $t ALTER COLUMN point.y FIRST") assert(getTableMetadata(t).schema == new StructType() @@ -617,9 +617,9 @@ trait AlterTableTests extends SharedSparkSession { .add("y", IntegerType)) .add("b", IntegerType)) - val e2 = intercept[SparkException]( + val e2 = intercept[AnalysisException]( sql(s"ALTER TABLE $t ALTER COLUMN point.y AFTER non_exist")) - assert(e2.getMessage.contains("AFTER column not found")) + assert(e2.getMessage.contains("Couldn't find the reference column")) // `AlterTable.resolved` checks column existence. intercept[AnalysisException]( From 8c1402d3bdb5754725feb83aa2edb869b76904a1 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 29 Jan 2020 16:18:31 -0800 Subject: [PATCH 09/10] I think it works now --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 18 +++++++++++++----- .../spark/sql/connector/AlterTableTests.scala | 4 ++-- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d21391f358ce7..e7e2926c4c39b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3086,8 +3086,8 @@ class Analyzer( val targetCol = schema.findNestedField( normalizedPath :+ after.column(), includeCollections = true, conf.resolver) if (targetCol.isEmpty) { - throw new AnalysisException("Couldn't find the reference column for " + - s"$after at ${normalizedPath.quoted}") + // Leave unchanged to CheckAnalysis + Some(position) } else { Some(TableChange.updateColumnPosition( (normalizedPath :+ field.name).toArray, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 32ccfe008ae6e..9fcdbf49a7fe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -501,11 +501,19 @@ trait CheckAnalysis extends PredicateHelper { } case updatePos: UpdateColumnPosition => findField("update", updatePos.fieldNames) - val parent = findField("update", updatePos.fieldNames.init) - if (!positionArgumentExists(updatePos.position(), parent.dataType)) { - alter.failAnalysis( - s"Couldn't resolve positional argument ${updatePos.position()} amongst " + - s"${updatePos.fieldNames.init.quoted}") + if (updatePos.fieldNames().length == 1) { + if (!positionArgumentExists(updatePos.position(), table.schema)) { + alter.failAnalysis( + s"Couldn't resolve positional argument ${updatePos.position()} amongst " + + s"${table.schema.fieldNames.mkString("[", ", ", "]")}") + } + } else { + val parent = findField("update", updatePos.fieldNames.init) + if (!positionArgumentExists(updatePos.position(), parent.dataType)) { + alter.failAnalysis( + s"Couldn't resolve positional argument ${updatePos.position()} amongst " + + s"${updatePos.fieldNames.init.quoted}") + } } case rename: RenameColumn => findField("rename", rename.fieldNames) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 6fa30e5743f9b..d5983f966e6b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -597,7 +597,7 @@ trait AlterTableTests extends SharedSparkSession { val e1 = intercept[AnalysisException]( sql(s"ALTER TABLE $t ALTER COLUMN b AFTER non_exist")) - assert(e1.getMessage.contains("Couldn't find the reference column")) + assert(e1.getMessage.contains("Couldn't resolve positional argument")) sql(s"ALTER TABLE $t ALTER COLUMN point.y FIRST") assert(getTableMetadata(t).schema == new StructType() @@ -619,7 +619,7 @@ trait AlterTableTests extends SharedSparkSession { val e2 = intercept[AnalysisException]( sql(s"ALTER TABLE $t ALTER COLUMN point.y AFTER non_exist")) - assert(e2.getMessage.contains("Couldn't find the reference column")) + assert(e2.getMessage.contains("Couldn't resolve positional argument")) // `AlterTable.resolved` checks column existence. intercept[AnalysisException]( From 2b364e23927fbcdd8a6d219c9ae6bd861dbe3f17 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 30 Jan 2020 10:26:47 -0800 Subject: [PATCH 10/10] address comments and add new tests --- .../sql/catalyst/analysis/Analyzer.scala | 3 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 74 ++++++++++--------- .../apache/spark/sql/types/StructType.scala | 3 +- .../spark/sql/connector/AlterTableTests.scala | 55 ++++++++++++++ .../V2CommandsCaseSensitivitySuite.scala | 2 +- 5 files changed, 97 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e7e2926c4c39b..daa18e5160abd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3032,8 +3032,7 @@ class Analyzer( pos)) case other => - throw new AnalysisException( - s"Columns can only be added to struct types. Found ${other.simpleString}.") + Some(add) } } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9fcdbf49a7fe0..4ec737fd9b70d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -425,7 +425,7 @@ trait CheckAnalysis extends PredicateHelper { case _ => } - case alter: AlterTable if alter.childrenResolved => + case alter: AlterTable if alter.table.resolved => val table = alter.table def findField(operation: String, fieldName: Array[String]): StructField = { // include collections because structs nested in maps and arrays may be altered @@ -437,32 +437,44 @@ trait CheckAnalysis extends PredicateHelper { } field.get._2 } - def positionArgumentExists(position: ColumnPosition, dataType: DataType): Boolean = { - (position, dataType) match { - case (after: After, struct: StructType) => - struct.fieldNames.contains(after.column()) - case (after: After, _) => false - case _ => true + def positionArgumentExists(position: ColumnPosition, struct: StructType): Unit = { + position match { + case after: After => + if (!struct.fieldNames.contains(after.column())) { + alter.failAnalysis(s"Couldn't resolve positional argument $position amongst " + + s"${struct.fieldNames.mkString("[", ", ", "]")}") + } + case _ => + } + } + def findParentStruct(operation: String, fieldNames: Array[String]): StructType = { + val parent = fieldNames.init + val field = if (parent.nonEmpty) { + findField(operation, parent).dataType + } else { + table.schema + } + field match { + case s: StructType => s + case o => alter.failAnalysis(s"Cannot $operation ${fieldNames.quoted}, because " + + s"its parent is not a StructType. Found $o") + } + } + def checkColumnNotExists( + operation: String, + fieldNames: Array[String], + struct: StructType): Unit = { + if (struct.findNestedField(fieldNames, includeCollections = true).isDefined) { + alter.failAnalysis(s"Cannot $operation column, because ${fieldNames.quoted} " + + s"already exists in ${struct.treeString}") } } alter.changes.foreach { case add: AddColumn => - val parent = add.fieldNames.init - if (parent.nonEmpty) { - val parentField = findField("add to", parent) - if (!positionArgumentExists(add.position(), parentField.dataType)) { - alter.failAnalysis( - s"Couldn't resolve positional argument ${add.position()} amongst " + - s"${parent.quoted}") - } - } else { - if (!positionArgumentExists(add.position(), table.schema)) { - alter.failAnalysis( - s"Couldn't resolve positional argument ${add.position()} amongst " + - s"${table.schema.treeString}") - } - } + checkColumnNotExists("add", add.fieldNames(), table.schema) + val parent = findParentStruct("add", add.fieldNames()) + positionArgumentExists(add.position(), parent) TypeUtils.failWithIntervalType(add.dataType()) case update: UpdateColumnType => val field = findField("update", update.fieldNames) @@ -501,22 +513,12 @@ trait CheckAnalysis extends PredicateHelper { } case updatePos: UpdateColumnPosition => findField("update", updatePos.fieldNames) - if (updatePos.fieldNames().length == 1) { - if (!positionArgumentExists(updatePos.position(), table.schema)) { - alter.failAnalysis( - s"Couldn't resolve positional argument ${updatePos.position()} amongst " + - s"${table.schema.fieldNames.mkString("[", ", ", "]")}") - } - } else { - val parent = findField("update", updatePos.fieldNames.init) - if (!positionArgumentExists(updatePos.position(), parent.dataType)) { - alter.failAnalysis( - s"Couldn't resolve positional argument ${updatePos.position()} amongst " + - s"${updatePos.fieldNames.init.quoted}") - } - } + val parent = findParentStruct("update", updatePos.fieldNames()) + positionArgumentExists(updatePos.position(), parent) case rename: RenameColumn => findField("rename", rename.fieldNames) + checkColumnNotExists( + "rename", rename.fieldNames().init :+ rename.newName(), table.schema) case update: UpdateColumnComment => findField("update", update.fieldNames) case delete: DeleteColumn => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index f65a67dfdf14c..e8eeecd48e803 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -319,7 +319,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru includeCollections: Boolean = false, resolver: Resolver = _ == _): Option[(Seq[String], StructField)] = { def prettyFieldName(nameParts: Seq[String]): String = { - nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + nameParts.quoted } def findField( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index d5983f966e6b4..3cdac59c20fc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -312,6 +312,30 @@ trait AlterTableTests extends SharedSparkSession { } } + test("AlterTable: add column - new column should not exist") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql( + s"""CREATE TABLE $t ( + |id int, + |point struct, + |arr array>, + |mk map, string>, + |mv map> + |) + |USING $v2Format""".stripMargin) + + Seq("id", "point.x", "arr.element.x", "mk.key.x", "mv.value.x").foreach { field => + + val e = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD COLUMNS $field double") + } + assert(e.getMessage.contains("add")) + assert(e.getMessage.contains(s"$field already exists")) + } + } + } + test("AlterTable: update column type int -> long") { val t = s"${catalogAndNamespace}table_name" withTable(t) { @@ -849,6 +873,37 @@ trait AlterTableTests extends SharedSparkSession { } } + test("AlterTable: rename column - new name should not exist") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql( + s"""CREATE TABLE $t ( + |id int, + |user_id int, + |point struct, + |arr array>, + |mk map, string>, + |mv map> + |) + |USING $v2Format""".stripMargin) + + Seq( + "id" -> "user_id", + "point.x" -> "y", + "arr.element.x" -> "y", + "mk.key.x" -> "y", + "mv.value.x" -> "y").foreach { case (field, newName) => + + val e = intercept[AnalysisException] { + sql(s"ALTER TABLE $t RENAME COLUMN $field TO $newName") + } + assert(e.getMessage.contains("rename")) + assert(e.getMessage.contains((field.split("\\.").init :+ newName).mkString("."))) + assert(e.getMessage.contains("already exists")) + } + } + } + test("AlterTable: drop column") { val t = s"${catalogAndNamespace}table_name" withTable(t) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index e980131090648..289f9dc427795 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -136,7 +136,7 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes val field = ref.split("\\.") alterTableTest( TableChange.addColumn(field, LongType), - Seq("add to", field.head) + Seq("add", field.head) ) } }