From 69adca50f12fde7deefc7bef2e31d459526cac3c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 27 Nov 2020 14:43:27 +0800 Subject: [PATCH] more fixes --- docs/sql-ref-datatypes.md | 2 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 4 +++- .../spark/sql/catalyst/util/CharVarcharUtils.scala | 5 +++-- .../org/apache/spark/sql/types/CharType.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 4 +--- .../org/apache/spark/sql/types/VarcharType.scala | 2 +- .../org/apache/spark/sql/DataFrameReader.scala | 14 ++++---------- .../spark/sql/streaming/DataStreamReader.scala | 10 +++------- .../apache/spark/sql/CharVarcharTestSuite.scala | 8 ++++++++ 9 files changed, 25 insertions(+), 26 deletions(-) diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index fa829623545e1..0087867a8c7f7 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -38,7 +38,7 @@ Spark SQL and DataFrames support the following data types: * String type - `StringType`: Represents character string values. - `VarcharType(length)`: A variant of `StringType` which has a length limitation. Data writing will fail if the input string exceeds the length limitation. Note: this type can only be used in table schema, not functions/operators. - - `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `VarcharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length. + - `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `CharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length. * Binary type - `BinaryType`: Represents byte sequence values. * Boolean type diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3d428b39a6db0..d173756a45f32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -99,7 +99,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { - withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList))) + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema( + StructType(visitColTypeList(ctx.colTypeList))) + withOrigin(ctx)(schema) } def parseRawDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index e8b09fd1247d2..0cbe5abdbbd7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -137,8 +137,9 @@ object CharVarcharUtils { } /** - * Returns an expression to apply write-side char type padding for the given expression. A string - * value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N) column/field. + * Returns an expression to apply write-side string length check for the given expression. A + * string value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N) + * column/field. */ def stringLengthCheck(expr: Expression, targetAttr: Attribute): Expression = { getRawType(targetAttr.metadata).map { rawType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala index b329b5a964c87..67ab1cc2f3321 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala @@ -25,7 +25,7 @@ import org.apache.spark.unsafe.types.UTF8String @Experimental case class CharType(length: Int) extends AtomicType { - require(length >= 0, "The length if char type cannot be negative.") + require(length >= 0, "The length of char type cannot be negative.") private[sql] type InternalType = UTF8String @transient private[sql] lazy val tag = typeTag[InternalType] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 73022de572747..e4ee6eb377a4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.internal.SQLConf @@ -133,8 +132,7 @@ object DataType { ddl, CatalystSqlParser.parseDataType, "Cannot parse the data type: ", - fallbackParser = str => CharVarcharUtils.replaceCharVarcharWithStringInSchema( - CatalystSqlParser.parseTableSchema(str))) + fallbackParser = str => CatalystSqlParser.parseTableSchema(str)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala index dd52b76ee2783..8d78640c1e125 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala @@ -24,7 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String @Experimental case class VarcharType(length: Int) extends AtomicType { - require(length >= 0, "The length if varchar type cannot be negative.") + require(length >= 0, "The length of varchar type cannot be negative.") private[sql] type InternalType = UTF8String @transient private[sql] lazy val tag = typeTag[InternalType] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3b5532ccb910f..49b3335bf1769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -73,7 +73,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def schema(schema: StructType): DataFrameReader = { - this.userSpecifiedSchema = Option(schema) + this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)) this } @@ -274,14 +274,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { extraOptions + ("paths" -> objectMapper.writeValueAsString(paths.toArray)) } - val cleanedUserSpecifiedSchema = userSpecifiedSchema - .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) - val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val (table, catalog, ident) = provider match { - case _: SupportsCatalogOptions if cleanedUserSpecifiedSchema.nonEmpty => + case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty => throw new IllegalArgumentException( s"$source does not support user specified schema. Please don't specify the schema.") case hasCatalog: SupportsCatalogOptions => @@ -293,8 +290,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { (catalog.loadTable(ident), Some(catalog), Some(ident)) case _ => // TODO: Non-catalog paths for DSV2 are currently not well defined. - val tbl = DataSourceV2Utils.getTableFromProvider( - provider, dsOptions, cleanedUserSpecifiedSchema) + val tbl = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) (tbl, None, None) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ @@ -316,15 +312,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } else { (paths, extraOptions) } - val cleanedUserSpecifiedSchema = userSpecifiedSchema - .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) // Code path for data source v1. sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = finalPaths, - userSpecifiedSchema = cleanedUserSpecifiedSchema, + userSpecifiedSchema = userSpecifiedSchema, className = source, options = finalOptions.originalMap).resolveRelation()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index b9a1a465d9e52..4e755682242d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -64,7 +64,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * @since 2.0.0 */ def schema(schema: StructType): DataStreamReader = { - this.userSpecifiedSchema = Option(schema) + this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)) this } @@ -203,9 +203,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo extraOptions + ("path" -> path.get) } - val cleanedUserSpecifiedSchema = userSpecifiedSchema - .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) - val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf). getConstructor().newInstance() // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. @@ -213,7 +210,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo // writer or whether the query is continuous. val v1DataSource = DataSource( sparkSession, - userSpecifiedSchema = cleanedUserSpecifiedSchema, + userSpecifiedSchema = userSpecifiedSchema, className = source, options = optionsWithPath.originalMap) val v1Relation = ds match { @@ -228,8 +225,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = DataSourceV2Utils.getTableFromProvider( - provider, dsOptions, cleanedUserSpecifiedSchema) + val table = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index d5100c237f732..abb13270d20e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -448,6 +448,14 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { assert(schema.map(_.dataType) == Seq(StringType)) } + test("user-specified schema in DataFrameReader: file source from Dataset") { + val ds = spark.range(10).map(_.toString) + val df1 = spark.read.schema(new StructType().add("id", CharType(5))).csv(ds) + assert(df1.schema.map(_.dataType) == Seq(StringType)) + val df2 = spark.read.schema("id char(5)").csv(ds) + assert(df2.schema.map(_.dataType) == Seq(StringType)) + } + test("user-specified schema in DataFrameReader: DSV1") { def checkSchema(df: DataFrame): Unit = { val relations = df.queryExecution.analyzed.collect {