diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 14b50f481f387..c5a0cbf990ef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} @@ -3088,7 +3088,12 @@ class Analyzer(override val catalogManager: CatalogManager) val projection = TableOutputResolver.resolveOutputColumns( v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf) if (projection != v2Write.query) { - v2Write.withNewQuery(projection) + val cleanedTable = v2Write.table match { + case r: DataSourceV2Relation => + r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)) + case other => other + } + v2Write.withNewQuery(projection).withNewTable(cleanedTable) } else { v2Write } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 452ba80b23441..261c008bd096e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement, SupportsPartitionManagement, Table} import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.internal.SQLConf @@ -94,6 +94,10 @@ trait CheckAnalysis extends PredicateHelper { case p if p.analyzed => // Skip already analyzed sub-plans + case leaf: LeafNode if leaf.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) => + throw new IllegalStateException( + "[BUG] leaf logical plan should not have output of char/varchar type: " + leaf) + case u: UnresolvedNamespace => u.failAnalysis(s"Namespace not found: ${u.multipartIdentifier.quoted}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index d3bb72badeb13..128ca1278bf54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, SupportsNamespaces, TableCatalog, TableChange} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableCatalog, TableChange} /** * Resolves catalogs from the multi-part identifiers in SQL statements, and convert the statements @@ -35,7 +35,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case AlterTableAddColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => cols.foreach(c => failNullType(c.dataType)) - cols.foreach(c => failCharType(c.dataType)) val changes = cols.map { col => TableChange.addColumn( col.name.toArray, @@ -49,7 +48,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case AlterTableReplaceColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => cols.foreach(c => failNullType(c.dataType)) - cols.foreach(c => failCharType(c.dataType)) val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { case Some(table) => // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. @@ -72,7 +70,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case a @ AlterTableAlterColumnStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _) => a.dataType.foreach(failNullType) - a.dataType.foreach(failCharType) val colName = a.column.toArray val typeChange = a.dataType.map { newDataType => TableChange.updateColumnType(colName, newDataType) @@ -145,7 +142,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ CreateTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) - assertNoCharTypeInSchema(c.tableSchema) CreateV2Table( catalog.asTableCatalog, tbl.asIdentifier, @@ -173,7 +169,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ ReplaceTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) - assertNoCharTypeInSchema(c.tableSchema) ReplaceTable( catalog.asTableCatalog, tbl.asIdentifier, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 4f33ca99c02db..c6bba370c8fef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types.DataType @@ -93,19 +94,17 @@ object TableOutputResolver { tableAttr.metadata == queryExpr.metadata) { Some(queryExpr) } else { - // Renaming is needed for handling the following cases like - // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 - // 2) Target tables have column metadata - storeAssignmentPolicy match { + val casted = storeAssignmentPolicy match { case StoreAssignmentPolicy.ANSI => - Some(Alias( - AnsiCast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)), - tableAttr.name)(explicitMetadata = Option(tableAttr.metadata))) + AnsiCast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) case _ => - Some(Alias( - Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)), - tableAttr.name)(explicitMetadata = Option(tableAttr.metadata))) + Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) } + val strLenChecked = CharVarcharUtils.stringLengthCheck(casted, tableAttr) + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Some(Alias(strLenChecked, tableAttr.name)(explicitMetadata = Some(tableAttr.metadata))) } storeAssignmentPolicy match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 17ab6664df75c..ec8daf3eb46d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} -import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE @@ -473,7 +473,13 @@ class SessionCatalog( val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) - externalCatalog.getTable(db, table) + removeCharVarcharFromTableSchema(externalCatalog.getTable(db, table)) + } + + // We replace char/varchar with string type in the table schema, as Spark's type system doesn't + // support char/varchar yet. + private def removeCharVarcharFromTableSchema(t: CatalogTable): CatalogTable = { + t.copy(schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(t.schema)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c3855fe088db6..3e79613404bee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} -import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition @@ -2201,7 +2201,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * Create a Spark DataType. */ private def visitSparkDataType(ctx: DataTypeContext): DataType = { - HiveStringType.replaceCharType(typedVisit(ctx)) + CharVarcharUtils.replaceCharVarcharWithString(typedVisit(ctx)) } /** @@ -2276,16 +2276,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg builder.putString("comment", _) } - // Add Hive type string to metadata. - val rawDataType = typedVisit[DataType](ctx.dataType) - val cleanedDataType = HiveStringType.replaceCharType(rawDataType) - if (rawDataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString) - } - StructField( name = colName.getText, - dataType = cleanedDataType, + dataType = typedVisit[DataType](ctx.dataType), nullable = NULL == null, metadata = builder.build()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 5bda2b5b8db01..1081fec37490c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PartitionSpec, Res import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange} import org.apache.spark.sql.connector.expressions.Transform @@ -45,9 +46,10 @@ trait V2WriteCommand extends Command { table.skipSchemaResolution || (query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) // names and types must match, nullability must be compatible inAttr.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outAttr.dataType) && + DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType) && (outAttr.nullable || !inAttr.nullable) }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala new file mode 100644 index 0000000000000..6b867b36c62d4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types._ + +object CharVarcharUtils { + + private val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING" + + /** + * Replaces CharType/VarcharType with StringType recursively in the given struct type. If a + * top-level StructField's data type is CharType/VarcharType or has nested CharType/VarcharType, + * this method will add the original type string to the StructField's metadata, so that we can + * re-construct the original data type with CharType/VarcharType later when needed. + */ + def replaceCharVarcharWithStringInSchema(st: StructType): StructType = { + StructType(st.map { field => + if (hasCharVarchar(field.dataType)) { + val metadata = new MetadataBuilder().withMetadata(field.metadata) + .putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, field.dataType.sql).build() + field.copy(dataType = replaceCharVarcharWithString(field.dataType), metadata = metadata) + } else { + field + } + }) + } + + /** + * Returns true if the given data type is CharType/VarcharType or has nested CharType/VarcharType. + */ + def hasCharVarchar(dt: DataType): Boolean = { + dt.existsRecursively(f => f.isInstanceOf[CharType] || f.isInstanceOf[VarcharType]) + } + + /** + * Replaces CharType/VarcharType with StringType recursively in the given data type. + */ + def replaceCharVarcharWithString(dt: DataType): DataType = dt match { + case ArrayType(et, nullable) => + ArrayType(replaceCharVarcharWithString(et), nullable) + case MapType(kt, vt, nullable) => + MapType(replaceCharVarcharWithString(kt), replaceCharVarcharWithString(vt), nullable) + case StructType(fields) => + StructType(fields.map { field => + field.copy(dataType = replaceCharVarcharWithString(field.dataType)) + }) + case _: CharType => StringType + case _: VarcharType => StringType + case _ => dt + } + + /** + * Removes the metadata entry that contains the original type string of CharType/VarcharType from + * the given attribute's metadata. + */ + def cleanAttrMetadata(attr: AttributeReference): AttributeReference = { + val cleaned = new MetadataBuilder().withMetadata(attr.metadata) + .remove(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY).build() + attr.withMetadata(cleaned) + } + + /** + * Re-construct the original data type from the type string in the given metadata. + * This is needed when dealing with char/varchar columns/fields. + */ + def getRawType(metadata: Metadata): Option[DataType] = { + if (metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)) { + Some(CatalystSqlParser.parseRawDataType( + metadata.getString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY))) + } else { + None + } + } + + /** + * Re-construct the original StructType from the type strings in the metadata of StructFields. + * This is needed when dealing with char/varchar columns/fields. + */ + def getRawSchema(schema: StructType): StructType = { + StructType(schema.map { field => + getRawType(field.metadata).map(rawType => field.copy(dataType = rawType)).getOrElse(field) + }) + } + + /** + * Returns expressions to apply read-side char type padding for the given attributes. String + * values should be right-padded to N characters if it's from a CHAR(N) column/field. + */ + def charTypePadding(output: Seq[AttributeReference]): Seq[NamedExpression] = { + output.map { attr => + getRawType(attr.metadata).filter { rawType => + rawType.existsRecursively(_.isInstanceOf[CharType]) + }.map { rawType => + Alias(charTypePadding(attr, rawType), attr.name)(explicitMetadata = Some(attr.metadata)) + }.getOrElse(attr) + } + } + + private def charTypePadding(expr: Expression, dt: DataType): Expression = dt match { + case CharType(length) => StringRPad(expr, Literal(length)) + + case StructType(fields) => + CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => + Seq(Literal(f.name), charTypePadding(GetStructField(expr, i, Some(f.name)), f.dataType)) + }) + + case ArrayType(et, containsNull) => charTypePaddingInArray(expr, et, containsNull) + + case MapType(kt, vt, valueContainsNull) => + val newKeys = charTypePaddingInArray(MapKeys(expr), kt, containsNull = false) + val newValues = charTypePaddingInArray(MapValues(expr), vt, valueContainsNull) + MapFromArrays(newKeys, newValues) + + case _ => expr + } + + private def charTypePaddingInArray( + arr: Expression, et: DataType, containsNull: Boolean): Expression = { + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + val func = LambdaFunction(charTypePadding(param, et), Seq(param)) + ArrayTransform(arr, func) + } + + /** + * Returns an expression to apply write-side char type padding for the given expression. A string + * value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N) column/field. + */ + def stringLengthCheck(expr: Expression, targetAttr: Attribute): Expression = { + getRawType(targetAttr.metadata).map { rawType => + stringLengthCheck(expr, rawType) + }.getOrElse(expr) + } + + private def stringLengthCheck(expr: Expression, dt: DataType): Expression = dt match { + case CharType(length) => + val trimmed = StringTrimRight(expr) + val errorMsg = Concat(Seq( + Literal("input string '"), + expr, + Literal(s"' exceeds char type length limitation: $length"))) + // Trailing spaces do not count in the length check. We don't need to retain the trailing + // spaces, as we will pad char type columns/fields at read time. + If( + GreaterThan(Length(trimmed), Literal(length)), + Cast(RaiseError(errorMsg), StringType), + trimmed) + + case VarcharType(length) => + val trimmed = StringTrimRight(expr) + val errorMsg = Concat(Seq( + Literal("input string '"), + expr, + Literal(s"' exceeds varchar type length limitation: $length"))) + // Trailing spaces do not count in the length check. We need to retain the trailing spaces + // (truncate to length N), as there is no read-time padding for varchar type. + // TODO: create a special TrimRight function that can trim to a certain length. + If( + LessThanOrEqual(Length(expr), Literal(length)), + expr, + If( + GreaterThan(Length(trimmed), Literal(length)), + Cast(RaiseError(errorMsg), StringType), + StringRPad(trimmed, Literal(length)))) + + case StructType(fields) => + CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => + Seq(Literal(f.name), stringLengthCheck(GetStructField(expr, i, Some(f.name)), f.dataType)) + }) + + case ArrayType(et, containsNull) => stringLengthCheckInArray(expr, et, containsNull) + + case MapType(kt, vt, valueContainsNull) => + val newKeys = stringLengthCheckInArray(MapKeys(expr), kt, containsNull = false) + val newValues = stringLengthCheckInArray(MapValues(expr), vt, valueContainsNull) + MapFromArrays(newKeys, newValues) + + case _ => expr + } + + private def stringLengthCheckInArray( + arr: Expression, et: DataType, containsNull: Boolean): Expression = { + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + val func = LambdaFunction(stringLengthCheck(param, et), Seq(param)) + ArrayTransform(arr, func) + } + + /** + * Return expressions to apply char type padding for the string comparison between the given + * attributes. When comparing two char type columns/fields, we need to pad the shorter one to + * the longer length. + */ + def addPaddingInStringComparison(attrs: Seq[Attribute]): Seq[Expression] = { + val rawTypes = attrs.map(attr => getRawType(attr.metadata)) + if (rawTypes.exists(_.isEmpty)) { + attrs + } else { + val typeWithTargetCharLength = rawTypes.map(_.get).reduce(typeWithWiderCharLength) + attrs.zip(rawTypes.map(_.get)).map { case (attr, rawType) => + padCharToTargetLength(attr, rawType, typeWithTargetCharLength).getOrElse(attr) + } + } + } + + private def typeWithWiderCharLength(type1: DataType, type2: DataType): DataType = { + (type1, type2) match { + case (CharType(len1), CharType(len2)) => + CharType(math.max(len1, len2)) + case (StructType(fields1), StructType(fields2)) => + assert(fields1.length == fields2.length) + StructType(fields1.zip(fields2).map { case (left, right) => + StructField("", typeWithWiderCharLength(left.dataType, right.dataType)) + }) + case (ArrayType(et1, _), ArrayType(et2, _)) => + ArrayType(typeWithWiderCharLength(et1, et2)) + case (MapType(kt1, vt1, _), MapType(kt2, vt2, _)) => + MapType(typeWithWiderCharLength(kt1, kt2), typeWithWiderCharLength(vt1, vt2)) + case _ => NullType + } + } + + private def padCharToTargetLength( + expr: Expression, + rawType: DataType, + typeWithTargetCharLength: DataType): Option[Expression] = { + (rawType, typeWithTargetCharLength) match { + case (CharType(len), CharType(target)) if target > len => + Some(StringRPad(expr, Literal(target))) + + case (StructType(fields), StructType(targets)) => + assert(fields.length == targets.length) + var i = 0 + var needPadding = false + val createStructExprs = mutable.ArrayBuffer.empty[Expression] + while (i < fields.length) { + val field = fields(i) + val fieldExpr = GetStructField(expr, i, Some(field.name)) + val padded = padCharToTargetLength(fieldExpr, field.dataType, targets(i).dataType) + needPadding = padded.isDefined + createStructExprs += Literal(field.name) + createStructExprs += padded.getOrElse(fieldExpr) + i += 1 + } + if (needPadding) Some(CreateNamedStruct(createStructExprs)) else None + + case (ArrayType(et, containsNull), ArrayType(target, _)) => + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + padCharToTargetLength(param, et, target).map { padded => + val func = LambdaFunction(padded, Seq(param)) + ArrayTransform(expr, func) + } + + // We don't handle MapType here as it's not comparable. + + case _ => None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 1a3a7207c6ca9..a14e165f7adf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -24,11 +24,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.AlterTable import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.types.{ArrayType, DataType, HIVE_TYPE_STRING, HiveStringType, MapType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils @@ -332,21 +331,6 @@ private[sql] object CatalogV2Util { .asTableCatalog } - def failCharType(dt: DataType): Unit = { - if (HiveStringType.containsCharType(dt)) { - throw new AnalysisException( - "Cannot use CHAR type in non-Hive-Serde tables, please use STRING type instead.") - } - } - - def assertNoCharTypeInSchema(schema: StructType): Unit = { - schema.foreach { f => - if (f.metadata.contains(HIVE_TYPE_STRING)) { - failCharType(CatalystSqlParser.parseRawDataType(f.metadata.getString(HIVE_TYPE_STRING))) - } - } - } - def failNullType(dt: DataType): Unit = { def containsNullType(dt: DataType): Boolean = dt match { case ArrayType(et, _) => containsNullType(et) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 45d89498f5ae9..c613f152bcab3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, Table, TableCapability} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, Statistics => V2Statistics, SupportsReportStatistics} import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream} @@ -147,8 +147,10 @@ object DataSourceV2Relation { catalog: Option[CatalogPlugin], identifier: Option[Identifier], options: CaseInsensitiveStringMap): DataSourceV2Relation = { - val output = table.schema().toAttributes - DataSourceV2Relation(table, output, catalog, identifier, options) + // The v2 source may return schema containing char/varchar type. We replace char/varchar + // with string type here as Spark's type system doesn't support char/varchar yet. + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(table.schema) + DataSourceV2Relation(table, schema.toAttributes, catalog, identifier, options) } def create( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala new file mode 100644 index 0000000000000..dce4bfaa4fab5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.unsafe.types.UTF8String + +@Experimental +case class CharType(length: Int) extends AtomicType { + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = typeTag[InternalType] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + override def defaultSize: Int = length + override def typeName: String = s"char($length)" + override def toString: String = s"CharType($length)" + private[spark] override def asNullable: CharType = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 043c88f88843c..2d4c60a2a3c95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -124,6 +124,8 @@ abstract class DataType extends AbstractDataType { object DataType { private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r + private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r + private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r def fromDDL(ddl: String): DataType = { parseTypeWithFallback( @@ -166,7 +168,7 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) - private val nonDecimalNameToType = { + private val otherTypes = { Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType) .map(t => t.typeName -> t).toMap @@ -177,7 +179,9 @@ object DataType { name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType.getOrElse( + case CHAR_TYPE(length) => CharType(length.toInt) + case VARCHAR_TYPE(length) => VarcharType(length.toInt) + case other => otherTypes.getOrElse( other, throw new IllegalArgumentException( s"Failed to convert the JSON string '$name' to a data type.")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala deleted file mode 100644 index a29f49ad14a77..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.types - -import scala.math.Ordering -import scala.reflect.runtime.universe.typeTag - -import org.apache.spark.unsafe.types.UTF8String - -/** - * A hive string type for compatibility. These datatypes should only used for parsing, - * and should NOT be used anywhere else. Any instance of these data types should be - * replaced by a [[StringType]] before analysis. - */ -sealed abstract class HiveStringType extends AtomicType { - private[sql] type InternalType = UTF8String - - private[sql] val ordering = implicitly[Ordering[InternalType]] - - @transient private[sql] lazy val tag = typeTag[InternalType] - - override def defaultSize: Int = length - - private[spark] override def asNullable: HiveStringType = this - - def length: Int -} - -object HiveStringType { - def replaceCharType(dt: DataType): DataType = dt match { - case ArrayType(et, nullable) => - ArrayType(replaceCharType(et), nullable) - case MapType(kt, vt, nullable) => - MapType(replaceCharType(kt), replaceCharType(vt), nullable) - case StructType(fields) => - StructType(fields.map { field => - field.copy(dataType = replaceCharType(field.dataType)) - }) - case _: HiveStringType => StringType - case _ => dt - } - - def containsCharType(dt: DataType): Boolean = dt match { - case ArrayType(et, _) => containsCharType(et) - case MapType(kt, vt, _) => containsCharType(kt) || containsCharType(vt) - case StructType(fields) => fields.exists(f => containsCharType(f.dataType)) - case _ => dt.isInstanceOf[CharType] - } -} - -/** - * Hive char type. Similar to other HiveStringType's, these datatypes should only used for - * parsing, and should NOT be used anywhere else. Any instance of these data types should be - * replaced by a [[StringType]] before analysis. - */ -case class CharType(length: Int) extends HiveStringType { - override def simpleString: String = s"char($length)" -} - -/** - * Hive varchar type. Similar to other HiveStringType's, these datatypes should only used for - * parsing, and should NOT be used anywhere else. Any instance of these data types should be - * replaced by a [[StringType]] before analysis. - */ -case class VarcharType(length: Int) extends HiveStringType { - override def simpleString: String = s"varchar($length)" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala new file mode 100644 index 0000000000000..14454550dd981 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.unsafe.types.UTF8String + +@Experimental +case class VarcharType(length: Int) extends AtomicType { + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = typeTag[InternalType] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + override def defaultSize: Int = length + override def typeName: String = s"varchar($length)" + override def toString: String = s"CharType($length)" + private[spark] override def asNullable: VarcharType = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala index f29cbc2069e39..346a51ea10c82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala @@ -21,12 +21,4 @@ package org.apache.spark.sql * Contains a type system for attributes produced by relations, including complex types like * structs, arrays and maps. */ -package object types { - /** - * Metadata key used to store the raw hive type string in the metadata of StructField. This - * is relevant for datatypes that do not have a direct Spark SQL counterpart, such as CHAR and - * VARCHAR. We need to preserve the original type in order to invoke the correct object - * inspector in Hive. - */ - val HIVE_TYPE_STRING = "HIVE_TYPE_STRING" -} +package object types diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f0a24d4a56048..6820d5d189537 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.TimeZone +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -41,9 +42,11 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.connector.InMemoryTable +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - +import org.apache.spark.sql.util.CaseInsensitiveStringMap class AnalysisSuite extends AnalysisTest with Matchers { import org.apache.spark.sql.catalyst.analysis.TestRelations._ @@ -55,6 +58,19 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } + test("fail for leaf node with char/varchar type") { + val schema1 = new StructType().add("c", CharType(5)) + val schema2 = new StructType().add("c", VarcharType(5)) + val schema3 = new StructType().add("c", ArrayType(CharType(5))) + Seq(schema1, schema2, schema3).foreach { schema => + val table = new InMemoryTable("t", schema, Array.empty, Map.empty[String, String].asJava) + intercept[IllegalStateException] { + DataSourceV2Relation( + table, schema.toAttributes, None, None, CaseInsensitiveStringMap.empty()).analyze + } + } + } + test("union project *") { val plan = (1 to 120) .map(_ => testRelation) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala index 6803fc307f919..5519f016e48d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -57,11 +57,6 @@ class TableSchemaParserSuite extends SparkFunSuite { |anotherArray:Array> """.stripMargin.replace("\n", "") - val builder = new MetadataBuilder - builder.putString(HIVE_TYPE_STRING, - "struct," + - "MAP:map,arrAy:array,anotherArray:array>") - val expectedDataType = StructType( StructField("complexStructCol", StructType( @@ -69,11 +64,9 @@ class TableSchemaParserSuite extends SparkFunSuite { StructType( StructField("deciMal", DecimalType.USER_DEFAULT) :: StructField("anotherDecimal", DecimalType(5, 2)) :: Nil)) :: - StructField("MAP", MapType(TimestampType, StringType)) :: + StructField("MAP", MapType(TimestampType, VarcharType(10))) :: StructField("arrAy", ArrayType(DoubleType)) :: - StructField("anotherArray", ArrayType(StringType)) :: Nil), - nullable = true, - builder.build()) :: Nil) + StructField("anotherArray", ArrayType(CharType(9))) :: Nil)) :: Nil) assert(parse(tableSchemaString) === expectedDataType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index b0325600e7530..badf46aad25af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -27,7 +27,7 @@ import scala.collection.mutable import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.connector.read._ @@ -97,11 +97,12 @@ class InMemoryTable( } } + val cleanedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) partitioning.map { case IdentityTransform(ref) => - extractor(ref.fieldNames, schema, row)._1 + extractor(ref.fieldNames, cleanedSchema, row)._1 case YearsTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (days: Int, DateType) => ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) case (micros: Long, TimestampType) => @@ -111,7 +112,7 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case MonthsTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (days: Int, DateType) => ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) case (micros: Long, TimestampType) => @@ -121,7 +122,7 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case DaysTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (days, DateType) => days case (micros: Long, TimestampType) => @@ -130,14 +131,14 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case HoursTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (micros: Long, TimestampType) => ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case BucketTransform(numBuckets, ref) => - (extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets + (extractor(ref.fieldNames, cleanedSchema, row).hashCode() & Integer.MAX_VALUE) % numBuckets } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala index 7a9a7f52ff8fd..da5cfab8be3c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala @@ -28,7 +28,7 @@ class CatalogV2UtilSuite extends SparkFunSuite { val testCatalog = mock(classOf[TableCatalog]) val ident = mock(classOf[Identifier]) val table = mock(classOf[Table]) - when(table.schema()).thenReturn(mock(classOf[StructType])) + when(table.schema()).thenReturn(new StructType().add("i", "int")) when(testCatalog.loadTable(ident)).thenReturn(table) val r = CatalogV2Util.loadRelation(testCatalog, ident) assert(r.isDefined) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 30792c9bacd53..4a33ebf15a3ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -1182,7 +1182,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = withExpr { Cast(expr, to) } + def cast(to: DataType): Column = withExpr { + Cast(expr, CharVarcharUtils.replaceCharVarcharWithString(to)) + } /** * Casts the column to a different data type, using the canonical string representation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 276d5d29bfa2c..5cacef1c72ce6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailureSafeParser} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, FailureSafeParser} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsCatalogOptions, SupportsRead} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils @@ -274,11 +274,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { extraOptions + ("paths" -> objectMapper.writeValueAsString(paths.toArray)) } + val cleanedUserSpecifiedSchema = userSpecifiedSchema + .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) + val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val (table, catalog, ident) = provider match { - case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty => + case _: SupportsCatalogOptions if cleanedUserSpecifiedSchema.nonEmpty => throw new IllegalArgumentException( s"$source does not support user specified schema. Please don't specify the schema.") case hasCatalog: SupportsCatalogOptions => @@ -290,7 +293,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { (catalog.loadTable(ident), Some(catalog), Some(ident)) case _ => // TODO: Non-catalog paths for DSV2 are currently not well defined. - val tbl = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) + val tbl = DataSourceV2Utils.getTableFromProvider( + provider, dsOptions, cleanedUserSpecifiedSchema) (tbl, None, None) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ @@ -312,13 +316,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } else { (paths, extraOptions) } + val cleanedUserSpecifiedSchema = userSpecifiedSchema + .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) // Code path for data source v1. sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = finalPaths, - userSpecifiedSchema = userSpecifiedSchema, + userSpecifiedSchema = cleanedUserSpecifiedSchema, className = source, options = finalOptions.originalMap).resolveRelation()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index bd9120a1fbe78..506ffbb4aadc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, Identifier, LookupCatalog, SupportsNamespaces, SupportsPartitionManagement, TableCatalog, TableChange, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, Identifier, LookupCatalog, SupportsNamespaces, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 -import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} +import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} /** * Resolves catalogs from the multi-part identifiers in SQL statements, and convert the statements @@ -50,9 +50,6 @@ class ResolveSessionCatalog( cols.foreach(c => failNullType(c.dataType)) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => - if (!DDLUtils.isHiveTable(v1Table.v1Table)) { - cols.foreach(c => failCharType(c.dataType)) - } cols.foreach { c => assertTopLevelColumn(c.name, "AlterTableAddColumnsCommand") if (!c.nullable) { @@ -62,7 +59,6 @@ class ResolveSessionCatalog( } AlterTableAddColumnsCommand(tbl.asTableIdentifier, cols.map(convertToStructField)) }.getOrElse { - cols.foreach(c => failCharType(c.dataType)) val changes = cols.map { col => TableChange.addColumn( col.name.toArray, @@ -81,7 +77,6 @@ class ResolveSessionCatalog( case Some(_: V1Table) => throw new AnalysisException("REPLACE COLUMNS is only supported with v2 tables.") case Some(table) => - cols.foreach(c => failCharType(c.dataType)) // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. val deleteChanges = table.schema.fieldNames.map { name => TableChange.deleteColumn(Array(name)) @@ -104,10 +99,6 @@ class ResolveSessionCatalog( a.dataType.foreach(failNullType) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => - if (!DDLUtils.isHiveTable(v1Table.v1Table)) { - a.dataType.foreach(failCharType) - } - if (a.column.length > 1) { throw new AnalysisException( "ALTER COLUMN with qualified column is only supported with v2 tables.") @@ -133,19 +124,13 @@ class ResolveSessionCatalog( s"Available: ${v1Table.schema.fieldNames.mkString(", ")}") } } - // Add Hive type string to metadata. - val cleanedDataType = HiveStringType.replaceCharType(dataType) - if (dataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, dataType.catalogString) - } val newColumn = StructField( colName, - cleanedDataType, + dataType, nullable = true, builder.build()) AlterTableChangeColumnCommand(tbl.asTableIdentifier, colName, newColumn) }.getOrElse { - a.dataType.foreach(failCharType) val colName = a.column.toArray val typeChange = a.dataType.map { newDataType => TableChange.updateColumnType(colName, newDataType) @@ -269,16 +254,12 @@ class ResolveSessionCatalog( assertNoNullTypeInSchema(c.tableSchema) val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { - if (!DDLUtils.isHiveTable(Some(provider))) { - assertNoCharTypeInSchema(c.tableSchema) - } val tableDesc = buildCatalogTable(tbl.asTableIdentifier, c.tableSchema, c.partitioning, c.bucketSpec, c.properties, provider, c.options, c.location, c.comment, c.ifNotExists) val mode = if (c.ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTable(tableDesc, mode, None) } else { - assertNoCharTypeInSchema(c.tableSchema) CreateV2Table( catalog.asTableCatalog, tbl.asIdentifier, @@ -328,7 +309,6 @@ class ResolveSessionCatalog( if (!isV2Provider(provider)) { throw new AnalysisException("REPLACE TABLE is only supported with v2 tables.") } else { - assertNoCharTypeInSchema(c.tableSchema) ReplaceTable( catalog.asTableCatalog, tbl.asIdentifier, @@ -716,17 +696,7 @@ class ResolveSessionCatalog( private def convertToStructField(col: QualifiedColType): StructField = { val builder = new MetadataBuilder col.comment.foreach(builder.putString("comment", _)) - - val cleanedDataType = HiveStringType.replaceCharType(col.dataType) - if (col.dataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, col.dataType.catalogString) - } - - StructField( - col.name.head, - cleanedDataType, - nullable = true, - builder.build()) + StructField(col.name.head, col.dataType, nullable = true, builder.build()) } private def isV2Provider(provider: String): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala new file mode 100644 index 0000000000000..35bb86f178eb1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryComparison, Expression, In, Literal, StringRPad} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.{CharType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * This rule applies char type padding in two places: + * 1. When reading values from column/field of type CHAR(N), right-pad the values to length N. + * 2. When comparing char type column/field with string literal or char type column/field, + * right-pad the shorter one to the longer length. + */ +object ApplyCharTypePadding extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + val padded = plan.resolveOperatorsUpWithNewOutput { + case r: LogicalRelation => + val projectList = CharVarcharUtils.charTypePadding(r.output) + if (projectList == r.output) { + r -> Nil + } else { + val cleanedOutput = r.output.map(CharVarcharUtils.cleanAttrMetadata) + val padded = Project(projectList, r.copy(output = cleanedOutput)) + padded -> r.output.zip(padded.output) + } + + case r: DataSourceV2Relation => + val projectList = CharVarcharUtils.charTypePadding(r.output) + if (projectList == r.output) { + r -> Nil + } else { + val cleanedOutput = r.output.map(CharVarcharUtils.cleanAttrMetadata) + val padded = Project(projectList, r.copy(output = cleanedOutput)) + padded -> r.output.zip(padded.output) + } + + case r: HiveTableRelation => + val projectList = CharVarcharUtils.charTypePadding(r.output) + if (projectList == r.output) { + r -> Nil + } else { + val cleanedDataCols = r.dataCols.map(CharVarcharUtils.cleanAttrMetadata) + val cleanedPartCols = r.partitionCols.map(CharVarcharUtils.cleanAttrMetadata) + val padded = Project(projectList, + r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols)) + padded -> r.output.zip(padded.output) + } + } + + padded.resolveOperatorsUp { + case operator if operator.resolved => operator.transformExpressionsUp { + // String literal is treated as char type when it's compared to a char type column. + // We should pad the shorter one to the longer length. + case b @ BinaryComparison(attr: Attribute, lit) if lit.foldable => + padAttrLitCmp(attr, lit).map { newChildren => + b.withNewChildren(newChildren) + }.getOrElse(b) + + case b @ BinaryComparison(lit, attr: Attribute) if lit.foldable => + padAttrLitCmp(attr, lit).map { newChildren => + b.withNewChildren(newChildren.reverse) + }.getOrElse(b) + + case i @ In(attr: Attribute, list) + if attr.dataType == StringType && list.forall(_.foldable) => + CharVarcharUtils.getRawType(attr.metadata).flatMap { + case CharType(length) => + val literalCharLengths = list.map(_.eval().asInstanceOf[UTF8String].numChars()) + val targetLen = (length +: literalCharLengths).max + Some(i.copy( + value = addPadding(attr, length, targetLen), + list = list.zip(literalCharLengths).map { + case (lit, charLength) => addPadding(lit, charLength, targetLen) + })) + case _ => None + }.getOrElse(i) + + // For char type column or inner field comparison, pad the shorter one to the longer length. + case b @ BinaryComparison(left: Attribute, right: Attribute) => + b.withNewChildren(CharVarcharUtils.addPaddingInStringComparison(Seq(left, right))) + + case i @ In(attr: Attribute, list) if list.forall(_.isInstanceOf[Attribute]) => + val newChildren = CharVarcharUtils.addPaddingInStringComparison( + attr +: list.map(_.asInstanceOf[Attribute])) + i.copy(value = newChildren.head, list = newChildren.tail) + } + } + } + + private def padAttrLitCmp(attr: Attribute, lit: Expression): Option[Seq[Expression]] = { + if (attr.dataType == StringType) { + CharVarcharUtils.getRawType(attr.metadata).flatMap { + case CharType(length) => + val str = lit.eval().asInstanceOf[UTF8String] + val stringLitLen = str.numChars() + if (length < stringLitLen) { + Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit)) + } else if (length > stringLitLen) { + Some(Seq(attr, StringRPad(lit, Literal(length)))) + } else { + None + } + case _ => None + } + } else { + None + } + } + + private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { + if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 33a3486bf6f67..0c6a80d441686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.sources.BaseRelation /** @@ -69,9 +69,17 @@ case class LogicalRelation( } object LogicalRelation { - def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation = - LogicalRelation(relation, relation.schema.toAttributes, None, isStreaming) + def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation = { + // The v1 source may return schema containing char/varchar type. We replace char/varchar + // with string type here as Spark's type system doesn't support char/varchar yet. + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) + LogicalRelation(relation, schema.toAttributes, None, isStreaming) + } - def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = - LogicalRelation(relation, relation.schema.toAttributes, Some(table), false) + def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = { + // The v1 source may return schema containing char/varchar type. We replace char/varchar + // with string type here as Spark's type system doesn't support char/varchar yet. + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) + LogicalRelation(relation, schema.toAttributes, Some(table), false) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 78f31fb80ecf6..6733aab947be6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -761,16 +761,6 @@ object JdbcUtils extends Logging { schema: StructType, caseSensitive: Boolean, createTableColumnTypes: String): Map[String, String] = { - def typeName(f: StructField): String = { - // char/varchar gets translated to string type. Real data type specified by the user - // is available in the field metadata as HIVE_TYPE_STRING - if (f.metadata.contains(HIVE_TYPE_STRING)) { - f.metadata.getString(HIVE_TYPE_STRING) - } else { - f.dataType.catalogString - } - } - val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) val nameEquality = if (caseSensitive) { org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution @@ -791,7 +781,7 @@ object JdbcUtils extends Logging { } } - val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap + val userSchemaMap = userSchema.fields.map(f => f.name -> f.dataType.catalogString).toMap if (caseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 538a5408723bb..a89a5de3b7e72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -189,6 +189,7 @@ abstract class BaseSessionStateBuilder( PreprocessTableCreation(session) +: PreprocessTableInsertion +: DataSourceAnalysis +: + ApplyCharTypePadding +: customPostHocResolutionRules override val extendedCheckRules: Seq[LogicalPlan => Unit] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9bc4acd49a980..b9a1a465d9e52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{SupportsRead, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils @@ -203,6 +203,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo extraOptions + ("path" -> path.get) } + val cleanedUserSpecifiedSchema = userSpecifiedSchema + .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) + val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf). getConstructor().newInstance() // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. @@ -210,7 +213,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo // writer or whether the query is continuous. val v1DataSource = DataSource( sparkSession, - userSpecifiedSchema = userSpecifiedSchema, + userSpecifiedSchema = cleanedUserSpecifiedSchema, className = source, options = optionsWithPath.originalMap) val v1Relation = ds match { @@ -225,7 +228,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) + val table = DataSourceV2Utils.getTableFromProvider( + provider, dsOptions, cleanedUserSpecifiedSchema) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala new file mode 100644 index 0000000000000..e192a63956232 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} + +trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { + + def format: String + + test("char type values should be padded: top-level columns") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format") + sql("INSERT INTO t VALUES ('1', 'a')") + checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) + } + } + + test("char type values should be padded: partitioned columns") { + // DS V2 doesn't support partitioned table. + if (!conf.contains(SQLConf.DEFAULT_CATALOG.key)) { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format PARTITIONED BY (c)") + sql("INSERT INTO t VALUES ('1', 'a')") + checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) + } + } + } + + test("char type values should be padded: nested in struct") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c STRUCT) USING $format") + sql("INSERT INTO t VALUES ('1', struct('a'))") + checkAnswer(spark.table("t"), Row("1", Row("a" + " " * 4))) + } + } + + test("char type values should be padded: nested in array") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c ARRAY) USING $format") + sql("INSERT INTO t VALUES ('1', array('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Seq("a" + " " * 4, "ab" + " " * 3))) + } + } + + test("char type values should be padded: nested in map key") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") + sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Map(("a" + " " * 4, "ab")))) + } + } + + test("char type values should be padded: nested in map value") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") + sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Map(("a", "ab" + " " * 3)))) + } + } + + test("char type values should be padded: nested in both map key and value") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") + sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Map(("a" + " " * 4, "ab" + " " * 8)))) + } + } + + test("char type values should be padded: nested in struct of array") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c STRUCT>) USING $format") + sql("INSERT INTO t VALUES ('1', struct(array('a', 'ab')))") + checkAnswer(spark.table("t"), Row("1", Row(Seq("a" + " " * 4, "ab" + " " * 3)))) + } + } + + test("char type values should be padded: nested in array of struct") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c ARRAY>) USING $format") + sql("INSERT INTO t VALUES ('1', array(struct('a'), struct('ab')))") + checkAnswer(spark.table("t"), Row("1", Seq(Row("a" + " " * 4), Row("ab" + " " * 3)))) + } + } + + test("char type values should be padded: nested in array of array") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c ARRAY>) USING $format") + sql("INSERT INTO t VALUES ('1', array(array('a', 'ab')))") + checkAnswer(spark.table("t"), Row("1", Seq(Seq("a" + " " * 4, "ab" + " " * 3)))) + } + } + + private def testTableWrite(f: String => Unit): Unit = { + withTable("t") { f("char") } + withTable("t") { f("varchar") } + } + + test("length check for input string values: top-level columns") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c $typeName(5)) USING $format") + sql("INSERT INTO t VALUES (null)") + checkAnswer(spark.table("t"), Row(null)) + val e = intercept[SparkException](sql("INSERT INTO t VALUES ('123456')")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: partitioned columns") { + // DS V2 doesn't support partitioned table. + if (!conf.contains(SQLConf.DEFAULT_CATALOG.key)) { + testTableWrite { typeName => + sql(s"CREATE TABLE t(i INT, c $typeName(5)) USING $format PARTITIONED BY (c)") + sql("INSERT INTO t VALUES (1, null)") + checkAnswer(spark.table("t"), Row(1, null)) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (1, '123456')")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + } + + test("length check for input string values: nested in struct") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c STRUCT) USING $format") + sql("INSERT INTO t SELECT struct(null)") + checkAnswer(spark.table("t"), Row(Row(null))) + val e = intercept[SparkException](sql("INSERT INTO t SELECT struct('123456')")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in array") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c ARRAY<$typeName(5)>) USING $format") + sql("INSERT INTO t VALUES (array(null))") + checkAnswer(spark.table("t"), Row(Seq(null))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (array('a', '123456'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in map key") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c MAP<$typeName(5), STRING>) USING $format") + val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in map value") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c MAP) USING $format") + sql("INSERT INTO t VALUES (map('a', null))") + checkAnswer(spark.table("t"), Row(Map("a" -> null))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in both map key and value") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c MAP<$typeName(5), $typeName(5)>) USING $format") + val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))")) + assert(e1.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))")) + assert(e2.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in struct of array") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c STRUCT>) USING $format") + sql("INSERT INTO t SELECT struct(array(null))") + checkAnswer(spark.table("t"), Row(Row(Seq(null)))) + val e = intercept[SparkException](sql("INSERT INTO t SELECT struct(array('123456'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in array of struct") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c ARRAY>) USING $format") + sql("INSERT INTO t VALUES (array(struct(null)))") + checkAnswer(spark.table("t"), Row(Seq(Row(null)))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(struct('123456')))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in array of array") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c ARRAY>) USING $format") + sql("INSERT INTO t VALUES (array(array(null)))") + checkAnswer(spark.table("t"), Row(Seq(Seq(null)))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(array('123456')))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: with trailing spaces") { + withTable("t") { + sql(s"CREATE TABLE t(c1 CHAR(5), c2 VARCHAR(5)) USING $format") + sql("INSERT INTO t VALUES ('12 ', '12 ')") + sql("INSERT INTO t VALUES ('1234 ', '1234 ')") + checkAnswer(spark.table("t"), Seq( + Row("12" + " " * 3, "12 "), + Row("1234 ", "1234 "))) + } + } + + test("length check for input string values: with implicit cast") { + withTable("t") { + sql(s"CREATE TABLE t(c1 CHAR(5), c2 VARCHAR(5)) USING $format") + sql("INSERT INTO t VALUES (1234, 1234)") + checkAnswer(spark.table("t"), Row("1234 ", "1234")) + val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (123456, 1)")) + assert(e1.getCause.getMessage.contains( + "input string '123456' exceeds char type length limitation: 5")) + val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (1, 123456)")) + assert(e2.getCause.getMessage.contains( + "input string '123456' exceeds varchar type length limitation: 5")) + } + } + + private def testConditions(df: DataFrame, conditions: Seq[(String, Boolean)]): Unit = { + checkAnswer(df.selectExpr(conditions.map(_._1): _*), Row.fromSeq(conditions.map(_._2))) + } + + test("char type comparison: top-level columns") { + withTable("t") { + sql(s"CREATE TABLE t(c1 CHAR(2), c2 CHAR(5)) USING $format") + sql("INSERT INTO t VALUES ('a', 'a')") + testConditions(spark.table("t"), Seq( + ("c1 = 'a'", true), + ("'a' = c1", true), + ("c1 = 'a '", true), + ("c1 > 'a'", false), + ("c1 IN ('a', 'b')", true), + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: partitioned columns") { + withTable("t") { + sql(s"CREATE TABLE t(i INT, c1 CHAR(2), c2 CHAR(5)) USING $format PARTITIONED BY (c1, c2)") + sql("INSERT INTO t VALUES (1, 'a', 'a')") + testConditions(spark.table("t"), Seq( + ("c1 = 'a'", true), + ("'a' = c1", true), + ("c1 = 'a '", true), + ("c1 > 'a'", false), + ("c1 IN ('a', 'b')", true), + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: join") { + withTable("t1", "t2") { + sql(s"CREATE TABLE t1(c CHAR(2)) USING $format") + sql(s"CREATE TABLE t2(c CHAR(5)) USING $format") + sql("INSERT INTO t1 VALUES ('a')") + sql("INSERT INTO t2 VALUES ('a')") + checkAnswer(sql("SELECT t1.c FROM t1 JOIN t2 ON t1.c = t2.c"), Row("a ")) + } + } + + test("char type comparison: nested in struct") { + withTable("t") { + sql(s"CREATE TABLE t(c1 STRUCT, c2 STRUCT) USING $format") + sql("INSERT INTO t VALUES (struct('a'), struct('a'))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in array") { + withTable("t") { + sql(s"CREATE TABLE t(c1 ARRAY, c2 ARRAY) USING $format") + sql("INSERT INTO t VALUES (array('a', 'b'), array('a', 'b'))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in struct of array") { + withTable("t") { + sql("CREATE TABLE t(c1 STRUCT>, c2 STRUCT>) " + + s"USING $format") + sql("INSERT INTO t VALUES (struct(array('a', 'b')), struct(array('a', 'b')))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in array of struct") { + withTable("t") { + sql("CREATE TABLE t(c1 ARRAY>, c2 ARRAY>) " + + s"USING $format") + sql("INSERT INTO t VALUES (array(struct('a')), array(struct('a')))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in array of array") { + withTable("t") { + sql("CREATE TABLE t(c1 ARRAY>, c2 ARRAY>) " + + s"USING $format") + sql("INSERT INTO t VALUES (array(array('a')), array(array('a')))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } +} + +class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSparkSession { + override def format: String = "parquet" + override protected def sparkConf: SparkConf = { + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet") + } +} + +class DSV2CharVarcharTestSuite extends CharVarcharTestSuite + with SharedSparkSession { + override def format: String = "foo" + protected override def sparkConf = { + super.sparkConf + .set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + .set(SQLConf.DEFAULT_CATALOG.key, "testcat") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index f5809ebbb836e..d40c6db5b6b3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.SimpleScanSource -import org.apache.spark.sql.types.{CharType, DoubleType, HIVE_TYPE_STRING, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType} +import org.apache.spark.sql.types.{CharType, DoubleType, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType} class PlanResolutionSuite extends AnalysisTest { import CatalystSqlParser._ @@ -1076,9 +1076,7 @@ class PlanResolutionSuite extends AnalysisTest { } val sql = s"ALTER TABLE v1HiveTable ALTER COLUMN i TYPE char(1)" - val builder = new MetadataBuilder - builder.putString(HIVE_TYPE_STRING, CharType(1).catalogString) - val newColumnWithCleanedType = StructField("i", StringType, true, builder.build()) + val newColumnWithCleanedType = StructField("i", CharType(1), true) val expected = AlterTableChangeColumnCommand( TableIdentifier("v1HiveTable", Some("default")), "i", newColumnWithCleanedType) val parsed = parseAndResolve(sql) @@ -1519,44 +1517,6 @@ class PlanResolutionSuite extends AnalysisTest { } } - test("SPARK-31147: forbid CHAR type in non-Hive tables") { - def checkFailure(t: String, provider: String): Unit = { - val types = Seq( - "CHAR(2)", - "ARRAY", - "MAP", - "MAP", - "STRUCT") - types.foreach { tpe => - intercept[AnalysisException] { - parseAndResolve(s"CREATE TABLE $t(col $tpe) USING $provider") - } - intercept[AnalysisException] { - parseAndResolve(s"REPLACE TABLE $t(col $tpe) USING $provider") - } - intercept[AnalysisException] { - parseAndResolve(s"CREATE OR REPLACE TABLE $t(col $tpe) USING $provider") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t ADD COLUMN col $tpe") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t ADD COLUMN col $tpe") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t ALTER COLUMN col TYPE $tpe") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t REPLACE COLUMNS (col $tpe)") - } - } - } - - checkFailure("v1Table", v1Format) - checkFailure("v2Table", v2Format) - checkFailure("testcat.tab", "foo") - } - // TODO: add tests for more commands. } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 9a95bf770772e..ca3e714665818 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -127,7 +128,7 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", - s"char_$i", + s"char_$i".padTo(18, ' '), Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -206,10 +207,6 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { (2 to 10).map(i => Row(i, i - 1)).toSeq) test("Schema and all fields") { - def hiveMetadata(dt: String): Metadata = { - new MetadataBuilder().putString(HIVE_TYPE_STRING, dt).build() - } - val expectedSchema = StructType( StructField("string$%Field", StringType, true) :: StructField("binaryField", BinaryType, true) :: @@ -224,8 +221,8 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { StructField("decimalField2", DecimalType(9, 2), true) :: StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: - StructField("varcharField", StringType, true, hiveMetadata("varchar(12)")) :: - StructField("charField", StringType, true, hiveMetadata("char(18)")) :: + StructField("varcharField", VarcharType(12), true) :: + StructField("charField", CharType(18), true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,7 +245,8 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { Nil ) - assert(expectedSchema == spark.table("tableWithSchema").schema) + assert(CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedSchema) == + spark.table("tableWithSchema").schema) withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { checkAnswer( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index b30492802495f..da37b61688951 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -90,6 +90,7 @@ class HiveSessionStateBuilder( PreprocessTableCreation(session) +: PreprocessTableInsertion +: DataSourceAnalysis +: + ApplyCharTypePadding +: HiveAnalysis +: customPostHocResolutionRules diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index a78e1cebc588c..a4816aeb4f291 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -977,19 +977,14 @@ private[hive] class HiveClientImpl( private[hive] object HiveClientImpl extends Logging { /** Converts the native StructField to Hive's FieldSchema. */ def toHiveColumn(c: StructField): FieldSchema = { - val typeString = if (c.metadata.contains(HIVE_TYPE_STRING)) { - c.metadata.getString(HIVE_TYPE_STRING) - } else { - // replace NullType to HiveVoidType since Hive parse void not null. - HiveVoidType.replaceVoidType(c.dataType).catalogString - } + val typeString = HiveVoidType.replaceVoidType(c.dataType).catalogString new FieldSchema(c.name, typeString, c.getComment().orNull) } /** Get the Spark SQL native DataType from Hive's FieldSchema. */ private def getSparkSQLDataType(hc: FieldSchema): DataType = { try { - CatalystSqlParser.parseDataType(hc.getType) + CatalystSqlParser.parseRawDataType(hc.getType) } catch { case e: ParseException => throw new SparkException( @@ -1000,18 +995,10 @@ private[hive] object HiveClientImpl extends Logging { /** Builds the native StructField from Hive's FieldSchema. */ def fromHiveColumn(hc: FieldSchema): StructField = { val columnType = getSparkSQLDataType(hc) - val replacedVoidType = HiveVoidType.replaceVoidType(columnType) - val metadata = if (hc.getType != replacedVoidType.catalogString) { - new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build() - } else { - Metadata.empty - } - val field = StructField( name = hc.getName, dataType = columnType, - nullable = true, - metadata = metadata) + nullable = true) Option(hc.getComment).map(field.withComment).getOrElse(field) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala new file mode 100644 index 0000000000000..55d305fda4f96 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveCharVarcharTestSuite extends CharVarcharTestSuite with TestHiveSingleton { + + // The default Hive serde doesn't support nested null values. + override def format: String = "hive OPTIONS(fileFormat='parquet')" + + private var originalPartitionMode = "" + + override protected def beforeAll(): Unit = { + super.beforeAll() + originalPartitionMode = spark.conf.get("hive.exec.dynamic.partition.mode", "") + spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict") + } + + override protected def afterAll(): Unit = { + if (originalPartitionMode == "") { + spark.conf.unset("hive.exec.dynamic.partition.mode") + } else { + spark.conf.set("hive.exec.dynamic.partition.mode", originalPartitionMode) + } + super.afterAll() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 8f71ba3337aa2..1a6f6843d3911 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -113,24 +113,19 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { .add("c9", "date") .add("c10", "timestamp") .add("c11", "string") - .add("c12", "string", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "char(10)").build()) - .add("c13", "string", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "varchar(10)").build()) + .add("c12", CharType(10), true) + .add("c13", VarcharType(10), true) .add("c14", "binary") .add("c15", "decimal") .add("c16", "decimal(10)") .add("c17", "decimal(10,2)") .add("c18", "array") .add("c19", "array") - .add("c20", "array", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "array").build()) + .add("c20", ArrayType(CharType(10)), true) .add("c21", "map") - .add("c22", "map", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "map").build()) + .add("c22", MapType(IntegerType, CharType(10)), true) .add("c23", "struct") - .add("c24", "struct", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "struct").build()) + .add("c24", new StructType().add("c", VarcharType(10)).add("d", "int"), true) assert(schema == expectedSchema) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 1f15bd685b239..2f594a8638199 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -2252,8 +2252,8 @@ class HiveDDLSuite ) sql("ALTER TABLE tab ADD COLUMNS (c5 char(10))") - assert(spark.table("tab").schema.find(_.name == "c5") - .get.metadata.getString("HIVE_TYPE_STRING") == "char(10)") + assert(spark.sharedState.externalCatalog.getTable("default", "tab") + .schema.find(_.name == "c5").get.dataType == CharType(10)) } } }