From 52d7b0373ce971acd6516609db9fa96f8f427513 Mon Sep 17 00:00:00 2001
From: "zhichao.li"
Date: Mon, 20 Jul 2015 16:08:26 +0800
Subject: [PATCH 1/8] add substring_index function
---
.../catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/stringOperations.scala | 74 +++++++++++++++++++
.../org/apache/spark/sql/functions.scala | 25 ++++++-
.../spark/sql/StringFunctionsSuite.scala | 28 +++++++
4 files changed, 127 insertions(+), 1 deletion(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e3d8d2adf2135..ec46ca5477d90 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -177,6 +177,7 @@ object FunctionRegistry {
expression[StringSplit]("split"),
expression[Substring]("substr"),
expression[Substring]("substring"),
+ expression[Substring_index]("substring_index"),
expression[StringTrim]("trim"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 1f18a6e9ff8a5..269eb81bdbce3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -21,6 +21,8 @@ import java.text.DecimalFormat
import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
+import org.apache.commons.lang.StringUtils
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -355,6 +357,78 @@ case class StringInstr(str: Expression, substr: Expression)
}
}
+/**
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
+ */
+case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
+
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
+ override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable
+ override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
+ override def prettyName: String = "substring_index"
+ override def toString: String = s"substring_index($strExpr, $delimExpr, $countExpr)"
+
+ override def eval(input: InternalRow): Any = {
+ val str = strExpr.eval(input)
+ val delim = delimExpr.eval(input)
+ val count = countExpr.eval(input)
+ if (str == null || delim == null || count == null) {
+ null
+ } else {
+ subStrIndex(
+ str.asInstanceOf[UTF8String],
+ delim.asInstanceOf[UTF8String],
+ count.asInstanceOf[Int])
+ }
+ }
+
+ private def ordinalIndexOf(str: UTF8String, delim: UTF8String, count: Int): Int = {
+ var found = 0
+ var index = -1
+ do {
+ index = str.indexOf(delim, index + 1)
+ if (index < 0) {
+ return index
+ }
+ found += 1
+ } while (found < count)
+ index
+ }
+
+ private def subStrIndex(strUtf8: UTF8String, delimUtf8: UTF8String, count: Int): UTF8String = {
+ if (strUtf8 == null || delimUtf8 == null || count == null) {
+ return null
+ }
+ if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) {
+ return UTF8String.fromString("")
+ }
+ val res: UTF8String =
+ if (count > 0) {
+ val idx = ordinalIndexOf(strUtf8, delimUtf8, count)
+ if (idx != -1) {
+ strUtf8.substring(0, idx)
+ } else {
+ strUtf8
+ }
+ } else {
+ val str = strUtf8.toString
+ val delim = delimUtf8.toString
+ val idx = StringUtils.lastOrdinalIndexOf(str, delim, -count)
+ if (idx != -1) {
+ UTF8String.fromString(str.substring(idx + 1))
+ } else {
+ UTF8String.fromString(str)
+ }
+ }
+ res
+ }
+}
+
/**
* A function that returns the position of the first occurrence of substr
* in given string after position pos.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index e5ff8ae7e3179..faa82ad19fd05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1777,8 +1777,31 @@ object functions {
def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr)
/**
- * Locate the position of the first occurrence of substr in a string column.
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
*
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def substring_index(str: String, delim: String, count: Int): Column =
+ substring_index(Column(str), delim, count)
+
+ /**
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def substring_index(str: Column, delim: String, count: Int): Column =
+ Substring_index(str.expr, lit(delim).expr, lit(count).expr)
+
+ /**
+ * Locate the position of the first occurrence of substr.
* NOTE: The position is not zero based, but 1 based index, returns 0 if substr
* could not be found in str.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 3702e73b4e74f..bc5ce2c49a7ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -156,6 +156,34 @@ class StringFunctionsSuite extends QueryTest {
Row(1))
}
+ test("string substring_index function") {
+ val df = Seq(("ac,ab,ad,ab,cc", "aa", "zz")).toDF("a", "b", "c")
+ checkAnswer(
+ df.select(substring_index($"a", ",", 2)),
+ Row("ac,ab"))
+ checkAnswer(
+ df.select(substring_index($"a", "ab", 2)),
+ Row("ac,ab,ad,"))
+ checkAnswer(
+ df.select(substring_index(lit(""), "ab", 2)),
+ Row(""))
+ checkAnswer(
+ df.select(substring_index(lit(null), "ab", 2)),
+ Row(null))
+ checkAnswer(
+ df.select(substring_index(lit("大千世界大千世界"), "千", 2)),
+ Row("大千世界大"))
+ checkAnswer(
+ df.selectExpr("""substring_index(a, ",", 2)"""),
+ Row("ac,ab"))
+ checkAnswer(
+ df.selectExpr("""substring_index(a, ",", -2)"""),
+ Row("ab,cc"))
+ checkAnswer(
+ df.selectExpr("""substring_index(a, ",", 10)"""),
+ Row("ac,ab,ad,ab,cc"))
+ }
+
test("string locate function") {
val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
From d92951b0ea1adcc67d0b5483bcbd6277f747ac89 Mon Sep 17 00:00:00 2001
From: "zhichao.li"
Date: Tue, 21 Jul 2015 16:21:01 +0800
Subject: [PATCH 2/8] add lastIndexOf
---
.../expressions/stringOperations.scala | 54 ++++++++++------
.../spark/sql/StringFunctionsSuite.scala | 2 +
.../apache/spark/unsafe/types/UTF8String.java | 63 +++++++++++++++++++
.../spark/unsafe/types/UTF8StringSuite.java | 16 +++++
4 files changed, 115 insertions(+), 20 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 269eb81bdbce3..6c20677c2b0d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -387,16 +387,33 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
}
}
- private def ordinalIndexOf(str: UTF8String, delim: UTF8String, count: Int): Int = {
+ private def lastOrdinalIndexOf(
+ str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
+ ordinalIndexOf(str, searchStr, ordinal, true)
+ }
+
+ private def ordinalIndexOf(
+ str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
+ if (str == null || searchStr == null || ordinal <= 0) {
+ return -1
+ }
+ val strNumChars = str.numChars()
+ if (searchStr.numBytes() == 0) {
+ return if (lastIndex) {strNumChars} else {0}
+ }
var found = 0
- var index = -1
+ var index = if (lastIndex) {strNumChars} else {0}
do {
- index = str.indexOf(delim, index + 1)
+ if (lastIndex) {
+ index = str.lastIndexOf(searchStr, index - 1)
+ } else {
+ index = str.indexOf(searchStr, index + 1)
+ }
if (index < 0) {
return index
}
found += 1
- } while (found < count)
+ } while (found < ordinal)
index
}
@@ -407,24 +424,21 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) {
return UTF8String.fromString("")
}
- val res: UTF8String =
- if (count > 0) {
- val idx = ordinalIndexOf(strUtf8, delimUtf8, count)
- if (idx != -1) {
- strUtf8.substring(0, idx)
- } else {
- strUtf8
- }
+ val res = if (count > 0) {
+ val idx = ordinalIndexOf(strUtf8, delimUtf8, count)
+ if (idx != -1) {
+ strUtf8.substring(0, idx)
} else {
- val str = strUtf8.toString
- val delim = delimUtf8.toString
- val idx = StringUtils.lastOrdinalIndexOf(str, delim, -count)
- if (idx != -1) {
- UTF8String.fromString(str.substring(idx + 1))
- } else {
- UTF8String.fromString(str)
- }
+ strUtf8
}
+ } else {
+ val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count)
+ if (idx != -1) {
+ strUtf8.substring(idx + 1, strUtf8.numChars())
+ } else {
+ strUtf8
+ }
+ }
res
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index bc5ce2c49a7ad..9a0eb16ec700b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -170,9 +170,11 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(
df.select(substring_index(lit(null), "ab", 2)),
Row(null))
+ // scalastyle:off
checkAnswer(
df.select(substring_index(lit("大千世界大千世界"), "千", 2)),
Row("大千世界大"))
+ // scalastyle:on
checkAnswer(
df.selectExpr("""substring_index(a, ",", 2)"""),
Row("ac,ab"))
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 946d355f1fc28..f3d14178ce3cc 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -352,6 +352,69 @@ public int indexOf(UTF8String v, int start) {
return -1;
}
+ private enum ByteType {FIRSTBYTE, MIDBYTE, SINGLEBYTECHAR};
+
+ private ByteType checkByteType(Byte b) {
+ int firstTwoBits = (b >>> 6) & 0x03;
+ if (firstTwoBits == 3) {
+ return ByteType.FIRSTBYTE;
+ } else if (firstTwoBits == 2) {
+ return ByteType.MIDBYTE;
+ } else {
+ return ByteType.SINGLEBYTECHAR;
+ }
+ }
+
+ /**
+ * Return the first byte position for a given byte which shared the same code point.
+ * @param bytePos any byte within the code point
+ * @return the first byte position of a given code point, throw exception if not a valid UTF8 str
+ */
+ private int firstOfCurrentCodePoint(int bytePos) {
+ while (bytePos >= 0) {
+ if (ByteType.FIRSTBYTE == checkByteType(getByte(bytePos))
+ || ByteType.SINGLEBYTECHAR == checkByteType(getByte(bytePos))) {
+ return bytePos;
+ }
+ bytePos--;
+ }
+ throw new RuntimeException("Invalid utf8 string");
+ }
+
+ private int endByte(int startCodePoint) {
+ int i = numBytes -1; // position in byte
+ int c = numChars() - 1; // position in character
+ while (i >=0 && c > startCodePoint) {
+ i = firstOfCurrentCodePoint(i) - 1;
+ c -= 1;
+ }
+ return i;
+ }
+
+ public int lastIndexOf(UTF8String v, int startCodePoint) {
+ if (v.numBytes == 0) {
+ return 0;
+ }
+ if (numBytes == 0) {
+ return -1;
+ }
+ int fromIndexEnd = endByte(startCodePoint);
+ int count = startCodePoint;
+ int vNumChars = v.numChars();
+ do {
+ if (fromIndexEnd - v.numBytes + 1 < 0 ) {
+ return -1;
+ }
+ if (ByteArrayMethods.arrayEquals(
+ base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) {
+ return count - vNumChars + 1;
+ }
+ fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
+ count--;
+ } while (fromIndexEnd >= 0);
+ return -1;
+ }
+
/**
* Returns str, right-padded with pad to a length of len
* For example:
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index e2a5628ff4d93..bee232b11a11e 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -221,6 +221,22 @@ public void indexOf() {
assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0));
}
+ @Test
+ public void lastIndexOf() {
+ assertEquals(0, fromString("").lastIndexOf(fromString(""), 0));
+ assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0));
+ assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0));
+ assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0));
+ assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3));
+ assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4));
+ assertEquals(2, fromString("hello").lastIndexOf(fromString("ll"), 4));
+ assertEquals(-1, fromString("hello").lastIndexOf(fromString("ll"), 0));
+ assertEquals(5, fromString("数据砖头数据砖头").lastIndexOf(fromString("据砖"), 7));
+ assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 3));
+ assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 0));
+ assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3));
+ }
+
@Test
public void reverse() {
assertEquals(fromString("olleh"), fromString("hello").reverse());
From 12e108f13cfe4bc6ef967adb6f42891ed01aa521 Mon Sep 17 00:00:00 2001
From: "zhichao.li"
Date: Wed, 22 Jul 2015 09:21:05 +0800
Subject: [PATCH 3/8] refine unittest
---
.../expressions/stringOperations.scala | 2 +-
.../spark/sql/StringFunctionsSuite.scala | 61 +++++++++++++------
.../apache/spark/unsafe/types/UTF8String.java | 4 +-
3 files changed, 47 insertions(+), 20 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 6c20677c2b0d7..35ec2c991a94b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -434,7 +434,7 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
} else {
val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count)
if (idx != -1) {
- strUtf8.substring(idx + 1, strUtf8.numChars())
+ strUtf8.substring(idx + delimUtf8.numChars(), strUtf8.numChars())
} else {
strUtf8
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 9a0eb16ec700b..08ea5664d212c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -157,33 +157,60 @@ class StringFunctionsSuite extends QueryTest {
}
test("string substring_index function") {
- val df = Seq(("ac,ab,ad,ab,cc", "aa", "zz")).toDF("a", "b", "c")
+ val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c")
checkAnswer(
- df.select(substring_index($"a", ",", 2)),
- Row("ac,ab"))
+ df.select(substring_index($"a", ".", 3)),
+ Row("www.apache.org"))
checkAnswer(
- df.select(substring_index($"a", "ab", 2)),
- Row("ac,ab,ad,"))
+ df.select(substring_index($"a", ".", 2)),
+ Row("www.apache"))
checkAnswer(
- df.select(substring_index(lit(""), "ab", 2)),
+ df.select(substring_index($"a", ".", 1)),
+ Row("www"))
+ checkAnswer(
+ df.select(substring_index($"a", ".", 0)),
+ Row(""))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), ".", -1)),
+ Row("org"))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), ".", -2)),
+ Row("apache.org"))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), ".", -3)),
+ Row("www.apache.org"))
+ // str is empty string
+ checkAnswer(
+ df.select(substring_index(lit(""), ".", 1)),
+ Row(""))
+ // empty string delim
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), "", 1)),
Row(""))
+ // delim does not exist in str
checkAnswer(
- df.select(substring_index(lit(null), "ab", 2)),
+ df.select(substring_index(lit("www.apache.org"), "#", 1)),
+ Row("www.apache.org"))
+ // delim is 2 chars
+ checkAnswer(
+ df.select(substring_index(lit("www||apache||org"), "||", 2)),
+ Row("www||apache"))
+ checkAnswer(
+ df.select(substring_index(lit("www||apache||org"), "||", -2)),
+ Row("apache||org"))
+ // null
+ checkAnswer(
+ df.select(substring_index(lit(null), "||", 2)),
+ Row(null))
+ checkAnswer(
+ df.select(substring_index(lit("www.apache.org"), null, 2)),
Row(null))
+ // non ascii chars
// scalastyle:off
checkAnswer(
- df.select(substring_index(lit("大千世界大千世界"), "千", 2)),
+ df.selectExpr("""substring_index("大千世界大千世界", "千", 2)"""),
Row("大千世界大"))
// scalastyle:on
- checkAnswer(
- df.selectExpr("""substring_index(a, ",", 2)"""),
- Row("ac,ab"))
- checkAnswer(
- df.selectExpr("""substring_index(a, ",", -2)"""),
- Row("ab,cc"))
- checkAnswer(
- df.selectExpr("""substring_index(a, ",", 10)"""),
- Row("ac,ab,ad,ab,cc"))
}
test("string locate function") {
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index f3d14178ce3cc..78d767ee4de12 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -381,7 +381,7 @@ private int firstOfCurrentCodePoint(int bytePos) {
throw new RuntimeException("Invalid utf8 string");
}
- private int endByte(int startCodePoint) {
+ private int indexEnd(int startCodePoint) {
int i = numBytes -1; // position in byte
int c = numChars() - 1; // position in character
while (i >=0 && c > startCodePoint) {
@@ -398,7 +398,7 @@ public int lastIndexOf(UTF8String v, int startCodePoint) {
if (numBytes == 0) {
return -1;
}
- int fromIndexEnd = endByte(startCodePoint);
+ int fromIndexEnd = indexEnd(startCodePoint);
int count = startCodePoint;
int vNumChars = v.numChars();
do {
From ac863e9048e1aedf587aaf5e90d09b91bbf8ec25 Mon Sep 17 00:00:00 2001
From: "zhichao.li"
Date: Wed, 22 Jul 2015 15:52:56 +0800
Subject: [PATCH 4/8] reduce the calling of numChars
---
.../expressions/stringOperations.scala | 78 ++--------
.../apache/spark/unsafe/types/UTF8String.java | 140 +++++++++++++++++-
.../spark/unsafe/types/UTF8StringSuite.java | 1 +
3 files changed, 148 insertions(+), 71 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 35ec2c991a94b..2d921124c38ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -21,10 +21,7 @@ import java.text.DecimalFormat
import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
-import org.apache.commons.lang.StringUtils
-
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -371,75 +368,22 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable
override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
override def prettyName: String = "substring_index"
- override def toString: String = s"substring_index($strExpr, $delimExpr, $countExpr)"
override def eval(input: InternalRow): Any = {
val str = strExpr.eval(input)
- val delim = delimExpr.eval(input)
- val count = countExpr.eval(input)
- if (str == null || delim == null || count == null) {
- null
- } else {
- subStrIndex(
- str.asInstanceOf[UTF8String],
- delim.asInstanceOf[UTF8String],
- count.asInstanceOf[Int])
- }
- }
-
- private def lastOrdinalIndexOf(
- str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
- ordinalIndexOf(str, searchStr, ordinal, true)
- }
-
- private def ordinalIndexOf(
- str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
- if (str == null || searchStr == null || ordinal <= 0) {
- return -1
- }
- val strNumChars = str.numChars()
- if (searchStr.numBytes() == 0) {
- return if (lastIndex) {strNumChars} else {0}
- }
- var found = 0
- var index = if (lastIndex) {strNumChars} else {0}
- do {
- if (lastIndex) {
- index = str.lastIndexOf(searchStr, index - 1)
- } else {
- index = str.indexOf(searchStr, index + 1)
- }
- if (index < 0) {
- return index
- }
- found += 1
- } while (found < ordinal)
- index
- }
-
- private def subStrIndex(strUtf8: UTF8String, delimUtf8: UTF8String, count: Int): UTF8String = {
- if (strUtf8 == null || delimUtf8 == null || count == null) {
- return null
- }
- if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) {
- return UTF8String.fromString("")
- }
- val res = if (count > 0) {
- val idx = ordinalIndexOf(strUtf8, delimUtf8, count)
- if (idx != -1) {
- strUtf8.substring(0, idx)
- } else {
- strUtf8
- }
- } else {
- val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count)
- if (idx != -1) {
- strUtf8.substring(idx + delimUtf8.numChars(), strUtf8.numChars())
- } else {
- strUtf8
+ if (str != null) {
+ val delim = delimExpr.eval(input)
+ if (delim != null) {
+ val count = countExpr.eval(input)
+ if (count != null) {
+ return UTF8String.subStringIndex(
+ str.asInstanceOf[UTF8String],
+ delim.asInstanceOf[UTF8String],
+ count.asInstanceOf[Int])
+ }
}
}
- res
+ null
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 78d767ee4de12..11e34f95bce8a 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -165,6 +165,27 @@ public UTF8String substring(final int start, final int until) {
return fromBytes(bytes);
}
+ /**
+ * Returns a substring of this from start to end.
+ * @param start the position of first code point
+ */
+ public UTF8String substring(final int start) {
+ if (start >= numBytes) {
+ return fromBytes(new byte[0]);
+ }
+
+ int i = 0;
+ int c = 0;
+ while (i < numBytes && c < start) {
+ i += numBytesForFirstByte(getByte(i));
+ c += 1;
+ }
+
+ byte[] bytes = new byte[numBytes - i];
+ copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i);
+ return fromBytes(bytes);
+ }
+
public UTF8String substringSQL(int pos, int length) {
// Information regarding the pos calculation:
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
@@ -391,7 +412,19 @@ private int indexEnd(int startCodePoint) {
return i;
}
+ /**
+ * Returns the index within this string of the last occurrence of the
+ * specified substring, searching backward starting at the specified index.
+ * @param v the substring to search for.
+ * @param startCodePoint the index to start search from
+ * @return the index of the last occurrence of the specified substring,
+ * searching backward from the specified index,
+ * or {@code -1} if there is no such occurrence.
+ */
public int lastIndexOf(UTF8String v, int startCodePoint) {
+ return lastIndexOf(v, v.numChars(), startCodePoint);
+ }
+ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
if (v.numBytes == 0) {
return 0;
}
@@ -399,22 +432,121 @@ public int lastIndexOf(UTF8String v, int startCodePoint) {
return -1;
}
int fromIndexEnd = indexEnd(startCodePoint);
- int count = startCodePoint;
- int vNumChars = v.numChars();
do {
if (fromIndexEnd - v.numBytes + 1 < 0 ) {
return -1;
}
if (ByteArrayMethods.arrayEquals(
base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) {
- return count - vNumChars + 1;
+ int count = 0; // count from right most to the match end in byte.
+ while (fromIndexEnd >= 0) {
+ count++;
+ fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
+ }
+ return count - vNumChars;
}
fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
- count--;
} while (fromIndexEnd >= 0);
return -1;
}
+ /**
+ * Finds the n-th last index within a String.
+ * This method uses {@link String#lastIndexOf(String)}.
+ *
+ * @param str the String to check, may be null
+ * @param searchStr the String to find, may be null
+ * @param searchStrNumChars num of code ponts of the searchStr
+ * @param ordinal the n-th last searchStr
to find
+ * @return the n-th last index of the search String,
+ * -1
if no match or null
string input
+ */
+ public static int lastOrdinalIndexOf(
+ UTF8String str,
+ UTF8String searchStr,
+ int searchStrNumChars,
+ int ordinal) {
+ return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, true);
+ }
+ /**
+ * Finds the n-th index within a String, handling null
.
+ * A null
String will return -1
+ *
+ * @param str the String to check, may be null
+ * @param searchStr the String to find, may be null
+ * @param searchStrNumChars num of code points of searchStr
+ * @param ordinal the n-th searchStr
to find
+ * @return the n-th index of the search String,
+ * -1
if no match or null
string input
+ */
+ public static int ordinalIndexOf(
+ UTF8String str,
+ UTF8String searchStr,
+ int searchStrNumChars,
+ int ordinal) {
+ return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, false);
+ }
+
+ private static int doOrdinalIndexOf(
+ UTF8String str,
+ UTF8String searchStr,
+ int searchStrNumChars,
+ int ordinal,
+ boolean lastIndex) {
+ if (str == null || searchStr == null || ordinal <= 0) {
+ return -1;
+ }
+ // Only calc numChars when lastIndex == true sicnc the calculation is expensive
+ int strNumChars = 0;
+ if (lastIndex) {
+ strNumChars = str.numChars();
+ }
+ if (searchStr.numBytes == 0) {
+ return lastIndex ? strNumChars : 0;
+ }
+ int found = 0;
+ int index = lastIndex ? strNumChars : 0;
+ do {
+ if (lastIndex) {
+ index = str.lastIndexOf(searchStr, searchStrNumChars, index - 1);
+ } else {
+ index = str.indexOf(searchStr, index + 1);
+ }
+ if (index < 0) {
+ return index;
+ }
+ found += 1;
+ } while (found < ordinal);
+ return index;
+ }
+ /**
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
+ */
+ public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) {
+ if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) {
+ return UTF8String.EMPTY_UTF8;
+ }
+ int delimNumChars = delim.numChars();
+ if (count > 0) {
+ int idx = ordinalIndexOf(str, delim, delimNumChars, count);
+ if (idx != -1) {
+ return str.substring(0, idx);
+ } else {
+ return str;
+ }
+ } else {
+ int idx = lastOrdinalIndexOf(str, delim, delimNumChars, -count);
+ if (idx != -1) {
+ return str.substring(idx + delimNumChars);
+ } else {
+ return str;
+ }
+ }
+ }
+
/**
* Returns str, right-padded with pad to a length of len
* For example:
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index bee232b11a11e..df69d8e655ead 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -226,6 +226,7 @@ public void lastIndexOf() {
assertEquals(0, fromString("").lastIndexOf(fromString(""), 0));
assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0));
assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0));
+ assertEquals(0, fromString("hello").lastIndexOf(fromString("h"), 4));
assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0));
assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3));
assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4));
From b19b013d382ebe3e64a9a7a6708b94737c92ecdd Mon Sep 17 00:00:00 2001
From: "zhichao.li"
Date: Thu, 23 Jul 2015 15:22:19 +0800
Subject: [PATCH 5/8] add codegen and clean code
---
.../expressions/stringOperations.scala | 26 ++++-
.../expressions/StringExpressionsSuite.scala | 31 ++++++
.../org/apache/spark/sql/functions.scala | 2 -
.../apache/spark/unsafe/types/UTF8String.java | 80 +++++++--------
.../spark/unsafe/types/UTF8StringSuite.java | 99 +++++++++++++++++++
5 files changed, 192 insertions(+), 46 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 2d921124c38ea..39329722f2d7d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -361,7 +361,7 @@ case class StringInstr(str: Expression, substr: Expression)
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
- extends Expression with ImplicitCastInputTypes with CodegenFallback {
+ extends Expression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -385,6 +385,30 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
}
null
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val str = strExpr.gen(ctx)
+ val delim = delimExpr.gen(ctx)
+ val count = countExpr.gen(ctx)
+ val resultCode =
+ s"""org.apache.spark.unsafe.types.UTF8String.subStringIndex(
+ |${str.primitive}, ${delim.primitive}, ${count.primitive})""".stripMargin
+ s"""
+ ${str.code}
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${str.isNull}) {
+ ${delim.code}
+ if (!${delim.isNull}) {
+ ${count.code}
+ if (!${count.isNull}) {
+ ${ev.isNull} = false;
+ ${ev.primitive} = $resultCode;
+ }
+ }
+ }
+ """
+ }
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 3c2d88731beb4..abccef6196d5c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(s.substring(0), "example", row)
}
+ test("string substring_index function") {
+ checkEvaluation(
+ Substring_index(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
+ checkEvaluation(
+ Substring_index(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
+ checkEvaluation(
+ Substring_index(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
+ checkEvaluation(
+ Substring_index(Literal("www.apache.org"), Literal("."), Literal(0)), "")
+ checkEvaluation(
+ Substring_index(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
+ checkEvaluation(
+ Substring_index(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
+ checkEvaluation(
+ Substring_index(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
+ checkEvaluation(
+ Substring_index(Literal(""), Literal("."), Literal(-2)), "")
+ checkEvaluation(
+ Substring_index(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
+ checkEvaluation(
+ Substring_index(Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
+ // non ascii chars
+ // scalastyle:off
+ checkEvaluation(
+ Substring_index(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大")
+ // scalastyle:on
+ checkEvaluation(
+ Substring_index(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
+ }
+
test("LIKE literal Regular Expression") {
checkEvaluation(Literal.create(null, StringType).like("a"), null)
checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index faa82ad19fd05..0aef3db89904d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1783,7 +1783,6 @@ object functions {
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*
* @group string_funcs
- * @since 1.5.0
*/
def substring_index(str: String, delim: String, count: Int): Column =
substring_index(Column(str), delim, count)
@@ -1795,7 +1794,6 @@ object functions {
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*
* @group string_funcs
- * @since 1.5.0
*/
def substring_index(str: Column, delim: String, count: Int): Column =
Substring_index(str.expr, lit(delim).expr, lit(count).expr)
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 11e34f95bce8a..946bd3ea57a3b 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -146,20 +146,8 @@ public UTF8String substring(final int start, final int until) {
if (until <= start || start >= numBytes) {
return fromBytes(new byte[0]);
}
-
- int i = 0;
- int c = 0;
- while (i < numBytes && c < start) {
- i += numBytesForFirstByte(getByte(i));
- c += 1;
- }
-
- int j = i;
- while (i < numBytes && c < until) {
- i += numBytesForFirstByte(getByte(i));
- c += 1;
- }
-
+ int j = firstByteIndex(start);
+ int i = firstByteIndex(until);
byte[] bytes = new byte[i - j];
copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j);
return fromBytes(bytes);
@@ -174,12 +162,7 @@ public UTF8String substring(final int start) {
return fromBytes(new byte[0]);
}
- int i = 0;
- int c = 0;
- while (i < numBytes && c < start) {
- i += numBytesForFirstByte(getByte(i));
- c += 1;
- }
+ int i = firstByteIndex(start);
byte[] bytes = new byte[numBytes - i];
copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i);
@@ -351,13 +334,8 @@ public int indexOf(UTF8String v, int start) {
return 0;
}
- // locate to the start position.
- int i = 0; // position in byte
- int c = 0; // position in character
- while (i < numBytes && c < start) {
- i += numBytesForFirstByte(getByte(i));
- c += 1;
- }
+ int i = firstByteIndex(start); // position in byte
+ int c = start; // position in character
do {
if (i + v.numBytes > numBytes) {
@@ -399,19 +377,29 @@ private int firstOfCurrentCodePoint(int bytePos) {
}
bytePos--;
}
- throw new RuntimeException("Invalid utf8 string");
+ throw new RuntimeException("Invalid UTF8 string");
}
- private int indexEnd(int startCodePoint) {
- int i = numBytes -1; // position in byte
- int c = numChars() - 1; // position in character
- while (i >=0 && c > startCodePoint) {
- i = firstOfCurrentCodePoint(i) - 1;
- c -= 1;
+ // Locate to the start position in byte for a given code point
+ private int firstByteIndex(int codePoint) {
+ int i = 0; // position in byte
+ int c = 0; // position in character
+ while (i < numBytes && c < codePoint) {
+ i += numBytesForFirstByte(getByte(i));
+ c += 1;
+ }
+ if (i > numBytes) {
+ throw new StringIndexOutOfBoundsException(codePoint);
}
return i;
}
+ // Locate to the last position in byte for a given code point
+ private int lastByteIndex(int codePoint) {
+ int i = firstByteIndex(codePoint);
+ return i + numBytesForFirstByte(getByte(i)) - 1;
+ }
+
/**
* Returns the index within this string of the last occurrence of the
* specified substring, searching backward starting at the specified index.
@@ -431,7 +419,7 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
if (numBytes == 0) {
return -1;
}
- int fromIndexEnd = indexEnd(startCodePoint);
+ int fromIndexEnd = lastByteIndex(startCodePoint);
do {
if (fromIndexEnd - v.numBytes + 1 < 0 ) {
return -1;
@@ -456,7 +444,6 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
*
* @param str the String to check, may be null
* @param searchStr the String to find, may be null
- * @param searchStrNumChars num of code ponts of the searchStr
* @param ordinal the n-th last searchStr
to find
* @return the n-th last index of the search String,
* -1
if no match or null
string input
@@ -464,17 +451,19 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
public static int lastOrdinalIndexOf(
UTF8String str,
UTF8String searchStr,
- int searchStrNumChars,
int ordinal) {
- return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, true);
+ if (str == null || searchStr == null) {
+ return -1;
+ }
+ return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, true);
}
+
/**
* Finds the n-th index within a String, handling null
.
* A null
String will return -1
*
* @param str the String to check, may be null
* @param searchStr the String to find, may be null
- * @param searchStrNumChars num of code points of searchStr
* @param ordinal the n-th searchStr
to find
* @return the n-th index of the search String,
* -1
if no match or null
string input
@@ -482,9 +471,11 @@ public static int lastOrdinalIndexOf(
public static int ordinalIndexOf(
UTF8String str,
UTF8String searchStr,
- int searchStrNumChars,
int ordinal) {
- return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, false);
+ if (str == null || searchStr == null) {
+ return -1;
+ }
+ return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, false);
}
private static int doOrdinalIndexOf(
@@ -526,19 +517,22 @@ private static int doOrdinalIndexOf(
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) {
+ if (str == null || delim == null) {
+ return null;
+ }
if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) {
return UTF8String.EMPTY_UTF8;
}
int delimNumChars = delim.numChars();
if (count > 0) {
- int idx = ordinalIndexOf(str, delim, delimNumChars, count);
+ int idx = doOrdinalIndexOf(str, delim, delimNumChars, count, false);
if (idx != -1) {
return str.substring(0, idx);
} else {
return str;
}
} else {
- int idx = lastOrdinalIndexOf(str, delim, delimNumChars, -count);
+ int idx = doOrdinalIndexOf(str, delim, delimNumChars, -count, true);
if (idx != -1) {
return str.substring(idx + delimNumChars);
} else {
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index df69d8e655ead..31ea799a98282 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -21,10 +21,12 @@
import java.util.Arrays;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
import static junit.framework.Assert.*;
import static org.apache.spark.unsafe.types.UTF8String.*;
+import static org.apache.spark.unsafe.types.UTF8String.fromString;
public class UTF8StringSuite {
@@ -184,6 +186,63 @@ public void substring() {
assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3));
assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5));
assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2));
+
+ assertEquals(fromString("hello"), fromString("hello").substring(0));
+ assertEquals(fromString("ello"), fromString("hello").substring(1));
+ assertEquals(fromString("砖头"), fromString("数据砖头").substring(2));
+ assertEquals(fromString("头"), fromString("数据砖头").substring(3));
+ ExpectedException exception = ExpectedException.none();
+ fromString("数据砖头").substring(4);
+ exception.expect(java.lang.StringIndexOutOfBoundsException.class);
+ assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0));
+ }
+
+ @Test
+ public void ordinalIndexOf() {
+ assertEquals(-1,
+ UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 0));
+ assertEquals(3,
+ UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 1));
+ assertEquals(10,
+ UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 2));
+ assertEquals(-1,
+ UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 3));
+ assertEquals(-1,
+ UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("#"), 0));
+ assertEquals(12,
+ UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2));
+ assertEquals(-1,
+ UTF8String.ordinalIndexOf(null, fromString("|||"), 1));
+ assertEquals(-1,
+ UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), null, 1));
+ assertEquals(2,
+ UTF8String.ordinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1));
+ assertEquals(-1,
+ UTF8String.ordinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2));
+ }
+
+ @Test
+ public void lastOrdinalIndexOf() {
+ assertEquals(-1,
+ UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 0));
+ assertEquals(10,
+ UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 1));
+ assertEquals(3,
+ UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 2));
+ assertEquals(-1,
+ UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 3));
+ assertEquals(-1,
+ UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("#"), 0));
+ assertEquals(3,
+ UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2));
+ assertEquals(-1,
+ UTF8String.lastOrdinalIndexOf(null, fromString("|||"), 1));
+ assertEquals(-1,
+ UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), null, 1));
+ assertEquals(3,
+ UTF8String.lastOrdinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1));
+ assertEquals(-1,
+ UTF8String.lastOrdinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2));
}
@Test
@@ -238,6 +297,46 @@ public void lastIndexOf() {
assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3));
}
+ @Test
+ public void substring_index() {
+ assertEquals(fromString("www.apache.org"),
+ UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 3));
+ assertEquals(fromString("www.apache"),
+ UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 2));
+ assertEquals(fromString("www"),
+ UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 1));
+ assertEquals(fromString(""),
+ UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 0));
+ assertEquals(fromString("org"),
+ UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -1));
+ assertEquals(fromString("apache.org"),
+ UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -2));
+ assertEquals(fromString("www.apache.org"),
+ UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -3));
+ // str is empty string
+ assertEquals(fromString(""),
+ UTF8String.subStringIndex(fromString(""), fromString("."), 1));
+ // empty string delim
+ assertEquals(fromString(""),
+ UTF8String.subStringIndex(fromString("www.apache.org"), fromString(""), 1));
+ // delim does not exist in str
+ assertEquals(fromString("www.apache.org"),
+ UTF8String.subStringIndex(fromString("www.apache.org"), fromString("#"), 2));
+ // delim is 2 chars
+ assertEquals(fromString("www||apache"),
+ UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), 2));
+ assertEquals(fromString("apache||org"),
+ UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), -2));
+ // null
+ assertEquals(null,
+ UTF8String.subStringIndex(null, fromString("."), -2));
+ assertEquals(null,
+ UTF8String.subStringIndex(fromString("www.apache.org"), null, -2));
+ // non ascii chars
+ assertEquals(fromString("大千世界大"),
+ UTF8String.subStringIndex(fromString("大千世界大千世界"), fromString("千"), 2));
+ }
+
@Test
public void reverse() {
assertEquals(fromString("olleh"), fromString("hello").reverse());
From 67c253a67b0c2ddf6a3e59881e8d42d0326a0bf9 Mon Sep 17 00:00:00 2001
From: "zhichao.li"
Date: Fri, 24 Jul 2015 10:30:17 +0800
Subject: [PATCH 6/8] hide some apis and clean code
---
.../expressions/stringOperations.scala | 7 +-
.../org/apache/spark/sql/functions.scala | 11 --
.../apache/spark/unsafe/types/UTF8String.java | 106 ++++++++----------
.../spark/unsafe/types/UTF8StringSuite.java | 74 ++++++------
4 files changed, 80 insertions(+), 118 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 39329722f2d7d..18877769bd212 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -376,8 +376,7 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
if (delim != null) {
val count = countExpr.eval(input)
if (count != null) {
- return UTF8String.subStringIndex(
- str.asInstanceOf[UTF8String],
+ return str.asInstanceOf[UTF8String].subStringIndex(
delim.asInstanceOf[UTF8String],
count.asInstanceOf[Int])
}
@@ -390,9 +389,7 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
val str = strExpr.gen(ctx)
val delim = delimExpr.gen(ctx)
val count = countExpr.gen(ctx)
- val resultCode =
- s"""org.apache.spark.unsafe.types.UTF8String.subStringIndex(
- |${str.primitive}, ${delim.primitive}, ${count.primitive})""".stripMargin
+ val resultCode = s"${str.primitive}.subStringIndex(${delim.primitive}, ${count.primitive})"
s"""
${str.code}
boolean ${ev.isNull} = true;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0aef3db89904d..b092ad047da50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1776,17 +1776,6 @@ object functions {
*/
def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr)
- /**
- * Returns the substring from string str before count occurrences of the delimiter delim.
- * If count is positive, everything the left of the final delimiter (counting from left) is
- * returned. If count is negative, every to the right of the final delimiter (counting from the
- * right) is returned. substring_index performs a case-sensitive match when searching for delim.
- *
- * @group string_funcs
- */
- def substring_index(str: String, delim: String, count: Int): Column =
- substring_index(Column(str), delim, count)
-
/**
* Returns the substring from string str before count occurrences of the delimiter delim.
* If count is positive, everything the left of the final delimiter (counting from left) is
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 946bd3ea57a3b..e84509f18e146 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -146,8 +146,8 @@ public UTF8String substring(final int start, final int until) {
if (until <= start || start >= numBytes) {
return fromBytes(new byte[0]);
}
- int j = firstByteIndex(start);
- int i = firstByteIndex(until);
+ int j = firstByteIndex(0, 0, start);
+ int i = firstByteIndex(j, start, until);
byte[] bytes = new byte[i - j];
copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j);
return fromBytes(bytes);
@@ -162,7 +162,7 @@ public UTF8String substring(final int start) {
return fromBytes(new byte[0]);
}
- int i = firstByteIndex(start);
+ int i = firstByteIndex(0, 0, start);
byte[] bytes = new byte[numBytes - i];
copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i);
@@ -334,7 +334,7 @@ public int indexOf(UTF8String v, int start) {
return 0;
}
- int i = firstByteIndex(start); // position in byte
+ int i = firstByteIndex(0, 0, start); // position in byte
int c = start; // position in character
do {
@@ -353,7 +353,7 @@ public int indexOf(UTF8String v, int start) {
private enum ByteType {FIRSTBYTE, MIDBYTE, SINGLEBYTECHAR};
- private ByteType checkByteType(Byte b) {
+ private ByteType checkByteType(byte b) {
int firstTwoBits = (b >>> 6) & 0x03;
if (firstTwoBits == 3) {
return ByteType.FIRSTBYTE;
@@ -371,19 +371,19 @@ private ByteType checkByteType(Byte b) {
*/
private int firstOfCurrentCodePoint(int bytePos) {
while (bytePos >= 0) {
- if (ByteType.FIRSTBYTE == checkByteType(getByte(bytePos))
- || ByteType.SINGLEBYTECHAR == checkByteType(getByte(bytePos))) {
+ ByteType byteType = checkByteType(getByte(bytePos));
+ if (ByteType.FIRSTBYTE == byteType || ByteType.SINGLEBYTECHAR == byteType) {
return bytePos;
}
bytePos--;
}
- throw new RuntimeException("Invalid UTF8 string");
+ throw new RuntimeException("Invalid UTF8 string: " + toString());
}
// Locate to the start position in byte for a given code point
- private int firstByteIndex(int codePoint) {
- int i = 0; // position in byte
- int c = 0; // position in character
+ private int firstByteIndex(int startByteIndex, int startPointIndex, int codePoint) {
+ int i = startByteIndex; // position in byte
+ int c = startPointIndex; // position in character
while (i < numBytes && c < codePoint) {
i += numBytesForFirstByte(getByte(i));
c += 1;
@@ -396,7 +396,7 @@ private int firstByteIndex(int codePoint) {
// Locate to the last position in byte for a given code point
private int lastByteIndex(int codePoint) {
- int i = firstByteIndex(codePoint);
+ int i = firstByteIndex(0, 0, codePoint);
return i + numBytesForFirstByte(getByte(i)) - 1;
}
@@ -410,31 +410,33 @@ private int lastByteIndex(int codePoint) {
* or {@code -1} if there is no such occurrence.
*/
public int lastIndexOf(UTF8String v, int startCodePoint) {
- return lastIndexOf(v, v.numChars(), startCodePoint);
- }
- public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
+ // Empty string always match
if (v.numBytes == 0) {
- return 0;
+ return startCodePoint;
}
+ return lastIndexOfInByte(v, lastByteIndex(startCodePoint));
+ }
+
+ private int lastIndexOfInByte(UTF8String v, int fromIndexInByte) {
if (numBytes == 0) {
return -1;
}
- int fromIndexEnd = lastByteIndex(startCodePoint);
do {
- if (fromIndexEnd - v.numBytes + 1 < 0 ) {
+ int startByteIndex = fromIndexInByte - v.numBytes + 1;
+ if (startByteIndex < 0 ) {
return -1;
}
if (ByteArrayMethods.arrayEquals(
- base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) {
+ base, offset + startByteIndex, v.base, v.offset, v.numBytes)) {
int count = 0; // count from right most to the match end in byte.
- while (fromIndexEnd >= 0) {
+ while (startByteIndex >= 0) {
count++;
- fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
+ startByteIndex = firstOfCurrentCodePoint(startByteIndex) - 1;
}
- return count - vNumChars;
+ return count - 1;
}
- fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
- } while (fromIndexEnd >= 0);
+ fromIndexInByte = firstOfCurrentCodePoint(fromIndexInByte) - 1;
+ } while (fromIndexInByte >= 0);
return -1;
}
@@ -442,66 +444,49 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
* Finds the n-th last index within a String.
* This method uses {@link String#lastIndexOf(String)}.
*
- * @param str the String to check, may be null
* @param searchStr the String to find, may be null
* @param ordinal the n-th last searchStr
to find
* @return the n-th last index of the search String,
* -1
if no match or null
string input
*/
- public static int lastOrdinalIndexOf(
- UTF8String str,
+ protected int lastOrdinalIndexOf(
UTF8String searchStr,
int ordinal) {
- if (str == null || searchStr == null) {
- return -1;
- }
- return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, true);
+ return doOrdinalIndexOf(searchStr, ordinal, true);
}
/**
* Finds the n-th index within a String, handling null
.
* A null
String will return -1
*
- * @param str the String to check, may be null
* @param searchStr the String to find, may be null
* @param ordinal the n-th searchStr
to find
* @return the n-th index of the search String,
* -1
if no match or null
string input
*/
- public static int ordinalIndexOf(
- UTF8String str,
+ protected int ordinalIndexOf(
UTF8String searchStr,
int ordinal) {
- if (str == null || searchStr == null) {
- return -1;
- }
- return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, false);
+ return doOrdinalIndexOf(searchStr, ordinal, false);
}
- private static int doOrdinalIndexOf(
- UTF8String str,
+ private int doOrdinalIndexOf(
UTF8String searchStr,
- int searchStrNumChars,
int ordinal,
boolean lastIndex) {
- if (str == null || searchStr == null || ordinal <= 0) {
+ if (ordinal <= 0) {
return -1;
}
- // Only calc numChars when lastIndex == true sicnc the calculation is expensive
- int strNumChars = 0;
- if (lastIndex) {
- strNumChars = str.numChars();
- }
if (searchStr.numBytes == 0) {
- return lastIndex ? strNumChars : 0;
+ return lastIndex ? numChars() : 0;
}
int found = 0;
- int index = lastIndex ? strNumChars : 0;
+ int index = lastIndex ? numBytes : -1;
do {
if (lastIndex) {
- index = str.lastIndexOf(searchStr, searchStrNumChars, index - 1);
+ index = lastIndexOfInByte(searchStr, index - 1);
} else {
- index = str.indexOf(searchStr, index + 1);
+ index = indexOf(searchStr, index + 1);
}
if (index < 0) {
return index;
@@ -516,27 +501,26 @@ private static int doOrdinalIndexOf(
* returned. If count is negative, every to the right of the final delimiter (counting from the
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
- public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) {
- if (str == null || delim == null) {
+ public UTF8String subStringIndex(UTF8String delim, int count) {
+ if (delim == null) {
return null;
}
- if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) {
+ if (delim.numBytes == 0 || count == 0) {
return UTF8String.EMPTY_UTF8;
}
- int delimNumChars = delim.numChars();
if (count > 0) {
- int idx = doOrdinalIndexOf(str, delim, delimNumChars, count, false);
+ int idx = ordinalIndexOf(delim, count);
if (idx != -1) {
- return str.substring(0, idx);
+ return substring(0, idx);
} else {
- return str;
+ return this;
}
} else {
- int idx = doOrdinalIndexOf(str, delim, delimNumChars, -count, true);
+ int idx = lastOrdinalIndexOf(delim, -count);
if (idx != -1) {
- return str.substring(idx + delimNumChars);
+ return substring(idx + delim.numChars());
} else {
- return str;
+ return this;
}
}
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 31ea799a98282..6386c9902885d 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -200,49 +200,43 @@ public void substring() {
@Test
public void ordinalIndexOf() {
assertEquals(-1,
- UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 0));
+ fromString("www.apache.org").ordinalIndexOf(fromString("."), 0));
+ assertEquals(0,
+ fromString("www.apache.org").ordinalIndexOf(fromString("w"), 1));
assertEquals(3,
- UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 1));
+ fromString("www.apache.org").ordinalIndexOf(fromString("."), 1));
assertEquals(10,
- UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 2));
+ fromString("www.apache.org").ordinalIndexOf(fromString("."), 2));
assertEquals(-1,
- UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 3));
+ fromString("www.apache.org").ordinalIndexOf(fromString("."), 3));
assertEquals(-1,
- UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("#"), 0));
+ fromString("www.apache.org").ordinalIndexOf(fromString("#"), 0));
assertEquals(12,
- UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2));
- assertEquals(-1,
- UTF8String.ordinalIndexOf(null, fromString("|||"), 1));
- assertEquals(-1,
- UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), null, 1));
+ fromString("www|||apache|||org").ordinalIndexOf(fromString("|||"), 2));
assertEquals(2,
- UTF8String.ordinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1));
+ fromString("数据砖砖头").ordinalIndexOf(fromString("砖"), 1));
assertEquals(-1,
- UTF8String.ordinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2));
+ fromString("砖头数据砖头").ordinalIndexOf(fromString("砖"), -2));
}
@Test
public void lastOrdinalIndexOf() {
assertEquals(-1,
- UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 0));
+ fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 0));
assertEquals(10,
- UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 1));
+ fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 1));
assertEquals(3,
- UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 2));
+ fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 2));
assertEquals(-1,
- UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 3));
+ fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 3));
assertEquals(-1,
- UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("#"), 0));
+ fromString("www.apache.org").lastOrdinalIndexOf(fromString("#"), 0));
assertEquals(3,
- UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2));
- assertEquals(-1,
- UTF8String.lastOrdinalIndexOf(null, fromString("|||"), 1));
- assertEquals(-1,
- UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), null, 1));
+ fromString("www|||apache|||org").lastOrdinalIndexOf(fromString("|||"), 2));
assertEquals(3,
- UTF8String.lastOrdinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1));
+ fromString("数据砖砖头").lastOrdinalIndexOf(fromString("砖"), 1));
assertEquals(-1,
- UTF8String.lastOrdinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2));
+ fromString("砖头数据砖头").lastOrdinalIndexOf(fromString("砖"), -2));
}
@Test
@@ -282,7 +276,7 @@ public void indexOf() {
@Test
public void lastIndexOf() {
- assertEquals(0, fromString("").lastIndexOf(fromString(""), 0));
+ assertEquals(1, fromString("hello").lastIndexOf(fromString(""), 1));
assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0));
assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0));
assertEquals(0, fromString("hello").lastIndexOf(fromString("h"), 4));
@@ -300,41 +294,39 @@ public void lastIndexOf() {
@Test
public void substring_index() {
assertEquals(fromString("www.apache.org"),
- UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 3));
+ fromString("www.apache.org").subStringIndex(fromString("."), 3));
assertEquals(fromString("www.apache"),
- UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 2));
+ fromString("www.apache.org").subStringIndex(fromString("."), 2));
assertEquals(fromString("www"),
- UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 1));
+ fromString("www.apache.org").subStringIndex(fromString("."), 1));
assertEquals(fromString(""),
- UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 0));
+ fromString("www.apache.org").subStringIndex(fromString("."), 0));
assertEquals(fromString("org"),
- UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -1));
+ fromString("www.apache.org").subStringIndex(fromString("."), -1));
assertEquals(fromString("apache.org"),
- UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -2));
+ fromString("www.apache.org").subStringIndex(fromString("."), -2));
assertEquals(fromString("www.apache.org"),
- UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -3));
+ fromString("www.apache.org").subStringIndex(fromString("."), -3));
// str is empty string
assertEquals(fromString(""),
- UTF8String.subStringIndex(fromString(""), fromString("."), 1));
+ fromString("").subStringIndex(fromString("."), 1));
// empty string delim
assertEquals(fromString(""),
- UTF8String.subStringIndex(fromString("www.apache.org"), fromString(""), 1));
+ fromString("www.apache.org").subStringIndex(fromString(""), 1));
// delim does not exist in str
assertEquals(fromString("www.apache.org"),
- UTF8String.subStringIndex(fromString("www.apache.org"), fromString("#"), 2));
+ fromString("www.apache.org").subStringIndex(fromString("#"), 2));
// delim is 2 chars
assertEquals(fromString("www||apache"),
- UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), 2));
+ fromString("www||apache||org").subStringIndex(fromString("||"), 2));
assertEquals(fromString("apache||org"),
- UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), -2));
+ fromString("www||apache||org").subStringIndex(fromString("||"), -2));
// null
assertEquals(null,
- UTF8String.subStringIndex(null, fromString("."), -2));
- assertEquals(null,
- UTF8String.subStringIndex(fromString("www.apache.org"), null, -2));
+ fromString("www.apache.org").subStringIndex(null, -2));
// non ascii chars
assertEquals(fromString("大千世界大"),
- UTF8String.subStringIndex(fromString("大千世界大千世界"), fromString("千"), 2));
+ fromString("大千世界大千世界").subStringIndex(fromString("千"), 2));
}
@Test
From 9546991d4964b17dd5d62f93909651280077ba13 Mon Sep 17 00:00:00 2001
From: "zhichao.li"
Date: Fri, 24 Jul 2015 10:46:04 +0800
Subject: [PATCH 7/8] scala style
---
.../sql/catalyst/expressions/StringExpressionsSuite.scala | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index abccef6196d5c..1a4376e6f5216 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -207,8 +207,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Substring_index(Literal(""), Literal("."), Literal(-2)), "")
checkEvaluation(
Substring_index(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
- checkEvaluation(
- Substring_index(Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
+ checkEvaluation(Substring_index(
+ Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
// non ascii chars
// scalastyle:off
checkEvaluation(
From 515519bebbec5a46de0789ea2ec82f448be0e8e8 Mon Sep 17 00:00:00 2001
From: "zhichao.li"
Date: Fri, 24 Jul 2015 13:49:19 +0800
Subject: [PATCH 8/8] add foldable and remove null checking
---
.../spark/sql/catalyst/expressions/stringOperations.scala | 1 +
.../main/java/org/apache/spark/unsafe/types/UTF8String.java | 5 +----
.../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 3 ---
3 files changed, 2 insertions(+), 7 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 18877769bd212..20252ea8a89cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -364,6 +364,7 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
extends Expression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
+ override def foldable: Boolean = strExpr.foldable && delimExpr.foldable && countExpr.foldable
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable
override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index e84509f18e146..79f2f25142e82 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -144,7 +144,7 @@ public byte[] getBytes() {
*/
public UTF8String substring(final int start, final int until) {
if (until <= start || start >= numBytes) {
- return fromBytes(new byte[0]);
+ return UTF8String.EMPTY_UTF8;
}
int j = firstByteIndex(0, 0, start);
int i = firstByteIndex(j, start, until);
@@ -502,9 +502,6 @@ private int doOrdinalIndexOf(
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
public UTF8String subStringIndex(UTF8String delim, int count) {
- if (delim == null) {
- return null;
- }
if (delim.numBytes == 0 || count == 0) {
return UTF8String.EMPTY_UTF8;
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 6386c9902885d..a78bb3c4374f4 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -321,9 +321,6 @@ public void substring_index() {
fromString("www||apache||org").subStringIndex(fromString("||"), 2));
assertEquals(fromString("apache||org"),
fromString("www||apache||org").subStringIndex(fromString("||"), -2));
- // null
- assertEquals(null,
- fromString("www.apache.org").subStringIndex(null, -2));
// non ascii chars
assertEquals(fromString("大千世界大"),
fromString("大千世界大千世界").subStringIndex(fromString("千"), 2));