Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ public UTF8String trimRight() {
* Trims at most `numSpaces` space characters (ASCII 32) from the end of this string.
*/
public UTF8String trimTrailingSpaces(int numSpaces) {
assert numSpaces > 0;
int endIdx = numBytes - 1;
int trimTo = numBytes - numSpaces;
while (endIdx >= trimTo && getByte(endIdx) == 0x20) endIdx--;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,55 +22,34 @@
public class CharVarcharCodegenUtils {
private static final UTF8String SPACE = UTF8String.fromString(" ");

/**
* Trailing spaces do not count in the length check. We don't need to retain the trailing
* spaces, as we will pad char type columns/fields at read time.
*/
public static UTF8String charTypeWriteSideCheck(UTF8String inputStr, int limit) {
if (inputStr == null) {
return null;
private static UTF8String trimTrailingSpaces(
UTF8String inputStr, int numChars, int limit) {
int numTailSpacesToTrim = numChars - limit;
UTF8String trimmed = inputStr.trimTrailingSpaces(numTailSpacesToTrim);
if (trimmed.numChars() > limit) {
throw new RuntimeException("Exceeds char/varchar type length limitation: " + limit);
} else {
UTF8String trimmed = inputStr.trimRight();
if (trimmed.numChars() > limit) {
throw new RuntimeException("Exceeds char type length limitation: " + limit);
}
return trimmed;
}
}

public static UTF8String charTypeReadSideCheck(UTF8String inputStr, int limit) {
if (inputStr == null) return null;
if (inputStr.numChars() > limit) {
throw new RuntimeException("Exceeds char type length limitation: " + limit);
public static UTF8String charTypeWriteSideCheck(UTF8String inputStr, int limit) {
int numChars = inputStr.numChars();
if (numChars == limit) {
return inputStr;
} else if (numChars < limit) {
return inputStr.rpad(limit, SPACE);
} else {
return trimTrailingSpaces(inputStr, numChars, limit);
}
return inputStr.rpad(limit, SPACE);
}

public static UTF8String varcharTypeWriteSideCheck(UTF8String inputStr, int limit) {
if (inputStr == null) {
return null;
int numChars = inputStr.numChars();
if (numChars <= limit) {
return inputStr;
} else {
int numChars = inputStr.numChars();
if (numChars <= limit) {
return inputStr;
} else {
// Trailing spaces do not count in the length check. We need to retain the trailing spaces
// (truncate to length N), as there is no read-time padding for varchar type.
int maxAllowedNumTailSpaces = numChars - limit;
UTF8String trimmed = inputStr.trimTrailingSpaces(maxAllowedNumTailSpaces);
if (trimmed.numChars() > limit) {
throw new RuntimeException("Exceeds varchar type length limitation: " + limit);
} else {
return trimmed;
}
}
}
}

public static UTF8String varcharTypeReadSideCheck(UTF8String inputStr, int limit) {
if (inputStr != null && inputStr.numChars() > limit) {
throw new RuntimeException("Exceeds varchar type length limitation: " + limit);
return trimTrailingSpaces(inputStr, numChars, limit);
}
return inputStr;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -279,6 +280,8 @@ class Analyzer(override val catalogManager: CatalogManager)
ResolveUnion ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Apply Char Padding", Once,
ApplyCharTypePadding),
Batch("Post-Hoc Resolution", Once,
Seq(ResolveCommandsWithIfExists) ++
postHocResolutionRules: _*),
Expand Down Expand Up @@ -3926,3 +3929,77 @@ object UpdateOuterReferences extends Rule[LogicalPlan] {
}
}
}

/**
* This rule performs string padding for char type comparison.
*
* When comparing char type column/field with string literal or char type column/field,
* right-pad the shorter one to the longer length.
*/
object ApplyCharTypePadding extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsUp {
case operator if operator.resolved => operator.transformExpressionsUp {
// String literal is treated as char type when it's compared to a char type column.
// We should pad the shorter one to the longer length.
case b @ BinaryComparison(attr: Attribute, lit) if lit.foldable =>
padAttrLitCmp(attr, lit).map { newChildren =>
b.withNewChildren(newChildren)
}.getOrElse(b)

case b @ BinaryComparison(lit, attr: Attribute) if lit.foldable =>
padAttrLitCmp(attr, lit).map { newChildren =>
b.withNewChildren(newChildren.reverse)
}.getOrElse(b)

case i @ In(attr: Attribute, list)
if attr.dataType == StringType && list.forall(_.foldable) =>
CharVarcharUtils.getRawType(attr.metadata).flatMap {
case CharType(length) =>
val literalCharLengths = list.map(_.eval().asInstanceOf[UTF8String].numChars())
val targetLen = (length +: literalCharLengths).max
Some(i.copy(
value = addPadding(attr, length, targetLen),
list = list.zip(literalCharLengths).map {
case (lit, charLength) => addPadding(lit, charLength, targetLen)
}))
case _ => None
}.getOrElse(i)

// For char type column or inner field comparison, pad the shorter one to the longer length.
case b @ BinaryComparison(left: Attribute, right: Attribute) =>
b.withNewChildren(CharVarcharUtils.addPaddingInStringComparison(Seq(left, right)))

case i @ In(attr: Attribute, list) if list.forall(_.isInstanceOf[Attribute]) =>
val newChildren = CharVarcharUtils.addPaddingInStringComparison(
attr +: list.map(_.asInstanceOf[Attribute]))
i.copy(value = newChildren.head, list = newChildren.tail)
}
}
}

private def padAttrLitCmp(attr: Attribute, lit: Expression): Option[Seq[Expression]] = {
if (attr.dataType == StringType) {
CharVarcharUtils.getRawType(attr.metadata).flatMap {
case CharType(length) =>
val str = lit.eval().asInstanceOf[UTF8String]
val stringLitLen = str.numChars()
if (length < stringLitLen) {
Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit))
} else if (length > stringLitLen) {
Some(Seq(attr, StringRPad(lit, Literal(length))))
} else {
None
}
case _ => None
}
} else {
None
}
}

private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = {
if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] {
case unresolvedPartSpec: UnresolvedPartitionSpec =>
val normalizedSpec = normalizePartitionSpec(
unresolvedPartSpec.spec,
partSchema.map(_.name),
partSchema,
tableName,
conf.resolver)
checkSpec(normalizedSpec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object CharVarcharUtils extends Logging {
StructType(st.map { field =>
if (hasCharVarchar(field.dataType)) {
val metadata = new MetadataBuilder().withMetadata(field.metadata)
.putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, field.dataType.sql).build()
.putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, field.dataType.catalogString).build()
field.copy(dataType = replaceCharVarcharWithString(field.dataType), metadata = metadata)
} else {
field
Expand Down Expand Up @@ -114,17 +114,20 @@ object CharVarcharUtils extends Logging {
attr.withMetadata(cleaned)
}

def getRawTypeString(metadata: Metadata): Option[String] = {
if (metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)) {
Some(metadata.getString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY))
} else {
None
}
}

/**
* Re-construct the original data type from the type string in the given metadata.
* This is needed when dealing with char/varchar columns/fields.
*/
def getRawType(metadata: Metadata): Option[DataType] = {
if (metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)) {
Some(CatalystSqlParser.parseDataType(
metadata.getString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)))
} else {
None
}
getRawTypeString(metadata).map(CatalystSqlParser.parseDataType)
}

/**
Expand All @@ -137,73 +140,6 @@ object CharVarcharUtils extends Logging {
StructType(fields)
}

/**
* Returns expressions to apply read-side char type padding for the given attributes.
*
* For a CHAR(N) column/field and the length of string value is M
* If M > N, raise runtime error
* If M <= N, the value should be right-padded to N characters.
*
* For a VARCHAR(N) column/field and the length of string value is M
* If M > N, raise runtime error
* If M <= N, the value should be remained.
*/
def paddingWithLengthCheck(output: Seq[AttributeReference]): Seq[NamedExpression] = {
output.map { attr =>
getRawType(attr.metadata).filter { rawType =>
rawType.existsRecursively(dt => dt.isInstanceOf[CharType] || dt.isInstanceOf[VarcharType])
}.map { rawType =>
Alias(paddingWithLengthCheck(attr, rawType), attr.name)(
explicitMetadata = Some(attr.metadata))
}.getOrElse(attr)
}
}

private def paddingWithLengthCheck(expr: Expression, dt: DataType): Expression = dt match {
case CharType(length) =>
StaticInvoke(
classOf[CharVarcharCodegenUtils],
StringType,
"charTypeReadSideCheck",
expr :: Literal(length) :: Nil,
propagateNull = false)

case VarcharType(length) =>
StaticInvoke(
classOf[CharVarcharCodegenUtils],
StringType,
"varcharTypeReadSideCheck",
expr :: Literal(length) :: Nil,
propagateNull = false)

case StructType(fields) =>
val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) =>
Seq(Literal(f.name),
paddingWithLengthCheck(GetStructField(expr, i, Some(f.name)), f.dataType))
})
if (expr.nullable) {
If(IsNull(expr), Literal(null, struct.dataType), struct)
} else {
struct
}

case ArrayType(et, containsNull) => charTypePaddingInArray(expr, et, containsNull)

case MapType(kt, vt, valueContainsNull) =>
val newKeys = charTypePaddingInArray(MapKeys(expr), kt, containsNull = false)
val newValues = charTypePaddingInArray(MapValues(expr), vt, valueContainsNull)
MapFromArrays(newKeys, newValues)

case _ => expr
}

