Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37047][SQL] Add lpad and rpad functions for binary strings #34154

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -101,4 +101,82 @@ public static byte[] concat(byte[]... inputs) {
}
return result;
}

// Helper method for implementing `lpad` and `rpad`.
// If the padding pattern's length is 0, return the first `len` bytes of the input byte
// sequence if it is longer than `len` bytes, or a copy of the byte sequence, otherwise.
private static byte[] padWithEmptyPattern(byte[] bytes, int len) {
len = Math.min(bytes.length, len);
final byte[] result = new byte[len];
Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, result, Platform.BYTE_ARRAY_OFFSET, len);
return result;
}

// Helper method for implementing `lpad` and `rpad`.
// Fills the resulting byte sequence with the pattern. The resulting byte sequence is
// passed as the first argument and it is filled from position `firstPos` (inclusive)
// to position `beyondPos` (not inclusive).
private static void fillWithPattern(byte[] result, int firstPos, int beyondPos, byte[] pad) {
for (int pos = firstPos; pos < beyondPos; pos += pad.length) {
final int jMax = Math.min(pad.length, beyondPos - pos);
for (int j = 0; j < jMax; ++j) {
result[pos + j] = (byte) pad[j];
}
}
}

// Left-pads the input byte sequence using the provided padding pattern.
// In the special case that the padding pattern is empty, the resulting byte sequence
// contains the first `len` bytes of the input if they exist, or is a copy of the input
// byte sequence otherwise.
// For padding patterns with positive byte length, the resulting byte sequence's byte length is
// equal to `len`. If the input byte sequence is not less than `len` bytes, its first `len` bytes
// are returned. Otherwise, the remaining missing bytes are filled in with the provided pattern.
public static byte[] lpad(byte[] bytes, int len, byte[] pad) {
mkaravel marked this conversation as resolved.
Show resolved Hide resolved
if (bytes == null || pad == null) return null;
// If the input length is 0, return the empty byte sequence.
if (len == 0) return EMPTY_BYTE;
// The padding pattern is empty.
if (pad.length == 0) return padWithEmptyPattern(bytes, len);
// The general case.
// 1. Copy the first `len` bytes of the input byte sequence into the output if they exist.
final byte[] result = new byte[len];
final int minLen = Math.min(len, bytes.length);
Platform.copyMemory(
bytes, Platform.BYTE_ARRAY_OFFSET,
result, Platform.BYTE_ARRAY_OFFSET + len - minLen,
minLen);
// 2. If the input has less than `len` bytes, fill in the rest using the provided pattern.
if (bytes.length < len) {
fillWithPattern(result, 0, len - bytes.length, pad);
}
return result;
}

// Right-pads the input byte sequence using the provided padding pattern.
// In the special case that the padding pattern is empty, the resulting byte sequence
// contains the first `len` bytes of the input if they exist, or is a copy of the input
// byte sequence otherwise.
// For padding patterns with positive byte length, the resulting byte sequence's byte length is
// equal to `len`. If the input byte sequence is not less than `len` bytes, its first `len` bytes
// are returned. Otherwise, the remaining missing bytes are filled in with the provided pattern.
public static byte[] rpad(byte[] bytes, int len, byte[] pad) {
if (bytes == null || pad == null) return null;
// If the input length is 0, return the empty byte sequence.
if (len == 0) return EMPTY_BYTE;
// The padding pattern is empty.
if (pad.length == 0) return padWithEmptyPattern(bytes, len);
// The general case.
// 1. Copy the first `len` bytes of the input sequence into the output if they exist.
final byte[] result = new byte[len];
Platform.copyMemory(
bytes, Platform.BYTE_ARRAY_OFFSET,
result, Platform.BYTE_ARRAY_OFFSET,
Math.min(len, bytes.length));
// 2. If the input has less than `len` bytes, fill in the rest using the provided pattern.
if (bytes.length < len) {
fillWithPattern(result, bytes.length, len, pad);
}
return result;
}
}
2 changes: 2 additions & 0 deletions docs/sql-migration-guide.md
Expand Up @@ -24,6 +24,8 @@ license: |

## Upgrading from Spark SQL 3.2 to 3.3

- Since Spark 3.3, the functions `lpad` and `rpad` have been overloaded to support byte sequences. When the first argument is a byte sequence, the optional padding pattern must also be a byte sequence and the result is a BINARY value. The default padding pattern in this case is the zero byte.

