From 36cf3c2f5c0edad8663bf4261d2a117cf7aab9df Mon Sep 17 00:00:00 2001 From: Menelaos Karavelas Date: Wed, 29 Sep 2021 23:26:37 -0700 Subject: [PATCH 1/7] [WIP] Add lpad and rpad functions for binary strings --- .../apache/spark/unsafe/types/ByteArray.java | 78 ++++++++++++++++++ .../expressions/stringExpressions.scala | 79 +++++++++++++++---- .../org/apache/spark/sql/functions.scala | 22 ++++++ 3 files changed, 164 insertions(+), 15 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 4383ee1533c2b..45ac8e88a554a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -101,4 +101,82 @@ public static byte[] concat(byte[]... inputs) { } return result; } + + // Helper method for implementing `lpad` and `rpad`. + // If the padding pattern's length is 0, return the first `len` bytes of the input + // binary string if it longer than `len` bytes, or a copy of the binary string, otherwise. + protected static byte[] padWithEmptyPattern(byte[] bytes, int len) { + len = Math.min(bytes.length, len); + final byte[] result = new byte[len]; + Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, result, Platform.BYTE_ARRAY_OFFSET, len); + return result; + } + + // Helper method for implementing `lpad` and `rpad`. + // Fills the resulting binary string with the pattern. The resulting binary string + // is passed as the first argument and it is filled from position `firstPos` (inclusive) + // to position `beyondPos` (not inclusive). + protected static void fillWithPattern(byte[] result, int firstPos, int beyondPos, byte[] pad) { + for (int pos = firstPos; pos < beyondPos; pos += pad.length) { + final int jMax = Math.min(pad.length, beyondPos - pos); + for (int j = 0; j < jMax; ++j) { + result[pos + j] = (byte) pad[j]; + } + } + } + + // Left-pads the input binary string using the provided padding pattern. + // In the special case that the padding pattern is empty, the resulting binary string + // contains the first `len` bytes of the input if they exist, or is a copy of the input + // binary stringkm otherwise. + // For padding patterns with positive byte length, the resulting binary string's byte length is + // equal to `len`. If the input binary string is not less than `len` bytes, its first `len` bytes + // are returned. Otherwise, the remaining missing bytes are filled in with the provided pattern. + public static byte[] lpad(byte[] bytes, int len, byte[] pad) { + if (bytes == null || pad == null) return null; + // If the input length is 0, return the empty binary string. + if (len == 0) return EMPTY_BYTE; + // The padding pattern is empty. + if (pad.length == 0) return padWithEmptyPattern(bytes, len); + // The general case. + // 1. Copy the first `len` bytes of the input string into the output if they exist. + final byte[] result = new byte[len]; + final int minLen = Math.min(len, bytes.length); + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET + len - minLen, + minLen); + // 2. If the input has less than `len` bytes, fill in the rest using the provided pattern. + if (bytes.length < len) { + fillWithPattern(result, 0, len - bytes.length, pad); + } + return result; + } + + // Right-pads the input binary string using the provided padding pattern. + // In the special case that the padding pattern is empty, the resulting binary string + // contains the first `len` bytes of the input if they exist, or is a copy of the input + // binary stringkm otherwise. + // For padding patterns with positive byte length, the resulting binary string's byte length is + // equal to `len`. If the input binary string is not less than `len` bytes, its first `len` bytes + // are returned. Otherwise, the remaining missing bytes are filled in with the provided pattern. + public static byte[] rpad(byte[] bytes, int len, byte[] pad) { + if (bytes == null || pad == null) return null; + // If the input length is 0, return the empty binary string. + if (len == 0) return EMPTY_BYTE; + // The padding pattern is empty. + if (pad.length == 0) return padWithEmptyPattern(bytes, len); + // The general case. + // 1. Copy the first `len` bytes of the input string into the output if they exist. + final byte[] result = new byte[len]; + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET, + Math.min(len, bytes.length)); + // 2. If the input has less than `len` bytes, fill in the rest using the provided pattern. + if (bytes.length < len) { + fillWithPattern(result, bytes.length, len, pad); + } + return result; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5956c3e882118..2cca58eee1680 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1328,14 +1328,31 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } +/** + * Helper class for implementing StringLPad and StringRPad. + * Returns the default expression to be used in StringLPad or StringRPad based on the type of + * the input expression. + * For character string expressions the default padding expression is the string literal ' '. + * For binary string expressions the default padding expression is the byte literal 0x00. + */ +object StringPadDefaultValue { + def get(str: Expression): Expression = { + str.dataType match { + case StringType => Literal(" ") + case BinaryType => Literal(Array[Byte](0x00)) + } + } +} + /** * Returns str, left-padded with pad to a length of len. */ @ExpressionDescription( usage = """ _FUNC_(str, len[, pad]) - Returns `str`, left-padded with `pad` to a length of `len`. - If `str` is longer than `len`, the return value is shortened to `len` characters. - If `pad` is not specified, `str` will be padded to the left with space characters. + If `str` is longer than `len`, the return value is shortened to `len` characters or bytes. + If `pad` is not specified, `str` will be padded to the left with space characters if it is + a character string, and with zeros if it is a binary string. """, examples = """ Examples: @@ -1345,6 +1362,10 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) h > SELECT _FUNC_('hi', 5); hi + > SELECT hex(_FUNC_(unhex('aabb'), 5)); + 000000AABB + > SELECT hex(_FUNC_(unhex('aabb'), 5, unhex('1122'))); + 112211AABB """, since = "1.5.0", group = "string_funcs") @@ -1352,21 +1373,33 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { - this(str, len, Literal(" ")) + this(str, len, StringPadDefaultValue.get(str)) } override def first: Expression = str override def second: Expression = len override def third: Expression = pad - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { - str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) + override def dataType: DataType = str.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, BinaryType), IntegerType, TypeCollection(StringType, BinaryType)) + + override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { + str.dataType match { + case StringType => string.asInstanceOf[UTF8String] + .lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) + case BinaryType => ByteArray.lpad(string.asInstanceOf[Array[Byte]], + len.asInstanceOf[Int], pad.asInstanceOf[Array[Byte]]) + } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") + defineCodeGen(ctx, ev, (string, len, pad) => { + str.dataType match { + case StringType => s"$string.lpad($len, $pad)" + case BinaryType => s"${classOf[ByteArray].getName}.lpad($string, $len, $pad)" + } + }) } override def prettyName: String = "lpad" @@ -1383,7 +1416,8 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera usage = """ _FUNC_(str, len[, pad]) - Returns `str`, right-padded with `pad` to a length of `len`. If `str` is longer than `len`, the return value is shortened to `len` characters. - If `pad` is not specified, `str` will be padded to the right with space characters. + If `pad` is not specified, `str` will be padded to the right with space characters if it is + a character string, and with zeros if it is a binary string. """, examples = """ Examples: @@ -1393,6 +1427,10 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera h > SELECT _FUNC_('hi', 5); hi + > SELECT hex(_FUNC_(unhex('aabb'), 5)); + AABB000000 + > SELECT hex(_FUNC_(unhex('aabb'), 5, unhex('1122'))); + AABB112211 """, since = "1.5.0", group = "string_funcs") @@ -1400,22 +1438,33 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { - this(str, len, Literal(" ")) + this(str, len, StringPadDefaultValue.get(str)) } override def first: Expression = str override def second: Expression = len override def third: Expression = pad - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + override def dataType: DataType = str.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, BinaryType), IntegerType, TypeCollection(StringType, BinaryType)) - override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { - str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) + override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { + str.dataType match { + case StringType => string.asInstanceOf[UTF8String] + .rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) + case BinaryType => ByteArray.rpad(string.asInstanceOf[Array[Byte]], + len.asInstanceOf[Int], pad.asInstanceOf[Array[Byte]]) + } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") + defineCodeGen(ctx, ev, (string, len, pad) => { + str.dataType match { + case StringType => s"$string.rpad($len, $pad)" + case BinaryType => s"${classOf[ByteArray].getName}.rpad($string, $len, $pad)" + } + }) } override def prettyName: String = "rpad" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7bca29f3d80c0..1aea498f4960a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2715,6 +2715,17 @@ object functions { StringLPad(str.expr, lit(len).expr, lit(pad).expr) } + /** + * Left-pad the binary column with pad to a byte length of len. If the binary column is longer + * than len, the return value is shortened to len bytes. + * + * @group string_funcs + * @since 3.3.0 + */ + def lpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { + StringLPad(str.expr, lit(len).expr, lit(pad).expr) + } + /** * Trim the spaces from left end for the specified string value. * @@ -2793,6 +2804,17 @@ object functions { StringRPad(str.expr, lit(len).expr, lit(pad).expr) } + /** + * Right-pad the binary column with pad to a byte length of len. If the binary column is longer + * than len, the return value is shortened to len bytes. + * + * @group string_funcs + * @since 3.3.0 + */ + def rpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { + StringRPad(str.expr, lit(len).expr, lit(pad).expr) + } + /** * Repeats a string column n times, and returns it as a new string column. * From 93246745cd3ce4fdebc89b3bef84654780e207aa Mon Sep 17 00:00:00 2001 From: Menelaos Karavelas Date: Mon, 18 Oct 2021 10:45:57 -0700 Subject: [PATCH 2/7] Added some unit tests --- .../sql-functions/sql-expression-schema.md | 2 + .../sql-tests/inputs/string-functions.sql | 22 +++ .../results/ansi/string-functions.sql.out | 146 +++++++++++++++++- .../results/string-functions.sql.out | 146 +++++++++++++++++- 4 files changed, 314 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index fc5134764905b..ca58d894d3af7 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -270,9 +270,11 @@ | org.apache.spark.sql.catalyst.expressions.Stack | stack | SELECT stack(2, 1, 2, 3) | struct | | org.apache.spark.sql.catalyst.expressions.StringInstr | instr | SELECT instr('SparkSQL', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.StringLPad | lpad | SELECT lpad('hi', 5, '??') | struct | +| org.apache.spark.sql.catalyst.expressions.StringLPad | lpad | SELECT hex(lpad(unhex('aabb'), 7, unhex('010203'))) | struct | | org.apache.spark.sql.catalyst.expressions.StringLocate | locate | SELECT locate('bar', 'foobarbar') | struct | | org.apache.spark.sql.catalyst.expressions.StringLocate | position | SELECT position('bar', 'foobarbar') | struct | | org.apache.spark.sql.catalyst.expressions.StringRPad | rpad | SELECT rpad('hi', 5, '??') | struct | +| org.apache.spark.sql.catalyst.expressions.StringRPad | rpad | SELECT hex(rpad(unhex('aabb'), 7, unhex('010203'))) | struct | | org.apache.spark.sql.catalyst.expressions.StringRepeat | repeat | SELECT repeat('123', 2) | struct | | org.apache.spark.sql.catalyst.expressions.StringReplace | replace | SELECT replace('ABCabc', 'abc', 'DEF') | struct | | org.apache.spark.sql.catalyst.expressions.StringSpace | space | SELECT concat(space(2), '1') | struct | diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index d44055d72e3bc..9c1f911aa7e5a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -62,6 +62,28 @@ SELECT btrim(encode('xxxbarxxx', 'utf-8'), encode('x', 'utf-8')); SELECT lpad('hi', 'invalid_length'); SELECT rpad('hi', 'invalid_length'); +-- lpad for BINARY inputs +SELECT hex(lpad(unhex(''), 5, unhex('1f'))); +SELECT hex(lpad(unhex('aa'), 5, unhex('1f'))); +SELECT hex(lpad(unhex('aa'), 6, unhex('1f'))); +SELECT hex(lpad(unhex(''), 5, unhex('1f2e'))); +SELECT hex(lpad(unhex('aa'), 5, unhex('1f2e'))); +SELECT hex(lpad(unhex('aa'), 6, unhex('1f2e'))); +SELECT hex(lpad(unhex(''), 6, unhex(''))); +SELECT hex(lpad(unhex('aabbcc'), 6, unhex(''))); +SELECT hex(lpad(unhex('aabbcc'), 2, unhex('ff'))); + +-- rpad for BINARY inputs +SELECT hex(rpad(unhex(''), 5, unhex('1f'))); +SELECT hex(rpad(unhex('aa'), 5, unhex('1f'))); +SELECT hex(rpad(unhex('aa'), 6, unhex('1f'))); +SELECT hex(rpad(unhex(''), 5, unhex('1f2e'))); +SELECT hex(rpad(unhex('aa'), 5, unhex('1f2e'))); +SELECT hex(rpad(unhex('aa'), 6, unhex('1f2e'))); +SELECT hex(rpad(unhex(''), 6, unhex(''))); +SELECT hex(rpad(unhex('aabbcc'), 6, unhex(''))); +SELECT hex(rpad(unhex('aabbcc'), 2, unhex('ff'))); + -- decode select decode(); select decode(encode('abc', 'utf-8')); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 3f01c8f755adb..8dc72d5f73b3a 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 50 +-- Number of queries: 68 -- !query @@ -350,6 +350,150 @@ java.lang.NumberFormatException invalid input syntax for type numeric: invalid_length +-- !query +SELECT hex(lpad(unhex(''), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1F + + +-- !query +SELECT hex(lpad(unhex('aa'), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1FAA + + +-- !query +SELECT hex(lpad(unhex('aa'), 6, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1FAA + + +-- !query +SELECT hex(lpad(unhex(''), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1F + + +-- !query +SELECT hex(lpad(unhex('aa'), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2EAA + + +-- !query +SELECT hex(lpad(unhex('aa'), 6, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1FAA + + +-- !query +SELECT hex(lpad(unhex(''), 6, unhex(''))) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 6, unhex(''))) +-- !query schema +struct +-- !query output +AABBCC + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 2, unhex('ff'))) +-- !query schema +struct +-- !query output +AABB + + +-- !query +SELECT hex(rpad(unhex(''), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +AA1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 6, unhex('1f'))) +-- !query schema +struct +-- !query output +AA1F1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex(''), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +AA1F2E1F2E + + +-- !query +SELECT hex(rpad(unhex('aa'), 6, unhex('1f2e'))) +-- !query schema +struct +-- !query output +AA1F2E1F2E1F + + +-- !query +SELECT hex(rpad(unhex(''), 6, unhex(''))) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 6, unhex(''))) +-- !query schema +struct +-- !query output +AABBCC + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 2, unhex('ff'))) +-- !query schema +struct +-- !query output +AABB + + -- !query select decode() -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 80e88d0566411..430430a477be4 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 50 +-- Number of queries: 68 -- !query @@ -340,6 +340,150 @@ struct NULL +-- !query +SELECT hex(lpad(unhex(''), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1F + + +-- !query +SELECT hex(lpad(unhex('aa'), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1FAA + + +-- !query +SELECT hex(lpad(unhex('aa'), 6, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1FAA + + +-- !query +SELECT hex(lpad(unhex(''), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1F + + +-- !query +SELECT hex(lpad(unhex('aa'), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2EAA + + +-- !query +SELECT hex(lpad(unhex('aa'), 6, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1FAA + + +-- !query +SELECT hex(lpad(unhex(''), 6, unhex(''))) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 6, unhex(''))) +-- !query schema +struct +-- !query output +AABBCC + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 2, unhex('ff'))) +-- !query schema +struct +-- !query output +AABB + + +-- !query +SELECT hex(rpad(unhex(''), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +1F1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 5, unhex('1f'))) +-- !query schema +struct +-- !query output +AA1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 6, unhex('1f'))) +-- !query schema +struct +-- !query output +AA1F1F1F1F1F + + +-- !query +SELECT hex(rpad(unhex(''), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +1F2E1F2E1F + + +-- !query +SELECT hex(rpad(unhex('aa'), 5, unhex('1f2e'))) +-- !query schema +struct +-- !query output +AA1F2E1F2E + + +-- !query +SELECT hex(rpad(unhex('aa'), 6, unhex('1f2e'))) +-- !query schema +struct +-- !query output +AA1F2E1F2E1F + + +-- !query +SELECT hex(rpad(unhex(''), 6, unhex(''))) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 6, unhex(''))) +-- !query schema +struct +-- !query output +AABBCC + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 2, unhex('ff'))) +-- !query schema +struct +-- !query output +AABB + + -- !query select decode() -- !query schema From 7a52529805a63a597b8ebc0117ca9f9761d6f6c6 Mon Sep 17 00:00:00 2001 From: Menelaos Karavelas Date: Mon, 18 Oct 2021 11:33:05 -0700 Subject: [PATCH 3/7] Add test cases where we omit the padding pattern. --- .../sql-tests/inputs/string-functions.sql | 6 +++ .../results/ansi/string-functions.sql.out | 50 ++++++++++++++++++- .../results/string-functions.sql.out | 50 ++++++++++++++++++- 3 files changed, 104 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 9c1f911aa7e5a..beacdbfcd593a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -63,6 +63,9 @@ SELECT lpad('hi', 'invalid_length'); SELECT rpad('hi', 'invalid_length'); -- lpad for BINARY inputs +SELECT hex(lpad(unhex(''), 5)); +SELECT hex(lpad(unhex('aabb'), 5)); +SELECT hex(lpad(unhex('aabbcc'), 2)); SELECT hex(lpad(unhex(''), 5, unhex('1f'))); SELECT hex(lpad(unhex('aa'), 5, unhex('1f'))); SELECT hex(lpad(unhex('aa'), 6, unhex('1f'))); @@ -74,6 +77,9 @@ SELECT hex(lpad(unhex('aabbcc'), 6, unhex(''))); SELECT hex(lpad(unhex('aabbcc'), 2, unhex('ff'))); -- rpad for BINARY inputs +SELECT hex(rpad(unhex(''), 5)); +SELECT hex(rpad(unhex('aabb'), 5)); +SELECT hex(rpad(unhex('aabbcc'), 2)); SELECT hex(rpad(unhex(''), 5, unhex('1f'))); SELECT hex(rpad(unhex('aa'), 5, unhex('1f'))); SELECT hex(rpad(unhex('aa'), 6, unhex('1f'))); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 8dc72d5f73b3a..56717aff7bc68 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 68 +-- Number of queries: 74 -- !query @@ -350,6 +350,30 @@ java.lang.NumberFormatException invalid input syntax for type numeric: invalid_length +-- !query +SELECT hex(lpad(unhex(''), 5)) +-- !query schema +struct +-- !query output +0000000000 + + +-- !query +SELECT hex(lpad(unhex('aabb'), 5)) +-- !query schema +struct +-- !query output +000000AABB + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 2)) +-- !query schema +struct +-- !query output +AABB + + -- !query SELECT hex(lpad(unhex(''), 5, unhex('1f'))) -- !query schema @@ -422,6 +446,30 @@ struct AABB +-- !query +SELECT hex(rpad(unhex(''), 5)) +-- !query schema +struct +-- !query output +0000000000 + + +-- !query +SELECT hex(rpad(unhex('aabb'), 5)) +-- !query schema +struct +-- !query output +AABB000000 + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 2)) +-- !query schema +struct +-- !query output +AABB + + -- !query SELECT hex(rpad(unhex(''), 5, unhex('1f'))) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 430430a477be4..e202521e9b2b2 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 68 +-- Number of queries: 74 -- !query @@ -340,6 +340,30 @@ struct NULL +-- !query +SELECT hex(lpad(unhex(''), 5)) +-- !query schema +struct +-- !query output +0000000000 + + +-- !query +SELECT hex(lpad(unhex('aabb'), 5)) +-- !query schema +struct +-- !query output +000000AABB + + +-- !query +SELECT hex(lpad(unhex('aabbcc'), 2)) +-- !query schema +struct +-- !query output +AABB + + -- !query SELECT hex(lpad(unhex(''), 5, unhex('1f'))) -- !query schema @@ -412,6 +436,30 @@ struct AABB +-- !query +SELECT hex(rpad(unhex(''), 5)) +-- !query schema +struct +-- !query output +0000000000 + + +-- !query +SELECT hex(rpad(unhex('aabb'), 5)) +-- !query schema +struct +-- !query output +AABB000000 + + +-- !query +SELECT hex(rpad(unhex('aabbcc'), 2)) +-- !query schema +struct +-- !query output +AABB + + -- !query SELECT hex(rpad(unhex(''), 5, unhex('1f'))) -- !query schema From 5a4e1145b700fd9edb660f4481779ff963458ed7 Mon Sep 17 00:00:00 2001 From: Menelaos Karavelas Date: Mon, 18 Oct 2021 23:01:05 -0700 Subject: [PATCH 4/7] Updated the sql-expression-schema.md file. --- .../src/test/resources/sql-functions/sql-expression-schema.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index ca58d894d3af7..fc5134764905b 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -270,11 +270,9 @@ | org.apache.spark.sql.catalyst.expressions.Stack | stack | SELECT stack(2, 1, 2, 3) | struct | | org.apache.spark.sql.catalyst.expressions.StringInstr | instr | SELECT instr('SparkSQL', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.StringLPad | lpad | SELECT lpad('hi', 5, '??') | struct | -| org.apache.spark.sql.catalyst.expressions.StringLPad | lpad | SELECT hex(lpad(unhex('aabb'), 7, unhex('010203'))) | struct | | org.apache.spark.sql.catalyst.expressions.StringLocate | locate | SELECT locate('bar', 'foobarbar') | struct | | org.apache.spark.sql.catalyst.expressions.StringLocate | position | SELECT position('bar', 'foobarbar') | struct | | org.apache.spark.sql.catalyst.expressions.StringRPad | rpad | SELECT rpad('hi', 5, '??') | struct | -| org.apache.spark.sql.catalyst.expressions.StringRPad | rpad | SELECT hex(rpad(unhex('aabb'), 7, unhex('010203'))) | struct | | org.apache.spark.sql.catalyst.expressions.StringRepeat | repeat | SELECT repeat('123', 2) | struct | | org.apache.spark.sql.catalyst.expressions.StringReplace | replace | SELECT replace('ABCabc', 'abc', 'DEF') | struct | | org.apache.spark.sql.catalyst.expressions.StringSpace | space | SELECT concat(space(2), '1') | struct | From 1f34b1c02e9181190238eea2719da9793cc1bdd3 Mon Sep 17 00:00:00 2001 From: Menelaos Karavelas Date: Tue, 19 Oct 2021 08:30:39 -0700 Subject: [PATCH 5/7] Address review comments. --- .../apache/spark/unsafe/types/ByteArray.java | 40 +++++++++---------- .../expressions/stringExpressions.scala | 4 +- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 45ac8e88a554a..9bde7134ce67c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -103,9 +103,9 @@ public static byte[] concat(byte[]... inputs) { } // Helper method for implementing `lpad` and `rpad`. - // If the padding pattern's length is 0, return the first `len` bytes of the input - // binary string if it longer than `len` bytes, or a copy of the binary string, otherwise. - protected static byte[] padWithEmptyPattern(byte[] bytes, int len) { + // If the padding pattern's length is 0, return the first `len` bytes of the input byte + // sequence if it is longer than `len` bytes, or a copy of the byte sequence, otherwise. + private static byte[] padWithEmptyPattern(byte[] bytes, int len) { len = Math.min(bytes.length, len); final byte[] result = new byte[len]; Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, result, Platform.BYTE_ARRAY_OFFSET, len); @@ -113,10 +113,10 @@ protected static byte[] padWithEmptyPattern(byte[] bytes, int len) { } // Helper method for implementing `lpad` and `rpad`. - // Fills the resulting binary string with the pattern. The resulting binary string - // is passed as the first argument and it is filled from position `firstPos` (inclusive) + // Fills the resulting byte sequence with the pattern. The resulting byte sequence is + // passed as the first argument and it is filled from position `firstPos` (inclusive) // to position `beyondPos` (not inclusive). - protected static void fillWithPattern(byte[] result, int firstPos, int beyondPos, byte[] pad) { + private static void fillWithPattern(byte[] result, int firstPos, int beyondPos, byte[] pad) { for (int pos = firstPos; pos < beyondPos; pos += pad.length) { final int jMax = Math.min(pad.length, beyondPos - pos); for (int j = 0; j < jMax; ++j) { @@ -125,21 +125,21 @@ protected static void fillWithPattern(byte[] result, int firstPos, int beyondPos } } - // Left-pads the input binary string using the provided padding pattern. - // In the special case that the padding pattern is empty, the resulting binary string + // Left-pads the input byte sequence using the provided padding pattern. + // In the special case that the padding pattern is empty, the resulting byte sequence // contains the first `len` bytes of the input if they exist, or is a copy of the input - // binary stringkm otherwise. - // For padding patterns with positive byte length, the resulting binary string's byte length is - // equal to `len`. If the input binary string is not less than `len` bytes, its first `len` bytes + // byte sequence otherwise. + // For padding patterns with positive byte length, the resulting byte sequence's byte length is + // equal to `len`. If the input byte sequence is not less than `len` bytes, its first `len` bytes // are returned. Otherwise, the remaining missing bytes are filled in with the provided pattern. public static byte[] lpad(byte[] bytes, int len, byte[] pad) { if (bytes == null || pad == null) return null; - // If the input length is 0, return the empty binary string. + // If the input length is 0, return the empty byte sequence. if (len == 0) return EMPTY_BYTE; // The padding pattern is empty. if (pad.length == 0) return padWithEmptyPattern(bytes, len); // The general case. - // 1. Copy the first `len` bytes of the input string into the output if they exist. + // 1. Copy the first `len` bytes of the input byte sequence into the output if they exist. final byte[] result = new byte[len]; final int minLen = Math.min(len, bytes.length); Platform.copyMemory( @@ -153,21 +153,21 @@ public static byte[] lpad(byte[] bytes, int len, byte[] pad) { return result; } - // Right-pads the input binary string using the provided padding pattern. - // In the special case that the padding pattern is empty, the resulting binary string + // Right-pads the input byte sequence using the provided padding pattern. + // In the special case that the padding pattern is empty, the resulting byte sequence // contains the first `len` bytes of the input if they exist, or is a copy of the input - // binary stringkm otherwise. - // For padding patterns with positive byte length, the resulting binary string's byte length is - // equal to `len`. If the input binary string is not less than `len` bytes, its first `len` bytes + // byte sequence otherwise. + // For padding patterns with positive byte length, the resulting byte sequence's byte length is + // equal to `len`. If the input byte sequence is not less than `len` bytes, its first `len` bytes // are returned. Otherwise, the remaining missing bytes are filled in with the provided pattern. public static byte[] rpad(byte[] bytes, int len, byte[] pad) { if (bytes == null || pad == null) return null; - // If the input length is 0, return the empty binary string. + // If the input length is 0, return the empty byte sequence. if (len == 0) return EMPTY_BYTE; // The padding pattern is empty. if (pad.length == 0) return padWithEmptyPattern(bytes, len); // The general case. - // 1. Copy the first `len` bytes of the input string into the output if they exist. + // 1. Copy the first `len` bytes of the input sequence into the output if they exist. final byte[] result = new byte[len]; Platform.copyMemory( bytes, Platform.BYTE_ARRAY_OFFSET, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2cca58eee1680..08f18f76f98cb 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1333,7 +1333,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns the default expression to be used in StringLPad or StringRPad based on the type of * the input expression. * For character string expressions the default padding expression is the string literal ' '. - * For binary string expressions the default padding expression is the byte literal 0x00. + * For byte sequence expressions the default padding expression is the byte literal 0x00. */ object StringPadDefaultValue { def get(str: Expression): Expression = { @@ -1352,7 +1352,7 @@ object StringPadDefaultValue { _FUNC_(str, len[, pad]) - Returns `str`, left-padded with `pad` to a length of `len`. If `str` is longer than `len`, the return value is shortened to `len` characters or bytes. If `pad` is not specified, `str` will be padded to the left with space characters if it is - a character string, and with zeros if it is a binary string. + a character string, and with zeros if it is a byte sequence. """, examples = """ Examples: From fdc9207e3f7fa27db5746a6ff804e2e324b99e36 Mon Sep 17 00:00:00 2001 From: Menelaos Karavelas Date: Wed, 20 Oct 2021 08:44:48 -0700 Subject: [PATCH 6/7] Fixed identation. Simplified the way the default padding value for BINARY is defined. --- .../org/apache/spark/unsafe/types/ByteArray.java | 12 ++++++------ .../sql/catalyst/expressions/stringExpressions.scala | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 9bde7134ce67c..cff9061aabb13 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -143,9 +143,9 @@ public static byte[] lpad(byte[] bytes, int len, byte[] pad) { final byte[] result = new byte[len]; final int minLen = Math.min(len, bytes.length); Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, - result, Platform.BYTE_ARRAY_OFFSET + len - minLen, - minLen); + bytes, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET + len - minLen, + minLen); // 2. If the input has less than `len` bytes, fill in the rest using the provided pattern. if (bytes.length < len) { fillWithPattern(result, 0, len - bytes.length, pad); @@ -170,9 +170,9 @@ public static byte[] rpad(byte[] bytes, int len, byte[] pad) { // 1. Copy the first `len` bytes of the input sequence into the output if they exist. final byte[] result = new byte[len]; Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, - result, Platform.BYTE_ARRAY_OFFSET, - Math.min(len, bytes.length)); + bytes, Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET, + Math.min(len, bytes.length)); // 2. If the input has less than `len` bytes, fill in the rest using the provided pattern. if (bytes.length < len) { fillWithPattern(result, bytes.length, len, pad); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 08f18f76f98cb..710f1ed588242 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1339,7 +1339,7 @@ object StringPadDefaultValue { def get(str: Expression): Expression = { str.dataType match { case StringType => Literal(" ") - case BinaryType => Literal(Array[Byte](0x00)) + case BinaryType => Literal(Array[Byte](0)) } } } From 9027989529ff3bbf5f79d2d2022981b0a36db97f Mon Sep 17 00:00:00 2001 From: Menelaos Karavelas Date: Wed, 20 Oct 2021 09:12:46 -0700 Subject: [PATCH 7/7] Add entry to the SQL migration guide fo rthe breaking change. --- docs/sql-migration-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index e0b18a374c836..3bea6b9751ab2 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -24,6 +24,8 @@ license: | ## Upgrading from Spark SQL 3.2 to 3.3 + - Since Spark 3.3, the functions `lpad` and `rpad` have been overloaded to support byte sequences. When the first argument is a byte sequence, the optional padding pattern must also be a byte sequence and the result is a BINARY value. The default padding pattern in this case is the zero byte. + - Since Spark 3.3, Spark turns a non-nullable schema into nullable for API `DataFrameReader.schema(schema: StructType).json(jsonDataset: Dataset[String])` and `DataFrameReader.schema(schema: StructType).csv(csvDataset: Dataset[String])` when the schema is specified by the user and contains non-nullable fields. - Since Spark 3.3, when the date or timestamp pattern is not specified, Spark converts an input string to a date/timestamp using the `CAST` expression approach. The changes affect CSV/JSON datasources and parsing of partition values. In Spark 3.2 or earlier, when the date or timestamp pattern is not set, Spark uses the default patterns: `yyyy-MM-dd` for dates and `yyyy-MM-dd HH:mm:ss` for timestamps. After the changes, Spark still recognizes the pattern together with