private def charTypePaddingInArray(
arr: Expression, et: DataType, containsNull: Boolean): Expression = {
val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull)
val func = LambdaFunction(paddingWithLengthCheck(param, et), Seq(param))
ArrayTransform(arr, func)
}

/**
* Returns an expression to apply write-side string length check for the given expression. A
* string value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N)
Expand All @@ -223,15 +159,15 @@ object CharVarcharUtils extends Logging {
StringType,
"charTypeWriteSideCheck",
expr :: Literal(length) :: Nil,
propagateNull = false)
returnNullable = false)

case VarcharType(length) =>
StaticInvoke(
classOf[CharVarcharCodegenUtils],
StringType,
"varcharTypeWriteSideCheck",
expr :: Literal(length) :: Nil,
propagateNull = false)
returnNullable = false)

case StructType(fields) =>
val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ package org.apache.spark.sql.util
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.types.{CharType, StructType, VarcharType}
import org.apache.spark.unsafe.types.UTF8String

object PartitioningUtils {
/**
Expand All @@ -30,14 +35,33 @@ object PartitioningUtils {
*/
def normalizePartitionSpec[T](
partitionSpec: Map[String, T],
partColNames: Seq[String],
partCols: StructType,
tblName: String,
resolver: Resolver): Map[String, T] = {
val rawSchema = CharVarcharUtils.getRawSchema(partCols)
val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) =>
val normalizedKey = partColNames.find(resolver(_, key)).getOrElse {
val normalizedFiled = rawSchema.find(f => resolver(f.name, key)).getOrElse {
throw new AnalysisException(s"$key is not a valid partition column in table $tblName.")
}
normalizedKey -> value

val normalizedVal = normalizedFiled.dataType match {
case CharType(len) if value != null && value != DEFAULT_PARTITION_NAME =>
val v = value match {
case Some(str: String) => Some(charTypeWriteSideCheck(str, len))
case str: String => charTypeWriteSideCheck(str, len)
case other => other
}
v.asInstanceOf[T]
case VarcharType(len) if value != null && value != DEFAULT_PARTITION_NAME =>
val v = value match {
case Some(str: String) => Some(varcharTypeWriteSideCheck(str, len))
case str: String => varcharTypeWriteSideCheck(str, len)
case other => other
}
v.asInstanceOf[T]
case _ => value
}
normalizedFiled.name -> normalizedVal
}

SchemaUtils.checkColumnNameDuplication(
Expand All @@ -46,6 +70,16 @@ object PartitioningUtils {
normalizedPartSpec.toMap
}

private def charTypeWriteSideCheck(inputStr: String, limit: Int): String = {
val toUtf8 = UTF8String.fromString(inputStr)
CharVarcharCodegenUtils.charTypeWriteSideCheck(toUtf8, limit).toString
}

private def varcharTypeWriteSideCheck(inputStr: String, limit: Int): String = {
val toUtf8 = UTF8String.fromString(inputStr)
CharVarcharCodegenUtils.varcharTypeWriteSideCheck(toUtf8, limit).toString
}

/**
* Verify if the input partition spec exactly matches the existing defined partition spec
* The columns must be the same but the orders could be different.
Expand Down
Loading