From 3ce78026c960982e328eec1403e4c23451bf7d61 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 31 Jul 2015 15:32:22 -0700 Subject: [PATCH] fix substringIndex --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/stringOperations.scala | 44 +--- .../expressions/StringExpressionsSuite.scala | 24 +- .../org/apache/spark/sql/functions.scala | 2 +- .../apache/spark/unsafe/types/UTF8String.java | 232 ++++++------------ .../spark/unsafe/types/UTF8StringSuite.java | 74 +----- 6 files changed, 94 insertions(+), 284 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9e3a4cdf2d050..320523aaaf489 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -198,7 +198,7 @@ object FunctionRegistry { expression[StringSplit]("split"), expression[Substring]("substr"), expression[Substring]("substring"), - expression[Substring_index]("substring_index"), + expression[SubstringIndex]("substring_index"), expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5f90dc1a5ca12..804e2c1be819e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -427,52 +427,22 @@ case class StringInstr(str: Expression, substr: Expression) * returned. If count is negative, every to the right of the final delimiter (counting from the * right) is returned. substring_index performs a case-sensitive match when searching for delim. */ -case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr: Expression) - extends Expression with ImplicitCastInputTypes { +case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) + extends TernaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType - override def foldable: Boolean = strExpr.foldable && delimExpr.foldable && countExpr.foldable override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) - override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr) override def prettyName: String = "substring_index" - override def eval(input: InternalRow): Any = { - val str = strExpr.eval(input) - if (str != null) { - val delim = delimExpr.eval(input) - if (delim != null) { - val count = countExpr.eval(input) - if (count != null) { - return str.asInstanceOf[UTF8String].subStringIndex( - delim.asInstanceOf[UTF8String], - count.asInstanceOf[Int]) - } - } - } - null + override def nullSafeEval(str: Any, delim: Any, count: Any): Any = { + str.asInstanceOf[UTF8String].subStringIndex( + delim.asInstanceOf[UTF8String], + count.asInstanceOf[Int]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val str = strExpr.gen(ctx) - val delim = delimExpr.gen(ctx) - val count = countExpr.gen(ctx) - val resultCode = s"${str.primitive}.subStringIndex(${delim.primitive}, ${count.primitive})" - s""" - ${str.code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${str.isNull}) { - ${delim.code} - if (!${delim.isNull}) { - ${count.code} - if (!${count.isNull}) { - ${ev.isNull} = false; - ${ev.primitive} = $resultCode; - } - } - } - """ + defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 52ce89bb42eca..cb564314d88b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -190,32 +190,32 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("string substring_index function") { checkEvaluation( - Substring_index(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org") + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org") checkEvaluation( - Substring_index(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache") + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache") checkEvaluation( - Substring_index(Literal("www.apache.org"), Literal("."), Literal(1)), "www") + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www") checkEvaluation( - Substring_index(Literal("www.apache.org"), Literal("."), Literal(0)), "") + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "") checkEvaluation( - Substring_index(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org") + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org") checkEvaluation( - Substring_index(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org") + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org") checkEvaluation( - Substring_index(Literal("www.apache.org"), Literal("."), Literal(-1)), "org") + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org") checkEvaluation( - Substring_index(Literal(""), Literal("."), Literal(-2)), "") + SubstringIndex(Literal(""), Literal("."), Literal(-2)), "") checkEvaluation( - Substring_index(Literal.create(null, StringType), Literal("."), Literal(-2)), null) - checkEvaluation(Substring_index( + SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null) + checkEvaluation(SubstringIndex( Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null) // non ascii chars // scalastyle:off checkEvaluation( - Substring_index(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大") + SubstringIndex(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大") // scalastyle:on checkEvaluation( - Substring_index(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache") + SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache") } test("LIKE literal Regular Expression") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index acc9ab5bfe707..1a4b5c7ce4275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1796,7 +1796,7 @@ object functions { * @group string_funcs */ def substring_index(str: Column, delim: String, count: Int): Column = - Substring_index(str.expr, lit(delim).expr, lit(count).expr) + SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) /** * Locate the position of the first occurrence of substr. diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 644993c62a00a..63d71b5e509f7 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -200,26 +200,22 @@ public UTF8String substring(final int start, final int until) { if (until <= start || start >= numBytes) { return UTF8String.EMPTY_UTF8; } - int j = firstByteIndex(0, 0, start); - int i = firstByteIndex(j, start, until); - byte[] bytes = new byte[i - j]; - copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); - return fromBytes(bytes); - } - /** - * Returns a substring of this from start to end. - * @param start the position of first code point - */ - public UTF8String substring(final int start) { - if (start >= numBytes) { - return fromBytes(new byte[0]); + int i = 0; + int c = 0; + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; } - int i = firstByteIndex(0, 0, start); + int j = i; + while (i < numBytes && c < until) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } - byte[] bytes = new byte[numBytes - i]; - copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i); + byte[] bytes = new byte[i - j]; + copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); } @@ -388,8 +384,13 @@ public int indexOf(UTF8String v, int start) { return 0; } - int i = firstByteIndex(0, 0, start); // position in byte - int c = start; // position in character + // locate to the start position. + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } do { if (i + v.numBytes > numBytes) { @@ -405,174 +406,81 @@ public int indexOf(UTF8String v, int start) { return -1; } - private enum ByteType {FIRSTBYTE, MIDBYTE, SINGLEBYTECHAR}; - - private ByteType checkByteType(byte b) { - int firstTwoBits = (b >>> 6) & 0x03; - if (firstTwoBits == 3) { - return ByteType.FIRSTBYTE; - } else if (firstTwoBits == 2) { - return ByteType.MIDBYTE; - } else { - return ByteType.SINGLEBYTECHAR; - } - } - /** - * Return the first byte position for a given byte which shared the same code point. - * @param bytePos any byte within the code point - * @return the first byte position of a given code point, throw exception if not a valid UTF8 str + * Find the `str` from left to right. */ - private int firstOfCurrentCodePoint(int bytePos) { - while (bytePos >= 0) { - ByteType byteType = checkByteType(getByte(bytePos)); - if (ByteType.FIRSTBYTE == byteType || ByteType.SINGLEBYTECHAR == byteType) { - return bytePos; + private int find(UTF8String str, int start) { + assert (str.numBytes > 0); + while (start <= numBytes - str.numBytes) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + return start; } - bytePos--; + start += 1; } - throw new RuntimeException("Invalid UTF8 string: " + toString()); - } - - // Locate to the start position in byte for a given code point - private int firstByteIndex(int startByteIndex, int startPointIndex, int codePoint) { - int i = startByteIndex; // position in byte - int c = startPointIndex; // position in character - while (i < numBytes && c < codePoint) { - i += numBytesForFirstByte(getByte(i)); - c += 1; - } - if (i > numBytes) { - throw new StringIndexOutOfBoundsException(codePoint); - } - return i; - } - - // Locate to the last position in byte for a given code point - private int lastByteIndex(int codePoint) { - int i = firstByteIndex(0, 0, codePoint); - return i + numBytesForFirstByte(getByte(i)) - 1; + return -1; } /** - * Returns the index within this string of the last occurrence of the - * specified substring, searching backward starting at the specified index. - * @param v the substring to search for. - * @param startCodePoint the index to start search from - * @return the index of the last occurrence of the specified substring, - * searching backward from the specified index, - * or {@code -1} if there is no such occurrence. + * Find the `str` from right to left. */ - public int lastIndexOf(UTF8String v, int startCodePoint) { - // Empty string always match - if (v.numBytes == 0) { - return startCodePoint; - } - return lastIndexOfInByte(v, lastByteIndex(startCodePoint)); - } - - private int lastIndexOfInByte(UTF8String v, int fromIndexInByte) { - if (numBytes == 0) { - return -1; - } - do { - int startByteIndex = fromIndexInByte - v.numBytes + 1; - if (startByteIndex < 0 ) { - return -1; + private int rfind(UTF8String str, int start) { + assert (str.numBytes > 0); + while (start >= 0) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + return start; } - if (ByteArrayMethods.arrayEquals( - base, offset + startByteIndex, v.base, v.offset, v.numBytes)) { - int count = 0; // count from right most to the match end in byte. - while (startByteIndex >= 0) { - count++; - startByteIndex = firstOfCurrentCodePoint(startByteIndex) - 1; - } - return count - 1; - } - fromIndexInByte = firstOfCurrentCodePoint(fromIndexInByte) - 1; - } while (fromIndexInByte >= 0); + start -= 1; + } return -1; } - /** - * Finds the n-th last index within a String. - * This method uses {@link String#lastIndexOf(String)}.

- * - * @param searchStr the String to find, may be null - * @param ordinal the n-th last searchStr to find - * @return the n-th last index of the search String, - * -1 if no match or null string input - */ - protected int lastOrdinalIndexOf( - UTF8String searchStr, - int ordinal) { - return doOrdinalIndexOf(searchStr, ordinal, true); - } - - /** - * Finds the n-th index within a String, handling null. - * A null String will return -1 - * - * @param searchStr the String to find, may be null - * @param ordinal the n-th searchStr to find - * @return the n-th index of the search String, - * -1 if no match or null string input - */ - protected int ordinalIndexOf( - UTF8String searchStr, - int ordinal) { - return doOrdinalIndexOf(searchStr, ordinal, false); - } - - private int doOrdinalIndexOf( - UTF8String searchStr, - int ordinal, - boolean lastIndex) { - if (ordinal <= 0) { - return -1; - } - if (searchStr.numBytes == 0) { - return lastIndex ? numChars() : 0; - } - int found = 0; - int index = lastIndex ? numBytes : -1; - do { - if (lastIndex) { - index = lastIndexOfInByte(searchStr, index - 1); - } else { - index = indexOf(searchStr, index + 1); - } - if (index < 0) { - return index; - } - found += 1; - } while (found < ordinal); - return index; - } /** * Returns the substring from string str before count occurrences of the delimiter delim. * If count is positive, everything the left of the final delimiter (counting from left) is * returned. If count is negative, every to the right of the final delimiter (counting from the - * right) is returned. substring_index performs a case-sensitive match when searching for delim. + * right) is returned. subStringIndex performs a case-sensitive match when searching for delim. */ public UTF8String subStringIndex(UTF8String delim, int count) { if (delim.numBytes == 0 || count == 0) { - return UTF8String.EMPTY_UTF8; + return EMPTY_UTF8; } if (count > 0) { - int idx = ordinalIndexOf(delim, count); - if (idx != -1) { - return substring(0, idx); - } else { - return this; + int idx = -1; + while (count > 0) { + idx = find(delim, idx + 1); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return this; + } } + if (idx == 0) { + return EMPTY_UTF8; + } + byte[] bytes = new byte[idx]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + return fromBytes(bytes); + } else { - int idx = lastOrdinalIndexOf(delim, -count); - if (idx != -1) { - return substring(idx + delim.numChars()); - } else { - return this; + int idx = numBytes - delim.numBytes + 1; + count = -count; + while (count > 0) { + idx = rfind(delim, idx - 1); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return this; + } + } + if (idx + delim.numBytes == numBytes) { + return EMPTY_UTF8; } + int size = numBytes - delim.numBytes - idx; + byte[] bytes = new byte[size]; + copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + return fromBytes(bytes); } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 009b074493c8a..22c467e509d4f 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -21,12 +21,9 @@ import java.util.Arrays; import org.junit.Test; -import org.junit.rules.ExpectedException; import static junit.framework.Assert.*; - import static org.apache.spark.unsafe.types.UTF8String.*; -import static org.apache.spark.unsafe.types.UTF8String.fromString; public class UTF8StringSuite { @@ -205,57 +202,6 @@ public void substring() { assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5)); assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2)); - - assertEquals(fromString("hello"), fromString("hello").substring(0)); - assertEquals(fromString("ello"), fromString("hello").substring(1)); - assertEquals(fromString("砖头"), fromString("数据砖头").substring(2)); - assertEquals(fromString("头"), fromString("数据砖头").substring(3)); - ExpectedException exception = ExpectedException.none(); - fromString("数据砖头").substring(4); - exception.expect(java.lang.StringIndexOutOfBoundsException.class); - assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0)); - } - - @Test - public void ordinalIndexOf() { - assertEquals(-1, - fromString("www.apache.org").ordinalIndexOf(fromString("."), 0)); - assertEquals(0, - fromString("www.apache.org").ordinalIndexOf(fromString("w"), 1)); - assertEquals(3, - fromString("www.apache.org").ordinalIndexOf(fromString("."), 1)); - assertEquals(10, - fromString("www.apache.org").ordinalIndexOf(fromString("."), 2)); - assertEquals(-1, - fromString("www.apache.org").ordinalIndexOf(fromString("."), 3)); - assertEquals(-1, - fromString("www.apache.org").ordinalIndexOf(fromString("#"), 0)); - assertEquals(12, - fromString("www|||apache|||org").ordinalIndexOf(fromString("|||"), 2)); - assertEquals(2, - fromString("数据砖砖头").ordinalIndexOf(fromString("砖"), 1)); - assertEquals(-1, - fromString("砖头数据砖头").ordinalIndexOf(fromString("砖"), -2)); - } - - @Test - public void lastOrdinalIndexOf() { - assertEquals(-1, - fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 0)); - assertEquals(10, - fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 1)); - assertEquals(3, - fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 2)); - assertEquals(-1, - fromString("www.apache.org").lastOrdinalIndexOf(fromString("."), 3)); - assertEquals(-1, - fromString("www.apache.org").lastOrdinalIndexOf(fromString("#"), 0)); - assertEquals(3, - fromString("www|||apache|||org").lastOrdinalIndexOf(fromString("|||"), 2)); - assertEquals(3, - fromString("数据砖砖头").lastOrdinalIndexOf(fromString("砖"), 1)); - assertEquals(-1, - fromString("砖头数据砖头").lastOrdinalIndexOf(fromString("砖"), -2)); } @Test @@ -293,23 +239,6 @@ public void indexOf() { assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); } - @Test - public void lastIndexOf() { - assertEquals(1, fromString("hello").lastIndexOf(fromString(""), 1)); - assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0)); - assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0)); - assertEquals(0, fromString("hello").lastIndexOf(fromString("h"), 4)); - assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0)); - assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3)); - assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4)); - assertEquals(2, fromString("hello").lastIndexOf(fromString("ll"), 4)); - assertEquals(-1, fromString("hello").lastIndexOf(fromString("ll"), 0)); - assertEquals(5, fromString("数据砖头数据砖头").lastIndexOf(fromString("据砖"), 7)); - assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 3)); - assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 0)); - assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3)); - } - @Test public void substring_index() { assertEquals(fromString("www.apache.org"), @@ -343,6 +272,9 @@ public void substring_index() { // non ascii chars assertEquals(fromString("大千世界大"), fromString("大千世界大千世界").subStringIndex(fromString("千"), 2)); + // overlapped delim + assertEquals(fromString("||"), fromString("||||||").subStringIndex(fromString("|||"), 3)); + assertEquals(fromString("|||"), fromString("||||||").subStringIndex(fromString("|||"), -4)); } @Test