Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Nov 27, 2020
1 parent 671471f commit 69adca5
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/sql-ref-datatypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Spark SQL and DataFrames support the following data types:
* String type
- `StringType`: Represents character string values.
- `VarcharType(length)`: A variant of `StringType` which has a length limitation. Data writing will fail if the input string exceeds the length limitation. Note: this type can only be used in table schema, not functions/operators.
- `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `VarcharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length.
- `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `CharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length.
* Binary type
- `BinaryType`: Represents byte sequence values.
* Boolean type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
}

override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = {
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(
StructType(visitColTypeList(ctx.colTypeList)))
withOrigin(ctx)(schema)
}

def parseRawDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ object CharVarcharUtils {
}

/**
* Returns an expression to apply write-side char type padding for the given expression. A string
* value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N) column/field.
* Returns an expression to apply write-side string length check for the given expression. A
* string value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N)
* column/field.
*/
def stringLengthCheck(expr: Expression, targetAttr: Attribute): Expression = {
getRawType(targetAttr.metadata).map { rawType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.unsafe.types.UTF8String

@Experimental
case class CharType(length: Int) extends AtomicType {
require(length >= 0, "The length if char type cannot be negative.")
require(length >= 0, "The length of char type cannot be negative.")

private[sql] type InternalType = UTF8String
@transient private[sql] lazy val tag = typeTag[InternalType]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer}
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -133,8 +132,7 @@ object DataType {
ddl,
CatalystSqlParser.parseDataType,
"Cannot parse the data type: ",
fallbackParser = str => CharVarcharUtils.replaceCharVarcharWithStringInSchema(
CatalystSqlParser.parseTableSchema(str)))
fallbackParser = str => CatalystSqlParser.parseTableSchema(str))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String

@Experimental
case class VarcharType(length: Int) extends AtomicType {
require(length >= 0, "The length if varchar type cannot be negative.")
require(length >= 0, "The length of varchar type cannot be negative.")

private[sql] type InternalType = UTF8String
@transient private[sql] lazy val tag = typeTag[InternalType]
Expand Down
14 changes: 4 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def schema(schema: StructType): DataFrameReader = {
this.userSpecifiedSchema = Option(schema)
this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema))
this
}

Expand Down Expand Up @@ -274,14 +274,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
extraOptions + ("paths" -> objectMapper.writeValueAsString(paths.toArray))
}

val cleanedUserSpecifiedSchema = userSpecifiedSchema
.map(CharVarcharUtils.replaceCharVarcharWithStringInSchema)

val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++
optionsWithPath.originalMap
val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
val (table, catalog, ident) = provider match {
case _: SupportsCatalogOptions if cleanedUserSpecifiedSchema.nonEmpty =>
case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty =>
throw new IllegalArgumentException(
s"$source does not support user specified schema. Please don't specify the schema.")
case hasCatalog: SupportsCatalogOptions =>
Expand All @@ -293,8 +290,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
(catalog.loadTable(ident), Some(catalog), Some(ident))
case _ =>
// TODO: Non-catalog paths for DSV2 are currently not well defined.
val tbl = DataSourceV2Utils.getTableFromProvider(
provider, dsOptions, cleanedUserSpecifiedSchema)
val tbl = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema)
(tbl, None, None)
}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
Expand All @@ -316,15 +312,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
} else {
(paths, extraOptions)
}
val cleanedUserSpecifiedSchema = userSpecifiedSchema
.map(CharVarcharUtils.replaceCharVarcharWithStringInSchema)

// Code path for data source v1.
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = finalPaths,
userSpecifiedSchema = cleanedUserSpecifiedSchema,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = finalOptions.originalMap).resolveRelation())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* @since 2.0.0
*/
def schema(schema: StructType): DataStreamReader = {
this.userSpecifiedSchema = Option(schema)
this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema))
this
}

Expand Down Expand Up @@ -203,17 +203,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
extraOptions + ("path" -> path.get)
}

val cleanedUserSpecifiedSchema = userSpecifiedSchema
.map(CharVarcharUtils.replaceCharVarcharWithStringInSchema)

val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).
getConstructor().newInstance()
// We need to generate the V1 data source so we can pass it to the V2 relation as a shim.
// We can't be sure at this point whether we'll actually want to use V2, since we don't know the
// writer or whether the query is continuous.
val v1DataSource = DataSource(
sparkSession,
userSpecifiedSchema = cleanedUserSpecifiedSchema,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = optionsWithPath.originalMap)
val v1Relation = ds match {
Expand All @@ -228,8 +225,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++
optionsWithPath.originalMap
val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
val table = DataSourceV2Utils.getTableFromProvider(
provider, dsOptions, cleanedUserSpecifiedSchema)
val table = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema)
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
table match {
case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,14 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession {
assert(schema.map(_.dataType) == Seq(StringType))
}

test("user-specified schema in DataFrameReader: file source from Dataset") {
val ds = spark.range(10).map(_.toString)
val df1 = spark.read.schema(new StructType().add("id", CharType(5))).csv(ds)
assert(df1.schema.map(_.dataType) == Seq(StringType))
val df2 = spark.read.schema("id char(5)").csv(ds)
assert(df2.schema.map(_.dataType) == Seq(StringType))
}

test("user-specified schema in DataFrameReader: DSV1") {
def checkSchema(df: DataFrame): Unit = {
val relations = df.queryExecution.analyzed.collect {
Expand Down

0 comments on commit 69adca5

Please sign in to comment.