From 52d7b0373ce971acd6516609db9fa96f8f427513 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Mon, 20 Jul 2015 16:08:26 +0800 Subject: [PATCH 1/8] add substring_index function --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/stringOperations.scala | 74 +++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 25 ++++++- .../spark/sql/StringFunctionsSuite.scala | 28 +++++++ 4 files changed, 127 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e3d8d2adf2135..ec46ca5477d90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -177,6 +177,7 @@ object FunctionRegistry { expression[StringSplit]("split"), expression[Substring]("substr"), expression[Substring]("substring"), + expression[Substring_index]("substring_index"), expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 1f18a6e9ff8a5..269eb81bdbce3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -21,6 +21,8 @@ import java.text.DecimalFormat import java.util.Locale import java.util.regex.{MatchResult, Pattern} +import org.apache.commons.lang.StringUtils + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -355,6 +357,78 @@ case class StringInstr(str: Expression, substr: Expression) } } +/** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. + */ +case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr: Expression) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable + override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr) + override def prettyName: String = "substring_index" + override def toString: String = s"substring_index($strExpr, $delimExpr, $countExpr)" + + override def eval(input: InternalRow): Any = { + val str = strExpr.eval(input) + val delim = delimExpr.eval(input) + val count = countExpr.eval(input) + if (str == null || delim == null || count == null) { + null + } else { + subStrIndex( + str.asInstanceOf[UTF8String], + delim.asInstanceOf[UTF8String], + count.asInstanceOf[Int]) + } + } + + private def ordinalIndexOf(str: UTF8String, delim: UTF8String, count: Int): Int = { + var found = 0 + var index = -1 + do { + index = str.indexOf(delim, index + 1) + if (index < 0) { + return index + } + found += 1 + } while (found < count) + index + } + + private def subStrIndex(strUtf8: UTF8String, delimUtf8: UTF8String, count: Int): UTF8String = { + if (strUtf8 == null || delimUtf8 == null || count == null) { + return null + } + if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) { + return UTF8String.fromString("") + } + val res: UTF8String = + if (count > 0) { + val idx = ordinalIndexOf(strUtf8, delimUtf8, count) + if (idx != -1) { + strUtf8.substring(0, idx) + } else { + strUtf8 + } + } else { + val str = strUtf8.toString + val delim = delimUtf8.toString + val idx = StringUtils.lastOrdinalIndexOf(str, delim, -count) + if (idx != -1) { + UTF8String.fromString(str.substring(idx + 1)) + } else { + UTF8String.fromString(str) + } + } + res + } +} + /** * A function that returns the position of the first occurrence of substr * in given string after position pos. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e5ff8ae7e3179..faa82ad19fd05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1777,8 +1777,31 @@ object functions { def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) /** - * Locate the position of the first occurrence of substr in a string column. + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. * + * @group string_funcs + * @since 1.5.0 + */ + def substring_index(str: String, delim: String, count: Int): Column = + substring_index(Column(str), delim, count) + + /** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. + * + * @group string_funcs + * @since 1.5.0 + */ + def substring_index(str: Column, delim: String, count: Int): Column = + Substring_index(str.expr, lit(delim).expr, lit(count).expr) + + /** + * Locate the position of the first occurrence of substr. * NOTE: The position is not zero based, but 1 based index, returns 0 if substr * could not be found in str. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3702e73b4e74f..bc5ce2c49a7ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -156,6 +156,34 @@ class StringFunctionsSuite extends QueryTest { Row(1)) } + test("string substring_index function") { + val df = Seq(("ac,ab,ad,ab,cc", "aa", "zz")).toDF("a", "b", "c") + checkAnswer( + df.select(substring_index($"a", ",", 2)), + Row("ac,ab")) + checkAnswer( + df.select(substring_index($"a", "ab", 2)), + Row("ac,ab,ad,")) + checkAnswer( + df.select(substring_index(lit(""), "ab", 2)), + Row("")) + checkAnswer( + df.select(substring_index(lit(null), "ab", 2)), + Row(null)) + checkAnswer( + df.select(substring_index(lit("大千世界大千世界"), "千", 2)), + Row("大千世界大")) + checkAnswer( + df.selectExpr("""substring_index(a, ",", 2)"""), + Row("ac,ab")) + checkAnswer( + df.selectExpr("""substring_index(a, ",", -2)"""), + Row("ab,cc")) + checkAnswer( + df.selectExpr("""substring_index(a, ",", 10)"""), + Row("ac,ab,ad,ab,cc")) + } + test("string locate function") { val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") From d92951b0ea1adcc67d0b5483bcbd6277f747ac89 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 21 Jul 2015 16:21:01 +0800 Subject: [PATCH 2/8] add lastIndexOf --- .../expressions/stringOperations.scala | 54 ++++++++++------ .../spark/sql/StringFunctionsSuite.scala | 2 + .../apache/spark/unsafe/types/UTF8String.java | 63 +++++++++++++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 16 +++++ 4 files changed, 115 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 269eb81bdbce3..6c20677c2b0d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -387,16 +387,33 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr } } - private def ordinalIndexOf(str: UTF8String, delim: UTF8String, count: Int): Int = { + private def lastOrdinalIndexOf( + str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = { + ordinalIndexOf(str, searchStr, ordinal, true) + } + + private def ordinalIndexOf( + str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = { + if (str == null || searchStr == null || ordinal <= 0) { + return -1 + } + val strNumChars = str.numChars() + if (searchStr.numBytes() == 0) { + return if (lastIndex) {strNumChars} else {0} + } var found = 0 - var index = -1 + var index = if (lastIndex) {strNumChars} else {0} do { - index = str.indexOf(delim, index + 1) + if (lastIndex) { + index = str.lastIndexOf(searchStr, index - 1) + } else { + index = str.indexOf(searchStr, index + 1) + } if (index < 0) { return index } found += 1 - } while (found < count) + } while (found < ordinal) index } @@ -407,24 +424,21 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) { return UTF8String.fromString("") } - val res: UTF8String = - if (count > 0) { - val idx = ordinalIndexOf(strUtf8, delimUtf8, count) - if (idx != -1) { - strUtf8.substring(0, idx) - } else { - strUtf8 - } + val res = if (count > 0) { + val idx = ordinalIndexOf(strUtf8, delimUtf8, count) + if (idx != -1) { + strUtf8.substring(0, idx) } else { - val str = strUtf8.toString - val delim = delimUtf8.toString - val idx = StringUtils.lastOrdinalIndexOf(str, delim, -count) - if (idx != -1) { - UTF8String.fromString(str.substring(idx + 1)) - } else { - UTF8String.fromString(str) - } + strUtf8 } + } else { + val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count) + if (idx != -1) { + strUtf8.substring(idx + 1, strUtf8.numChars()) + } else { + strUtf8 + } + } res } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index bc5ce2c49a7ad..9a0eb16ec700b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -170,9 +170,11 @@ class StringFunctionsSuite extends QueryTest { checkAnswer( df.select(substring_index(lit(null), "ab", 2)), Row(null)) + // scalastyle:off checkAnswer( df.select(substring_index(lit("大千世界大千世界"), "千", 2)), Row("大千世界大")) + // scalastyle:on checkAnswer( df.selectExpr("""substring_index(a, ",", 2)"""), Row("ac,ab")) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 946d355f1fc28..f3d14178ce3cc 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -352,6 +352,69 @@ public int indexOf(UTF8String v, int start) { return -1; } + private enum ByteType {FIRSTBYTE, MIDBYTE, SINGLEBYTECHAR}; + + private ByteType checkByteType(Byte b) { + int firstTwoBits = (b >>> 6) & 0x03; + if (firstTwoBits == 3) { + return ByteType.FIRSTBYTE; + } else if (firstTwoBits == 2) { + return ByteType.MIDBYTE; + } else { + return ByteType.SINGLEBYTECHAR; + } + } + + /** + * Return the first byte position for a given byte which shared the same code point. + * @param bytePos any byte within the code point + * @return the first byte position of a given code point, throw exception if not a valid UTF8 str + */ + private int firstOfCurrentCodePoint(int bytePos) { + while (bytePos >= 0) { + if (ByteType.FIRSTBYTE == checkByteType(getByte(bytePos)) + || ByteType.SINGLEBYTECHAR == checkByteType(getByte(bytePos))) { + return bytePos; + } + bytePos--; + } + throw new RuntimeException("Invalid utf8 string"); + } + + private int endByte(int startCodePoint) { + int i = numBytes -1; // position in byte + int c = numChars() - 1; // position in character + while (i >=0 && c > startCodePoint) { + i = firstOfCurrentCodePoint(i) - 1; + c -= 1; + } + return i; + } + + public int lastIndexOf(UTF8String v, int startCodePoint) { + if (v.numBytes == 0) { + return 0; + } + if (numBytes == 0) { + return -1; + } + int fromIndexEnd = endByte(startCodePoint); + int count = startCodePoint; + int vNumChars = v.numChars(); + do { + if (fromIndexEnd - v.numBytes + 1 < 0 ) { + return -1; + } + if (ByteArrayMethods.arrayEquals( + base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) { + return count - vNumChars + 1; + } + fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1; + count--; + } while (fromIndexEnd >= 0); + return -1; + } + /** * Returns str, right-padded with pad to a length of len * For example: diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index e2a5628ff4d93..bee232b11a11e 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -221,6 +221,22 @@ public void indexOf() { assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); } + @Test + public void lastIndexOf() { + assertEquals(0, fromString("").lastIndexOf(fromString(""), 0)); + assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0)); + assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0)); + assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3)); + assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4)); + assertEquals(2, fromString("hello").lastIndexOf(fromString("ll"), 4)); + assertEquals(-1, fromString("hello").lastIndexOf(fromString("ll"), 0)); + assertEquals(5, fromString("数据砖头数据砖头").lastIndexOf(fromString("据砖"), 7)); + assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 3)); + assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 0)); + assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3)); + } + @Test public void reverse() { assertEquals(fromString("olleh"), fromString("hello").reverse()); From 12e108f13cfe4bc6ef967adb6f42891ed01aa521 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 22 Jul 2015 09:21:05 +0800 Subject: [PATCH 3/8] refine unittest --- .../expressions/stringOperations.scala | 2 +- .../spark/sql/StringFunctionsSuite.scala | 61 +++++++++++++------ .../apache/spark/unsafe/types/UTF8String.java | 4 +- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 6c20677c2b0d7..35ec2c991a94b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -434,7 +434,7 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr } else { val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count) if (idx != -1) { - strUtf8.substring(idx + 1, strUtf8.numChars()) + strUtf8.substring(idx + delimUtf8.numChars(), strUtf8.numChars()) } else { strUtf8 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 9a0eb16ec700b..08ea5664d212c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -157,33 +157,60 @@ class StringFunctionsSuite extends QueryTest { } test("string substring_index function") { - val df = Seq(("ac,ab,ad,ab,cc", "aa", "zz")).toDF("a", "b", "c") + val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c") checkAnswer( - df.select(substring_index($"a", ",", 2)), - Row("ac,ab")) + df.select(substring_index($"a", ".", 3)), + Row("www.apache.org")) checkAnswer( - df.select(substring_index($"a", "ab", 2)), - Row("ac,ab,ad,")) + df.select(substring_index($"a", ".", 2)), + Row("www.apache")) checkAnswer( - df.select(substring_index(lit(""), "ab", 2)), + df.select(substring_index($"a", ".", 1)), + Row("www")) + checkAnswer( + df.select(substring_index($"a", ".", 0)), + Row("")) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), ".", -1)), + Row("org")) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), ".", -2)), + Row("apache.org")) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), ".", -3)), + Row("www.apache.org")) + // str is empty string + checkAnswer( + df.select(substring_index(lit(""), ".", 1)), + Row("")) + // empty string delim + checkAnswer( + df.select(substring_index(lit("www.apache.org"), "", 1)), Row("")) + // delim does not exist in str checkAnswer( - df.select(substring_index(lit(null), "ab", 2)), + df.select(substring_index(lit("www.apache.org"), "#", 1)), + Row("www.apache.org")) + // delim is 2 chars + checkAnswer( + df.select(substring_index(lit("www||apache||org"), "||", 2)), + Row("www||apache")) + checkAnswer( + df.select(substring_index(lit("www||apache||org"), "||", -2)), + Row("apache||org")) + // null + checkAnswer( + df.select(substring_index(lit(null), "||", 2)), + Row(null)) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), null, 2)), Row(null)) + // non ascii chars // scalastyle:off checkAnswer( - df.select(substring_index(lit("大千世界大千世界"), "千", 2)), + df.selectExpr("""substring_index("大千世界大千世界", "千", 2)"""), Row("大千世界大")) // scalastyle:on - checkAnswer( - df.selectExpr("""substring_index(a, ",", 2)"""), - Row("ac,ab")) - checkAnswer( - df.selectExpr("""substring_index(a, ",", -2)"""), - Row("ab,cc")) - checkAnswer( - df.selectExpr("""substring_index(a, ",", 10)"""), - Row("ac,ab,ad,ab,cc")) } test("string locate function") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f3d14178ce3cc..78d767ee4de12 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -381,7 +381,7 @@ private int firstOfCurrentCodePoint(int bytePos) { throw new RuntimeException("Invalid utf8 string"); } - private int endByte(int startCodePoint) { + private int indexEnd(int startCodePoint) { int i = numBytes -1; // position in byte int c = numChars() - 1; // position in character while (i >=0 && c > startCodePoint) { @@ -398,7 +398,7 @@ public int lastIndexOf(UTF8String v, int startCodePoint) { if (numBytes == 0) { return -1; } - int fromIndexEnd = endByte(startCodePoint); + int fromIndexEnd = indexEnd(startCodePoint); int count = startCodePoint; int vNumChars = v.numChars(); do { From ac863e9048e1aedf587aaf5e90d09b91bbf8ec25 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 22 Jul 2015 15:52:56 +0800 Subject: [PATCH 4/8] reduce the calling of numChars --- .../expressions/stringOperations.scala | 78 ++-------- .../apache/spark/unsafe/types/UTF8String.java | 140 +++++++++++++++++- .../spark/unsafe/types/UTF8StringSuite.java | 1 + 3 files changed, 148 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 35ec2c991a94b..2d921124c38ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -21,10 +21,7 @@ import java.text.DecimalFormat import java.util.Locale import java.util.regex.{MatchResult, Pattern} -import org.apache.commons.lang.StringUtils - import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -371,75 +368,22 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr) override def prettyName: String = "substring_index" - override def toString: String = s"substring_index($strExpr, $delimExpr, $countExpr)" override def eval(input: InternalRow): Any = { val str = strExpr.eval(input) - val delim = delimExpr.eval(input) - val count = countExpr.eval(input) - if (str == null || delim == null || count == null) { - null - } else { - subStrIndex( - str.asInstanceOf[UTF8String], - delim.asInstanceOf[UTF8String], - count.asInstanceOf[Int]) - } - } - - private def lastOrdinalIndexOf( - str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = { - ordinalIndexOf(str, searchStr, ordinal, true) - } - - private def ordinalIndexOf( - str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = { - if (str == null || searchStr == null || ordinal <= 0) { - return -1 - } - val strNumChars = str.numChars() - if (searchStr.numBytes() == 0) { - return if (lastIndex) {strNumChars} else {0} - } - var found = 0 - var index = if (lastIndex) {strNumChars} else {0} - do { - if (lastIndex) { - index = str.lastIndexOf(searchStr, index - 1) - } else { - index = str.indexOf(searchStr, index + 1) - } - if (index < 0) { - return index - } - found += 1 - } while (found < ordinal) - index - } - - private def subStrIndex(strUtf8: UTF8String, delimUtf8: UTF8String, count: Int): UTF8String = { - if (strUtf8 == null || delimUtf8 == null || count == null) { - return null - } - if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) { - return UTF8String.fromString("") - } - val res = if (count > 0) { - val idx = ordinalIndexOf(strUtf8, delimUtf8, count) - if (idx != -1) { - strUtf8.substring(0, idx) - } else { - strUtf8 - } - } else { - val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count) - if (idx != -1) { - strUtf8.substring(idx + delimUtf8.numChars(), strUtf8.numChars()) - } else { - strUtf8 + if (str != null) { + val delim = delimExpr.eval(input) + if (delim != null) { + val count = countExpr.eval(input) + if (count != null) { + return UTF8String.subStringIndex( + str.asInstanceOf[UTF8String], + delim.asInstanceOf[UTF8String], + count.asInstanceOf[Int]) + } } } - res + null } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 78d767ee4de12..11e34f95bce8a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -165,6 +165,27 @@ public UTF8String substring(final int start, final int until) { return fromBytes(bytes); } + /** + * Returns a substring of this from start to end. + * @param start the position of first code point + */ + public UTF8String substring(final int start) { + if (start >= numBytes) { + return fromBytes(new byte[0]); + } + + int i = 0; + int c = 0; + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + byte[] bytes = new byte[numBytes - i]; + copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i); + return fromBytes(bytes); + } + public UTF8String substringSQL(int pos, int length) { // Information regarding the pos calculation: // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and @@ -391,7 +412,19 @@ private int indexEnd(int startCodePoint) { return i; } + /** + * Returns the index within this string of the last occurrence of the + * specified substring, searching backward starting at the specified index. + * @param v the substring to search for. + * @param startCodePoint the index to start search from + * @return the index of the last occurrence of the specified substring, + * searching backward from the specified index, + * or {@code -1} if there is no such occurrence. + */ public int lastIndexOf(UTF8String v, int startCodePoint) { + return lastIndexOf(v, v.numChars(), startCodePoint); + } + public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) { if (v.numBytes == 0) { return 0; } @@ -399,22 +432,121 @@ public int lastIndexOf(UTF8String v, int startCodePoint) { return -1; } int fromIndexEnd = indexEnd(startCodePoint); - int count = startCodePoint; - int vNumChars = v.numChars(); do { if (fromIndexEnd - v.numBytes + 1 < 0 ) { return -1; } if (ByteArrayMethods.arrayEquals( base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) { - return count - vNumChars + 1; + int count = 0; // count from right most to the match end in byte. + while (fromIndexEnd >= 0) { + count++; + fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1; + } + return count - vNumChars; } fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1; - count--; } while (fromIndexEnd >= 0); return -1; } + /** + * Finds the n-th last index within a String. + * This method uses {@link String#lastIndexOf(String)}.

+ * + * @param str the String to check, may be null + * @param searchStr the String to find, may be null + * @param searchStrNumChars num of code ponts of the searchStr + * @param ordinal the n-th last searchStr to find + * @return the n-th last index of the search String, + * -1 if no match or null string input + */ + public static int lastOrdinalIndexOf( + UTF8String str, + UTF8String searchStr, + int searchStrNumChars, + int ordinal) { + return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, true); + } + /** + * Finds the n-th index within a String, handling null. + * A null String will return -1 + * + * @param str the String to check, may be null + * @param searchStr the String to find, may be null + * @param searchStrNumChars num of code points of searchStr + * @param ordinal the n-th searchStr to find + * @return the n-th index of the search String, + * -1 if no match or null string input + */ + public static int ordinalIndexOf( + UTF8String str, + UTF8String searchStr, + int searchStrNumChars, + int ordinal) { + return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, false); + } + + private static int doOrdinalIndexOf( + UTF8String str, + UTF8String searchStr, + int searchStrNumChars, + int ordinal, + boolean lastIndex) { + if (str == null || searchStr == null || ordinal <= 0) { + return -1; + } + // Only calc numChars when lastIndex == true sicnc the calculation is expensive + int strNumChars = 0; + if (lastIndex) { + strNumChars = str.numChars(); + } + if (searchStr.numBytes == 0) { + return lastIndex ? strNumChars : 0; + } + int found = 0; + int index = lastIndex ? strNumChars : 0; + do { + if (lastIndex) { + index = str.lastIndexOf(searchStr, searchStrNumChars, index - 1); + } else { + index = str.indexOf(searchStr, index + 1); + } + if (index < 0) { + return index; + } + found += 1; + } while (found < ordinal); + return index; + } + /** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. + */ + public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) { + if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) { + return UTF8String.EMPTY_UTF8; + } + int delimNumChars = delim.numChars(); + if (count > 0) { + int idx = ordinalIndexOf(str, delim, delimNumChars, count); + if (idx != -1) { + return str.substring(0, idx); + } else { + return str; + } + } else { + int idx = lastOrdinalIndexOf(str, delim, delimNumChars, -count); + if (idx != -1) { + return str.substring(idx + delimNumChars); + } else { + return str; + } + } + } + /** * Returns str, right-padded with pad to a length of len * For example: diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index bee232b11a11e..df69d8e655ead 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -226,6 +226,7 @@ public void lastIndexOf() { assertEquals(0, fromString("").lastIndexOf(fromString(""), 0)); assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0)); assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0)); + assertEquals(0, fromString("hello").lastIndexOf(fromString("h"), 4)); assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0)); assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3)); assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4)); From b19b013d382ebe3e64a9a7a6708b94737c92ecdd Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 23 Jul 2015 15:22:19 +0800 Subject: [PATCH 5/8] add codegen and clean code --- .../expressions/stringOperations.scala | 26 ++++- .../expressions/StringExpressionsSuite.scala | 31 ++++++ .../org/apache/spark/sql/functions.scala | 2 - .../apache/spark/unsafe/types/UTF8String.java | 80 +++++++-------- .../spark/unsafe/types/UTF8StringSuite.java | 99 +++++++++++++++++++ 5 files changed, 192 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 2d921124c38ea..39329722f2d7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -361,7 +361,7 @@ case class StringInstr(str: Expression, substr: Expression) * right) is returned. substring_index performs a case-sensitive match when searching for delim. */ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -385,6 +385,30 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr } null } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val str = strExpr.gen(ctx) + val delim = delimExpr.gen(ctx) + val count = countExpr.gen(ctx) + val resultCode = + s"""org.apache.spark.unsafe.types.UTF8String.subStringIndex( + |${str.primitive}, ${delim.primitive}, ${count.primitive})""".stripMargin + s""" + ${str.code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${str.isNull}) { + ${delim.code} + if (!${delim.isNull}) { + ${count.code} + if (!${count.isNull}) { + ${ev.isNull} = false; + ${ev.primitive} = $resultCode; + } + } + } + """ + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3c2d88731beb4..abccef6196d5c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(s.substring(0), "example", row) } + test("string substring_index function") { + checkEvaluation( + Substring_index(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org") + checkEvaluation( + Substring_index(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache") + checkEvaluation( + Substring_index(Literal("www.apache.org"), Literal("."), Literal(1)), "www") + checkEvaluation( + Substring_index(Literal("www.apache.org"), Literal("."), Literal(0)), "") + checkEvaluation( + Substring_index(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org") + checkEvaluation( + Substring_index(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org") + checkEvaluation( + Substring_index(Literal("www.apache.org"), Literal("."), Literal(-1)), "org") + checkEvaluation( + Substring_index(Literal(""), Literal("."), Literal(-2)), "") + checkEvaluation( + Substring_index(Literal.create(null, StringType), Literal("."), Literal(-2)), null) + checkEvaluation( + Substring_index(Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null) + // non ascii chars + // scalastyle:off + checkEvaluation( + Substring_index(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大") + // scalastyle:on + checkEvaluation( + Substring_index(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache") + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index faa82ad19fd05..0aef3db89904d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1783,7 +1783,6 @@ object functions { * right) is returned. substring_index performs a case-sensitive match when searching for delim. * * @group string_funcs - * @since 1.5.0 */ def substring_index(str: String, delim: String, count: Int): Column = substring_index(Column(str), delim, count) @@ -1795,7 +1794,6 @@ object functions { * right) is returned. substring_index performs a case-sensitive match when searching for delim. * * @group string_funcs - * @since 1.5.0 */ def substring_index(str: Column, delim: String, count: Int): Column = Substring_index(str.expr, lit(delim).expr, lit(count).expr) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 11e34f95bce8a..946bd3ea57a3b 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -146,20 +146,8 @@ public UTF8String substring(final int start, final int until) { if (until <= start || start >= numBytes) { return fromBytes(new byte[0]); } - - int i = 0; - int c = 0; - while (i < numBytes && c < start) { - i += numBytesForFirstByte(getByte(i)); - c += 1; - } - - int j = i; - while (i < numBytes && c < until) { - i += numBytesForFirstByte(getByte(i)); - c += 1; - } - + int j = firstByteIndex(start); + int i = firstByteIndex(until); byte[] bytes = new byte[i - j]; copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); @@ -174,12 +162,7 @@ public UTF8String substring(final int start) { return fromBytes(new byte[0]); } - int i = 0; - int c = 0; - while (i < numBytes && c < start) { - i += numBytesForFirstByte(getByte(i)); - c += 1; - } + int i = firstByteIndex(start); byte[] bytes = new byte[numBytes - i]; copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i); @@ -351,13 +334,8 @@ public int indexOf(UTF8String v, int start) { return 0; } - // locate to the start position. - int i = 0; // position in byte - int c = 0; // position in character - while (i < numBytes && c < start) { - i += numBytesForFirstByte(getByte(i)); - c += 1; - } + int i = firstByteIndex(start); // position in byte + int c = start; // position in character do { if (i + v.numBytes > numBytes) { @@ -399,19 +377,29 @@ private int firstOfCurrentCodePoint(int bytePos) { } bytePos--; } - throw new RuntimeException("Invalid utf8 string"); + throw new RuntimeException("Invalid UTF8 string"); } - private int indexEnd(int startCodePoint) { - int i = numBytes -1; // position in byte - int c = numChars() - 1; // position in character - while (i >=0 && c > startCodePoint) { - i = firstOfCurrentCodePoint(i) - 1; - c -= 1; + // Locate to the start position in byte for a given code point + private int firstByteIndex(int codePoint) { + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < codePoint) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + if (i > numBytes) { + throw new StringIndexOutOfBoundsException(codePoint); } return i; } + // Locate to the last position in byte for a given code point + private int lastByteIndex(int codePoint) { + int i = firstByteIndex(codePoint); + return i + numBytesForFirstByte(getByte(i)) - 1; + } + /** * Returns the index within this string of the last occurrence of the * specified substring, searching backward starting at the specified index. @@ -431,7 +419,7 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) { if (numBytes == 0) { return -1; } - int fromIndexEnd = indexEnd(startCodePoint); + int fromIndexEnd = lastByteIndex(startCodePoint); do { if (fromIndexEnd - v.numBytes + 1 < 0 ) { return -1; @@ -456,7 +444,6 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) { * * @param str the String to check, may be null * @param searchStr the String to find, may be null - * @param searchStrNumChars num of code ponts of the searchStr * @param ordinal the n-th last searchStr to find * @return the n-th last index of the search String, * -1 if no match or null string input @@ -464,17 +451,19 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) { public static int lastOrdinalIndexOf( UTF8String str, UTF8String searchStr, - int searchStrNumChars, int ordinal) { - return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, true); + if (str == null || searchStr == null) { + return -1; + } + return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, true); } + /** * Finds the n-th index within a String, handling null. * A null String will return -1 * * @param str the String to check, may be null * @param searchStr the String to find, may be null - * @param searchStrNumChars num of code points of searchStr * @param ordinal the n-th searchStr to find * @return the n-th index of the search String, * -1 if no match or null string input @@ -482,9 +471,11 @@ public static int lastOrdinalIndexOf( public static int ordinalIndexOf( UTF8String str, UTF8String searchStr, - int searchStrNumChars, int ordinal) { - return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, false); + if (str == null || searchStr == null) { + return -1; + } + return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, false); } private static int doOrdinalIndexOf( @@ -526,19 +517,22 @@ private static int doOrdinalIndexOf( * right) is returned. substring_index performs a case-sensitive match when searching for delim. */ public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) { + if (str == null || delim == null) { + return null; + } if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) { return UTF8String.EMPTY_UTF8; } int delimNumChars = delim.numChars(); if (count > 0) { - int idx = ordinalIndexOf(str, delim, delimNumChars, count); + int idx = doOrdinalIndexOf(str, delim, delimNumChars, count, false); if (idx != -1) { return str.substring(0, idx); } else { return str; } } else { - int idx = lastOrdinalIndexOf(str, delim, delimNumChars, -count); + int idx = doOrdinalIndexOf(str, delim, delimNumChars, -count, true); if (idx != -1) { return str.substring(idx + delimNumChars); } else { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index df69d8e655ead..31ea799a98282 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -21,10 +21,12 @@ import java.util.Arrays; import org.junit.Test; +import org.junit.rules.ExpectedException; import static junit.framework.Assert.*; import static org.apache.spark.unsafe.types.UTF8String.*; +import static org.apache.spark.unsafe.types.UTF8String.fromString; public class UTF8StringSuite { @@ -184,6 +186,63 @@ public void substring() { assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5)); assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2)); + + assertEquals(fromString("hello"), fromString("hello").substring(0)); + assertEquals(fromString("ello"), fromString("hello").substring(1)); + assertEquals(fromString("砖头"), fromString("数据砖头").substring(2)); + assertEquals(fromString("头"), fromString("数据砖头").substring(3)); + ExpectedException exception = ExpectedException.none(); + fromString("数据砖头").substring(4); + exception.expect(java.lang.StringIndexOutOfBoundsException.class); + assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0)); + } + + @Test + public void ordinalIndexOf() { + assertEquals(-1, + UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 0)); + assertEquals(3, + UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 1)); + assertEquals(10, + UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 2)); + assertEquals(-1, + UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 3)); + assertEquals(-1, + UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("#"), 0)); + assertEquals(12, + UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2)); + assertEquals(-1, + UTF8String.ordinalIndexOf(null, fromString("|||"), 1)); + assertEquals(-1, + UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), null, 1)); + assertEquals(2, + UTF8String.ordinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1)); + assertEquals(-1, + UTF8String.ordinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2)); + } + + @Test + public void lastOrdinalIndexOf() { + assertEquals(-1, + UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 0)); + assertEquals(10, + UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 1)); + assertEquals(3, + UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 2)); + assertEquals(-1, + UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 3)); + assertEquals(-1, + UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("#"), 0)); + assertEquals(3, + UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2)); + assertEquals(-1, + UTF8String.lastOrdinalIndexOf(null, fromString("|||"), 1)); + assertEquals(-1, + UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), null, 1)); + assertEquals(3, + UTF8String.lastOrdinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1)); + assertEquals(-1, + UTF8String.lastOrdinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2)); } @Test @@ -238,6 +297,46 @@ public void lastIndexOf() { assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3)); } + @Test + public void substring_index() { + assertEquals(fromString("www.apache.org"), + UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 3)); + assertEquals(fromString("www.apache"), + UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 2)); + assertEquals(fromString("www"), + UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 1)); + assertEquals(fromString(""), + UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 0)); + assertEquals(fromString("org"), + UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -1)); + assertEquals(fromString("apache.org"), + UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -2)); + assertEquals(fromString("www.apache.org"), + UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -3)); + // str is empty string + assertEquals(fromString(""), + UTF8String.subStringIndex(fromString(""), fromString("."), 1)); + // empty string delim + assertEquals(fromString(""), + UTF8String.subStringIndex(fromString("www.apache.org"), fromString(""), 1)); + // delim does not exist in str + assertEquals(fromString("www.apache.org"), + UTF8String.subStringIndex(fromString("www.apache.org"), fromString("#"), 2)); + // delim is 2 chars + assertEquals(fromString("www||apache"), + UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), 2)); + assertEquals(fromString("apache||org"), + UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), -2)); + // null + assertEquals(null, + UTF8String.subStringIndex(null, fromString("."), -2)); + assertEquals(null, + UTF8String.subStringIndex(fromString("www.apache.org"), null, -2)); + // non ascii chars + assertEquals(fromString("大千世界大"), + UTF8String.subStringIndex(fromString("大千世界大千世界"), fromString("千"), 2)); + } + @Test public void reverse() { assertEquals(fromString("olleh"), fromString("hello").reverse()); From 67c253a67b0c2ddf6a3e59881e8d42d0326a0bf9 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Fri, 24 Jul 2015 10:30:17 +0800 Subject: [PATCH 6/8] hide some apis and clean code --- .../expressions/stringOperations.scala | 7 +- .../org/apache/spark/sql/functions.scala | 11 -- .../apache/spark/unsafe/types/UTF8String.java | 106 ++++++++---------- .../spark/unsafe/types/UTF8StringSuite.java | 74 ++++++------ 4 files changed, 80 insertions(+), 118 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 39329722f2d7d..18877769bd212 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -376,8 +376,7 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr if (delim != null) { val count = countExpr.eval(input) if (count != null) { - return UTF8String.subStringIndex( - str.asInstanceOf[UTF8String], + return str.asInstanceOf[UTF8String].subStringIndex( delim.asInstanceOf[UTF8String], count.asInstanceOf[Int]) } @@ -390,9 +389,7 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr val str = strExpr.gen(ctx) val delim = delimExpr.gen(ctx) val count = countExpr.gen(ctx) - val resultCode = - s"""org.apache.spark.unsafe.types.UTF8String.subStringIndex( - |${str.primitive}, ${delim.primitive}, ${count.primitive})""".stripMargin + val resultCode = s"${str.primitive}.subStringIndex(${delim.primitive}, ${count.primitive})" s""" ${str.code} boolean ${ev.isNull} = true; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0aef3db89904d..b092ad047da50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1776,17 +1776,6 @@ object functions { */ def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) - /** - * Returns the substring from string str before count occurrences of the delimiter delim. - * If count is positive, everything the left of the final delimiter (counting from left) is - * returned. If count is negative, every to the right of the final delimiter (counting from the - * right) is returned. substring_index performs a case-sensitive match when searching for delim. - * - * @group string_funcs - */ - def substring_index(str: String, delim: String, count: Int): Column = - substring_index(Column(str), delim, count) - /** * Returns the substring from string str before count occurrences of the delimiter delim. * If count is positive, everything the left of the final delimiter (counting from left) is diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 946bd3ea57a3b..e84509f18e146 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -146,8 +146,8 @@ public UTF8String substring(final int start, final int until) { if (until <= start || start >= numBytes) { return fromBytes(new byte[0]); } - int j = firstByteIndex(start); - int i = firstByteIndex(until); + int j = firstByteIndex(0, 0, start); + int i = firstByteIndex(j, start, until); byte[] bytes = new byte[i - j]; copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); @@ -162,7 +162,7 @@ public UTF8String substring(final int start) { return fromBytes(new byte[0]); } - int i = firstByteIndex(start); + int i = firstByteIndex(0, 0, start); byte[] bytes = new byte[numBytes - i]; copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i); @@ -334,7 +334,7 @@ public int indexOf(UTF8String v, int start) { return 0; } - int i = firstByteIndex(start); // position in byte + int i = firstByteIndex(0, 0, start); // position in byte int c = start; // position in character do { @@ -353,7 +353,7 @@ public int indexOf(UTF8String v, int start) { private enum ByteType {FIRSTBYTE, MIDBYTE, SINGLEBYTECHAR}; - private ByteType checkByteType(Byte b) { + private ByteType checkByteType(byte b) { int firstTwoBits = (b >>> 6) & 0x03; if (firstTwoBits == 3) { return ByteType.FIRSTBYTE; @@ -371,19 +371,19 @@ private ByteType checkByteType(Byte b) { */ private int firstOfCurrentCodePoint(int bytePos) { while (bytePos >= 0) { - if (ByteType.FIRSTBYTE == checkByteType(getByte(bytePos)) - || ByteType.SINGLEBYTECHAR == checkByteType(getByte(bytePos))) { + ByteType byteType = checkByteType(getByte(bytePos)); + if (ByteType.FIRSTBYTE == byteType || ByteType.SINGLEBYTECHAR == byteType) { return bytePos; } bytePos--; } - throw new RuntimeException("Invalid UTF8 string"); + throw new RuntimeException("Invalid UTF8 string: " + toString()); } // Locate to the start position in byte for a given code point - private int firstByteIndex(int codePoint) { - int i = 0; // position in byte - int c = 0; // position in character + private int firstByteIndex(int startByteIndex, int startPointIndex, int codePoint) { + int i = startByteIndex; // position in byte + int c = startPointIndex; // position in character while (i < numBytes && c < codePoint) { i += numBytesForFirstByte(getByte(i)); c += 1; @@ -396,7 +396,7 @@ private int firstByteIndex(int codePoint) { // Locate to the last position in byte for a given code point private int lastByteIndex(int codePoint) { - int i = firstByteIndex(codePoint); + int i = firstByteIndex(0, 0, codePoint); return i + numBytesForFirstByte(getByte(i)) - 1; } @@ -410,31 +410,33 @@ private int lastByteIndex(int codePoint) { * or {@code -1} if there is no such occurrence. */ public int lastIndexOf(UTF8String v, int startCodePoint) { - return lastIndexOf(v, v.numChars(), startCodePoint); - } - public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) { + // Empty string always match if (v.numBytes == 0) { - return 0; + return startCodePoint; } + return lastIndexOfInByte(v, lastByteIndex(startCodePoint)); + } + + private int lastIndexOfInByte(UTF8String v, int fromIndexInByte) { if (numBytes == 0) { return -1; } - int fromIndexEnd = lastByteIndex(startCodePoint); do { - if (fromIndexEnd - v.numBytes + 1 < 0 ) { + int startByteIndex = fromIndexInByte - v.numBytes + 1; + if (startByteIndex < 0 ) { return -1; } if (ByteArrayMethods.arrayEquals( - base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) { + base, offset + startByteIndex, v.base, v.offset, v.numBytes)) { int count = 0; // count from right most to the match end in byte. - while (fromIndexEnd >= 0) { + while (startByteIndex >= 0) { count++; - fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1; + startByteIndex = firstOfCurrentCodePoint(startByteIndex) - 1; } - return count - vNumChars; + return count - 1; } - fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1; - } while (fromIndexEnd >= 0); + fromIndexInByte = firstOfCurrentCodePoint(fromIndexInByte) - 1; + } while (fromIndexInByte >= 0); return -1; } @@ -442,66 +444,49 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) { * Finds the n-th last index within a String. * This method uses {@link String#lastIndexOf(String)}.

* - * @param str the String to check, may be null * @param searchStr the String to find, may be null * @param ordinal the n-th last searchStr to find * @return the n-th last index of the search String, * -1 if no match or null string input */ - public static int lastOrdinalIndexOf( - UTF8String str, + protected int lastOrdinalIndexOf( UTF8String searchStr, int ordinal) { - if (str == null || searchStr == null) { - return -1; - } - return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, true); + return doOrdinalIndexOf(searchStr, ordinal, true); } /** * Finds the n-th index within a String, handling null. * A null String will return -1 * - * @param str the String to check, may be null * @param searchStr the String to find, may be null * @param ordinal the n-th searchStr to find * @return the n-th index of the search String, * -1 if no match or null string input */ - public static int ordinalIndexOf( - UTF8String str, + protected int ordinalIndexOf( UTF8String searchStr, int ordinal) { - if (str == null || searchStr == null) { - return -1; - } - return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, false); + return doOrdinalIndexOf(searchStr, ordinal, false); } - private static int doOrdinalIndexOf( - UTF8String str, + private int doOrdinalIndexOf( UTF8String searchStr, - int searchStrNumChars, int ordinal, boolean lastIndex) { - if (str == null || searchStr == null || ordinal <= 0) { + if (ordinal <= 0) { return -1; } - // Only calc numChars when lastIndex == true sicnc the calculation is expensive - int strNumChars = 0; - if (lastIndex) { - strNumChars = str.numChars(); - } if (searchStr.numBytes == 0) { - return lastIndex ? strNumChars : 0; + return lastIndex ? numChars() : 0; } int found = 0; - int index = lastIndex ? strNumChars : 0; + int index = lastIndex ? numBytes : -1; do { if (lastIndex) { - index = str.lastIndexOf(searchStr, searchStrNumChars, index - 1); + index = lastIndexOfInByte(searchStr, index - 1); } else { - index = str.indexOf(searchStr, index + 1); + index = indexOf(searchStr, index + 1); } if (index < 0) { return index; @@ -516,27 +501,26 @@ private static int doOrdinalIndexOf( * returned. If count is negative, every to the right of the final delimiter (counting from the * right) is returned. substring_index performs a case-sensitive match when searching for delim. */ - public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) { - if (str == null || delim == null) { + public UTF8String subStringIndex(UTF8String delim, int count) { + if (delim == null) { return null; } - if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) { + if (delim.numBytes == 0 || count == 0) { return UTF8String.EMPTY_UTF8; } - int delimNumChars = delim.numChars(); if (count > 0) { - int idx = doOrdinalIndexOf(str, delim, delimNumChars, count, false); + int idx = ordinalIndexOf(delim, count); if (idx != -1) { - return str.substring(0, idx); + return substring(0, idx); } else { - return str; + return this; } } else { - int idx = doOrdinalIndexOf(str, delim, delimNumChars, -count, true); + int idx = lastOrdinalIndexOf(delim, -count); if (idx != -1) { - return str.substring(idx + delimNumChars); + return substring(idx + delim.numChars()); } else { - return str; + return this; } } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 31ea799a98282..6386c9902885d 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -200,49 +200,43 @@ public void substring() { @Test public void ordinalIndexOf() { assertEquals(-1, - UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 0)); + fromString("www.apache.org").ordinalIndexOf(fromString("."), 0)); + assertEquals(0, + fromString("www.apache.org").ordinalIndexOf(fromString("w"), 1)); assertEquals(3, - UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 1)); + fromString("www.apache.org").ordinalIndexOf(fromString("."), 1)); assertEquals(10, - UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 2)); + fromString("www.apache.org").ordinalIndexOf(fromString("."), 2)); assertEquals(-1, - UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 3)); + fromString("www.apache.org").ordinalIndexOf(fromString("."), 3)); assertEquals(-1, - UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("#"), 0)); + fromString("www.apache.org").ordinalIndexOf(fromString("#"), 0)); assertEquals(12, - UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2)); - assertEquals(-1, - UTF8String.ordinalIndexOf(null, fromString("|||"), 1)); - assertEquals(-1, - UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), null, 1)); + fromString("www|||apache|||org").ordinalIndexOf(fromString("|||"), 2)); assertEquals(2, - UTF8String.ordinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1)); + fromString("数据砖砖头").ordinalIndexOf(fromString("砖"), 1)); assertEquals(-1, - UTF8String.ordinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2)); + fromString("砖头数据砖头").ordinalIndexOf(fromString("砖"), -2)); } @Test public void lastOrdinalIndexOf() { assertEquals(-1, - UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 0)); + fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 0)); assertEquals(10, - UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 1)); + fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 1)); assertEquals(3, - UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 2)); + fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 2)); assertEquals(-1, - UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 3)); + fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 3)); assertEquals(-1, - UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("#"), 0)); + fromString("www.apache.org").lastOrdinalIndexOf(fromString("#"), 0)); assertEquals(3, - UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2)); - assertEquals(-1, - UTF8String.lastOrdinalIndexOf(null, fromString("|||"), 1)); - assertEquals(-1, - UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), null, 1)); + fromString("www|||apache|||org").lastOrdinalIndexOf(fromString("|||"), 2)); assertEquals(3, - UTF8String.lastOrdinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1)); + fromString("数据砖砖头").lastOrdinalIndexOf(fromString("砖"), 1)); assertEquals(-1, - UTF8String.lastOrdinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2)); + fromString("砖头数据砖头").lastOrdinalIndexOf(fromString("砖"), -2)); } @Test @@ -282,7 +276,7 @@ public void indexOf() { @Test public void lastIndexOf() { - assertEquals(0, fromString("").lastIndexOf(fromString(""), 0)); + assertEquals(1, fromString("hello").lastIndexOf(fromString(""), 1)); assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0)); assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0)); assertEquals(0, fromString("hello").lastIndexOf(fromString("h"), 4)); @@ -300,41 +294,39 @@ public void lastIndexOf() { @Test public void substring_index() { assertEquals(fromString("www.apache.org"), - UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 3)); + fromString("www.apache.org").subStringIndex(fromString("."), 3)); assertEquals(fromString("www.apache"), - UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 2)); + fromString("www.apache.org").subStringIndex(fromString("."), 2)); assertEquals(fromString("www"), - UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 1)); + fromString("www.apache.org").subStringIndex(fromString("."), 1)); assertEquals(fromString(""), - UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 0)); + fromString("www.apache.org").subStringIndex(fromString("."), 0)); assertEquals(fromString("org"), - UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -1)); + fromString("www.apache.org").subStringIndex(fromString("."), -1)); assertEquals(fromString("apache.org"), - UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -2)); + fromString("www.apache.org").subStringIndex(fromString("."), -2)); assertEquals(fromString("www.apache.org"), - UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -3)); + fromString("www.apache.org").subStringIndex(fromString("."), -3)); // str is empty string assertEquals(fromString(""), - UTF8String.subStringIndex(fromString(""), fromString("."), 1)); + fromString("").subStringIndex(fromString("."), 1)); // empty string delim assertEquals(fromString(""), - UTF8String.subStringIndex(fromString("www.apache.org"), fromString(""), 1)); + fromString("www.apache.org").subStringIndex(fromString(""), 1)); // delim does not exist in str assertEquals(fromString("www.apache.org"), - UTF8String.subStringIndex(fromString("www.apache.org"), fromString("#"), 2)); + fromString("www.apache.org").subStringIndex(fromString("#"), 2)); // delim is 2 chars assertEquals(fromString("www||apache"), - UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), 2)); + fromString("www||apache||org").subStringIndex(fromString("||"), 2)); assertEquals(fromString("apache||org"), - UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), -2)); + fromString("www||apache||org").subStringIndex(fromString("||"), -2)); // null assertEquals(null, - UTF8String.subStringIndex(null, fromString("."), -2)); - assertEquals(null, - UTF8String.subStringIndex(fromString("www.apache.org"), null, -2)); + fromString("www.apache.org").subStringIndex(null, -2)); // non ascii chars assertEquals(fromString("大千世界大"), - UTF8String.subStringIndex(fromString("大千世界大千世界"), fromString("千"), 2)); + fromString("大千世界大千世界").subStringIndex(fromString("千"), 2)); } @Test From 9546991d4964b17dd5d62f93909651280077ba13 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Fri, 24 Jul 2015 10:46:04 +0800 Subject: [PATCH 7/8] scala style --- .../sql/catalyst/expressions/StringExpressionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index abccef6196d5c..1a4376e6f5216 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -207,8 +207,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Substring_index(Literal(""), Literal("."), Literal(-2)), "") checkEvaluation( Substring_index(Literal.create(null, StringType), Literal("."), Literal(-2)), null) - checkEvaluation( - Substring_index(Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null) + checkEvaluation(Substring_index( + Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null) // non ascii chars // scalastyle:off checkEvaluation( From 515519bebbec5a46de0789ea2ec82f448be0e8e8 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Fri, 24 Jul 2015 13:49:19 +0800 Subject: [PATCH 8/8] add foldable and remove null checking --- .../spark/sql/catalyst/expressions/stringOperations.scala | 1 + .../main/java/org/apache/spark/unsafe/types/UTF8String.java | 5 +---- .../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 3 --- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 18877769bd212..20252ea8a89cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -364,6 +364,7 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr extends Expression with ImplicitCastInputTypes { override def dataType: DataType = StringType + override def foldable: Boolean = strExpr.foldable && delimExpr.foldable && countExpr.foldable override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e84509f18e146..79f2f25142e82 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -144,7 +144,7 @@ public byte[] getBytes() { */ public UTF8String substring(final int start, final int until) { if (until <= start || start >= numBytes) { - return fromBytes(new byte[0]); + return UTF8String.EMPTY_UTF8; } int j = firstByteIndex(0, 0, start); int i = firstByteIndex(j, start, until); @@ -502,9 +502,6 @@ private int doOrdinalIndexOf( * right) is returned. substring_index performs a case-sensitive match when searching for delim. */ public UTF8String subStringIndex(UTF8String delim, int count) { - if (delim == null) { - return null; - } if (delim.numBytes == 0 || count == 0) { return UTF8String.EMPTY_UTF8; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 6386c9902885d..a78bb3c4374f4 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -321,9 +321,6 @@ public void substring_index() { fromString("www||apache||org").subStringIndex(fromString("||"), 2)); assertEquals(fromString("apache||org"), fromString("www||apache||org").subStringIndex(fromString("||"), -2)); - // null - assertEquals(null, - fromString("www.apache.org").subStringIndex(null, -2)); // non ascii chars assertEquals(fromString("大千世界大"), fromString("大千世界大千世界").subStringIndex(fromString("千"), 2));