Skip to content

Commit

Permalink
add codegen and clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhichao-li committed Jul 23, 2015
1 parent ac863e9 commit b19b013
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ case class StringInstr(str: Expression, substr: Expression)
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
extends Expression with ImplicitCastInputTypes with CodegenFallback {
extends Expression with ImplicitCastInputTypes {

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
Expand All @@ -385,6 +385,30 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
}
null
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val str = strExpr.gen(ctx)
val delim = delimExpr.gen(ctx)
val count = countExpr.gen(ctx)
val resultCode =
s"""org.apache.spark.unsafe.types.UTF8String.subStringIndex(
|${str.primitive}, ${delim.primitive}, ${count.primitive})""".stripMargin
s"""
${str.code}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${str.isNull}) {
${delim.code}
if (!${delim.isNull}) {
${count.code}
if (!${count.isNull}) {
${ev.isNull} = false;
${ev.primitive} = $resultCode;
}
}
}
"""
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(s.substring(0), "example", row)
}

test("string substring_index function") {
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(0)), "")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
checkEvaluation(
Substring_index(Literal(""), Literal("."), Literal(-2)), "")
checkEvaluation(
Substring_index(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
// non ascii chars
// scalastyle:off
checkEvaluation(
Substring_index(Literal("大千世界大千世界"), Literal( ""), Literal(2)), "大千世界大")
// scalastyle:on
checkEvaluation(
Substring_index(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
}

test("LIKE literal Regular Expression") {
checkEvaluation(Literal.create(null, StringType).like("a"), null)
checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)
Expand Down
2 changes: 0 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1783,7 +1783,6 @@ object functions {
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*
* @group string_funcs
* @since 1.5.0
*/
def substring_index(str: String, delim: String, count: Int): Column =
substring_index(Column(str), delim, count)
Expand All @@ -1795,7 +1794,6 @@ object functions {
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*
* @group string_funcs
* @since 1.5.0
*/
def substring_index(str: Column, delim: String, count: Int): Column =
Substring_index(str.expr, lit(delim).expr, lit(count).expr)
Expand Down
80 changes: 37 additions & 43 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,8 @@ public UTF8String substring(final int start, final int until) {
if (until <= start || start >= numBytes) {
return fromBytes(new byte[0]);
}

int i = 0;
int c = 0;
while (i < numBytes && c < start) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}

int j = i;
while (i < numBytes && c < until) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}

int j = firstByteIndex(start);
int i = firstByteIndex(until);
byte[] bytes = new byte[i - j];
copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j);
return fromBytes(bytes);
Expand All @@ -174,12 +162,7 @@ public UTF8String substring(final int start) {
return fromBytes(new byte[0]);
}

int i = 0;
int c = 0;
while (i < numBytes && c < start) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
int i = firstByteIndex(start);

byte[] bytes = new byte[numBytes - i];
copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i);
Expand Down Expand Up @@ -351,13 +334,8 @@ public int indexOf(UTF8String v, int start) {
return 0;
}

// locate to the start position.
int i = 0; // position in byte
int c = 0; // position in character
while (i < numBytes && c < start) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
int i = firstByteIndex(start); // position in byte
int c = start; // position in character

do {
if (i + v.numBytes > numBytes) {
Expand Down Expand Up @@ -399,19 +377,29 @@ private int firstOfCurrentCodePoint(int bytePos) {
}
bytePos--;
}
throw new RuntimeException("Invalid utf8 string");
throw new RuntimeException("Invalid UTF8 string");
}

private int indexEnd(int startCodePoint) {
int i = numBytes -1; // position in byte
int c = numChars() - 1; // position in character
while (i >=0 && c > startCodePoint) {
i = firstOfCurrentCodePoint(i) - 1;
c -= 1;
// Locate to the start position in byte for a given code point
private int firstByteIndex(int codePoint) {
int i = 0; // position in byte
int c = 0; // position in character
while (i < numBytes && c < codePoint) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
if (i > numBytes) {
throw new StringIndexOutOfBoundsException(codePoint);
}
return i;
}

// Locate to the last position in byte for a given code point
private int lastByteIndex(int codePoint) {
int i = firstByteIndex(codePoint);
return i + numBytesForFirstByte(getByte(i)) - 1;
}

/**
* Returns the index within this string of the last occurrence of the
* specified substring, searching backward starting at the specified index.
Expand All @@ -431,7 +419,7 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
if (numBytes == 0) {
return -1;
}
int fromIndexEnd = indexEnd(startCodePoint);
int fromIndexEnd = lastByteIndex(startCodePoint);
do {
if (fromIndexEnd - v.numBytes + 1 < 0 ) {
return -1;
Expand All @@ -456,35 +444,38 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
*
* @param str the String to check, may be null
* @param searchStr the String to find, may be null
* @param searchStrNumChars num of code ponts of the searchStr
* @param ordinal the n-th last <code>searchStr</code> to find
* @return the n-th last index of the search String,
* <code>-1</code> if no match or <code>null</code> string input
*/
public static int lastOrdinalIndexOf(
UTF8String str,
UTF8String searchStr,
int searchStrNumChars,
int ordinal) {
return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, true);
if (str == null || searchStr == null) {
return -1;
}
return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, true);
}

/**
* Finds the n-th index within a String, handling <code>null</code>.
* A <code>null</code> String will return <code>-1</code>
*
* @param str the String to check, may be null
* @param searchStr the String to find, may be null
* @param searchStrNumChars num of code points of searchStr
* @param ordinal the n-th <code>searchStr</code> to find
* @return the n-th index of the search String,
* <code>-1</code> if no match or <code>null</code> string input
*/
public static int ordinalIndexOf(
UTF8String str,
UTF8String searchStr,
int searchStrNumChars,
int ordinal) {
return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, false);
if (str == null || searchStr == null) {
return -1;
}
return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, false);
}