- Since Spark 3.3, Spark turns a non-nullable schema into nullable for API `DataFrameReader.schema(schema: StructType).json(jsonDataset: Dataset[String])` and `DataFrameReader.schema(schema: StructType).csv(csvDataset: Dataset[String])` when the schema is specified by the user and contains non-nullable fields.

- Since Spark 3.3, when the date or timestamp pattern is not specified, Spark converts an input string to a date/timestamp using the `CAST` expression approach. The changes affect CSV/JSON datasources and parsing of partition values. In Spark 3.2 or earlier, when the date or timestamp pattern is not set, Spark uses the default patterns: `yyyy-MM-dd` for dates and `yyyy-MM-dd HH:mm:ss` for timestamps. After the changes, Spark still recognizes the pattern together with
Expand Down
Expand Up @@ -1328,14 +1328,31 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)

}

/**
* Helper class for implementing StringLPad and StringRPad.
* Returns the default expression to be used in StringLPad or StringRPad based on the type of
* the input expression.
* For character string expressions the default padding expression is the string literal ' '.
* For byte sequence expressions the default padding expression is the byte literal 0x00.
*/
object StringPadDefaultValue {
def get(str: Expression): Expression = {
str.dataType match {
case StringType => Literal(" ")
case BinaryType => Literal(Array[Byte](0))
}
}
}

/**
* Returns str, left-padded with pad to a length of len.
*/
@ExpressionDescription(
usage = """
_FUNC_(str, len[, pad]) - Returns `str`, left-padded with `pad` to a length of `len`.
If `str` is longer than `len`, the return value is shortened to `len` characters.
If `pad` is not specified, `str` will be padded to the left with space characters.
If `str` is longer than `len`, the return value is shortened to `len` characters or bytes.
If `pad` is not specified, `str` will be padded to the left with space characters if it is
a character string, and with zeros if it is a byte sequence.
""",
examples = """
Examples:
Expand All @@ -1345,28 +1362,44 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
h
> SELECT _FUNC_('hi', 5);
hi
> SELECT hex(_FUNC_(unhex('aabb'), 5));
000000AABB
> SELECT hex(_FUNC_(unhex('aabb'), 5, unhex('1122')));
112211AABB
""",
since = "1.5.0",
group = "string_funcs")
case class StringLPad(str: Expression, len: Expression, pad: Expression = Literal(" "))
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

def this(str: Expression, len: Expression) = {
this(str, len, Literal(" "))
this(str, len, StringPadDefaultValue.get(str))
}

override def first: Expression = str
override def second: Expression = len
override def third: Expression = pad
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)

override def nullSafeEval(str: Any, len: Any, pad: Any): Any = {
str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
override def dataType: DataType = str.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringType, BinaryType), IntegerType, TypeCollection(StringType, BinaryType))

override def nullSafeEval(string: Any, len: Any, pad: Any): Any = {
str.dataType match {
case StringType => string.asInstanceOf[UTF8String]
.lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
case BinaryType => ByteArray.lpad(string.asInstanceOf[Array[Byte]],
len.asInstanceOf[Int], pad.asInstanceOf[Array[Byte]])
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)")
defineCodeGen(ctx, ev, (string, len, pad) => {
str.dataType match {
case StringType => s"$string.lpad($len, $pad)"
case BinaryType => s"${classOf[ByteArray].getName}.lpad($string, $len, $pad)"
}
})
}

override def prettyName: String = "lpad"
Expand All @@ -1383,7 +1416,8 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera
usage = """
_FUNC_(str, len[, pad]) - Returns `str`, right-padded with `pad` to a length of `len`.
If `str` is longer than `len`, the return value is shortened to `len` characters.
If `pad` is not specified, `str` will be padded to the right with space characters.
If `pad` is not specified, `str` will be padded to the right with space characters if it is
a character string, and with zeros if it is a binary string.
""",
examples = """
Examples:
Expand All @@ -1393,29 +1427,44 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera
h
> SELECT _FUNC_('hi', 5);
hi
> SELECT hex(_FUNC_(unhex('aabb'), 5));
AABB000000
> SELECT hex(_FUNC_(unhex('aabb'), 5, unhex('1122')));
AABB112211
""",
since = "1.5.0",
group = "string_funcs")
case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" "))
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

