diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 4383ee1533c2b..cff9061aabb13 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -101,4 +101,82 @@ public static byte[] concat(byte[]... inputs) { } return result; } + + // Helper method for implementing `lpad` and `rpad`. + // If the padding pattern's length is 0, return the first `len` bytes of the input byte + // sequence if it is longer than `len` bytes, or a copy of the byte sequence, otherwise. + private static byte[] padWithEmptyPattern(byte[] bytes, int len) { + len = Math.min(bytes.length, len); + final byte[] result = new byte[len]; + Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, result, Platform.BYTE_ARRAY_OFFSET, len); + return result; + } + + // Helper method for implementing `lpad` and `rpad`. + // Fills the resulting byte sequence with the pattern. The resulting byte sequence is + // passed as the first argument and it is filled from position `firstPos` (inclusive) + // to position `beyondPos` (not inclusive). + private static void fillWithPattern(byte[] result, int firstPos, int beyondPos, byte[] pad) { + for (int pos = firstPos; pos < beyondPos; pos += pad.length) { + final int jMax = Math.min(pad.length, beyondPos - pos); + for (int j = 0; j < jMax; ++j) { + result[pos + j] = (byte) pad[j]; + } + } + } + + // Left-pads the input byte sequence using the provided padding pattern. + // In the special case that the padding pattern is empty, the resulting byte sequence + // contains the first `len` bytes of the input if they exist, or is a copy of the input + // byte sequence otherwise. + // For padding patterns with positive byte length, the resulting byte sequence's byte length is + // equal to `len`. If the input byte sequence is not less than `len` bytes, its first `len` bytes + // are returned. Otherwise, the remaining missing bytes are filled in with the provided pattern. + public static byte[] lpad(byte[] bytes, int len, byte[] pad) { + if (bytes == null || pad == null) return null; + // If the input length is 0, return the empty byte sequence. + if (len == 0) return EMPTY_BYTE; + // The padding pattern is empty. + if (pad.length == 0) return padWithEmptyPattern(bytes, len); + // The general case. + // 1. Copy the first `len` bytes of the input byte sequence into the output if they exist. + final byte[] result = new byte[len]; + final int minLen = Math.min(len, bytes.length); + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET + len - minLen, + minLen); + // 2. If the input has less than `len` bytes, fill in the rest using the provided pattern. + if (bytes.length < len) { + fillWithPattern(result, 0, len - bytes.length, pad); + } + return result; + } + + // Right-pads the input byte sequence using the provided padding pattern. + // In the special case that the padding pattern is empty, the resulting byte sequence + // contains the first `len` bytes of the input if they exist, or is a copy of the input + // byte sequence otherwise. + // For padding patterns with positive byte length, the resulting byte sequence's byte length is + // equal to `len`. If the input byte sequence is not less than `len` bytes, its first `len` bytes + // are returned. Otherwise, the remaining missing bytes are filled in with the provided pattern. + public static byte[] rpad(byte[] bytes, int len, byte[] pad) { + if (bytes == null || pad == null) return null; + // If the input length is 0, return the empty byte sequence. + if (len == 0) return EMPTY_BYTE; + // The padding pattern is empty. + if (pad.length == 0) return padWithEmptyPattern(bytes, len); + // The general case. + // 1. Copy the first `len` bytes of the input sequence into the output if they exist. + final byte[] result = new byte[len]; + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET, + Math.min(len, bytes.length)); + // 2. If the input has less than `len` bytes, fill in the rest using the provided pattern. + if (bytes.length < len) { + fillWithPattern(result, bytes.length, len, pad); + } + return result; + } } diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index e0b18a374c836..3bea6b9751ab2 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -24,6 +24,8 @@ license: | ## Upgrading from Spark SQL 3.2 to 3.3 + - Since Spark 3.3, the functions `lpad` and `rpad` have been overloaded to support byte sequences. When the first argument is a byte sequence, the optional padding pattern must also be a byte sequence and the result is a BINARY value. The default padding pattern in this case is the zero byte. + - Since Spark 3.3, Spark turns a non-nullable schema into nullable for API `DataFrameReader.schema(schema: StructType).json(jsonDataset: Dataset[String])` and `DataFrameReader.schema(schema: StructType).csv(csvDataset: Dataset[String])` when the schema is specified by the user and contains non-nullable fields. - Since Spark 3.3, when the date or timestamp pattern is not specified, Spark converts an input string to a date/timestamp using the `CAST` expression approach. The changes affect CSV/JSON datasources and parsing of partition values. In Spark 3.2 or earlier, when the date or timestamp pattern is not set, Spark uses the default patterns: `yyyy-MM-dd` for dates and `yyyy-MM-dd HH:mm:ss` for timestamps. After the changes, Spark still recognizes the pattern together with diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5956c3e882118..710f1ed588242 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1328,14 +1328,31 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } +/** + * Helper class for implementing StringLPad and StringRPad. + * Returns the default expression to be used in StringLPad or StringRPad based on the type of + * the input expression. + * For character string expressions the default padding expression is the string literal ' '. + * For byte sequence expressions the default padding expression is the byte literal 0x00. + */ +object StringPadDefaultValue { + def get(str: Expression): Expression = { + str.dataType match { + case StringType => Literal(" ") + case BinaryType => Literal(Array[Byte](0)) + } + } +} + /** * Returns str, left-padded with pad to a length of len. */ @ExpressionDescription( usage = """ _FUNC_(str, len[, pad]) - Returns `str`, left-padded with `pad` to a length of `len`. - If `str` is longer than `len`, the return value is shortened to `len` characters. - If `pad` is not specified, `str` will be padded to the left with space characters. + If `str` is longer than `len`, the return value is shortened to `len` characters or bytes. + If `pad` is not specified, `str` will be padded to the left with space characters if it is + a character string, and with zeros if it is a byte sequence. """, examples = """ Examples: @@ -1345,6 +1362,10 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) h > SELECT _FUNC_('hi', 5); hi + > SELECT hex(_FUNC_(unhex('aabb'), 5)); + 000000AABB + > SELECT hex(_FUNC_(unhex('aabb'), 5, unhex('1122'))); + 112211AABB """, since = "1.5.0", group = "string_funcs") @@ -1352,21 +1373,33 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { - this(str, len, Literal(" ")) + this(str, len, StringPadDefaultValue.get(str)) } override def first: Expression = str override def second: Expression = len override def third: Expression = pad - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { - str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) + override def dataType: DataType = str.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, BinaryType), IntegerType, TypeCollection(StringType, BinaryType)) + + override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { + str.dataType match { + case StringType => string.asInstanceOf[UTF8String] + .lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) + case BinaryType => ByteArray.lpad(string.asInstanceOf[Array[Byte]], + len.asInstanceOf[Int], pad.asInstanceOf[Array[Byte]]) + } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") + defineCodeGen(ctx, ev, (string, len, pad) => { + str.dataType match { + case StringType => s"$string.lpad($len, $pad)" + case BinaryType => s"${classOf[ByteArray].getName}.lpad($string, $len, $pad)" + } + }) } override def prettyName: String = "lpad" @@ -1383,7 +1416,8 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera usage = """ _FUNC_(str, len[, pad]) - Returns `str`, right-padded with `pad` to a length of `len`. If `str` is longer than `len`, the return value is shortened to `len` characters. - If `pad` is not specified, `str` will be padded to the right with space characters. + If `pad` is not specified, `str` will be padded to the right with space characters if it is + a character string, and with zeros if it is a binary string. """, examples = """ Examples: @@ -1393,6 +1427,10 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera h > SELECT _FUNC_('hi', 5); hi + > SELECT hex(_FUNC_(unhex('aabb'), 5)); + AABB000000 + > SELECT hex(_FUNC_(unhex('aabb'), 5, unhex('1122'))); + AABB112211 """, since = "1.5.0", group = "string_funcs") @@ -1400,22 +1438,33 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { - this(str, len, Literal(" ")) + this(str, len, StringPadDefaultValue.get(str)) } override def first: Expression = str override def second: Expression = len override def third: Expression = pad - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + override def dataType: DataType = str.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, BinaryType), IntegerType, TypeCollection(StringType, BinaryType)) - override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { - str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) + override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { + str.dataType match { + case StringType => string.asInstanceOf[UTF8String] + .rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) + case BinaryType => ByteArray.rpad(string.asInstanceOf[Array[Byte]], + len.asInstanceOf[Int], pad.asInstanceOf[Array[Byte]]) + } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") + defineCodeGen(ctx, ev, (string, len, pad) => { + str.dataType match { + case StringType => s"$string.rpad($len, $pad)" + case BinaryType => s"${classOf[ByteArray].getName}.rpad($string, $len, $pad)" + } + }) } override def prettyName: String = "rpad" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7bca29f3d80c0..1aea498f4960a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2715,6 +2715,17 @@ object functions { StringLPad(str.expr, lit(len).expr, lit(pad).expr) } + /** + * Left-pad the binary column with pad to a byte length of len. If the binary column is longer + * than len, the return value is shortened to len bytes. + * + * @group string_funcs + * @since 3.3.0 + */ + def lpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { + StringLPad(str.expr, lit(len).expr, lit(pad).expr) + } + /** * Trim the spaces from left end for the specified string value. * @@ -2793,6 +2804,17 @@ object functions { StringRPad(str.expr, lit(len).expr, lit(pad).expr) } + /** + * Right-pad the binary column with pad to a byte length of len. If the binary column is longer + * than len, the return value is shortened to len bytes. + * + * @group string_funcs + * @since 3.3.0 + */ + def rpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { + StringRPad(str.expr, lit(len).expr, lit(pad).expr) + } + /** * Repeats a string column n times, and returns it as a new string column. * diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index d44055d72e3bc..beacdbfcd593a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -62,6 +62,34 @@ SELECT btrim(encode('xxxbarxxx', 'utf-8'), encode('x', 'utf-8')); SELECT lpad('hi', 'invalid_length'); SELECT rpad('hi', 'invalid_length'); +-- lpad for BINARY inputs +SELECT hex(lpad(unhex(''), 5)); +SELECT hex(lpad(unhex('aabb'), 5)); +SELECT hex(lpad(unhex('aabbcc'), 2)); +SELECT hex(lpad(unhex(''), 5, unhex('1f'))); +SELECT hex(lpad(unhex('aa'), 5, unhex('1f'))); +SELECT hex(lpad(unhex('aa'), 6, unhex('1f'))); +SELECT hex(lpad(unhex(''), 5, unhex('1f2e'))); +SELECT hex(lpad(unhex('aa'), 5, unhex('1f2e'))); +SELECT hex(lpad(unhex('aa'), 6, unhex('1f2e'))); +SELECT hex(lpad(unhex(''), 6, unhex(''))); +SELECT hex(lpad(unhex('aabbcc'), 6, unhex(''))); +SELECT hex(lpad(unhex('aabbcc'), 2, unhex('ff'))); + +-- rpad for BINARY inputs +SELECT hex(rpad(unhex(''), 5)); +SELECT hex(rpad(unhex('aabb'), 5)); +SELECT hex(rpad(unhex('aabbcc'), 2)); +SELECT hex(rpad(unhex(''), 5, unhex('1f'))); +SELECT hex(rpad(unhex('aa'), 5, unhex('1f'))); +SELECT hex(rpad(unhex('aa'), 6, unhex('1f'))); +SELECT hex(rpad(unhex(''), 5, unhex('1f2e'))); +SELECT hex(rpad(unhex('aa'), 5, unhex('1f2e'))); +SELECT hex(rpad(unhex('aa'), 6, unhex('1f2e'))); +SELECT hex(rpad(unhex(''), 6, unhex(''))); +SELECT hex(rpad(unhex('aabbcc'), 6, unhex(''))); +SELECT hex(rpad(unhex('aabbcc'), 2, unhex('ff'))); + -- decode select decode(); select decode(encode('abc', 'utf-8')); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 3f01c8f755adb..56717aff7bc68 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 50 +-- Number of queries: 74 -- !query @@ -350,6 +350,198 @@ java.lang.NumberFormatException invalid input syntax for type numeric: invalid_length +-- !query +SELECT hex(lpad(unhex(''), 5)) +-- !query schema +struct +-- !query output +0000000000 + + +-- !query +SELECT hex(lpad(unhex('aabb'), 5)) +-- !query schema +struct +-- !query output +000000AABB + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 2)) +-- !query schema +struct +-- !query output +AABB + + +-- !query +SELECT hex(lpad(unhex(''), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1F + + +-- !query +SELECT hex(lpad(unhex('aa'), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1FAA + + +-- !query +SELECT hex(lpad(unhex('aa'), 6, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1FAA + + +-- !query +SELECT hex(lpad(unhex(''), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1F + + +-- !query +SELECT hex(lpad(unhex('aa'), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2EAA + + +-- !query +SELECT hex(lpad(unhex('aa'), 6, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1FAA + + +-- !query +SELECT hex(lpad(unhex(''), 6, unhex(''))) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 6, unhex(''))) +-- !query schema +struct +-- !query output +AABBCC + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 2, unhex('ff'))) +-- !query schema +struct +-- !query output +AABB + + +-- !query +SELECT hex(rpad(unhex(''), 5)) +-- !query schema +struct +-- !query output +0000000000 + + +-- !query +SELECT hex(rpad(unhex('aabb'), 5)) +-- !query schema +struct +-- !query output +AABB000000 + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 2)) +-- !query schema +struct +-- !query output +AABB + + +-- !query +SELECT hex(rpad(unhex(''), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +AA1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 6, unhex('1f'))) +-- !query schema +struct +-- !query output +AA1F1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex(''), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +AA1F2E1F2E + + +-- !query +SELECT hex(rpad(unhex('aa'), 6, unhex('1f2e'))) +-- !query schema +struct +-- !query output +AA1F2E1F2E1F + + +-- !query +SELECT hex(rpad(unhex(''), 6, unhex(''))) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 6, unhex(''))) +-- !query schema +struct +-- !query output +AABBCC + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 2, unhex('ff'))) +-- !query schema +struct +-- !query output +AABB + + -- !query select decode() -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 80e88d0566411..e202521e9b2b2 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 50 +-- Number of queries: 74 -- !query @@ -340,6 +340,198 @@ struct NULL +-- !query +SELECT hex(lpad(unhex(''), 5)) +-- !query schema +struct +-- !query output +0000000000 + + +-- !query +SELECT hex(lpad(unhex('aabb'), 5)) +-- !query schema +struct +-- !query output +000000AABB + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 2)) +-- !query schema +struct +-- !query output +AABB + + +-- !query +SELECT hex(lpad(unhex(''), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1F + + +-- !query +SELECT hex(lpad(unhex('aa'), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1FAA + + +-- !query +SELECT hex(lpad(unhex('aa'), 6, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1FAA + + +-- !query +SELECT hex(lpad(unhex(''), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1F + + +-- !query +SELECT hex(lpad(unhex('aa'), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2EAA + + +-- !query +SELECT hex(lpad(unhex('aa'), 6, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1FAA + + +-- !query +SELECT hex(lpad(unhex(''), 6, unhex(''))) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 6, unhex(''))) +-- !query schema +struct +-- !query output +AABBCC + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 2, unhex('ff'))) +-- !query schema +struct +-- !query output +AABB + + +-- !query +SELECT hex(rpad(unhex(''), 5)) +-- !query schema +struct +-- !query output +0000000000 + + +-- !query +SELECT hex(rpad(unhex('aabb'), 5)) +-- !query schema +struct +-- !query output +AABB000000 + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 2)) +-- !query schema +struct +-- !query output +AABB + + +-- !query +SELECT hex(rpad(unhex(''), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +AA1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 6, unhex('1f'))) +-- !query schema +struct +-- !query output +AA1F1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex(''), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +AA1F2E1F2E + + +-- !query +SELECT hex(rpad(unhex('aa'), 6, unhex('1f2e'))) +-- !query schema +struct +-- !query output +AA1F2E1F2E1F + + +-- !query +SELECT hex(rpad(unhex(''), 6, unhex(''))) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 6, unhex(''))) +-- !query schema +struct +-- !query output +AABBCC + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 2, unhex('ff'))) +-- !query schema +struct +-- !query output +AABB + + -- !query select decode() -- !query schema