From 3ce78026c960982e328eec1403e4c23451bf7d61 Mon Sep 17 00:00:00 2001
From: Davies Liu
Date: Fri, 31 Jul 2015 15:32:22 -0700
Subject: [PATCH] fix substringIndex
---
.../catalyst/analysis/FunctionRegistry.scala | 2 +-
.../expressions/stringOperations.scala | 44 +---
.../expressions/StringExpressionsSuite.scala | 24 +-
.../org/apache/spark/sql/functions.scala | 2 +-
.../apache/spark/unsafe/types/UTF8String.java | 232 ++++++------------
.../spark/unsafe/types/UTF8StringSuite.java | 74 +-----
6 files changed, 94 insertions(+), 284 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 9e3a4cdf2d050..320523aaaf489 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -198,7 +198,7 @@ object FunctionRegistry {
expression[StringSplit]("split"),
expression[Substring]("substr"),
expression[Substring]("substring"),
- expression[Substring_index]("substring_index"),
+ expression[SubstringIndex]("substring_index"),
expression[StringTrim]("trim"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 5f90dc1a5ca12..804e2c1be819e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -427,52 +427,22 @@ case class StringInstr(str: Expression, substr: Expression)
* returned. If count is negative, every to the right of the final delimiter (counting from the
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
-case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
- extends Expression with ImplicitCastInputTypes {
+case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
- override def foldable: Boolean = strExpr.foldable && delimExpr.foldable && countExpr.foldable
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
- override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable
override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
override def prettyName: String = "substring_index"
- override def eval(input: InternalRow): Any = {
- val str = strExpr.eval(input)
- if (str != null) {
- val delim = delimExpr.eval(input)
- if (delim != null) {
- val count = countExpr.eval(input)
- if (count != null) {
- return str.asInstanceOf[UTF8String].subStringIndex(
- delim.asInstanceOf[UTF8String],
- count.asInstanceOf[Int])
- }
- }
- }
- null
+ override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
+ str.asInstanceOf[UTF8String].subStringIndex(
+ delim.asInstanceOf[UTF8String],
+ count.asInstanceOf[Int])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val str = strExpr.gen(ctx)
- val delim = delimExpr.gen(ctx)
- val count = countExpr.gen(ctx)
- val resultCode = s"${str.primitive}.subStringIndex(${delim.primitive}, ${count.primitive})"
- s"""
- ${str.code}
- boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${str.isNull}) {
- ${delim.code}
- if (!${delim.isNull}) {
- ${count.code}
- if (!${count.isNull}) {
- ${ev.isNull} = false;
- ${ev.primitive} = $resultCode;
- }
- }
- }
- """
+ defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 52ce89bb42eca..cb564314d88b0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -190,32 +190,32 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("string substring_index function") {
checkEvaluation(
- Substring_index(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
checkEvaluation(
- Substring_index(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
checkEvaluation(
- Substring_index(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
checkEvaluation(
- Substring_index(Literal("www.apache.org"), Literal("."), Literal(0)), "")
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "")
checkEvaluation(
- Substring_index(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
checkEvaluation(
- Substring_index(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
checkEvaluation(
- Substring_index(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
checkEvaluation(
- Substring_index(Literal(""), Literal("."), Literal(-2)), "")
+ SubstringIndex(Literal(""), Literal("."), Literal(-2)), "")
checkEvaluation(
- Substring_index(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
- checkEvaluation(Substring_index(
+ SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
+ checkEvaluation(SubstringIndex(
Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
// non ascii chars
// scalastyle:off
checkEvaluation(
- Substring_index(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大")
+ SubstringIndex(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大")
// scalastyle:on
checkEvaluation(
- Substring_index(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
+ SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
}
test("LIKE literal Regular Expression") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index acc9ab5bfe707..1a4b5c7ce4275 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1796,7 +1796,7 @@ object functions {
* @group string_funcs
*/
def substring_index(str: Column, delim: String, count: Int): Column =
- Substring_index(str.expr, lit(delim).expr, lit(count).expr)
+ SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
/**
* Locate the position of the first occurrence of substr.
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 644993c62a00a..63d71b5e509f7 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -200,26 +200,22 @@ public UTF8String substring(final int start, final int until) {
if (until <= start || start >= numBytes) {
return UTF8String.EMPTY_UTF8;
}
- int j = firstByteIndex(0, 0, start);
- int i = firstByteIndex(j, start, until);
- byte[] bytes = new byte[i - j];
- copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j);
- return fromBytes(bytes);
- }
- /**
- * Returns a substring of this from start to end.
- * @param start the position of first code point
- */
- public UTF8String substring(final int start) {
- if (start >= numBytes) {
- return fromBytes(new byte[0]);
+ int i = 0;
+ int c = 0;
+ while (i < numBytes && c < start) {
+ i += numBytesForFirstByte(getByte(i));
+ c += 1;
}
- int i = firstByteIndex(0, 0, start);
+ int j = i;
+ while (i < numBytes && c < until) {
+ i += numBytesForFirstByte(getByte(i));
+ c += 1;
+ }
- byte[] bytes = new byte[numBytes - i];
- copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i);
+ byte[] bytes = new byte[i - j];
+ copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j);
return fromBytes(bytes);
}
@@ -388,8 +384,13 @@ public int indexOf(UTF8String v, int start) {
return 0;
}
- int i = firstByteIndex(0, 0, start); // position in byte
- int c = start; // position in character
+ // locate to the start position.
+ int i = 0; // position in byte
+ int c = 0; // position in character
+ while (i < numBytes && c < start) {
+ i += numBytesForFirstByte(getByte(i));
+ c += 1;
+ }
do {
if (i + v.numBytes > numBytes) {
@@ -405,174 +406,81 @@ public int indexOf(UTF8String v, int start) {
return -1;
}
- private enum ByteType {FIRSTBYTE, MIDBYTE, SINGLEBYTECHAR};
-
- private ByteType checkByteType(byte b) {
- int firstTwoBits = (b >>> 6) & 0x03;
- if (firstTwoBits == 3) {
- return ByteType.FIRSTBYTE;
- } else if (firstTwoBits == 2) {
- return ByteType.MIDBYTE;
- } else {
- return ByteType.SINGLEBYTECHAR;
- }
- }
-
/**
- * Return the first byte position for a given byte which shared the same code point.
- * @param bytePos any byte within the code point
- * @return the first byte position of a given code point, throw exception if not a valid UTF8 str
+ * Find the `str` from left to right.
*/
- private int firstOfCurrentCodePoint(int bytePos) {
- while (bytePos >= 0) {
- ByteType byteType = checkByteType(getByte(bytePos));
- if (ByteType.FIRSTBYTE == byteType || ByteType.SINGLEBYTECHAR == byteType) {
- return bytePos;
+ private int find(UTF8String str, int start) {
+ assert (str.numBytes > 0);
+ while (start <= numBytes - str.numBytes) {
+ if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
+ return start;
}
- bytePos--;
+ start += 1;
}
- throw new RuntimeException("Invalid UTF8 string: " + toString());
- }
-
- // Locate to the start position in byte for a given code point
- private int firstByteIndex(int startByteIndex, int startPointIndex, int codePoint) {
- int i = startByteIndex; // position in byte
- int c = startPointIndex; // position in character
- while (i < numBytes && c < codePoint) {
- i += numBytesForFirstByte(getByte(i));
- c += 1;
- }
- if (i > numBytes) {
- throw new StringIndexOutOfBoundsException(codePoint);
- }
- return i;
- }
-
- // Locate to the last position in byte for a given code point
- private int lastByteIndex(int codePoint) {
- int i = firstByteIndex(0, 0, codePoint);
- return i + numBytesForFirstByte(getByte(i)) - 1;
+ return -1;
}
/**
- * Returns the index within this string of the last occurrence of the
- * specified substring, searching backward starting at the specified index.
- * @param v the substring to search for.
- * @param startCodePoint the index to start search from
- * @return the index of the last occurrence of the specified substring,
- * searching backward from the specified index,
- * or {@code -1} if there is no such occurrence.
+ * Find the `str` from right to left.
*/
- public int lastIndexOf(UTF8String v, int startCodePoint) {
- // Empty string always match
- if (v.numBytes == 0) {
- return startCodePoint;
- }
- return lastIndexOfInByte(v, lastByteIndex(startCodePoint));
- }
-
- private int lastIndexOfInByte(UTF8String v, int fromIndexInByte) {
- if (numBytes == 0) {
- return -1;
- }
- do {
- int startByteIndex = fromIndexInByte - v.numBytes + 1;
- if (startByteIndex < 0 ) {
- return -1;
+ private int rfind(UTF8String str, int start) {
+ assert (str.numBytes > 0);
+ while (start >= 0) {
+ if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
+ return start;
}
- if (ByteArrayMethods.arrayEquals(
- base, offset + startByteIndex, v.base, v.offset, v.numBytes)) {
- int count = 0; // count from right most to the match end in byte.
- while (startByteIndex >= 0) {
- count++;
- startByteIndex = firstOfCurrentCodePoint(startByteIndex) - 1;
- }
- return count - 1;
- }
- fromIndexInByte = firstOfCurrentCodePoint(fromIndexInByte) - 1;
- } while (fromIndexInByte >= 0);
+ start -= 1;
+ }
return -1;
}
- /**
- * Finds the n-th last index within a String.
- * This method uses {@link String#lastIndexOf(String)}.
- *
- * @param searchStr the String to find, may be null
- * @param ordinal the n-th last searchStr
to find
- * @return the n-th last index of the search String,
- * -1
if no match or null
string input
- */
- protected int lastOrdinalIndexOf(
- UTF8String searchStr,
- int ordinal) {
- return doOrdinalIndexOf(searchStr, ordinal, true);
- }
-
- /**
- * Finds the n-th index within a String, handling null
.
- * A null
String will return -1
- *
- * @param searchStr the String to find, may be null
- * @param ordinal the n-th searchStr
to find
- * @return the n-th index of the search String,
- * -1
if no match or null
string input
- */
- protected int ordinalIndexOf(
- UTF8String searchStr,
- int ordinal) {
- return doOrdinalIndexOf(searchStr, ordinal, false);
- }
-
- private int doOrdinalIndexOf(
- UTF8String searchStr,
- int ordinal,
- boolean lastIndex) {
- if (ordinal <= 0) {
- return -1;
- }
- if (searchStr.numBytes == 0) {
- return lastIndex ? numChars() : 0;
- }
- int found = 0;
- int index = lastIndex ? numBytes : -1;
- do {
- if (lastIndex) {
- index = lastIndexOfInByte(searchStr, index - 1);
- } else {
- index = indexOf(searchStr, index + 1);
- }
- if (index < 0) {
- return index;
- }
- found += 1;
- } while (found < ordinal);
- return index;
- }
/**
* Returns the substring from string str before count occurrences of the delimiter delim.
* If count is positive, everything the left of the final delimiter (counting from left) is
* returned. If count is negative, every to the right of the final delimiter (counting from the
- * right) is returned. substring_index performs a case-sensitive match when searching for delim.
+ * right) is returned. subStringIndex performs a case-sensitive match when searching for delim.
*/
public UTF8String subStringIndex(UTF8String delim, int count) {
if (delim.numBytes == 0 || count == 0) {
- return UTF8String.EMPTY_UTF8;
+ return EMPTY_UTF8;
}
if (count > 0) {
- int idx = ordinalIndexOf(delim, count);
- if (idx != -1) {
- return substring(0, idx);
- } else {
- return this;
+ int idx = -1;
+ while (count > 0) {
+ idx = find(delim, idx + 1);
+ if (idx >= 0) {
+ count --;
+ } else {
+ // can not find enough delim
+ return this;
+ }
}
+ if (idx == 0) {
+ return EMPTY_UTF8;
+ }
+ byte[] bytes = new byte[idx];
+ copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx);
+ return fromBytes(bytes);
+
} else {
- int idx = lastOrdinalIndexOf(delim, -count);
- if (idx != -1) {
- return substring(idx + delim.numChars());
- } else {
- return this;
+ int idx = numBytes - delim.numBytes + 1;
+ count = -count;
+ while (count > 0) {
+ idx = rfind(delim, idx - 1);
+ if (idx >= 0) {
+ count --;
+ } else {
+ // can not find enough delim
+ return this;
+ }
+ }
+ if (idx + delim.numBytes == numBytes) {
+ return EMPTY_UTF8;
}
+ int size = numBytes - delim.numBytes - idx;
+ byte[] bytes = new byte[size];
+ copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size);
+ return fromBytes(bytes);
}
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 009b074493c8a..22c467e509d4f 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -21,12 +21,9 @@
import java.util.Arrays;
import org.junit.Test;
-import org.junit.rules.ExpectedException;
import static junit.framework.Assert.*;
-
import static org.apache.spark.unsafe.types.UTF8String.*;
-import static org.apache.spark.unsafe.types.UTF8String.fromString;
public class UTF8StringSuite {
@@ -205,57 +202,6 @@ public void substring() {
assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3));
assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5));
assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2));
-
- assertEquals(fromString("hello"), fromString("hello").substring(0));
- assertEquals(fromString("ello"), fromString("hello").substring(1));
- assertEquals(fromString("砖头"), fromString("数据砖头").substring(2));
- assertEquals(fromString("头"), fromString("数据砖头").substring(3));
- ExpectedException exception = ExpectedException.none();
- fromString("数据砖头").substring(4);
- exception.expect(java.lang.StringIndexOutOfBoundsException.class);
- assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0));
- }
-
- @Test
- public void ordinalIndexOf() {
- assertEquals(-1,
- fromString("www.apache.org").ordinalIndexOf(fromString("."), 0));
- assertEquals(0,
- fromString("www.apache.org").ordinalIndexOf(fromString("w"), 1));
- assertEquals(3,
- fromString("www.apache.org").ordinalIndexOf(fromString("."), 1));
- assertEquals(10,
- fromString("www.apache.org").ordinalIndexOf(fromString("."), 2));
- assertEquals(-1,
- fromString("www.apache.org").ordinalIndexOf(fromString("."), 3));
- assertEquals(-1,
- fromString("www.apache.org").ordinalIndexOf(fromString("#"), 0));
- assertEquals(12,
- fromString("www|||apache|||org").ordinalIndexOf(fromString("|||"), 2));
- assertEquals(2,
- fromString("数据砖砖头").ordinalIndexOf(fromString("砖"), 1));
- assertEquals(-1,
- fromString("砖头数据砖头").ordinalIndexOf(fromString("砖"), -2));
- }
-
- @Test
- public void lastOrdinalIndexOf() {
- assertEquals(-1,
- fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 0));
- assertEquals(10,
- fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 1));
- assertEquals(3,
- fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 2));
- assertEquals(-1,
- fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 3));
- assertEquals(-1,
- fromString("www.apache.org").lastOrdinalIndexOf(fromString("#"), 0));
- assertEquals(3,
- fromString("www|||apache|||org").lastOrdinalIndexOf(fromString("|||"), 2));
- assertEquals(3,
- fromString("数据砖砖头").lastOrdinalIndexOf(fromString("砖"), 1));
- assertEquals(-1,
- fromString("砖头数据砖头").lastOrdinalIndexOf(fromString("砖"), -2));
}
@Test
@@ -293,23 +239,6 @@ public void indexOf() {
assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0));
}
- @Test
- public void lastIndexOf() {
- assertEquals(1, fromString("hello").lastIndexOf(fromString(""), 1));
- assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0));
- assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0));
- assertEquals(0, fromString("hello").lastIndexOf(fromString("h"), 4));
- assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0));
- assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3));
- assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4));
- assertEquals(2, fromString("hello").lastIndexOf(fromString("ll"), 4));
- assertEquals(-1, fromString("hello").lastIndexOf(fromString("ll"), 0));
- assertEquals(5, fromString("数据砖头数据砖头").lastIndexOf(fromString("据砖"), 7));
- assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 3));
- assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 0));
- assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3));
- }
-
@Test
public void substring_index() {
assertEquals(fromString("www.apache.org"),
@@ -343,6 +272,9 @@ public void substring_index() {
// non ascii chars
assertEquals(fromString("大千世界大"),
fromString("大千世界大千世界").subStringIndex(fromString("千"), 2));
+ // overlapped delim
+ assertEquals(fromString("||"), fromString("||||||").subStringIndex(fromString("|||"), 3));
+ assertEquals(fromString("|||"), fromString("||||||").subStringIndex(fromString("|||"), -4));
}
@Test