def this(str: Expression, len: Expression) = {
this(str, len, Literal(" "))
this(str, len, StringPadDefaultValue.get(str))
}

override def first: Expression = str
override def second: Expression = len
override def third: Expression = pad

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)
override def dataType: DataType = str.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringType, BinaryType), IntegerType, TypeCollection(StringType, BinaryType))

override def nullSafeEval(str: Any, len: Any, pad: Any): Any = {
str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
override def nullSafeEval(string: Any, len: Any, pad: Any): Any = {
str.dataType match {
case StringType => string.asInstanceOf[UTF8String]
.rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
case BinaryType => ByteArray.rpad(string.asInstanceOf[Array[Byte]],
len.asInstanceOf[Int], pad.asInstanceOf[Array[Byte]])
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)")
defineCodeGen(ctx, ev, (string, len, pad) => {
str.dataType match {
case StringType => s"$string.rpad($len, $pad)"
case BinaryType => s"${classOf[ByteArray].getName}.rpad($string, $len, $pad)"
}
})
}

override def prettyName: String = "rpad"
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -2715,6 +2715,17 @@ object functions {
StringLPad(str.expr, lit(len).expr, lit(pad).expr)
}

/**
* Left-pad the binary column with pad to a byte length of len. If the binary column is longer
* than len, the return value is shortened to len bytes.
*
* @group string_funcs
* @since 3.3.0
*/
def lpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr {
StringLPad(str.expr, lit(len).expr, lit(pad).expr)
}

/**
* Trim the spaces from left end for the specified string value.
*
Expand Down Expand Up @@ -2793,6 +2804,17 @@ object functions {
StringRPad(str.expr, lit(len).expr, lit(pad).expr)
}

/**
* Right-pad the binary column with pad to a byte length of len. If the binary column is longer
* than len, the return value is shortened to len bytes.
*
* @group string_funcs
* @since 3.3.0
*/
def rpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr {
StringRPad(str.expr, lit(len).expr, lit(pad).expr)
}

/**
* Repeats a string column n times, and returns it as a new string column.
*
Expand Down
28 changes: 28 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
Expand Up @@ -62,6 +62,34 @@ SELECT btrim(encode('xxxbarxxx', 'utf-8'), encode('x', 'utf-8'));
SELECT lpad('hi', 'invalid_length');
SELECT rpad('hi', 'invalid_length');

-- lpad for BINARY inputs
SELECT hex(lpad(unhex(''), 5));
SELECT hex(lpad(unhex('aabb'), 5));
SELECT hex(lpad(unhex('aabbcc'), 2));
SELECT hex(lpad(unhex(''), 5, unhex('1f')));
SELECT hex(lpad(unhex('aa'), 5, unhex('1f')));
SELECT hex(lpad(unhex('aa'), 6, unhex('1f')));
SELECT hex(lpad(unhex(''), 5, unhex('1f2e')));
SELECT hex(lpad(unhex('aa'), 5, unhex('1f2e')));
SELECT hex(lpad(unhex('aa'), 6, unhex('1f2e')));
SELECT hex(lpad(unhex(''), 6, unhex('')));
SELECT hex(lpad(unhex('aabbcc'), 6, unhex('')));
SELECT hex(lpad(unhex('aabbcc'), 2, unhex('ff')));

-- rpad for BINARY inputs
SELECT hex(rpad(unhex(''), 5));
SELECT hex(rpad(unhex('aabb'), 5));
SELECT hex(rpad(unhex('aabbcc'), 2));
SELECT hex(rpad(unhex(''), 5, unhex('1f')));
SELECT hex(rpad(unhex('aa'), 5, unhex('1f')));
SELECT hex(rpad(unhex('aa'), 6, unhex('1f')));
SELECT hex(rpad(unhex(''), 5, unhex('1f2e')));
SELECT hex(rpad(unhex('aa'), 5, unhex('1f2e')));
SELECT hex(rpad(unhex('aa'), 6, unhex('1f2e')));
SELECT hex(rpad(unhex(''), 6, unhex('')));
SELECT hex(rpad(unhex('aabbcc'), 6, unhex('')));
SELECT hex(rpad(unhex('aabbcc'), 2, unhex('ff')));

-- decode
select decode();
select decode(encode('abc', 'utf-8'));
Expand Down