From 41233e6a95e74f13d0a3cf439ee59b5f984cc0da Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 30 Jun 2016 15:08:09 -0700 Subject: [PATCH 1/9] [SPARK-16285][SQL] Implement sentences SQL functions --- .../apache/spark/unsafe/types/UTF8String.java | 46 +++++++++++++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/regexpExpressions.scala | 27 +++++++++++ .../expressions/ComplexTypeSuite.scala | 27 +++++++++++ .../spark/sql/hive/HiveSessionCatalog.scala | 2 +- 5 files changed, 102 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 54a54569240c0..f7db962a4f4b8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -22,7 +22,10 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; +import java.text.BreakIterator; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import com.esotericsoftware.kryo.Kryo; @@ -801,6 +804,49 @@ public UTF8String[] split(UTF8String pattern, int limit) { return res; } + /** + * Return a locale of the given language and country, or a default locale when failures occur. + */ + private Locale getLocale(UTF8String language, UTF8String country) { + try { + return new Locale(language.toString(), country.toString()); + } finally { + return Locale.getDefault(); + } + } + + /** + * Splits a string into arrays of sentences, where each sentence is an array of words. + */ + public ArrayList> sentences(UTF8String language, UTF8String country) { + String sentences = this.toString(); + Locale locale = getLocale(language, country); + + BreakIterator bi = BreakIterator.getSentenceInstance(locale); + bi.setText(sentences); + int idx = 0; + ArrayList> result = new ArrayList<>(); + while (bi.next() != BreakIterator.DONE) { + String sentence = sentences.substring(idx, bi.current()); + idx = bi.current(); + + BreakIterator wi = BreakIterator.getWordInstance(locale); + int widx = 0; + wi.setText(sentence); + ArrayList words = new ArrayList<>(); + while(wi.next() != BreakIterator.DONE) { + String word = sentence.substring(widx, wi.current()); + widx = wi.current(); + if(Character.isLetterOrDigit(word.charAt(0))) { + words.add(UTF8String.fromString(word)); + } + } + result.add(words); + } + + return result; + } + // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes public UTF8String translate(Map dict) { String srcStr = this.toString(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f6ebcaeded484..842c9c63ce147 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -296,6 +296,7 @@ object FunctionRegistry { expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), + expression[Sentences]("sentences"), expression[SoundEx]("soundex"), expression[StringSpace]("space"), expression[StringSplit]("split"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 541b8601a344b..e0e01a6df8acf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.{MatchResult, Pattern} +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -198,6 +200,31 @@ case class StringSplit(str: Expression, pattern: Expression) override def prettyName: String = "split" } +/** + * Splits a string into arrays of sentences, where each sentence is an array of words. + * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used. + */ +@ExpressionDescription( + usage = "_FUNC_(s) - Splits str into an array of array of words.") +case class Sentences( + str: Expression, + language: Expression = Literal(""), + country: Expression = Literal("")) + extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { + + override def dataType: DataType = ArrayType(ArrayType(StringType)) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = str :: language :: country :: Nil + + override def nullSafeEval(string: Any, language: Any, country: Any): Any = { + val sentences = string.asInstanceOf[UTF8String].sentences( + language.asInstanceOf[UTF8String], country.asInstanceOf[UTF8String]) + val result = ArrayBuffer.empty[GenericArrayData] + for (i <- 0 until sentences.size()) + result += new GenericArrayData(sentences.get(i).toArray) + new GenericArrayData(result.toArray) + } +} /** * Replace all substrings of str that match regexp with rep. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index ec7be4d4b849d..79039633fff31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -246,4 +246,31 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkMetadata(CreateStructUnsafe(Seq(a, b))) checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } + + test("Sentences") { + // Hive compatible test-cases. + checkEvaluation( + Sentences("Hi there! The price was $1,234.56.... But, not now."), + Seq( + Seq("Hi", "there").map(UTF8String.fromString), + Seq("The", "price", "was").map(UTF8String.fromString), + Seq("But", "not", "now").map(UTF8String.fromString)), + EmptyRow) + + checkEvaluation( + Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), + Seq( + Seq("Hi", "there").map(UTF8String.fromString), + Seq("The", "price", "was").map(UTF8String.fromString), + Seq("But", "not", "now").map(UTF8String.fromString)), + EmptyRow) + + checkEvaluation( + Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"), + Seq( + Seq("Hi", "there").map(UTF8String.fromString), + Seq("The", "price", "was").map(UTF8String.fromString), + Seq("But", "not", "now").map(UTF8String.fromString)), + EmptyRow) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index fdc4c18e70d69..6f05f0f3058cf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog( // str_to_map, windowingtablefunction. private val hiveFunctions = Seq( "hash", "java_method", "histogram_numeric", - "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "str_to_map", + "parse_url", "percentile", "percentile_approx", "reflect", "str_to_map", "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long", "xpath_number", "xpath_short", "xpath_string" ) From d1fb0578b627a5cf85ad814e2641688992ddfb7c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 2 Jul 2016 23:17:48 -0700 Subject: [PATCH 2/9] Add constructors, testcases, and extended description. --- .../sql/catalyst/expressions/regexpExpressions.scala | 7 ++++++- .../org/apache/spark/sql/StringFunctionsSuite.scala | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index e0e01a6df8acf..1fb4e5a100693 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ @@ -205,13 +206,17 @@ case class StringSplit(str: Expression, pattern: Expression) * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used. */ @ExpressionDescription( - usage = "_FUNC_(s) - Splits str into an array of array of words.") + usage = "_FUNC_(str, lang, country) - Splits str into an array of array of words.", + extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]") case class Sentences( str: Expression, language: Expression = Literal(""), country: Expression = Literal("")) extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { + def this(str: Expression) = this(str, Literal(""), Literal("")) + def this(str: Expression, language: Expression) = this(str, language, Literal("")) + override def dataType: DataType = ArrayType(ArrayType(StringType)) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) override def children: Seq[Expression] = str :: language :: country :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index dff4226051494..e0b5fcfa3cde9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -250,6 +250,17 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("ihhh")) } + test("string sentences function") { + val df = Seq(("Hi there! Good morning.", "en", "US")).toDF("str", "language", "country") + + checkAnswer(df.selectExpr("sentences(str)"), + Row(Seq(Seq("Hi", "there"), Seq("Good", "morning")))) + checkAnswer(df.selectExpr("sentences(str, language)"), + Row(Seq(Seq("Hi", "there"), Seq("Good", "morning")))) + checkAnswer(df.selectExpr("sentences(str, language, country)"), + Row(Seq(Seq("Hi", "there"), Seq("Good", "morning")))) + } + test("string space function") { val df = Seq((2, 3)).toDF("a", "b") From f1a5c1b645840bc8bc5db3cc9dbfe1642eb109a0 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 3 Jul 2016 21:55:56 -0700 Subject: [PATCH 3/9] Fix indentation and remove unused import. --- .../sql/catalyst/expressions/regexpExpressions.scala | 1 - .../spark/sql/catalyst/expressions/ComplexTypeSuite.scala | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 1fb4e5a100693..53da889c4d2f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 79039633fff31..7d052f4e0f9a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -267,10 +267,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"), - Seq( - Seq("Hi", "there").map(UTF8String.fromString), - Seq("The", "price", "was").map(UTF8String.fromString), - Seq("But", "not", "now").map(UTF8String.fromString)), + Seq( + Seq("Hi", "there").map(UTF8String.fromString), + Seq("The", "price", "was").map(UTF8String.fromString), + Seq("But", "not", "now").map(UTF8String.fromString)), EmptyRow) } } From a98c05e736d1638b8d9c1475bb658206bada314b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 5 Jul 2016 22:10:29 -0700 Subject: [PATCH 4/9] Address comments. --- .../apache/spark/unsafe/types/UTF8String.java | 46 ------------------- .../expressions/regexpExpressions.scala | 38 +++++++++++++-- .../expressions/ComplexTypeSuite.scala | 25 +++++----- .../spark/sql/StringFunctionsSuite.scala | 11 ----- 4 files changed, 47 insertions(+), 73 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f7db962a4f4b8..54a54569240c0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -22,10 +22,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; -import java.text.BreakIterator; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Locale; import java.util.Map; import com.esotericsoftware.kryo.Kryo; @@ -804,49 +801,6 @@ public UTF8String[] split(UTF8String pattern, int limit) { return res; } - /** - * Return a locale of the given language and country, or a default locale when failures occur. - */ - private Locale getLocale(UTF8String language, UTF8String country) { - try { - return new Locale(language.toString(), country.toString()); - } finally { - return Locale.getDefault(); - } - } - - /** - * Splits a string into arrays of sentences, where each sentence is an array of words. - */ - public ArrayList> sentences(UTF8String language, UTF8String country) { - String sentences = this.toString(); - Locale locale = getLocale(language, country); - - BreakIterator bi = BreakIterator.getSentenceInstance(locale); - bi.setText(sentences); - int idx = 0; - ArrayList> result = new ArrayList<>(); - while (bi.next() != BreakIterator.DONE) { - String sentence = sentences.substring(idx, bi.current()); - idx = bi.current(); - - BreakIterator wi = BreakIterator.getWordInstance(locale); - int widx = 0; - wi.setText(sentence); - ArrayList words = new ArrayList<>(); - while(wi.next() != BreakIterator.DONE) { - String word = sentence.substring(widx, wi.current()); - widx = wi.current(); - if(Character.isLetterOrDigit(word.charAt(0))) { - words.add(UTF8String.fromString(word)); - } - } - result.add(words); - } - - return result; - } - // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes public UTF8String translate(Map dict) { String srcStr = this.toString(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 53da889c4d2f6..78278e8b46ec4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.BreakIterator +import java.util.Locale import java.util.regex.{MatchResult, Pattern} import scala.collection.mutable.ArrayBuffer @@ -221,13 +223,41 @@ case class Sentences( override def children: Seq[Expression] = str :: language :: country :: Nil override def nullSafeEval(string: Any, language: Any, country: Any): Any = { - val sentences = string.asInstanceOf[UTF8String].sentences( - language.asInstanceOf[UTF8String], country.asInstanceOf[UTF8String]) + val sentences = getSentences(string.asInstanceOf[UTF8String].toString, + language.asInstanceOf[UTF8String].toString, country.asInstanceOf[UTF8String].toString) val result = ArrayBuffer.empty[GenericArrayData] - for (i <- 0 until sentences.size()) - result += new GenericArrayData(sentences.get(i).toArray) + sentences.foreach(sentence => result += new GenericArrayData(sentence.toArray)) new GenericArrayData(result.toArray) } + + private def getSentences(sentences: String, language: String, country: String) = { + val locale = try { + new Locale(language, country) + } finally { + Locale.getDefault + } + + val bi = BreakIterator.getSentenceInstance(locale) + bi.setText(sentences) + var idx = 0 + val result = new ArrayBuffer[ArrayBuffer[UTF8String]] + while (bi.next != BreakIterator.DONE) { + val sentence = sentences.substring(idx, bi.current) + idx = bi.current + + val wi = BreakIterator.getWordInstance(locale) + var widx = 0 + wi.setText(sentence) + val words = new ArrayBuffer[UTF8String] + while (wi.next != BreakIterator.DONE) { + val word = sentence.substring(widx, wi.current) + widx = wi.current + if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word) + } + result += words + } + result + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 7d052f4e0f9a6..b3de05e5d6bae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -248,29 +248,30 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Sentences") { + val correct_answer = Seq( + Seq("Hi", "there").map(UTF8String.fromString), + Seq("The", "price", "was").map(UTF8String.fromString), + Seq("But", "not", "now").map(UTF8String.fromString)) + // Hive compatible test-cases. checkEvaluation( Sentences("Hi there! The price was $1,234.56.... But, not now."), - Seq( - Seq("Hi", "there").map(UTF8String.fromString), - Seq("The", "price", "was").map(UTF8String.fromString), - Seq("But", "not", "now").map(UTF8String.fromString)), + correct_answer, EmptyRow) checkEvaluation( Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), - Seq( - Seq("Hi", "there").map(UTF8String.fromString), - Seq("The", "price", "was").map(UTF8String.fromString), - Seq("But", "not", "now").map(UTF8String.fromString)), + correct_answer, EmptyRow) checkEvaluation( Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"), - Seq( - Seq("Hi", "there").map(UTF8String.fromString), - Seq("The", "price", "was").map(UTF8String.fromString), - Seq("But", "not", "now").map(UTF8String.fromString)), + correct_answer, + EmptyRow) + + checkEvaluation( + Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXXXX", "YYYYY"), + correct_answer, EmptyRow) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index e0b5fcfa3cde9..dff4226051494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -250,17 +250,6 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("ihhh")) } - test("string sentences function") { - val df = Seq(("Hi there! Good morning.", "en", "US")).toDF("str", "language", "country") - - checkAnswer(df.selectExpr("sentences(str)"), - Row(Seq(Seq("Hi", "there"), Seq("Good", "morning")))) - checkAnswer(df.selectExpr("sentences(str, language)"), - Row(Seq(Seq("Hi", "there"), Seq("Good", "morning")))) - checkAnswer(df.selectExpr("sentences(str, language, country)"), - Row(Seq(Seq("Hi", "there"), Seq("Good", "morning")))) - } - test("string space function") { val df = Seq((2, 3)).toDF("a", "b") From 4529d43aeacafb84a9fc18165520ce320d546aac Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 6 Jul 2016 01:13:26 -0700 Subject: [PATCH 5/9] Use Expression instead of TernaryExpression. --- .../expressions/regexpExpressions.scala | 40 ++++++++++--------- .../expressions/ComplexTypeSuite.scala | 28 ------------- .../expressions/StringExpressionsSuite.scala | 37 +++++++++++++++++ .../spark/sql/StringFunctionsSuite.scala | 20 ++++++++++ 4 files changed, 79 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 78278e8b46ec4..2e0ccf2e2ef6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ @@ -213,34 +214,37 @@ case class Sentences( str: Expression, language: Expression = Literal(""), country: Expression = Literal("")) - extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes with CodegenFallback { def this(str: Expression) = this(str, Literal(""), Literal("")) def this(str: Expression, language: Expression) = this(str, language, Literal("")) - override def dataType: DataType = ArrayType(ArrayType(StringType)) + override def nullable: Boolean = true + override def dataType: DataType = + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) override def children: Seq[Expression] = str :: language :: country :: Nil - override def nullSafeEval(string: Any, language: Any, country: Any): Any = { - val sentences = getSentences(string.asInstanceOf[UTF8String].toString, - language.asInstanceOf[UTF8String].toString, country.asInstanceOf[UTF8String].toString) - val result = ArrayBuffer.empty[GenericArrayData] - sentences.foreach(sentence => result += new GenericArrayData(sentence.toArray)) - new GenericArrayData(result.toArray) - } - - private def getSentences(sentences: String, language: String, country: String) = { - val locale = try { - new Locale(language, country) - } finally { - Locale.getDefault + override def eval(input: InternalRow): Any = { + val string = str.eval(input) + if (string == null) { + null + } else { + val locale = try { + new Locale(language.eval(input).asInstanceOf[UTF8String].toString, + country.eval(input).asInstanceOf[UTF8String].toString) + } catch { + case _: NullPointerException | _: ClassCastException => Locale.getDefault + } + getSentences(string.asInstanceOf[UTF8String].toString, locale) } + } + private def getSentences(sentences: String, locale: Locale) = { val bi = BreakIterator.getSentenceInstance(locale) bi.setText(sentences) var idx = 0 - val result = new ArrayBuffer[ArrayBuffer[UTF8String]] + val result = new ArrayBuffer[GenericArrayData] while (bi.next != BreakIterator.DONE) { val sentence = sentences.substring(idx, bi.current) idx = bi.current @@ -254,9 +258,9 @@ case class Sentences( widx = wi.current if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word) } - result += words + result += new GenericArrayData(words) } - result + new GenericArrayData(result) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b3de05e5d6bae..ec7be4d4b849d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -246,32 +246,4 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkMetadata(CreateStructUnsafe(Seq(a, b))) checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } - - test("Sentences") { - val correct_answer = Seq( - Seq("Hi", "there").map(UTF8String.fromString), - Seq("The", "price", "was").map(UTF8String.fromString), - Seq("But", "not", "now").map(UTF8String.fromString)) - - // Hive compatible test-cases. - checkEvaluation( - Sentences("Hi there! The price was $1,234.56.... But, not now."), - correct_answer, - EmptyRow) - - checkEvaluation( - Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), - correct_answer, - EmptyRow) - - checkEvaluation( - Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"), - correct_answer, - EmptyRow) - - checkEvaluation( - Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXXXX", "YYYYY"), - correct_answer, - EmptyRow) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 5f01561986f19..4cc913e58fa60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -725,4 +725,41 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0) checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0) } + + test("Sentences") { + val nullString = Literal.create(null, StringType) + checkEvaluation(Sentences(nullString, nullString, nullString), null, EmptyRow) + checkEvaluation(Sentences(nullString, nullString), null, EmptyRow) + checkEvaluation(Sentences(nullString), null, EmptyRow) + checkEvaluation(Sentences(Literal.create(null, NullType)), null, EmptyRow) + checkEvaluation(Sentences("", nullString, nullString), Seq.empty, EmptyRow) + checkEvaluation(Sentences("", nullString), Seq.empty, EmptyRow) + checkEvaluation(Sentences(""), Seq.empty, EmptyRow) + + val correct_answer = Seq( + Seq("Hi", "there"), + Seq("The", "price", "was"), + Seq("But", "not", "now")) + + // Hive compatible test-cases. + checkEvaluation( + Sentences("Hi there! The price was $1,234.56.... But, not now."), + correct_answer, + EmptyRow) + + checkEvaluation( + Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), + correct_answer, + EmptyRow) + + checkEvaluation( + Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"), + correct_answer, + EmptyRow) + + checkEvaluation( + Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXXXX", "YYYYY"), + correct_answer, + EmptyRow) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index dff4226051494..292ebb351d143 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -347,4 +347,24 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { df2.filter("b>0").selectExpr("format_number(a, b)"), Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil) } + + test("string sentences function") { + val df = Seq(("Hi there! The price was $1,234.56.... But, not now.", "en", "US")) + .toDF("str", "language", "country") + + checkAnswer( + df.selectExpr("sentences(str, language, country)"), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + + // Type coercion + checkAnswer( + df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"), + Row(null, Seq(Seq("10")), Seq(Seq("3.14")))) + + // Argument number exception + val m = intercept[AnalysisException] { + df.selectExpr("sentences()") + }.getMessage + assert(m.contains("Invalid number of arguments")) + } } From 8d7c3d400581332503f7cb831e4d0f850294b608 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 6 Jul 2016 22:56:10 -0700 Subject: [PATCH 6/9] Address comments. --- .../expressions/regexpExpressions.scala | 8 ++-- .../expressions/StringExpressionsSuite.scala | 44 +++++++------------ .../spark/sql/StringFunctionsSuite.scala | 2 +- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 2e0ccf2e2ef6d..69fba91649d06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -230,11 +230,11 @@ case class Sentences( if (string == null) { null } else { - val locale = try { - new Locale(language.eval(input).asInstanceOf[UTF8String].toString, + var locale = Locale.getDefault + if (language != null && language.eval(input) != null && + country != null && country.eval(input) != null) { + locale = new Locale(language.eval(input).asInstanceOf[UTF8String].toString, country.eval(input).asInstanceOf[UTF8String].toString) - } catch { - case _: NullPointerException | _: ClassCastException => Locale.getDefault } getSentences(string.asInstanceOf[UTF8String].toString, locale) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 4cc913e58fa60..256ce85743c61 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -728,38 +728,24 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Sentences") { val nullString = Literal.create(null, StringType) - checkEvaluation(Sentences(nullString, nullString, nullString), null, EmptyRow) - checkEvaluation(Sentences(nullString, nullString), null, EmptyRow) - checkEvaluation(Sentences(nullString), null, EmptyRow) - checkEvaluation(Sentences(Literal.create(null, NullType)), null, EmptyRow) - checkEvaluation(Sentences("", nullString, nullString), Seq.empty, EmptyRow) - checkEvaluation(Sentences("", nullString), Seq.empty, EmptyRow) - checkEvaluation(Sentences(""), Seq.empty, EmptyRow) - - val correct_answer = Seq( + checkEvaluation(Sentences(nullString, nullString, nullString), null) + checkEvaluation(Sentences(nullString, nullString), null) + checkEvaluation(Sentences(nullString), null) + checkEvaluation(Sentences(Literal.create(null, NullType)), null) + checkEvaluation(Sentences("", nullString, nullString), Seq.empty) + checkEvaluation(Sentences("", nullString), Seq.empty) + checkEvaluation(Sentences(""), Seq.empty) + + val answer = Seq( Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")) - // Hive compatible test-cases. - checkEvaluation( - Sentences("Hi there! The price was $1,234.56.... But, not now."), - correct_answer, - EmptyRow) - - checkEvaluation( - Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), - correct_answer, - EmptyRow) - - checkEvaluation( - Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"), - correct_answer, - EmptyRow) - - checkEvaluation( - Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXXXX", "YYYYY"), - correct_answer, - EmptyRow) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now."), answer) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), answer) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"), + answer) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXX", "YYY"), + answer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 292ebb351d143..433a23bcb9422 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -365,6 +365,6 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { val m = intercept[AnalysisException] { df.selectExpr("sentences()") }.getMessage - assert(m.contains("Invalid number of arguments")) + assert(m.contains("Invalid number of arguments for function sentences")) } } From 4144e7fc38ce8a3362f864632ee13d86046a2a8f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 6 Jul 2016 23:46:55 -0700 Subject: [PATCH 7/9] Remove redundant evaluations. --- .../sql/catalyst/expressions/regexpExpressions.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 69fba91649d06..dcf6c0fef73f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -231,10 +231,11 @@ case class Sentences( null } else { var locale = Locale.getDefault - if (language != null && language.eval(input) != null && - country != null && country.eval(input) != null) { - locale = new Locale(language.eval(input).asInstanceOf[UTF8String].toString, - country.eval(input).asInstanceOf[UTF8String].toString) + val lang = language.eval(input) + val coun = country.eval(input) + if (lang != null && coun != null) { + locale = new Locale(lang.asInstanceOf[UTF8String].toString, + coun.asInstanceOf[UTF8String].toString) } getSentences(string.asInstanceOf[UTF8String].toString, locale) } From 9164f544644004d9ea2613ae9b46037514c743d0 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 7 Jul 2016 10:24:49 -0700 Subject: [PATCH 8/9] Fix the coding style. --- .../sql/catalyst/expressions/regexpExpressions.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index dcf6c0fef73f8..7be9a830581c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -230,12 +230,12 @@ case class Sentences( if (string == null) { null } else { - var locale = Locale.getDefault - val lang = language.eval(input) - val coun = country.eval(input) - if (lang != null && coun != null) { - locale = new Locale(lang.asInstanceOf[UTF8String].toString, - coun.asInstanceOf[UTF8String].toString) + val languageStr = language.eval(input).asInstanceOf[UTF8String] + val countryStr = country.eval(input).asInstanceOf[UTF8String] + val locale = if (languageStr != null && countryStr != null) { + new Locale(languageStr.toString, countryStr.toString) + } else { + Locale.getDefault } getSentences(string.asInstanceOf[UTF8String].toString, locale) } From 7912bf7edb6226259b23bc68968d9bc50575d20e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 7 Jul 2016 23:21:01 -0700 Subject: [PATCH 9/9] Move to stringExpressions and address comments. --- .../expressions/regexpExpressions.scala | 66 ------------------ .../expressions/stringExpressions.scala | 68 ++++++++++++++++++- 2 files changed, 66 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 7be9a830581c2..541b8601a344b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -17,15 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.BreakIterator -import java.util.Locale import java.util.regex.{MatchResult, Pattern} -import scala.collection.mutable.ArrayBuffer - import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ @@ -203,67 +198,6 @@ case class StringSplit(str: Expression, pattern: Expression) override def prettyName: String = "split" } -/** - * Splits a string into arrays of sentences, where each sentence is an array of words. - * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used. - */ -@ExpressionDescription( - usage = "_FUNC_(str, lang, country) - Splits str into an array of array of words.", - extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]") -case class Sentences( - str: Expression, - language: Expression = Literal(""), - country: Expression = Literal("")) - extends Expression with ImplicitCastInputTypes with CodegenFallback { - - def this(str: Expression) = this(str, Literal(""), Literal("")) - def this(str: Expression, language: Expression) = this(str, language, Literal("")) - - override def nullable: Boolean = true - override def dataType: DataType = - ArrayType(ArrayType(StringType, containsNull = false), containsNull = false) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = str :: language :: country :: Nil - - override def eval(input: InternalRow): Any = { - val string = str.eval(input) - if (string == null) { - null - } else { - val languageStr = language.eval(input).asInstanceOf[UTF8String] - val countryStr = country.eval(input).asInstanceOf[UTF8String] - val locale = if (languageStr != null && countryStr != null) { - new Locale(languageStr.toString, countryStr.toString) - } else { - Locale.getDefault - } - getSentences(string.asInstanceOf[UTF8String].toString, locale) - } - } - - private def getSentences(sentences: String, locale: Locale) = { - val bi = BreakIterator.getSentenceInstance(locale) - bi.setText(sentences) - var idx = 0 - val result = new ArrayBuffer[GenericArrayData] - while (bi.next != BreakIterator.DONE) { - val sentence = sentences.substring(idx, bi.current) - idx = bi.current - - val wi = BreakIterator.getWordInstance(locale) - var widx = 0 - wi.setText(sentence) - val words = new ArrayBuffer[UTF8String] - while (wi.next != BreakIterator.DONE) { - val word = sentence.substring(widx, wi.current) - widx = wi.current - if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word) - } - result += new GenericArrayData(words) - } - new GenericArrayData(result) - } -} /** * Replace all substrings of str that match regexp with rep. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0df957637f1a..894e12d4a38ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -1188,3 +1190,65 @@ case class FormatNumber(x: Expression, d: Expression) override def prettyName: String = "format_number" } + +/** + * Splits a string into arrays of sentences, where each sentence is an array of words. + * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used. + */ +@ExpressionDescription( + usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of words.", + extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]") +case class Sentences( + str: Expression, + language: Expression = Literal(""), + country: Expression = Literal("")) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + def this(str: Expression) = this(str, Literal(""), Literal("")) + def this(str: Expression, language: Expression) = this(str, language, Literal("")) + + override def nullable: Boolean = true + override def dataType: DataType = + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = str :: language :: country :: Nil + + override def eval(input: InternalRow): Any = { + val string = str.eval(input) + if (string == null) { + null + } else { + val languageStr = language.eval(input).asInstanceOf[UTF8String] + val countryStr = country.eval(input).asInstanceOf[UTF8String] + val locale = if (languageStr != null && countryStr != null) { + new Locale(languageStr.toString, countryStr.toString) + } else { + Locale.getDefault + } + getSentences(string.asInstanceOf[UTF8String].toString, locale) + } + } + + private def getSentences(sentences: String, locale: Locale) = { + val bi = BreakIterator.getSentenceInstance(locale) + bi.setText(sentences) + var idx = 0 + val result = new ArrayBuffer[GenericArrayData] + while (bi.next != BreakIterator.DONE) { + val sentence = sentences.substring(idx, bi.current) + idx = bi.current + + val wi = BreakIterator.getWordInstance(locale) + var widx = 0 + wi.setText(sentence) + val words = new ArrayBuffer[UTF8String] + while (wi.next != BreakIterator.DONE) { + val word = sentence.substring(widx, wi.current) + widx = wi.current + if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word) + } + result += new GenericArrayData(words) + } + new GenericArrayData(result) + } +}