diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8024a8de07c98..b00659e4b3465 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -918,6 +918,25 @@ def trunc(date, format): return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) +@since(1.5) +@ignore_unicode_prefix +def substring_index(str, delim, count): + """ + Returns the substring from string str before count occurrences of the delimiter delim. + If count is positive, everything the left of the final delimiter (counting from left) is + returned. If count is negative, every to the right of the final delimiter (counting from the + right) is returned. substring_index performs a case-sensitive match when searching for delim. + + >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s']) + >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect() + [Row(s=u'a.b')] + >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect() + [Row(s=u'b.c.d')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count)) + + @since(1.5) def size(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1bf7204a2515c..320523aaaf489 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -198,6 +198,7 @@ object FunctionRegistry { expression[StringSplit]("split"), expression[Substring]("substr"), expression[Substring]("substring"), + expression[SubstringIndex]("substring_index"), expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 684eac12bd6f0..804e2c1be819e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -421,6 +421,31 @@ case class StringInstr(str: Expression, substr: Expression) } } +/** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. + */ +case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr) + override def prettyName: String = "substring_index" + + override def nullSafeEval(str: Any, delim: Any, count: Any): Any = { + str.asInstanceOf[UTF8String].subStringIndex( + delim.asInstanceOf[UTF8String], + count.asInstanceOf[Int]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") + } +} + /** * A function that returns the position of the first occurrence of substr * in given string after position pos. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3ecd0d374c46b..cb564314d88b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(s.substring(0), "example", row) } + test("string substring_index function") { + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org") + checkEvaluation( + SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org") + checkEvaluation( + SubstringIndex(Literal(""), Literal("."), Literal(-2)), "") + checkEvaluation( + SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null) + checkEvaluation(SubstringIndex( + Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null) + // non ascii chars + // scalastyle:off + checkEvaluation( + SubstringIndex(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大") + // scalastyle:on + checkEvaluation( + SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache") + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5d82a5eadd94d..1a4b5c7ce4275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1788,8 +1788,18 @@ object functions { def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) /** - * Locate the position of the first occurrence of substr in a string column. + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. * + * @group string_funcs + */ + def substring_index(str: Column, delim: String, count: Int): Column = + SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + + /** + * Locate the position of the first occurrence of substr. * NOTE: The position is not zero based, but 1 based index, returns 0 if substr * could not be found in str. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 8e0ea76d15881..f23c4c699a149 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -154,6 +154,63 @@ class StringFunctionsSuite extends QueryTest { Row(1)) } + test("string substring_index function") { + val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c") + checkAnswer( + df.select(substring_index($"a", ".", 3)), + Row("www.apache.org")) + checkAnswer( + df.select(substring_index($"a", ".", 2)), + Row("www.apache")) + checkAnswer( + df.select(substring_index($"a", ".", 1)), + Row("www")) + checkAnswer( + df.select(substring_index($"a", ".", 0)), + Row("")) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), ".", -1)), + Row("org")) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), ".", -2)), + Row("apache.org")) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), ".", -3)), + Row("www.apache.org")) + // str is empty string + checkAnswer( + df.select(substring_index(lit(""), ".", 1)), + Row("")) + // empty string delim + checkAnswer( + df.select(substring_index(lit("www.apache.org"), "", 1)), + Row("")) + // delim does not exist in str + checkAnswer( + df.select(substring_index(lit("www.apache.org"), "#", 1)), + Row("www.apache.org")) + // delim is 2 chars + checkAnswer( + df.select(substring_index(lit("www||apache||org"), "||", 2)), + Row("www||apache")) + checkAnswer( + df.select(substring_index(lit("www||apache||org"), "||", -2)), + Row("apache||org")) + // null + checkAnswer( + df.select(substring_index(lit(null), "||", 2)), + Row(null)) + checkAnswer( + df.select(substring_index(lit("www.apache.org"), null, 2)), + Row(null)) + // non ascii chars + // scalastyle:off + checkAnswer( + df.selectExpr("""substring_index("大千世界大千世界", "千", 2)"""), + Row("大千世界大")) + // scalastyle:on + } + test("string locate function") { val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index c38953f65d7d7..63d71b5e509f7 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -198,7 +198,7 @@ public byte[] getBytes() { */ public UTF8String substring(final int start, final int until) { if (until <= start || start >= numBytes) { - return fromBytes(new byte[0]); + return UTF8String.EMPTY_UTF8; } int i = 0; @@ -406,6 +406,84 @@ public int indexOf(UTF8String v, int start) { return -1; } + /** + * Find the `str` from left to right. + */ + private int find(UTF8String str, int start) { + assert (str.numBytes > 0); + while (start <= numBytes - str.numBytes) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + return start; + } + start += 1; + } + return -1; + } + + /** + * Find the `str` from right to left. + */ + private int rfind(UTF8String str, int start) { + assert (str.numBytes > 0); + while (start >= 0) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + return start; + } + start -= 1; + } + return -1; + } + + /** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. subStringIndex performs a case-sensitive match when searching for delim. + */ + public UTF8String subStringIndex(UTF8String delim, int count) { + if (delim.numBytes == 0 || count == 0) { + return EMPTY_UTF8; + } + if (count > 0) { + int idx = -1; + while (count > 0) { + idx = find(delim, idx + 1); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return this; + } + } + if (idx == 0) { + return EMPTY_UTF8; + } + byte[] bytes = new byte[idx]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + return fromBytes(bytes); + + } else { + int idx = numBytes - delim.numBytes + 1; + count = -count; + while (count > 0) { + idx = rfind(delim, idx - 1); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return this; + } + } + if (idx + delim.numBytes == numBytes) { + return EMPTY_UTF8; + } + int size = numBytes - delim.numBytes - idx; + byte[] bytes = new byte[size]; + copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + return fromBytes(bytes); + } + } + /** * Returns str, right-padded with pad to a length of len * For example: diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index f2cc19ca6b172..cd3a9ad5b9536 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -240,6 +240,44 @@ public void indexOf() { assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); } + @Test + public void substring_index() { + assertEquals(fromString("www.apache.org"), + fromString("www.apache.org").subStringIndex(fromString("."), 3)); + assertEquals(fromString("www.apache"), + fromString("www.apache.org").subStringIndex(fromString("."), 2)); + assertEquals(fromString("www"), + fromString("www.apache.org").subStringIndex(fromString("."), 1)); + assertEquals(fromString(""), + fromString("www.apache.org").subStringIndex(fromString("."), 0)); + assertEquals(fromString("org"), + fromString("www.apache.org").subStringIndex(fromString("."), -1)); + assertEquals(fromString("apache.org"), + fromString("www.apache.org").subStringIndex(fromString("."), -2)); + assertEquals(fromString("www.apache.org"), + fromString("www.apache.org").subStringIndex(fromString("."), -3)); + // str is empty string + assertEquals(fromString(""), + fromString("").subStringIndex(fromString("."), 1)); + // empty string delim + assertEquals(fromString(""), + fromString("www.apache.org").subStringIndex(fromString(""), 1)); + // delim does not exist in str + assertEquals(fromString("www.apache.org"), + fromString("www.apache.org").subStringIndex(fromString("#"), 2)); + // delim is 2 chars + assertEquals(fromString("www||apache"), + fromString("www||apache||org").subStringIndex(fromString("||"), 2)); + assertEquals(fromString("apache||org"), + fromString("www||apache||org").subStringIndex(fromString("||"), -2)); + // non ascii chars + assertEquals(fromString("大千世界大"), + fromString("大千世界大千世界").subStringIndex(fromString("千"), 2)); + // overlapped delim + assertEquals(fromString("||"), fromString("||||||").subStringIndex(fromString("|||"), 3)); + assertEquals(fromString("|||"), fromString("||||||").subStringIndex(fromString("|||"), -4)); + } + @Test public void reverse() { assertEquals(fromString("olleh"), fromString("hello").reverse());