private static int doOrdinalIndexOf(
Expand Down Expand Up @@ -526,19 +517,22 @@ private static int doOrdinalIndexOf(
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) {
if (str == null || delim == null) {
return null;
}
if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) {
return UTF8String.EMPTY_UTF8;
}
int delimNumChars = delim.numChars();
if (count > 0) {
int idx = ordinalIndexOf(str, delim, delimNumChars, count);
int idx = doOrdinalIndexOf(str, delim, delimNumChars, count, false);
if (idx != -1) {
return str.substring(0, idx);
} else {
return str;
}
} else {
int idx = lastOrdinalIndexOf(str, delim, delimNumChars, -count);
int idx = doOrdinalIndexOf(str, delim, delimNumChars, -count, true);
if (idx != -1) {
return str.substring(idx + delimNumChars);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import java.util.Arrays;

import org.junit.Test;
import org.junit.rules.ExpectedException;

import static junit.framework.Assert.*;

import static org.apache.spark.unsafe.types.UTF8String.*;
import static org.apache.spark.unsafe.types.UTF8String.fromString;

public class UTF8StringSuite {

Expand Down Expand Up @@ -184,6 +186,63 @@ public void substring() {
assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3));
assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5));
assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2));

assertEquals(fromString("hello"), fromString("hello").substring(0));
assertEquals(fromString("ello"), fromString("hello").substring(1));
assertEquals(fromString("砖头"), fromString("数据砖头").substring(2));
assertEquals(fromString("头"), fromString("数据砖头").substring(3));
ExpectedException exception = ExpectedException.none();
fromString("数据砖头").substring(4);
exception.expect(java.lang.StringIndexOutOfBoundsException.class);
assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0));
}

@Test
public void ordinalIndexOf() {
assertEquals(-1,
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 0));
assertEquals(3,
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 1));
assertEquals(10,
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 2));
assertEquals(-1,
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 3));
assertEquals(-1,
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("#"), 0));
assertEquals(12,
UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2));
assertEquals(-1,
UTF8String.ordinalIndexOf(null, fromString("|||"), 1));
assertEquals(-1,
UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), null, 1));
assertEquals(2,
UTF8String.ordinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1));
assertEquals(-1,
UTF8String.ordinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2));
}

@Test
public void lastOrdinalIndexOf() {
assertEquals(-1,
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 0));
assertEquals(10,
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 1));
assertEquals(3,
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 2));
assertEquals(-1,
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 3));
assertEquals(-1,
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("#"), 0));
assertEquals(3,
UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2));
assertEquals(-1,
UTF8String.lastOrdinalIndexOf(null, fromString("|||"), 1));
assertEquals(-1,
UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), null, 1));
assertEquals(3,
UTF8String.lastOrdinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1));
assertEquals(-1,
UTF8String.lastOrdinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2));
}

@Test
Expand Down Expand Up @@ -238,6 +297,46 @@ public void lastIndexOf() {
assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3));
}

@Test
public void substring_index() {
assertEquals(fromString("www.apache.org"),
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 3));
assertEquals(fromString("www.apache"),
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 2));
assertEquals(fromString("www"),
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 1));
assertEquals(fromString(""),
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 0));
assertEquals(fromString("org"),
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -1));
assertEquals(fromString("apache.org"),
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -2));
assertEquals(fromString("www.apache.org"),
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -3));
// str is empty string
assertEquals(fromString(""),
UTF8String.subStringIndex(fromString(""), fromString("."), 1));
// empty string delim
assertEquals(fromString(""),
UTF8String.subStringIndex(fromString("www.apache.org"), fromString(""), 1));
// delim does not exist in str
assertEquals(fromString("www.apache.org"),
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("#"), 2));
// delim is 2 chars
assertEquals(fromString("www||apache"),
UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), 2));
assertEquals(fromString("apache||org"),
UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), -2));
// null
assertEquals(null,
UTF8String.subStringIndex(null, fromString("."), -2));
assertEquals(null,
UTF8String.subStringIndex(fromString("www.apache.org"), null, -2));
// non ascii chars
assertEquals(fromString("大千世界大"),
UTF8String.subStringIndex(fromString("大千世界大千世界"), fromString("千"), 2));
}

@Test
public void reverse() {
assertEquals(fromString("olleh"), fromString("hello").reverse());
Expand Down

0 comments on commit b19b013

Please sign in to comment.